def fftfreq(n, d=1.0): if isinstance(n, (list, tuple)): raise ValueError( "The n argument of jax.numpy.fft.fftfreq only takes an int. " "Got n = %s." % list(n)) elif isinstance(d, (list, tuple)): raise ValueError( "The d argument of jax.numpy.fft.fftfreq only takes a single value. " "Got d = %s." % list(d)) k = jnp.zeros(n) if n % 2 == 0: # k[0: n // 2 - 1] = jnp.arange(0, n // 2 - 1) k = k.at[0: n // 2].set( jnp.arange(0, n // 2)) # k[n // 2:] = jnp.arange(-n // 2, -1) k = k.at[n // 2:].set( jnp.arange(-n // 2, 0)) else: # k[0: (n - 1) // 2] = jnp.arange(0, (n - 1) // 2) k = k.at[0: (n - 1) // 2 + 1].set(jnp.arange(0, (n - 1) // 2 + 1)) # k[(n - 1) // 2 + 1:] = jnp.arange(-(n - 1) // 2, -1) k = k.at[(n - 1) // 2 + 1:].set(jnp.arange(-(n - 1) // 2, 0)) return k / (d * n)
def body(k, state): pivot, perm, a = state m_idx = jnp.arange(m) n_idx = jnp.arange(n) if jnp.issubdtype(a.dtype, jnp.complexfloating): t = a[:, k] magnitude = jnp.abs(jnp.real(t)) + jnp.abs(jnp.imag(t)) else: magnitude = jnp.abs(a[:, k]) i = jnp.argmax(jnp.where(m_idx >= k, magnitude, -jnp.inf)) pivot = ops.index_update(pivot, ops.index[k], i) a = ops.index_update(a, ops.index[[k, i], ], a[[i, k], ]) perm = ops.index_update(perm, ops.index[[i, k], ], perm[[k, i], ]) # a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes x = a[k, k] a = ops.index_update(a, ops.index[:, k], jnp.where(m_idx > k, a[:, k] / x, a[:, k])) # a[k+1:, k+1:] -= jnp.outer(a[k+1:, k], a[k, k+1:]) a = a - jnp.where( (m_idx[:, None] > k) & (n_idx > k), jnp.outer(a[:, k], a[k, :]), jnp.array(0, dtype=a.dtype)) return pivot, perm, a
def setdiff1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None): _check_arraylike("setdiff1d", ar1, ar2) if size is None: ar1 = core.concrete_or_error(None, ar1, "The error arose in setdiff1d()") else: size = core.concrete_or_error(operator.index, size, "The error arose in setdiff1d()") ar1 = asarray(ar1) fill_value = asarray(0 if fill_value is None else fill_value, dtype=ar1.dtype) if ar1.size == 0: return full_like(ar1, fill_value, shape=size or 0) if not assume_unique: ar1 = unique(ar1, size=size and ar1.size) mask = in1d(ar1, ar2, invert=True) if size is None: return ar1[mask] else: if not (assume_unique or size is None): # Set mask to zero at locations corresponding to unique() padding. n_unique = ar1.size + 1 - (ar1 == ar1[0]).sum() mask = where(arange(ar1.size) < n_unique, mask, False) return where( arange(size) < mask.sum(), ar1[where(mask, size=size)], fill_value)
def _update_T_Z(m, T, Z): mu = np_linalg.eigvals(lax.dynamic_slice(T, (m - 1, m - 1), (2, 2))) - T[m, m] r = np_linalg.norm(jnp.array([mu[0], T[m, m - 1]])).astype(T.dtype) c = mu[0] / r s = T[m, m - 1] / r G = jnp.array([[c.conj(), s], [-s, c]], dtype=T.dtype) # T[m-1:m+1, m-1:] = G @ T[m-1:m+1, m-1:] T_rows = lax.dynamic_slice_in_dim(T, m - 1, 2, axis=0) col_mask = jnp.arange(N) >= m - 1 G_dot_T_zeroed_cols = G @ jnp.where(col_mask, T_rows, 0) T_rows_new = jnp.where(~col_mask, T_rows, G_dot_T_zeroed_cols) T = lax.dynamic_update_slice_in_dim(T, T_rows_new, m - 1, axis=0) # T[:m+1, m-1:m+1] = T[:m+1, m-1:m+1] @ G.conj().T T_cols = lax.dynamic_slice_in_dim(T, m - 1, 2, axis=1) row_mask = jnp.arange(N)[:, jnp.newaxis] < m + 1 T_zeroed_rows_dot_GH = jnp.where(row_mask, T_cols, 0) @ G.conj().T T_cols_new = jnp.where(~row_mask, T_cols, T_zeroed_rows_dot_GH) T = lax.dynamic_update_slice_in_dim(T, T_cols_new, m - 1, axis=1) # Z[:, m-1:m+1] = Z[:, m-1:m+1] @ G.conj().T Z_cols = lax.dynamic_slice_in_dim(Z, m - 1, 2, axis=1) Z = lax.dynamic_update_slice_in_dim(Z, Z_cols @ G.conj().T, m - 1, axis=1) return T, Z
def polyder(p, m=1): _check_arraylike("polyder", p) m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyder") p, = _promote_dtypes_inexact(p) if m < 0: raise ValueError("Order of derivative must be positive") if m == 0: return p coeff = (arange(len(p), m, -1)[np.newaxis, :] - 1 - arange(m)[:, np.newaxis]).prod(0) return p[:-m] * coeff
def _inner_loop(i, p_F_minden): p, F, minden = p_F_minden j = i+p s = T[i-1, j-1] * (F[j-1, j-1] - F[i-1, i-1]) T_row, T_col = T[i-1], T[:, j-1] F_row, F_col = F[i-1], F[:, j-1] ind = (jnp.arange(N) >= i) & (jnp.arange(N) < j-1) val = (jnp.where(ind, T_row, 0) @ jnp.where(ind, F_col, 0) - jnp.where(ind, F_row, 0) @ jnp.where(ind, T_col, 0)) s = s + val den = T[j-1, j-1] - T[i-1, i-1] s = jnp.where(den != 0, s / den, s) F = F.at[i-1, j-1].set(s) minden = jnp.minimum(minden, jnp.abs(den)) return p, F, minden
def detrend(data, axis=-1, type='linear', bp=0, overwrite_data=None): if overwrite_data is not None: raise NotImplementedError("overwrite_data argument not implemented.") if type not in ['constant', 'linear']: raise ValueError("Trend type must be 'linear' or 'constant'.") data, = _promote_dtypes_inexact(jnp.asarray(data)) if type == 'constant': return data - data.mean(axis, keepdims=True) else: N = data.shape[axis] # bp is static, so we use np operations to avoid pushing to device. bp = np.sort(np.unique(np.r_[0, bp, N])) if bp[0] < 0 or bp[-1] > N: raise ValueError( "Breakpoints must be non-negative and less than length of data along given axis." ) data = jnp.moveaxis(data, axis, 0) shape = data.shape data = data.reshape(N, -1) for m in range(len(bp) - 1): Npts = bp[m + 1] - bp[m] A = jnp.vstack([ jnp.ones(Npts, dtype=data.dtype), jnp.arange(1, Npts + 1, dtype=data.dtype) / Npts ]).T sl = slice(bp[m], bp[m + 1]) coef, *_ = linalg.lstsq(A, data[sl]) data = data.at[sl].add( -jnp.matmul(A, coef, precision=lax.Precision.HIGHEST)) return jnp.moveaxis(data.reshape(shape), 0, axis)
def _lu_blocked(a, block_size=128): """Blocked LU decomposition, as an unrolled loop.""" m, n = a.shape r = min(m, n) pivot = jnp.zeros((r, ), dtype=jnp.int32) perm = jnp.arange(m, dtype=jnp.int32) for k in range(0, r, block_size): b = min(r - k, block_size) block_pivot, block_perm, lu_block = _lu_unblocked(a[k:, k:k + b]) pivot = ops.index_update(pivot, ops.index[k:k + b], block_pivot + k) perm = ops.index_update(perm, ops.index[k:], perm[block_perm + k]) a = ops.index_update(a, ops.index[k:, :], a[block_perm + k, :]) a = ops.index_update(a, ops.index[k:, k:k + b], lu_block) if k + b < n: a = ops.index_update( a, ops.index[k:k + b, k + b:], triangular_solve(a[k:k + b, k:k + b], a[k:k + b, k + b:], left_side=True, lower=True, unit_diagonal=True)) a = ops.index_add( a, ops.index[k + b:, k + b:], -lax.dot(a[k + b:, k:k + b], a[k:k + b, k + b:], precision=lax.Precision.HIGHEST)) return a, pivot, perm
def polyint(p, m=1, k=None): m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint") k = 0 if k is None else k _check_arraylike("polyint", p, k) p, k = _promote_dtypes_inexact(p, k) if m < 0: raise ValueError("Order of integral must be positive (see polyder)") k = atleast_1d(k) if len(k) == 1: k = full((m,), k[0]) if k.shape != (m,): raise ValueError("k must be a scalar or a rank-1 array of length 1 or m.") if m == 0: return p else: coeff = maximum(1, arange(len(p) + m, 0, -1)[np.newaxis, :] - 1 - arange(m)[:, np.newaxis]).prod(0) return true_divide(concatenate((p, k)), coeff)
def rfftfreq(n, d=1.0): if isinstance(n, (list, tuple)): raise ValueError( "The n argument of jax.numpy.fft.rfftfreq only takes an int. " "Got n = %s." % list(n)) elif isinstance(d, (list, tuple)): raise ValueError( "The d argument of jax.numpy.fft.rfftfreq only takes a single value. " "Got d = %s." % list(d)) if n % 2 == 0: k = jnp.arange(0, n // 2 + 1) else: k = jnp.arange(0, (n - 1) // 2 + 1) return k / (d * n)
def _segment_update(name: str, data: Array, segment_ids: Array, scatter_op: Callable, num_segments: Optional[int] = None, indices_are_sorted: bool = False, unique_indices: bool = False, bucket_size: Optional[int] = None, reducer: Optional[Callable] = None, mode: Optional[lax.GatherScatterMode] = None) -> Array: jnp._check_arraylike(name, data, segment_ids) mode = lax.GatherScatterMode.FILL_OR_DROP if mode is None else mode data = jnp.asarray(data) segment_ids = jnp.asarray(segment_ids) dtype = data.dtype if num_segments is None: num_segments = jnp.max(segment_ids) + 1 num_segments = core.concrete_or_error( int, num_segments, "segment_sum() `num_segments` argument.") if num_segments is not None and num_segments < 0: raise ValueError("num_segments must be non-negative.") num_buckets = 1 if bucket_size is None \ else util.ceil_of_ratio(segment_ids.size, bucket_size) if num_buckets == 1: out = jnp.full((num_segments, ) + data.shape[1:], _get_identity(scatter_op, dtype), dtype=dtype) return _scatter_update(out, segment_ids, data, scatter_op, indices_are_sorted, unique_indices, normalize_indices=False, mode=mode) # Bucketize indices and perform segment_update on each bucket to improve # numerical stability for operations like product and sum. assert reducer is not None out = jnp.full((num_buckets, num_segments) + data.shape[1:], _get_identity(scatter_op, dtype), dtype=dtype) out = _scatter_update( out, np.index_exp[lax.div(jnp.arange(segment_ids.shape[0]), bucket_size), segment_ids[None, :]], data, scatter_op, indices_are_sorted, unique_indices, normalize_indices=False, mode=mode) return reducer(out, axis=0).astype(dtype)
def _roots_with_zeros(p, num_leading_zeros): # Avoid lapack errors when p is all zero p = _where(len(p) == num_leading_zeros, 1.0, p) # Roll any leading zeros to the end & compute the roots roots = _roots_no_zeros(roll(p, -num_leading_zeros)) # Sort zero roots to the end. roots = lax.sort_key_val(roots == 0, roots)[1] # Set roots associated with num_leading_zeros to NaN return _where( arange(roots.size) < roots.size - num_leading_zeros, roots, complex(np.nan, np.nan))
def multigammaln(a, d): d = core.concrete_or_error(int, d, "d argument of multigammaln") a, d = _promote_args_inexact("multigammaln", a, d) constant = lax.mul(lax.mul(lax.mul(_constant_like(a, 0.25), d), lax.sub(d, _constant_like(a, 1))), lax.log(_constant_like(a, np.pi))) res = jnp.sum(gammaln(jnp.expand_dims(a, axis=-1) - lax.div(jnp.arange(d), _constant_like(a, 2))), axis=-1) return res + constant
def multigammaln(a, d): a, = _promote_args_inexact("multigammaln", a) d = lax.convert_element_type(d, lax.dtype(a)) constant = lax.mul( lax.mul(lax.mul(_constant_like(a, 0.25), d), lax.sub(d, _constant_like(a, 1))), lax.log(_constant_like(a, np.pi))) res = jnp.sum(gammaln( jnp.expand_dims(a, axis=-1) - lax.div(jnp.arange(d), _constant_like(a, 2))), axis=-1) return res + constant
def _make_1d_grid_from_slice(s: slice, op_name: str): start = core.concrete_or_error(None, s.start, f"slice start of jnp.{op_name}") or 0 stop = core.concrete_or_error(None, s.stop, f"slice stop of jnp.{op_name}") step = core.concrete_or_error(None, s.step, f"slice step of jnp.{op_name}") or 1 if np.iscomplex(step): newobj = linspace(start, stop, int(abs(step))) else: newobj = arange(start, stop, step) return newobj
def _squaring(R, n_squarings): # squaring step to undo scaling def _squaring_precise(x): return _precise_dot(x, x) def _identity(x): return x def _scan_f(c, i): return lax.cond(i < n_squarings, _squaring_precise, _identity, c), None res, _ = lax.scan(_scan_f, R, jnp.arange(16)) return res
def multigammaln(a, d): d = core.concrete_or_error(int, d, "d argument of multigammaln") a, d_ = _promote_args_inexact("multigammaln", a, d) constant = lax.mul( lax.mul(lax.mul(lax._const(a, 0.25), d_), lax.sub(d_, lax._const(a, 1))), lax.log(lax._const(a, np.pi))) b = lax.div(jnp.arange(d, dtype=d_.dtype), lax._const(a, 2)) res = jnp.sum(gammaln( jnp.expand_dims(a, axis=-1) - jnp.expand_dims(b, axis=tuple(range(a.ndim)))), axis=-1) return res + constant
def _lu(a, permute_l): a = np_linalg._promote_arg_dtypes(jnp.asarray(a)) lu, pivots, permutation = lax_linalg.lu(a) dtype = lax.dtype(a) m, n = jnp.shape(a) p = jnp.real(jnp.array(permutation == jnp.arange(m)[:, None], dtype=dtype)) k = min(m, n) l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype) u = jnp.triu(lu)[:k, :] if permute_l: return jnp.matmul(p, l), u else: return p, l, u
def _lu(a, permute_l): a, = _promote_dtypes_inexact(jnp.asarray(a)) lu, _, permutation = lax_linalg.lu(a) dtype = lax.dtype(a) m, n = jnp.shape(a) p = jnp.real( jnp.array(permutation[None, :] == jnp.arange( m, dtype=permutation.dtype)[:, None], dtype=dtype)) k = min(m, n) l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype) u = jnp.triu(lu)[:k, :] if permute_l: return jnp.matmul(p, l), u else: return p, l, u
def _slogdet_lu(a): dtype = lax.dtype(a) lu, pivot, _ = lax_linalg.lu(a) diag = jnp.diagonal(lu, axis1=-2, axis2=-1) is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1) iota = lax.expand_dims(jnp.arange(a.shape[-1]), range(pivot.ndim - 1)) parity = jnp.count_nonzero(pivot != iota, axis=-1) if jnp.iscomplexobj(a): sign = jnp.prod(diag / jnp.abs(diag), axis=-1) else: sign = jnp.array(1, dtype=dtype) parity = parity + jnp.count_nonzero(diag < 0, axis=-1) sign = jnp.where(is_zero, jnp.array(0, dtype=dtype), sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype)) logdet = jnp.where(is_zero, jnp.array(-jnp.inf, dtype=dtype), jnp.sum(jnp.log(jnp.abs(diag)), axis=-1)) return sign, jnp.real(logdet)
def _sph_harm(m: jnp.ndarray, n: jnp.ndarray, theta: jnp.ndarray, phi: jnp.ndarray, n_max: int) -> jnp.ndarray: """Computes the spherical harmonics.""" cos_colatitude = jnp.cos(phi) legendre = _gen_associated_legendre(n_max, cos_colatitude, True) legendre_val = legendre[abs(m), n, jnp.arange(len(n))] angle = abs(m) * theta vandermonde = lax.complex(jnp.cos(angle), jnp.sin(angle)) harmonics = lax.complex(legendre_val * jnp.real(vandermonde), legendre_val * jnp.imag(vandermonde)) # Negative order. harmonics = jnp.where(m < 0, (-1.0)**abs(m) * jnp.conjugate(harmonics), harmonics) return harmonics
def _median_bias(n): """ Returns the bias of the median of a set of periodograms relative to the mean. See Appendix B from [1]_ for details. Args: n : int Numbers of periodograms being averaged. Returns: bias : float Calculated bias. References: .. [1] B. Allen, W.G. Anderson, P.R. Brady, D.A. Brown, J.D.E. Creighton. "FINDCHIRP: an algorithm for detection of gravitational waves from inspiraling compact binaries", Physical Review D 85, 2012, :arxiv:`gr-qc/0509116` """ ii_2 = jnp.arange(2., n, 2) return 1 + jnp.sum(1. / (ii_2 + 1) - 1. / ii_2)
def slogdet(a): a = _promote_arg_dtypes(jnp.asarray(a)) dtype = lax.dtype(a) a_shape = jnp.shape(a) if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]: msg = "Argument to slogdet() must have shape [..., n, n], got {}" raise ValueError(msg.format(a_shape)) lu, pivot, _ = lax_linalg.lu(a) diag = jnp.diagonal(lu, axis1=-2, axis2=-1) is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1) parity = jnp.count_nonzero(pivot != jnp.arange(a_shape[-1]), axis=-1) if jnp.iscomplexobj(a): sign = jnp.prod(diag / jnp.abs(diag), axis=-1) else: sign = jnp.array(1, dtype=dtype) parity = parity + jnp.count_nonzero(diag < 0, axis=-1) sign = jnp.where(is_zero, jnp.array(0, dtype=dtype), sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype)) logdet = jnp.where(is_zero, jnp.array(-jnp.inf, dtype=dtype), jnp.sum(jnp.log(jnp.abs(diag)), axis=-1)) return sign, jnp.real(logdet)
def _lu_unblocked(a): """Unblocked LU decomposition, as a rolled loop.""" m, n = a.shape def body(k, state): pivot, perm, a = state m_idx = jnp.arange(m) n_idx = jnp.arange(n) if jnp.issubdtype(a.dtype, jnp.complexfloating): t = a[:, k] magnitude = jnp.abs(jnp.real(t)) + jnp.abs(jnp.imag(t)) else: magnitude = jnp.abs(a[:, k]) i = jnp.argmax(jnp.where(m_idx >= k, magnitude, -jnp.inf)) pivot = ops.index_update(pivot, ops.index[k], i) a = ops.index_update(a, ops.index[[k, i], ], a[[i, k], ]) perm = ops.index_update(perm, ops.index[[i, k], ], perm[[k, i], ]) # a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes x = a[k, k] a = ops.index_update(a, ops.index[:, k], jnp.where(m_idx > k, a[:, k] / x, a[:, k])) # a[k+1:, k+1:] -= jnp.outer(a[k+1:, k], a[k, k+1:]) a = a - jnp.where( (m_idx[:, None] > k) & (n_idx > k), jnp.outer(a[:, k], a[k, :]), jnp.array(0, dtype=a.dtype)) return pivot, perm, a pivot = jnp.zeros((min(m, n), ), dtype=jnp.int32) perm = jnp.arange(m, dtype=jnp.int32) if m == 0 and n == 0: # If the array is empty, the loop body never executes but tracing it to a # jaxpr fails because the indexing cannot succeed. return (pivot, perm, a) return lax.fori_loop(0, min(m, n), body, (pivot, perm, a))
def _cofactor_solve(a, b): """Equivalent to det(a)*solve(a, b) for nonsingular mat. Intermediate function used for jvp and vjp of det. This function borrows heavily from jax.numpy.linalg.solve and jax.numpy.linalg.slogdet to compute the gradient of the determinant in a way that is well defined even for low rank matrices. This function handles two different cases: * rank(a) == n or n-1 * rank(a) < n-1 For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix. Rather than computing det(a)*solve(a, b), which would return NaN, we work directly with the LU decomposition. If a = p @ l @ u, then det(a)*solve(a, b) = prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b = prod(diag(u)) * triangular_solve(u, solve(p @ l, b)) If a is rank n-1, then the lower right corner of u will be zero and the triangular_solve will fail. Let x = solve(p @ l, b) and y = det(a)*solve(a, b). Then y_{n} x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) = x_{n} * prod_{i=1...n-1}(u_{ii}) So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1 we can avoid the triangular_solve failing. To correctly compute the rest of y_{i} for i != n, we simply multiply x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1. For the second case, a check is done on the matrix to see if `solve` returns NaN or Inf, and gives a matrix of zeros as a result, as the gradient of the determinant of a matrix with rank less than n-1 is 0. This will still return the correct value for rank n-1 matrices, as the check is applied *after* the lower right corner of u has been updated. Args: a: A square matrix or batch of matrices, possibly singular. b: A matrix, or batch of matrices of the same dimension as a. Returns: det(a) and cofactor(a)^T*b, aka adjugate(a)*b """ a = _promote_arg_dtypes(jnp.asarray(a)) b = _promote_arg_dtypes(jnp.asarray(b)) a_shape = jnp.shape(a) b_shape = jnp.shape(b) a_ndims = len(a_shape) if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2] and b_shape[-2:] == a_shape[-2:]): msg = ("The arguments to _cofactor_solve must have shapes " "a=[..., m, m] and b=[..., m, m]; got a={} and b={}") raise ValueError(msg.format(a_shape, b_shape)) if a_shape[-1] == 1: return a[..., 0, 0], b # lu contains u in the upper triangular matrix and l in the strict lower # triangular matrix. # The diagonal of l is set to ones without loss of generality. lu, pivots, permutation = lax_linalg.lu(a) dtype = lax.dtype(a) batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2]) x = jnp.broadcast_to(b, batch_dims + b.shape[-2:]) lu = jnp.broadcast_to(lu, batch_dims + lu.shape[-2:]) # Compute (partial) determinant, ignoring last diagonal of LU diag = jnp.diagonal(lu, axis1=-2, axis2=-1) parity = jnp.count_nonzero(pivots != jnp.arange(a_shape[-1]), axis=-1) sign = jnp.asarray(-2 * (parity % 2) + 1, dtype=dtype) # partial_det[:, -1] contains the full determinant and # partial_det[:, -2] contains det(u) / u_{nn}. partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None] lu = lu.at[..., -1, -1].set(1.0 / partial_det[..., -2]) permutation = jnp.broadcast_to(permutation, batch_dims + (a_shape[-1], )) iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1, ))) # filter out any matrices that are not full rank d = jnp.ones(x.shape[:-1], x.dtype) d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False) d = jnp.any(jnp.logical_or(jnp.isnan(d), jnp.isinf(d)), axis=-1) d = jnp.tile(d[..., None, None], d.ndim * (1, ) + x.shape[-2:]) x = jnp.where(d, jnp.zeros_like(x), x) # first filter x = x[iotas[:-1] + (permutation, slice(None))] x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=True, unit_diagonal=True) x = jnp.concatenate( (x[..., :-1, :] * partial_det[..., -1, None, None], x[..., -1:, :]), axis=-2) x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False) x = jnp.where(d, jnp.zeros_like(x), x) # second filter return partial_det[..., -1], x
def _process_axis_index(self, frame): return batching.BatchTracer(self, lax_numpy.arange(frame.size, dtype=np.int32), 0)
def _gen_derivatives(p: jnp.ndarray, x: jnp.ndarray, is_normalized: bool) -> jnp.ndarray: """Generates derivatives of associated Legendre functions of the first kind. Args: p: The 3D array containing the values of associated Legendre functions; the dimensions are in the sequence of order (m), degree (l), and evalution points. x: A vector of type `float32` or `float64` containing the sampled points. is_normalized: True if the associated Legendre functions are normalized. Returns: The 3D array representing the derivatives of associated Legendre functions of the first kind. """ num_m, num_l, num_x = p.shape # p_{l-1}^m. p_m_lm1 = jnp.pad(p, ((0, 0), (1, 0), (0, 0)))[:, :num_l, :] # p_{l-1}^{m+2}. p_mp2_lm1 = jnp.pad(p_m_lm1, ((0, 2), (0, 0), (0, 0)))[2:num_m + 2, :, :] # p_{l-1}^{m-2}. p_mm2_lm1 = jnp.pad(p_m_lm1, ((2, 0), (0, 0), (0, 0)))[:num_m, :, :] # Derivative computation requires negative orders. if is_normalized: raise NotImplementedError( 'Negative orders for normalization is not implemented yet.') else: if num_l > 1: l_vec = jnp.arange(1, num_l - 1) p_p1 = p[1, 1:num_l - 1, :] coeff = -1.0 / ((l_vec + 1) * l_vec) update_p_p1 = jnp.einsum('i,ij->ij', coeff, p_p1) p_mm2_lm1 = p_mm2_lm1.at[ops.index[1, 2:num_l, :]].set(update_p_p1) if num_l > 2: l_vec = jnp.arange(2, num_l - 1) p_p2 = p[2, 2:num_l - 1, :] coeff = 1.0 / ((l_vec + 2) * (l_vec + 1) * l_vec) update_p_p2 = jnp.einsum('i,ij->ij', coeff, p_p2) p_mm2_lm1 = p_mm2_lm1.at[ops.index[0, 3:num_l, :]].set(update_p_p2) m_mat, l_mat = jnp.mgrid[:num_m, :num_l] coeff_zeros = jnp.zeros((num_m, num_l)) upper_0_indices = jnp.triu_indices(num_m, 0, num_l) zero_vec = jnp.zeros((num_l, )) a0 = -0.5 / (m_mat - 1.0) a0_masked = coeff_zeros.at[upper_0_indices].set(a0[upper_0_indices]) a0_masked = a0_masked.at[1, :].set(zero_vec) b0 = l_mat + m_mat c0 = a0 * (b0 - 2.0) * (b0 - 1.0) c0_masked = coeff_zeros.at[upper_0_indices].set(c0[upper_0_indices]) c0_masked = c0_masked.at[1, :].set(zero_vec) # p_l^{m-1}. p_mm1_l = (jnp.einsum('ij,ijk->ijk', a0_masked, p_m_lm1) + jnp.einsum('ij,ijk->ijk', c0_masked, p_mm2_lm1)) d0 = -0.5 / (m_mat + 1.0) d0_masked = coeff_zeros.at[upper_0_indices].set(d0[upper_0_indices]) e0 = d0 * b0 * (b0 + 1.0) e0_masked = coeff_zeros.at[upper_0_indices].set(e0[upper_0_indices]) # p_l^{m+1}. p_mp1_l = (jnp.einsum('ij,ijk->ijk', d0_masked, p_mp2_lm1) + jnp.einsum('ij,ijk->ijk', e0_masked, p_m_lm1)) f0 = b0 * (l_mat - m_mat + 1.0) / 2.0 f0_masked = coeff_zeros.at[upper_0_indices].set(f0[upper_0_indices]) p_derivative = jnp.einsum('ij,ijk->ijk', f0_masked, p_mm1_l) - 0.5 * p_mp1_l # Special treatment of the singularity at m = 1. if num_m > 1: l_vec = jnp.arange(num_l) g0 = jnp.einsum('i,ij->ij', (l_vec + 1) * l_vec, p[0, :, :]) if num_l > 2: g0 = g0 - p[2, :, :] p_derivative_m0 = jnp.einsum('j,ij->ij', 0.5 / jnp.sqrt(1 - x * x), g0) p_derivative = p_derivative.at[1, :, :].set(p_derivative_m0) p_derivative = p_derivative.at[1, 0, :].set(jnp.zeros((num_x, ))) return p_derivative
def eigh_tridiagonal(d, e, *, eigvals_only=False, select='a', select_range=None, tol=None): if not eigvals_only: raise NotImplementedError( "Calculation of eigenvectors is not implemented") def _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, x): """Implements the Sturm sequence recurrence.""" n = alpha.shape[0] zeros = jnp.zeros(x.shape, dtype=jnp.int32) ones = jnp.ones(x.shape, dtype=jnp.int32) # The first step in the Sturm sequence recurrence # requires special care if x is equal to alpha[0]. def sturm_step0(): q = alpha[0] - x count = jnp.where(q < 0, ones, zeros) q = jnp.where(alpha[0] == x, alpha0_perturbation, q) return q, count # Subsequent steps all take this form: def sturm_step(i, q, count): q = alpha[i] - beta_sq[i - 1] / q - x count = jnp.where(q <= pivmin, count + 1, count) q = jnp.where(q <= pivmin, jnp.minimum(q, -pivmin), q) return q, count # The first step initializes q and count. q, count = sturm_step0() # Peel off ((n-1) % blocksize) steps from the main loop, so we can run # the bulk of the iterations unrolled by a factor of blocksize. blocksize = 16 i = 1 peel = (n - 1) % blocksize unroll_cnt = peel def unrolled_steps(args): start, q, count = args for j in range(unroll_cnt): q, count = sturm_step(start + j, q, count) return start + unroll_cnt, q, count i, q, count = unrolled_steps((i, q, count)) # Run the remaining steps of the Sturm sequence using a partially # unrolled while loop. unroll_cnt = blocksize def cond(iqc): i, q, count = iqc return jnp.less(i, n) _, _, count = lax.while_loop(cond, unrolled_steps, (i, q, count)) return count alpha = jnp.asarray(d) beta = jnp.asarray(e) supported_dtypes = (jnp.float32, jnp.float64, jnp.complex64, jnp.complex128) if alpha.dtype != beta.dtype: raise TypeError( "diagonal and off-diagonal values must have same dtype, " f"got {alpha.dtype} and {beta.dtype}") if alpha.dtype not in supported_dtypes or beta.dtype not in supported_dtypes: raise TypeError( "Only float32 and float64 inputs are supported as inputs " "to jax.scipy.linalg.eigh_tridiagonal, got " f"{alpha.dtype} and {beta.dtype}") n = alpha.shape[0] if n <= 1: return jnp.real(alpha) if jnp.issubdtype(alpha.dtype, jnp.complexfloating): alpha = jnp.real(alpha) beta_sq = jnp.real(beta * jnp.conj(beta)) beta_abs = jnp.sqrt(beta_sq) else: beta_abs = jnp.abs(beta) beta_sq = jnp.square(beta) # Estimate the largest and smallest eigenvalues of T using the Gershgorin # circle theorem. off_diag_abs_row_sum = jnp.concatenate( [beta_abs[:1], beta_abs[:-1] + beta_abs[1:], beta_abs[-1:]], axis=0) lambda_est_max = jnp.amax(alpha + off_diag_abs_row_sum) lambda_est_min = jnp.amin(alpha - off_diag_abs_row_sum) # Upper bound on 2-norm of T. t_norm = jnp.maximum(jnp.abs(lambda_est_min), jnp.abs(lambda_est_max)) # Compute the smallest allowed pivot in the Sturm sequence to avoid # overflow. finfo = np.finfo(alpha.dtype) one = np.ones([], dtype=alpha.dtype) safemin = np.maximum(one / finfo.max, (one + finfo.eps) * finfo.tiny) pivmin = safemin * jnp.maximum(1, jnp.amax(beta_sq)) alpha0_perturbation = jnp.square(finfo.eps * beta_abs[0]) abs_tol = finfo.eps * t_norm if tol is not None: abs_tol = jnp.maximum(tol, abs_tol) # In the worst case, when the absolute tolerance is eps*lambda_est_max and # lambda_est_max = -lambda_est_min, we have to take as many bisection steps # as there are bits in the mantissa plus 1. # The proof is left as an exercise to the reader. max_it = finfo.nmant + 1 # Determine the indices of the desired eigenvalues, based on select and # select_range. if select == 'a': target_counts = jnp.arange(n, dtype=jnp.int32) elif select == 'i': if select_range[0] > select_range[1]: raise ValueError('Got empty index range in select_range.') target_counts = jnp.arange(select_range[0], select_range[1] + 1, dtype=jnp.int32) elif select == 'v': # TODO(phawkins): requires dynamic shape support. raise NotImplementedError("eigh_tridiagonal(..., select='v') is not " "implemented") else: raise ValueError("'select must have a value in {'a', 'i', 'v'}.") # Run binary search for all desired eigenvalues in parallel, starting from # the interval lightly wider than the estimated # [lambda_est_min, lambda_est_max]. fudge = 2.1 # We widen starting interval the Gershgorin interval a bit. norm_slack = jnp.array(n, alpha.dtype) * fudge * finfo.eps * t_norm lower = lambda_est_min - norm_slack - 2 * fudge * pivmin upper = lambda_est_max + norm_slack + fudge * pivmin # Pre-broadcast the scalars used in the Sturm sequence for improved # performance. target_shape = jnp.shape(target_counts) lower = jnp.broadcast_to(lower, shape=target_shape) upper = jnp.broadcast_to(upper, shape=target_shape) mid = 0.5 * (upper + lower) pivmin = jnp.broadcast_to(pivmin, target_shape) alpha0_perturbation = jnp.broadcast_to(alpha0_perturbation, target_shape) # Start parallel binary searches. def cond(args): i, lower, _, upper = args return jnp.logical_and(jnp.less(i, max_it), jnp.less(abs_tol, jnp.amax(upper - lower))) def body(args): i, lower, mid, upper = args counts = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, mid) lower = jnp.where(counts <= target_counts, mid, lower) upper = jnp.where(counts > target_counts, mid, upper) mid = 0.5 * (lower + upper) return i + 1, lower, mid, upper _, _, mid, _ = lax.while_loop(cond, body, (0, lower, mid, upper)) return mid
def _gen_associated_legendre(l_max: int, x: jnp.ndarray, is_normalized: bool) -> jnp.ndarray: r"""Computes associated Legendre functions (ALFs) of the first kind. The ALFs of the first kind are used in spherical harmonics. The spherical harmonic of degree `l` and order `m` can be written as `Y_l^m(θ, φ) = N_l^m * P_l^m(cos(θ)) * exp(i m φ)`, where `N_l^m` is the normalization factor and θ and φ are the colatitude and longitude, repectively. `N_l^m` is chosen in the way that the spherical harmonics form a set of orthonormal basis function of L^2(S^2). For the computational efficiency of spherical harmonics transform, the normalization factor is used in the computation of the ALFs. In addition, normalizing `P_l^m` avoids overflow/underflow and achieves better numerical stability. Three recurrence relations are used in the computation. Args: l_max: The maximum degree of the associated Legendre function. Both the degrees and orders are `[0, 1, 2, ..., l_max]`. x: A vector of type `float32`, `float64` containing the sampled points in spherical coordinates, at which the ALFs are computed; `x` is essentially `cos(θ)`. For the numerical integration used by the spherical harmonics transforms, `x` contains the quadrature points in the interval of `[-1, 1]`. There are several approaches to provide the quadrature points: Gauss-Legendre method (`scipy.special.roots_legendre`), Gauss-Chebyshev method (`scipy.special.roots_chebyu`), and Driscoll & Healy method (Driscoll, James R., and Dennis M. Healy. "Computing Fourier transforms and convolutions on the 2-sphere." Advances in applied mathematics 15, no. 2 (1994): 202-250.). The Gauss-Legendre quadrature points are nearly equal-spaced along θ and provide exact discrete orthogonality, (P^m)^T W P_m = I, where `T` represents the transpose operation, `W` is a diagonal matrix containing the quadrature weights, and `I` is the identity matrix. The Gauss-Chebyshev points are equally spaced, which only provide approximate discrete orthogonality. The Driscoll & Healy qudarture points are equally spaced and provide the exact discrete orthogonality. The number of sampling points is required to be twice as the number of frequency points (modes) in the Driscoll & Healy approach, which enables FFT and achieves a fast spherical harmonics transform. is_normalized: True if the associated Legendre functions are normalized. With normalization, `N_l^m` is applied such that the spherical harmonics form a set of orthonormal basis functions of L^2(S^2). Returns: The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values of the ALFs at `x`; the dimensions in the sequence of order, degree, and evalution points. """ p = jnp.zeros((l_max + 1, l_max + 1, x.shape[0])) a_idx = jnp.arange(1, l_max + 1) b_idx = jnp.arange(l_max) if is_normalized: initial_value = 0.5 / jnp.sqrt(jnp.pi) # The initial value p(0,0). f_a = jnp.cumprod(-1 * jnp.sqrt(1.0 + 0.5 / a_idx)) f_b = jnp.sqrt(2.0 * b_idx + 3.0) else: initial_value = 1.0 # The initial value p(0,0). f_a = jnp.cumprod(1.0 - 2.0 * a_idx) f_b = 2.0 * b_idx + 1.0 p = p.at[(0, 0)].set(initial_value) # Compute the diagonal entries p(l,l) with recurrence. y = jnp.cumprod(jnp.broadcast_to(jnp.sqrt(1.0 - x * x), (l_max, x.shape[0])), axis=0) p_diag = initial_value * jnp.einsum('i,ij->ij', f_a, y) diag_indices = jnp.diag_indices(l_max + 1) p = p.at[(diag_indices[0][1:], diag_indices[1][1:])].set(p_diag) # Compute the off-diagonal entries with recurrence. p_offdiag = jnp.einsum('ij,ij->ij', jnp.einsum('i,j->ij', f_b, x), p[jnp.diag_indices(l_max)]) offdiag_indices = (diag_indices[0][:l_max], diag_indices[1][:l_max] + 1) p = p.at[offdiag_indices].set(p_offdiag) # Compute the remaining entries with recurrence. d0_mask_3d, d1_mask_3d = _gen_recurrence_mask(l_max, is_normalized=is_normalized) def body_fun(i, p_val): coeff_0 = d0_mask_3d[i] coeff_1 = d1_mask_3d[i] h = (jnp.einsum( 'ij,ijk->ijk', coeff_0, jnp.einsum('ijk,k->ijk', jnp.roll(p_val, shift=1, axis=1), x)) - jnp.einsum('ij,ijk->ijk', coeff_1, jnp.roll( p_val, shift=2, axis=1))) p_val = p_val + h return p_val if l_max > 1: p = lax.fori_loop(lower=2, upper=l_max + 1, body_fun=body_fun, init_val=p) return p
def _spectral_helper(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None, detrend_type='constant', return_onesided=True, scaling='density', axis=-1, mode='psd', boundary=None, padded=False): """LAX-backend implementation of `scipy.signal._spectral_helper`. Unlike the original helper function, `y` can be None for explicitly indicating auto-spectral (non cross-spectral) computation. In addition to this, `detrend` argument is renamed to `detrend_type` for avoiding internal name overlap. """ if mode not in ('psd', 'stft'): raise ValueError(f"Unknown value for mode {mode}, " "must be one of: ('psd', 'stft')") def make_pad(mode, **kwargs): def pad(x, n, axis=-1): pad_width = [(0, 0) for unused_n in range(x.ndim)] pad_width[axis] = (n, n) return jnp.pad(x, pad_width, mode, **kwargs) return pad boundary_funcs = { 'even': make_pad('reflect'), 'odd': odd_ext, 'constant': make_pad('edge'), 'zeros': make_pad('constant', constant_values=0.0), None: lambda x, *args, **kwargs: x } # Check/ normalize inputs if boundary not in boundary_funcs: raise ValueError(f"Unknown boundary option '{boundary}', " f"must be one of: {list(boundary_funcs.keys())}") axis = jax.core.concrete_or_error(operator.index, axis, "axis of windowed-FFT") axis = canonicalize_axis(axis, x.ndim) if nperseg is not None: # if specified by user nperseg = jax.core.concrete_or_error(int, nperseg, "nperseg of windowed-FFT") if nperseg < 1: raise ValueError('nperseg must be a positive integer') # parse window; if array like, then set nperseg = win.shape win, nperseg = signal_helper._triage_segments(window, nperseg, input_length=x.shape[axis]) if noverlap is None: noverlap = nperseg // 2 else: noverlap = jax.core.concrete_or_error(int, noverlap, "noverlap of windowed-FFT") if nfft is None: nfft = nperseg else: nfft = jax.core.concrete_or_error(int, nfft, "nfft of windowed-FFT") _check_arraylike("_spectral_helper", x) x = jnp.asarray(x) if y is None: outdtype = jax.dtypes.canonicalize_dtype( np.result_type(x, np.complex64)) else: _check_arraylike("_spectral_helper", y) y = jnp.asarray(y) outdtype = jax.dtypes.canonicalize_dtype( np.result_type(x, y, np.complex64)) if mode != 'psd': raise ValueError( "two-argument mode is available only when mode=='psd'") if x.ndim != y.ndim: raise ValueError( "two-arguments must have the same rank ({x.ndim} vs {y.ndim})." ) # Check if we can broadcast the outer axes together try: outershape = jnp.broadcast_shapes(tuple_delete(x.shape, axis), tuple_delete(y.shape, axis)) except ValueError as e: raise ValueError('x and y cannot be broadcast together.') from e # Special cases for size == 0 if y is None: if x.size == 0: return jnp.zeros(x.shape), jnp.zeros(x.shape), jnp.zeros(x.shape) else: if x.size == 0 or y.size == 0: outshape = tuple_insert(outershape, min([x.shape[axis], y.shape[axis]]), axis) emptyout = jnp.zeros(outshape) return emptyout, emptyout, emptyout # Move time-axis to the end if x.ndim > 1: if axis != -1: x = jnp.moveaxis(x, axis, -1) if y is not None and y.ndim > 1: y = jnp.moveaxis(y, axis, -1) # Check if x and y are the same length, zero-pad if necessary if y is not None: if x.shape[-1] != y.shape[-1]: if x.shape[-1] < y.shape[-1]: pad_shape = list(x.shape) pad_shape[-1] = y.shape[-1] - x.shape[-1] x = jnp.concatenate((x, jnp.zeros(pad_shape)), -1) else: pad_shape = list(y.shape) pad_shape[-1] = x.shape[-1] - y.shape[-1] y = jnp.concatenate((y, jnp.zeros(pad_shape)), -1) if nfft < nperseg: raise ValueError('nfft must be greater than or equal to nperseg.') if noverlap >= nperseg: raise ValueError('noverlap must be less than nperseg.') nstep = nperseg - noverlap # Apply paddings if boundary is not None: ext_func = boundary_funcs[boundary] x = ext_func(x, nperseg // 2, axis=-1) if y is not None: y = ext_func(y, nperseg // 2, axis=-1) if padded: # Pad to integer number of windowed segments # I.e make x.shape[-1] = nperseg + (nseg-1)*nstep, with integer nseg nadd = (-(x.shape[-1] - nperseg) % nstep) % nperseg zeros_shape = list(x.shape[:-1]) + [nadd] x = jnp.concatenate((x, jnp.zeros(zeros_shape)), axis=-1) if y is not None: zeros_shape = list(y.shape[:-1]) + [nadd] y = jnp.concatenate((y, jnp.zeros(zeros_shape)), axis=-1) # Handle detrending and window functions if not detrend_type: def detrend_func(d): return d elif not hasattr(detrend_type, '__call__'): def detrend_func(d): return detrend(d, type=detrend_type, axis=-1) elif axis != -1: # Wrap this function so that it receives a shape that it could # reasonably expect to receive. def detrend_func(d): d = jnp.moveaxis(d, axis, -1) d = detrend_type(d) return jnp.moveaxis(d, -1, axis) else: detrend_func = detrend_type if np.result_type(win, np.complex64) != outdtype: win = win.astype(outdtype) # Determine scale if scaling == 'density': scale = 1.0 / (fs * (win * win).sum()) elif scaling == 'spectrum': scale = 1.0 / win.sum()**2 else: raise ValueError(f'Unknown scaling: {scaling}') if mode == 'stft': scale = jnp.sqrt(scale) # Determine onesided/ two-sided if return_onesided: sides = 'onesided' if jnp.iscomplexobj(x) or jnp.iscomplexobj(y): sides = 'twosided' warnings.warn('Input data is complex, switching to ' 'return_onesided=False') else: sides = 'twosided' if sides == 'twosided': freqs = jax.numpy.fft.fftfreq(nfft, 1 / fs) elif sides == 'onesided': freqs = jax.numpy.fft.rfftfreq(nfft, 1 / fs) # Perform the windowed FFTs result = _fft_helper(x, win, detrend_func, nperseg, noverlap, nfft, sides) if y is not None: # All the same operations on the y data result_y = _fft_helper(y, win, detrend_func, nperseg, noverlap, nfft, sides) result = jnp.conjugate(result) * result_y elif mode == 'psd': result = jnp.conjugate(result) * result result *= scale if sides == 'onesided' and mode == 'psd': end = None if nfft % 2 else -1 result = result.at[..., 1:end].mul(2) time = jnp.arange(nperseg / 2, x.shape[-1] - nperseg / 2 + 1, nperseg - noverlap) / fs if boundary is not None: time -= (nperseg / 2) / fs result = result.astype(outdtype) # All imaginary parts are zero anyways if y is None and mode != 'stft': result = result.real # Move frequency axis back to axis where the data came from result = jnp.moveaxis(result, -1, axis) return freqs, time, result