def _calc_P_Q(A): A = jnp.asarray(A) if A.ndim != 2 or A.shape[0] != A.shape[1]: raise ValueError('expected A to be a square matrix') A_L1 = np_linalg.norm(A, 1) n_squarings = 0 if A.dtype == 'float64' or A.dtype == 'complex128': maxnorm = 5.371920351148152 n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm))) A = A / 2**n_squarings.astype(A.dtype) conds = jnp.array([ 1.495585217958292e-002, 2.539398330063230e-001, 9.504178996162932e-001, 2.097847961257068e+000 ], dtype=A_L1.dtype) idx = jnp.digitize(A_L1, conds) U, V = lax.switch(idx, [_pade3, _pade5, _pade7, _pade9, _pade13], A) elif A.dtype == 'float32' or A.dtype == 'complex64': maxnorm = 3.925724783138660 n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm))) A = A / 2**n_squarings.astype(A.dtype) conds = jnp.array([4.258730016922831e-001, 1.880152677804762e+000], dtype=A_L1.dtype) idx = jnp.digitize(A_L1, conds) U, V = lax.switch(idx, [_pade3, _pade5, _pade7], A) else: raise TypeError(f"A.dtype={A.dtype} is not supported.") P = U + V # p_m(A) : numerator Q = -U + V # q_m(A) : denominator return P, Q, n_squarings
def _calc_P_Q(A): A = jnp.asarray(A) if A.ndim != 2 or A.shape[0] != A.shape[1]: raise ValueError('expected A to be a square matrix') A_L1 = np_linalg.norm(A,1) n_squarings = 0 if A.dtype == 'float64' or A.dtype == 'complex128': U3, V3 = _pade3(A) U5, V5 = _pade5(A) U7, V7 = _pade7(A) U9, V9 = _pade9(A) maxnorm = 5.371920351148152 n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm))) A = A / 2**n_squarings U13, V13 = _pade13(A) conds=jnp.array([1.495585217958292e-002, 2.539398330063230e-001, 9.504178996162932e-001, 2.097847961257068e+000]) U = jnp.select((A_L1<conds), (U3, U5, U7, U9), U13) V = jnp.select((A_L1<conds), (V3, V5, V7, V9), V13) elif A.dtype == 'float32' or A.dtype == 'complex64': U3,V3 = _pade3(A) U5,V5 = _pade5(A) maxnorm = 3.925724783138660 n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm))) A = A / 2**n_squarings U7,V7 = _pade7(A) conds=jnp.array([4.258730016922831e-001, 1.880152677804762e+000]) U = jnp.select((A_L1<conds), (U3, U5), U7) V = jnp.select((A_L1<conds), (V3, V5), V7) else: raise TypeError("A.dtype={} is not supported.".format(A.dtype)) P = U + V # p_m(A) : numerator Q = -U + V # q_m(A) : denominator return P, Q, n_squarings
def _update_T_Z(m, T, Z): mu = np_linalg.eigvals(lax.dynamic_slice(T, (m - 1, m - 1), (2, 2))) - T[m, m] r = np_linalg.norm(jnp.array([mu[0], T[m, m - 1]])).astype(T.dtype) c = mu[0] / r s = T[m, m - 1] / r G = jnp.array([[c.conj(), s], [-s, c]], dtype=T.dtype) # T[m-1:m+1, m-1:] = G @ T[m-1:m+1, m-1:] T_rows = lax.dynamic_slice_in_dim(T, m - 1, 2, axis=0) col_mask = jnp.arange(N) >= m - 1 G_dot_T_zeroed_cols = G @ jnp.where(col_mask, T_rows, 0) T_rows_new = jnp.where(~col_mask, T_rows, G_dot_T_zeroed_cols) T = lax.dynamic_update_slice_in_dim(T, T_rows_new, m - 1, axis=0) # T[:m+1, m-1:m+1] = T[:m+1, m-1:m+1] @ G.conj().T T_cols = lax.dynamic_slice_in_dim(T, m - 1, 2, axis=1) row_mask = jnp.arange(N)[:, jnp.newaxis] < m + 1 T_zeroed_rows_dot_GH = jnp.where(row_mask, T_cols, 0) @ G.conj().T T_cols_new = jnp.where(~row_mask, T_cols, T_zeroed_rows_dot_GH) T = lax.dynamic_update_slice_in_dim(T, T_cols_new, m - 1, axis=1) # Z[:, m-1:m+1] = Z[:, m-1:m+1] @ G.conj().T Z_cols = lax.dynamic_slice_in_dim(Z, m - 1, 2, axis=1) Z = lax.dynamic_update_slice_in_dim(Z, Z_cols @ G.conj().T, m - 1, axis=1) return T, Z
def roots(p, *, strip_zeros=True): # ported from https://github.com/numpy/numpy/blob/v1.17.0/numpy/lib/polynomial.py#L168-L251 p = atleast_1d(p) if p.ndim != 1: raise ValueError("Input must be a rank-1 array.") # strip_zeros=False is unsafe because leading zeros aren't removed if not strip_zeros: if p.size > 1: return _roots_no_zeros(p) else: return array([]) if all(p == 0): return array([]) # factor out trivial roots start, end = _nonzero_range(p) # number of trailing zeros = number of roots at 0 trailing_zeros = p.size - end # strip leading and trailing zeros p = p[start:end] if p.size < 2: return zeros(trailing_zeros, p.dtype) else: roots = _roots_no_zeros(p) # combine roots and zero roots roots = hstack((roots, zeros(trailing_zeros, p.dtype))) return roots
def body(k, state): pivot, perm, a = state m_idx = jnp.arange(m) n_idx = jnp.arange(n) if jnp.issubdtype(a.dtype, jnp.complexfloating): t = a[:, k] magnitude = jnp.abs(jnp.real(t)) + jnp.abs(jnp.imag(t)) else: magnitude = jnp.abs(a[:, k]) i = jnp.argmax(jnp.where(m_idx >= k, magnitude, -jnp.inf)) pivot = ops.index_update(pivot, ops.index[k], i) a = ops.index_update(a, ops.index[[k, i], ], a[[i, k], ]) perm = ops.index_update(perm, ops.index[[i, k], ], perm[[k, i], ]) # a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes x = a[k, k] a = ops.index_update(a, ops.index[:, k], jnp.where(m_idx > k, a[:, k] / x, a[:, k])) # a[k+1:, k+1:] -= jnp.outer(a[k+1:, k], a[k, k+1:]) a = a - jnp.where( (m_idx[:, None] > k) & (n_idx > k), jnp.outer(a[:, k], a[k, :]), jnp.array(0, dtype=a.dtype)) return pivot, perm, a
def _roots_no_zeros(p): # build companion matrix and find its eigenvalues (the roots) if p.size < 2: return array([], dtype=dtypes._to_complex_dtype(p.dtype)) A = diag(ones((p.size - 2, ), p.dtype), -1) A = A.at[0, :].set(-p[1:] / p[0]) return linalg.eigvals(A)
def _slice(operand, start_indices, dynamic_slice_sizes, static_slice_sizes, fill_value=0): """Similar to lax.dynamic_slice, but handles arrays with dynamic sizes. Returns fill_value instead of clamping start_indices for those elements that would overflow the side of the array. Args: operand: the array to slice start_indices: the offset of the start of the slice dynamic_slice_sizes: the true (unpadded) size of the slice static_slice_sizes: the padded size of the slice, which must be known at compile time. The static size must be larger than the dynamic size. fill_value: value with which to replace masked-out elements. Returns: An array with static shape `static_slice_sizes`, padded from its true (dynamic) size `dynamic_slice_sizes`. """ # We must pad the input array so the dynamic_slice is guaranteed to fall # entirely in bounds. padded = lax.pad(operand, jnp.array(0, operand.dtype), [(0, d, 0) for d in static_slice_sizes]) out = lax.dynamic_slice(padded, tuple(jnp.int32(i) for i in start_indices), static_slice_sizes) return _mask(out, dynamic_slice_sizes, fill_value)
def _fft_core(func_name, fft_type, a, s, axes, norm): full_name = "jax.numpy.fft." + func_name if s is not None: s = tuple(map(operator.index, s)) if np.any(np.less(s, 0)): raise ValueError("Shape should be non-negative.") if s is not None and axes is not None and len(s) != len(axes): # Same error as numpy. raise ValueError("Shape and axes have different lengths.") orig_axes = axes if axes is None: if s is None: axes = range(a.ndim) else: axes = range(a.ndim - len(s), a.ndim) if len(axes) != len(set(axes)): raise ValueError( f"{full_name} does not support repeated axes. Got axes {axes}.") if len(axes) > 3: # XLA does not support FFTs over more than 3 dimensions raise ValueError("%s only supports 1D, 2D, and 3D FFTs. " "Got axes %s with input rank %s." % (full_name, orig_axes, a.ndim)) # XLA only supports FFTs over the innermost axes, so rearrange if necessary. if orig_axes is not None: axes = tuple(range(a.ndim - len(axes), a.ndim)) a = jnp.moveaxis(a, orig_axes, axes) if s is not None: a = jnp.asarray(a) in_s = list(a.shape) for axis, x in safe_zip(axes, s): in_s[axis] = x if fft_type == xla_client.FftType.IRFFT: in_s[-1] = (in_s[-1] // 2 + 1) # Cropping a = a[tuple(map(slice, in_s))] # Padding a = jnp.pad(a, [(0, x - y) for x, y in zip(in_s, a.shape)]) else: if fft_type == xla_client.FftType.IRFFT: s = [a.shape[axis] for axis in axes[:-1]] if axes: s += [max(0, 2 * (a.shape[axes[-1]] - 1))] else: s = [a.shape[axis] for axis in axes] transformed = lax.fft(a, fft_type, tuple(s)) transformed *= _fft_norm(jnp.array(s, dtype=transformed.dtype), func_name, norm) if orig_axes is not None: transformed = jnp.moveaxis(transformed, axes, orig_axes) return transformed
def __getitem__(self, key): if not isinstance(key, tuple): key = (key,) params = [self.axis, self.ndmin, self.trans1d, -1] if isinstance(key[0], str): # split off the directive directive, *key = key # pytype: disable=bad-unpacking # check two special cases: matrix directives if directive == "r": params[-1] = 0 elif directive == "c": params[-1] = 1 else: vec = directive.split(",") k = len(vec) if k < 4: vec += params[k:] else: # ignore everything after the first three comma-separated ints vec = vec[:3] + params[-1] try: params = list(map(int, vec)) except ValueError as err: raise ValueError( f"could not understand directive {directive!r}" ) from err axis, ndmin, trans1d, matrix = params output = [] for item in key: if isinstance(item, slice): newobj = _make_1d_grid_from_slice(item, op_name=self.op_name) elif isinstance(item, str): raise ValueError("string directive must be placed at the beginning") else: newobj = item newobj = array(newobj, copy=False, ndmin=ndmin) if trans1d != -1 and ndmin - np.ndim(item) > 0: shape_obj = list(range(ndmin)) # Calculate number of left shifts, with overflow protection by mod num_lshifts = ndmin - abs(ndmin + trans1d + 1) % ndmin shape_obj = tuple(shape_obj[num_lshifts:] + shape_obj[:num_lshifts]) newobj = transpose(newobj, shape_obj) output.append(newobj) res = concatenate(tuple(output), axis=axis) if matrix != -1 and res.ndim == 1: # insert 2nd dim at axis 0 or 1 res = expand_dims(res, matrix) return res
def _eval_expint_k(A, B, x): # helper function for all subsequent intervals A, B = [jnp.array(U, dtype=x.dtype) for U in [A, B]] one = _constant_like(x, 1.0) w = one / x f = jnp.polyval(A, w) / jnp.polyval(B, w) f = w * f + one return jnp.exp(x) * w * f
def _all_gather_via_psum(x, *, all_gather_dimension, axis_name, axis_index_groups, axis_size): index = axis_index(axis_name) if axis_index_groups is not None: indices = np.array(axis_index_groups).flatten() axis_index_to_group_index = indices.argsort() % len(axis_index_groups[0]) index = lax_numpy.array(axis_index_to_group_index)[index] outs = tree_util.tree_map(partial(_expand, all_gather_dimension, axis_size, index), x) return psum(outs, axis_name, axis_index_groups=axis_index_groups)
def _slogdet_lu(a): dtype = lax.dtype(a) lu, pivot, _ = lax_linalg.lu(a) diag = jnp.diagonal(lu, axis1=-2, axis2=-1) is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1) iota = lax.expand_dims(jnp.arange(a.shape[-1]), range(pivot.ndim - 1)) parity = jnp.count_nonzero(pivot != iota, axis=-1) if jnp.iscomplexobj(a): sign = jnp.prod(diag / jnp.abs(diag), axis=-1) else: sign = jnp.array(1, dtype=dtype) parity = parity + jnp.count_nonzero(diag < 0, axis=-1) sign = jnp.where(is_zero, jnp.array(0, dtype=dtype), sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype)) logdet = jnp.where(is_zero, jnp.array(-jnp.inf, dtype=dtype), jnp.sum(jnp.log(jnp.abs(diag)), axis=-1)) return sign, jnp.real(logdet)
def setxor1d(ar1, ar2, assume_unique=False): _check_arraylike("setxor1d", ar1, ar2) ar1 = core.concrete_or_error(None, ar1, "The error arose in setxor1d()") ar2 = core.concrete_or_error(None, ar2, "The error arose in setxor1d()") ar1 = ravel(ar1) ar2 = ravel(ar2) if not assume_unique: ar1 = unique(ar1) ar2 = unique(ar2) aux = concatenate((ar1, ar2)) if aux.size == 0: return aux aux = sort(aux) flag = concatenate((array([True]), aux[1:] != aux[:-1], array([True]))) return aux[flag[1:] & flag[:-1]]
def slogdet(a): a = _promote_arg_dtypes(jnp.asarray(a)) dtype = lax.dtype(a) a_shape = jnp.shape(a) if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]: msg = "Argument to slogdet() must have shape [..., n, n], got {}" raise ValueError(msg.format(a_shape)) lu, pivot, _ = lax_linalg.lu(a) diag = jnp.diagonal(lu, axis1=-2, axis2=-1) is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1) parity = jnp.count_nonzero(pivot != jnp.arange(a_shape[-1]), axis=-1) if jnp.iscomplexobj(a): sign = jnp.prod(diag / jnp.abs(diag), axis=-1) else: sign = jnp.array(1, dtype=dtype) parity = parity + jnp.count_nonzero(diag < 0, axis=-1) sign = jnp.where(is_zero, jnp.array(0, dtype=dtype), sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype)) logdet = jnp.where(is_zero, jnp.array(-jnp.inf, dtype=dtype), jnp.sum(jnp.log(jnp.abs(diag)), axis=-1)) return sign, jnp.real(logdet)
def _triage_segments(window, nperseg, input_length): """ Parses window and nperseg arguments for spectrogram and _spectral_helper. This is a helper function, not meant to be called externally. Parameters ---------- window : string, tuple, or ndarray If window is specified by a string or tuple and nperseg is not specified, nperseg is set to the default of 256 and returns a window of that length. If instead the window is array_like and nperseg is not specified, then nperseg is set to the length of the window. A ValueError is raised if the user supplies both an array_like window and a value for nperseg but nperseg does not equal the length of the window. nperseg : int Length of each segment input_length: int Length of input signal, i.e. x.shape[-1]. Used to test for errors. Returns ------- win : ndarray window. If function was called with string or tuple than this will hold the actual array used as a window. nperseg : int Length of each segment. If window is str or tuple, nperseg is set to 256. If window is array_like, nperseg is set to the length of the 6 window. """ # parse window; if array like, then set nperseg = win.shape if isinstance(window, (str, tuple)): # if nperseg not specified if nperseg is None: nperseg = 256 # then change to default if nperseg > input_length: warnings.warn(f'nperseg = {nperseg} is greater than input length ' f' = {input_length}, using nperseg = {nperseg}') nperseg = input_length win = jnp.array(osp_signal.get_window(window, nperseg)) else: win = jnp.asarray(window) if len(win.shape) != 1: raise ValueError('window must be 1-D') if input_length < win.shape[-1]: raise ValueError('window is longer than input signal') if nperseg is None: nperseg = win.shape[0] elif nperseg is not None: if nperseg != win.shape[0]: raise ValueError("value specified for nperseg is different" " from length of window") return win, nperseg
def _lu(a, permute_l): a = np_linalg._promote_arg_dtypes(jnp.asarray(a)) lu, pivots, permutation = lax_linalg.lu(a) dtype = lax.dtype(a) m, n = jnp.shape(a) p = jnp.real(jnp.array(permutation == jnp.arange(m)[:, None], dtype=dtype)) k = min(m, n) l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype) u = jnp.triu(lu)[:k, :] if permute_l: return jnp.matmul(p, l), u else: return p, l, u
def pinv(a, rcond=None): # Uses same algorithm as # https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979 a = jnp.conj(a) if rcond is None: max_rows_cols = max(a.shape[-2:]) rcond = 10. * max_rows_cols * jnp.array(jnp.finfo(a.dtype).eps) rcond = jnp.asarray(rcond) u, s, vh = svd(a, full_matrices=False) # Singular values less than or equal to ``rcond * largest_singular_value`` # are set to zero. rcond = lax.expand_dims(rcond[..., jnp.newaxis], range(s.ndim - rcond.ndim - 1)) cutoff = rcond * jnp.amax(s, axis=-1, keepdims=True, initial=-jnp.inf) s = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype) res = jnp.matmul(_T(vh), jnp.divide(_T(u), s[..., jnp.newaxis])) return lax.convert_element_type(res, a.dtype)
def _lu(a, permute_l): a, = _promote_dtypes_inexact(jnp.asarray(a)) lu, _, permutation = lax_linalg.lu(a) dtype = lax.dtype(a) m, n = jnp.shape(a) p = jnp.real( jnp.array(permutation[None, :] == jnp.arange( m, dtype=permutation.dtype)[:, None], dtype=dtype)) k = min(m, n) l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype) u = jnp.triu(lu)[:k, :] if permute_l: return jnp.matmul(p, l), u else: return p, l, u
def sph_harm(m: jnp.ndarray, n: jnp.ndarray, theta: jnp.ndarray, phi: jnp.ndarray, n_max: Optional[int] = None) -> jnp.ndarray: r"""Computes the spherical harmonics. The JAX version has one extra argument `n_max`, the maximum value in `n`. The spherical harmonic of degree `n` and order `m` can be written as :math:`Y_n^m(\theta, \phi) = N_n^m * P_n^m(\cos \phi) * \exp(i m \theta)`, where :math:`N_n^m = \sqrt{\frac{\left(2n+1\right) \left(n-m\right)!} {4 \pi \left(n+m\right)!}}` is the normalization factor and :math:`\phi` and :math:\theta` are the colatitude and longitude, repectively. :math:`N_n^m` is chosen in the way that the spherical harmonics form a set of orthonormal basis functions of :math:`L^2(S^2)`. Args: m: The order of the harmonic; must have `|m| <= n`. Return values for `|m| > n` ara undefined. n: The degree of the harmonic; must have `n >= 0`. The standard notation for degree in descriptions of spherical harmonics is `l (lower case L)`. We use `n` here to be consistent with `scipy.special.sph_harm`. Return values for `n < 0` are undefined. theta: The azimuthal (longitudinal) coordinate; must be in [0, 2*pi]. phi: The polar (colatitudinal) coordinate; must be in [0, pi]. n_max: The maximum degree `max(n)`. If the supplied `n_max` is not the true maximum value of `n`, the results are clipped to `n_max`. For example, `sph_harm(m=jnp.array([2]), n=jnp.array([10]), theta, phi, n_max=6)` acutually returns `sph_harm(m=jnp.array([2]), n=jnp.array([6]), theta, phi, n_max=6)` Returns: A 1D array containing the spherical harmonics at (m, n, theta, phi). """ if jnp.isscalar(phi): phi = jnp.array([phi]) if n_max is None: n_max = jnp.max(n) n_max = core.concrete_or_error( int, n_max, 'The `n_max` argument of `jnp.scipy.special.sph_harm` must ' 'be statically specified to use `sph_harm` within JAX transformations.' ) return _sph_harm(m, n, theta, phi, n_max)
def roots(p, *, strip_zeros=True): _check_arraylike("roots", p) p = atleast_1d(*_promote_dtypes_inexact(p)) if p.ndim != 1: raise ValueError("Input must be a rank-1 array.") if p.size < 2: return array([], dtype=dtypes._to_complex_dtype(p.dtype)) num_leading_zeros = _where(all(p == 0), len(p), argmin(p == 0)) if strip_zeros: num_leading_zeros = core.concrete_or_error( int, num_leading_zeros, "The error occurred in the jnp.roots() function. To use this within a " "JIT-compiled context, pass strip_zeros=False, but be aware that leading zeros " "will be result in some returned roots being set to NaN.") return _roots_no_zeros(p[num_leading_zeros:]) else: return _roots_with_zeros(p, num_leading_zeros)
def _update_slice(operand, update, start_indices, update_dims): """ Similar to lax.dynamic_update_slice, but handles padded updates where padding values should not overwrite existing values in the array. Args: operand: the array to update update: the padded array to write start_indices: the offset at which to write `update`. update_dims: the true dimensions of the padded update `update`. Only values inside the rectangle given by `update_dims` will be overwritten.""" operand_shape = operand.shape operand = lax.pad(operand, jnp.array(0, operand.dtype), [(0, d, 0) for d in update.shape]) start_indices = tuple(jnp.int32(i) for i in start_indices) t = lax.dynamic_slice(operand, start_indices, update.shape) t = _mask(update, update_dims, t) operand = lax.dynamic_update_slice(operand, t, start_indices) return lax.slice(operand, [0] * operand.ndim, operand_shape)
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): if b is not None: a, b = _promote_args_inexact("logsumexp", a, b) a = jnp.where(b != 0, a, -jnp.inf) else: a, = _promote_args_inexact("logsumexp", a) pos_dims, dims = _reduction_dims(a, axis) amax = jnp.max(a, axis=dims, keepdims=keepdims) amax = lax.stop_gradient( lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0))) amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims) # fast path if the result cannot be negative. if b is None and not np.issubdtype(a.dtype, np.complexfloating): out = lax.add( lax.log( jnp.sum(lax.exp(lax.sub(a, amax_with_dims)), axis=dims, keepdims=keepdims)), amax) sign = jnp.where(jnp.isnan(out), out, 1.0) sign = jnp.where(jnp.isneginf(out), 0.0, sign).astype(out.dtype) else: expsub = lax.exp(lax.sub(a, amax_with_dims)) if b is not None: expsub = lax.mul(expsub, b) sumexp = jnp.sum(expsub, axis=dims, keepdims=keepdims) sign = lax.stop_gradient(jnp.sign(sumexp)) if np.issubdtype(sumexp.dtype, np.complexfloating): if return_sign: sumexp = sign * sumexp out = lax.add(lax.log(sumexp), amax) else: out = lax.add(lax.log(lax.abs(sumexp)), amax) if return_sign: return (out, sign) if b is not None: if not np.issubdtype(out.dtype, np.complexfloating): # Use jnp.array(nan) to avoid false positives in debug_nans # (see https://github.com/google/jax/issues/7634) out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype), out) return out
def _expint1(x): # 0 < x <= 2 A = [ -5.350447357812542947283e0, 2.185049168816613393830e2, -4.176572384826693777058e3, 5.541176756393557601232e4, -3.313381331178144034309e5, 1.592627163384945414220e6, ] B = [ 1.0, -5.250547959112862969197e1, 1.259616186786790571525e3, -1.756549581973534652631e4, 1.493062117002725991967e5, -7.294949239640527645655e5, 1.592627163384945429726e6, ] A, B = [jnp.array(U, dtype=x.dtype) for U in [A, B]] f = jnp.polyval(A, x) / jnp.polyval(B, x) return x * f + jnp.euler_gamma + jnp.log(x)
def _lstsq(a, b, rcond, *, numpy_resid=False): # TODO: add lstsq to lax_linalg and implement this function via those wrappers. # TODO: add custom jvp rule for more robust lstsq differentiation a, b = _promote_dtypes_inexact(a, b) if a.shape[0] != b.shape[0]: raise ValueError("Leading dimensions of input arrays must match") b_orig_ndim = b.ndim if b_orig_ndim == 1: b = b[:, None] if a.ndim != 2: raise TypeError( f"{a.ndim}-dimensional array given. Array must be two-dimensional") if b.ndim != 2: raise TypeError( f"{b.ndim}-dimensional array given. Array must be one or two-dimensional") m, n = a.shape dtype = a.dtype if rcond is None: rcond = jnp.finfo(dtype).eps * max(n, m) else: rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond) u, s, vt = svd(a, full_matrices=False) mask = s >= jnp.array(rcond, dtype=s.dtype) * s[0] rank = mask.sum() safe_s = jnp.where(mask, s, 1).astype(a.dtype) s_inv = jnp.where(mask, 1 / safe_s, 0)[:, jnp.newaxis] uTb = jnp.matmul(u.conj().T, b, precision=lax.Precision.HIGHEST) x = jnp.matmul(vt.conj().T, s_inv * uTb, precision=lax.Precision.HIGHEST) # Numpy returns empty residuals in some cases. To allow compilation, we # default to returning full residuals in all cases. if numpy_resid and (rank < n or m <= n): resid = jnp.asarray([]) else: b_estimate = jnp.matmul(a, x, precision=lax.Precision.HIGHEST) resid = norm(b - b_estimate, axis=0) ** 2 if b_orig_ndim == 1: x = x.ravel() return x, resid, rank, s
def poly(seq_of_zeros): _check_arraylike('poly', seq_of_zeros) seq_of_zeros, = _promote_dtypes_inexact(seq_of_zeros) seq_of_zeros = atleast_1d(seq_of_zeros) sh = seq_of_zeros.shape if len(sh) == 2 and sh[0] == sh[1] and sh[0] != 0: # import at runtime to avoid circular import from jax._src.numpy import linalg seq_of_zeros = linalg.eigvals(seq_of_zeros) if seq_of_zeros.ndim != 1: raise ValueError("input must be 1d or non-empty square 2d array.") dt = seq_of_zeros.dtype if len(seq_of_zeros) == 0: return ones((), dtype=dt) a = ones((1, ), dtype=dt) for k in range(len(seq_of_zeros)): a = convolve(a, array([1, -seq_of_zeros[k]], dtype=dt), mode='full') return a
def _expn1(n, x): # exponential integral En _c = _constant_like x = jnp.array(x) MACHEP = jnp.finfo(x.dtype).eps zero = _c(x, 0.0) one = _c(x, 1.0) psi = -jnp.euler_gamma - jnp.log(x) psi = lax.fori_loop(_c(n, 1), n, lambda i, psi: psi + one / i, psi) n1 = jnp.where(n == _c(n, 1), one + one, n) init = dict( x=x, z=-x, xk=zero, yk=one, pk=one - n, ans=jnp.where(n == _c(n, 1), zero, one / (one - n1)), t=jnp.inf, ) def body(d): d["xk"] += one d["yk"] *= d["z"] / d["xk"] d["pk"] += one d["ans"] += jnp.where(d["pk"] != zero, d["yk"] / d["pk"], zero) d["t"] = jnp.where(d["ans"] != zero, abs(d["yk"] / d["ans"]), one) return d def cond(d): return (d["x"] > _c(d["x"], 0.0)) & (d["t"] > MACHEP) d = lax.while_loop(cond, body, init) t = n r = n - _c(n, 1) return d["z"]**r * psi / jnp.exp(gammaln(t)) - d["ans"]
def all_gather(x, axis_name, *, axis_index_groups=None): """Gather values of x across all replicas. If ``x`` is a pytree then the result is equivalent to mapping this function to each leaf in the tree. This is equivalent to, but faster than, all_to_all(broadcast(x)). Args: x: array(s) with a mapped axis named ``axis_name``. axis_name: hashable Python object used to name a pmapped axis (see the :func:`jax.pmap` documentation for more details). axis_index_groups: optional list of lists containing axis indices (e.g. for an axis of size 4, [[0, 1], [2, 3]] would run all gather over the first two and last two replicas). Groups must cover all axis indices exactly once, and all groups must be the same size. Returns: Array(s) representing the result of an all-gather along the axis ``axis_name``. Shapes are the same as ``x.shape``, but with a leading dimension of the axis_size. For example, with 4 XLA devices available: >>> x = np.arange(4) >>> y = jax.pmap(lambda x: jax.lax.all_gather(x, 'i'), axis_name='i')(x) >>> print(y) [[0 1 2 3] [0 1 2 3] [0 1 2 3] [0 1 2 3]] An example of using axis_index_groups, groups split by even & odd device ids: >>> x = np.arange(16).reshape(4, 4) >>> print(x) [[ 0. 1. 2. 3.] [ 4. 5. 6. 7.] [ 8. 9. 10. 11.] [12. 13. 14. 15.]] >>> y = jax.pmap(lambda x: jax.lax.all_gather( ... x, 'i', axis_index_groups=[[0, 2], [3, 1]]))(x) >>> print(y) [[[ 0. 1. 2. 3.] [ 8. 9. 10. 11.]] [[12. 13. 14. 15.] [ 4. 5. 6. 7.]] [[ 0. 1. 2. 3.] [ 8. 9. 10. 11.]] [[12. 13. 14. 15.] [ 4. 5. 6. 7.]] """ index = axis_index(axis_name) if axis_index_groups is not None: indices = np.array(axis_index_groups).flatten() axis_index_to_group_index = indices.argsort() % len(axis_index_groups[0]) index = lax_numpy.array(axis_index_to_group_index)[index] axis_size = psum(1, axis_name, axis_index_groups=axis_index_groups) return _allgather(x, 0, axis_size, index, axis_name, axis_index_groups)
def eigh_tridiagonal(d, e, *, eigvals_only=False, select='a', select_range=None, tol=None): if not eigvals_only: raise NotImplementedError( "Calculation of eigenvectors is not implemented") def _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, x): """Implements the Sturm sequence recurrence.""" n = alpha.shape[0] zeros = jnp.zeros(x.shape, dtype=jnp.int32) ones = jnp.ones(x.shape, dtype=jnp.int32) # The first step in the Sturm sequence recurrence # requires special care if x is equal to alpha[0]. def sturm_step0(): q = alpha[0] - x count = jnp.where(q < 0, ones, zeros) q = jnp.where(alpha[0] == x, alpha0_perturbation, q) return q, count # Subsequent steps all take this form: def sturm_step(i, q, count): q = alpha[i] - beta_sq[i - 1] / q - x count = jnp.where(q <= pivmin, count + 1, count) q = jnp.where(q <= pivmin, jnp.minimum(q, -pivmin), q) return q, count # The first step initializes q and count. q, count = sturm_step0() # Peel off ((n-1) % blocksize) steps from the main loop, so we can run # the bulk of the iterations unrolled by a factor of blocksize. blocksize = 16 i = 1 peel = (n - 1) % blocksize unroll_cnt = peel def unrolled_steps(args): start, q, count = args for j in range(unroll_cnt): q, count = sturm_step(start + j, q, count) return start + unroll_cnt, q, count i, q, count = unrolled_steps((i, q, count)) # Run the remaining steps of the Sturm sequence using a partially # unrolled while loop. unroll_cnt = blocksize def cond(iqc): i, q, count = iqc return jnp.less(i, n) _, _, count = lax.while_loop(cond, unrolled_steps, (i, q, count)) return count alpha = jnp.asarray(d) beta = jnp.asarray(e) supported_dtypes = (jnp.float32, jnp.float64, jnp.complex64, jnp.complex128) if alpha.dtype != beta.dtype: raise TypeError( "diagonal and off-diagonal values must have same dtype, " f"got {alpha.dtype} and {beta.dtype}") if alpha.dtype not in supported_dtypes or beta.dtype not in supported_dtypes: raise TypeError( "Only float32 and float64 inputs are supported as inputs " "to jax.scipy.linalg.eigh_tridiagonal, got " f"{alpha.dtype} and {beta.dtype}") n = alpha.shape[0] if n <= 1: return jnp.real(alpha) if jnp.issubdtype(alpha.dtype, jnp.complexfloating): alpha = jnp.real(alpha) beta_sq = jnp.real(beta * jnp.conj(beta)) beta_abs = jnp.sqrt(beta_sq) else: beta_abs = jnp.abs(beta) beta_sq = jnp.square(beta) # Estimate the largest and smallest eigenvalues of T using the Gershgorin # circle theorem. off_diag_abs_row_sum = jnp.concatenate( [beta_abs[:1], beta_abs[:-1] + beta_abs[1:], beta_abs[-1:]], axis=0) lambda_est_max = jnp.amax(alpha + off_diag_abs_row_sum) lambda_est_min = jnp.amin(alpha - off_diag_abs_row_sum) # Upper bound on 2-norm of T. t_norm = jnp.maximum(jnp.abs(lambda_est_min), jnp.abs(lambda_est_max)) # Compute the smallest allowed pivot in the Sturm sequence to avoid # overflow. finfo = np.finfo(alpha.dtype) one = np.ones([], dtype=alpha.dtype) safemin = np.maximum(one / finfo.max, (one + finfo.eps) * finfo.tiny) pivmin = safemin * jnp.maximum(1, jnp.amax(beta_sq)) alpha0_perturbation = jnp.square(finfo.eps * beta_abs[0]) abs_tol = finfo.eps * t_norm if tol is not None: abs_tol = jnp.maximum(tol, abs_tol) # In the worst case, when the absolute tolerance is eps*lambda_est_max and # lambda_est_max = -lambda_est_min, we have to take as many bisection steps # as there are bits in the mantissa plus 1. # The proof is left as an exercise to the reader. max_it = finfo.nmant + 1 # Determine the indices of the desired eigenvalues, based on select and # select_range. if select == 'a': target_counts = jnp.arange(n, dtype=jnp.int32) elif select == 'i': if select_range[0] > select_range[1]: raise ValueError('Got empty index range in select_range.') target_counts = jnp.arange(select_range[0], select_range[1] + 1, dtype=jnp.int32) elif select == 'v': # TODO(phawkins): requires dynamic shape support. raise NotImplementedError("eigh_tridiagonal(..., select='v') is not " "implemented") else: raise ValueError("'select must have a value in {'a', 'i', 'v'}.") # Run binary search for all desired eigenvalues in parallel, starting from # the interval lightly wider than the estimated # [lambda_est_min, lambda_est_max]. fudge = 2.1 # We widen starting interval the Gershgorin interval a bit. norm_slack = jnp.array(n, alpha.dtype) * fudge * finfo.eps * t_norm lower = lambda_est_min - norm_slack - 2 * fudge * pivmin upper = lambda_est_max + norm_slack + fudge * pivmin # Pre-broadcast the scalars used in the Sturm sequence for improved # performance. target_shape = jnp.shape(target_counts) lower = jnp.broadcast_to(lower, shape=target_shape) upper = jnp.broadcast_to(upper, shape=target_shape) mid = 0.5 * (upper + lower) pivmin = jnp.broadcast_to(pivmin, target_shape) alpha0_perturbation = jnp.broadcast_to(alpha0_perturbation, target_shape) # Start parallel binary searches. def cond(args): i, lower, _, upper = args return jnp.logical_and(jnp.less(i, max_it), jnp.less(abs_tol, jnp.amax(upper - lower))) def body(args): i, lower, mid, upper = args counts = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, mid) lower = jnp.where(counts <= target_counts, mid, lower) upper = jnp.where(counts > target_counts, mid, upper) mid = 0.5 * (lower + upper) return i + 1, lower, mid, upper _, _, mid, _ = lax.while_loop(cond, body, (0, lower, mid, upper)) return mid
def ppf(q, loc=0, scale=1): return jnp.array(special.ndtri(q) * scale + loc, 'float64')
def _eigh_work(H, n, termination_size=256): """ The main work loop performing the symmetric eigendecomposition of H. Each step recursively computes a projector into the space of eigenvalues above jnp.mean(jnp.diag(H)). The result of the projections into and out of that space, along with the isometries accomplishing these, are then computed. This is performed recursively until the projections have size 1, and thus store an eigenvalue of the original input; the corresponding isometry is the related eigenvector. The results are then composed. This function cannot be Jitted because the internal split_spectrum cannot be. Args: H: The Hermitian input. n: The true (dynamic) shape of H. Returns: H, V: The result of the projection. """ # We turn what was originally a recursive algorithm into an iterative # algorithm with an explicit stack. N, _ = H.shape n = jnp.asarray(n, jnp.int32) agenda = Stack.create( N + 1, _Subproblem(jnp.array(0, jnp.int32), jnp.array(0, jnp.int32))) agenda = agenda.push(_Subproblem(offset=jnp.int32(0), size=n)) # eigenvectors is the array in which we build the output eigenvectors. # We initialize it with the identity matrix so the initial matrix # multiplications in_split_spectrum_jittable are the identity. eigenvectors = jnp.eye(N, dtype=H.dtype) # blocks is an array representing a stack of Hermitian matrix blocks that we # need to recursively decompose. Subproblems are different sizes, so the stack # of blocks is ragged. Subproblems are left-aligned (i.e. starting at the 0th # column). Here is an ASCII art picture of three blocks A, B, C, embedded # in the larger `blocks` workspace (represented with trailing dots). # # A A A . . . # A A A . . . # A A A . . . # B B . . . . # B B . . . . # C C C C . . # C C C C . . # C C C C . . # C C C C . . # # Each step of the algorithm subdivides a block into two subblocks whose # sizes sum to the original block size. We overwrite the original block with # those two subblocks so we don't need any additional scratch space. # # At termination, "blocks" will contain 1x1 blocks (i.e., the eigenvalues) in # its first column. blocks = H def base_case(B, offset, b, agenda, blocks, eigenvectors): # Base case: for blocks under a minimum size, we cutoff the recursion # and call the TPU Jacobi eigendecomposition implementation. The Jacobi # algorithm works well for small matrices but scales poorly, so the two # complement each other well. H = _slice(blocks, (offset, 0), (b, b), (B, B)) V = _slice(eigenvectors, (0, offset), (n, b), (N, B)) # We replace the masked-out part of the matrix with the identity matrix. # We know that the TPU Jacobi eigh implementation will not alter the order # of the eigenvalues, so we know the eigendecomposition of the original # matrix is in the top-left corner of the eigendecomposition of the padded # matrix. # It is very important that the underlying eigh implementation does not sort # the eigenvalues for this reason! This is currently not true of JAX's CPU # and GPU eigendecompositions, and for those platforms this algorithm will # only do the right thing if termination_size == 1. H = _mask(H, (b, b), jnp.eye(B, dtype=H.dtype)) eig_vecs, eig_vals = lax.linalg.eigh(H, sort_eigenvalues=False) eig_vecs = _mask(eig_vecs, (b, b)) eig_vals = _mask(eig_vals, (b,)) eig_vecs = jnp.dot(V, eig_vecs) blocks = _update_slice(blocks, eig_vals[:, None], (offset, 0), (b, b)) eigenvectors = _update_slice(eigenvectors, eig_vecs, (0, offset), (n, b)) return agenda, blocks, eigenvectors def recursive_case(B, offset, b, agenda, blocks, eigenvectors): # The recursive case of the algorithm, specialized to a static block size # of B. H = _slice(blocks, (offset, 0), (b, b), (B, B)) V = _slice(eigenvectors, (0, offset), (n, b), (N, B)) split_point = jnp.nanmedian(_mask(jnp.diag(H), (b,), jnp.nan)) # TODO: Improve this? H_minus, V_minus, H_plus, V_plus, rank = split_spectrum(H, b, split_point, V0=V) blocks = _update_slice(blocks, H_minus, (offset, 0), (rank, rank)) blocks = _update_slice(blocks, H_plus, (offset + rank, 0), (b - rank, b - rank)) eigenvectors = _update_slice(eigenvectors, V_minus, (0, offset), (n, rank)) eigenvectors = _update_slice(eigenvectors, V_plus, (0, offset + rank), (n, b - rank)) agenda = agenda.push(_Subproblem(offset + rank, (b - rank))) agenda = agenda.push(_Subproblem(offset, rank)) return agenda, blocks, eigenvectors def loop_cond(state): agenda, _, _ = state return ~agenda.empty() # It would be wasteful to perform all computation padded up to the original # matrix size. Instead, we form buckets of padded sizes e.g., # [256, 512, 1024, ..., N], aiming for a balance between compilation time # and runtime. cutoff = min(N, termination_size) buckets = [cutoff] branches = [partial(base_case, cutoff)] i = cutoff while i < N: i = min(2 * i, N) buckets.append(i) branches.append(partial(recursive_case, i)) buckets = jnp.array(buckets) def loop_body(state): agenda, blocks, eigenvectors = state (offset, b), agenda = agenda.pop() which = jnp.where(buckets < b, jnp.iinfo(jnp.int32).max, buckets) choice = jnp.argmin(which) return lax.switch(choice, branches, offset, b, agenda, blocks, eigenvectors) _, blocks, eigenvectors = lax.while_loop( loop_cond, loop_body, (agenda, blocks, eigenvectors)) return blocks[:, 0], eigenvectors