def _unique_sorted_mask(ar, axis): aux = moveaxis(ar, axis, 0) if np.issubdtype(aux.dtype, np.complexfloating): # Work around issue in sorting of complex numbers with Nan only in the # imaginary component. This can be removed if sorting in this situation # is fixed to match numpy. aux = where(isnan(aux), _lax_const(aux, np.nan), aux) size, *out_shape = aux.shape if _prod(out_shape) == 0: size = 1 perm = zeros(1, dtype=int) else: perm = lexsort(aux.reshape(size, _prod(out_shape)).T[::-1]) aux = aux[perm] if aux.size: if dtypes.issubdtype(aux.dtype, np.inexact): # This is appropriate for both float and complex due to the documented behavior of np.unique: # See https://github.com/numpy/numpy/blob/v1.22.0/numpy/lib/arraysetops.py#L212-L220 neq = lambda x, y: lax.ne(x, y) & ~(isnan(x) & isnan(y)) else: neq = lax.ne mask = ones(size, dtype=bool).at[1:].set( any(neq(aux[1:], aux[:-1]), tuple(range(1, aux.ndim)))) else: mask = zeros(size, dtype=bool) return aux, mask, perm
def roots(p, *, strip_zeros=True): # ported from https://github.com/numpy/numpy/blob/v1.17.0/numpy/lib/polynomial.py#L168-L251 p = atleast_1d(p) if p.ndim != 1: raise ValueError("Input must be a rank-1 array.") # strip_zeros=False is unsafe because leading zeros aren't removed if not strip_zeros: if p.size > 1: return _roots_no_zeros(p) else: return array([]) if all(p == 0): return array([]) # factor out trivial roots start, end = _nonzero_range(p) # number of trailing zeros = number of roots at 0 trailing_zeros = p.size - end # strip leading and trailing zeros p = p[start:end] if p.size < 2: return zeros(trailing_zeros, p.dtype) else: roots = _roots_no_zeros(p) # combine roots and zero roots roots = hstack((roots, zeros(trailing_zeros, p.dtype))) return roots
def _lu_blocked(a, block_size=128): """Blocked LU decomposition, as an unrolled loop.""" m, n = a.shape r = min(m, n) pivot = jnp.zeros((r, ), dtype=jnp.int32) perm = jnp.arange(m, dtype=jnp.int32) for k in range(0, r, block_size): b = min(r - k, block_size) block_pivot, block_perm, lu_block = _lu_unblocked(a[k:, k:k + b]) pivot = ops.index_update(pivot, ops.index[k:k + b], block_pivot + k) perm = ops.index_update(perm, ops.index[k:], perm[block_perm + k]) a = ops.index_update(a, ops.index[k:, :], a[block_perm + k, :]) a = ops.index_update(a, ops.index[k:, k:k + b], lu_block) if k + b < n: a = ops.index_update( a, ops.index[k:k + b, k + b:], triangular_solve(a[k:k + b, k:k + b], a[k:k + b, k + b:], left_side=True, lower=True, unit_diagonal=True)) a = ops.index_add( a, ops.index[k + b:, k + b:], -lax.dot(a[k + b:, k:k + b], a[k:k + b, k + b:], precision=lax.Precision.HIGHEST)) return a, pivot, perm
def fftfreq(n, d=1.0): if isinstance(n, (list, tuple)): raise ValueError( "The n argument of jax.numpy.fft.fftfreq only takes an int. " "Got n = %s." % list(n)) elif isinstance(d, (list, tuple)): raise ValueError( "The d argument of jax.numpy.fft.fftfreq only takes a single value. " "Got d = %s." % list(d)) k = jnp.zeros(n) if n % 2 == 0: # k[0: n // 2 - 1] = jnp.arange(0, n // 2 - 1) k = k.at[0: n // 2].set( jnp.arange(0, n // 2)) # k[n // 2:] = jnp.arange(-n // 2, -1) k = k.at[n // 2:].set( jnp.arange(-n // 2, 0)) else: # k[0: (n - 1) // 2] = jnp.arange(0, (n - 1) // 2) k = k.at[0: (n - 1) // 2 + 1].set(jnp.arange(0, (n - 1) // 2 + 1)) # k[(n - 1) // 2 + 1:] = jnp.arange(-(n - 1) // 2, -1) k = k.at[(n - 1) // 2 + 1:].set(jnp.arange(-(n - 1) // 2, 0)) return k / (d * n)
def segment_sum(data, segment_ids, num_segments=None, indices_are_sorted=False, unique_indices=False, bucket_size=None): # TODO(zhangqiaorjc): use non-None default. """Computes the sum within segments of an array. Similar to TensorFlow's segment_sum: https://www.tensorflow.org/api_docs/python/tf/math/segment_sum Args: data: an array with the values to be summed. segment_ids: an array with integer dtype that indicates the segments of `data` (along its leading axis) to be summed. Values can be repeated and need not be sorted. Values outside of the range [0, num_segments) are dropped and do not contribute to the sum. num_segments: optional, an int with nonnegative value indicating the number of segments. The default is set to be the minimum number of segments that would support all indices in ``segment_ids``, calculated as ``max(segment_ids) + 1``. Since `num_segments` determines the size of the output, a static value must be provided to use ``segment_sum`` in a ``jit``-compiled function. indices_are_sorted: whether ``segment_ids`` is known to be sorted. unique_indices: whether `segment_ids` is known to be free of duplicates. bucket_size: size of bucket to group indices into. ``segment_sum`` is performed on each bucket separately to improve numerical stability of addition. Default ``None`` means no bucketing. Returns: An array with shape :code:`(num_segments,) + data.shape[1:]` representing the segment sums. """ if num_segments is None: num_segments = jnp.max(segment_ids) + 1 num_segments = int(num_segments) out = jnp.zeros((num_segments, ) + data.shape[1:], dtype=data.dtype) num_buckets = 1 if bucket_size is None \ else util.ceil_of_ratio(segment_ids.size, bucket_size) if num_buckets == 1: return _scatter_update(out, segment_ids, data, lax.scatter_add, indices_are_sorted, unique_indices, normalize_indices=False) # Bucketize indices and perform segment_sum on each bucket to improve # numerical stability. outs = [] for sub_data, sub_segment_ids in zip( jnp.array_split(data, num_buckets), jnp.array_split(segment_ids, num_buckets)): outs.append( segment_sum(sub_data, sub_segment_ids, num_segments, indices_are_sorted, unique_indices)) return jnp.sum(jnp.stack(outs), axis=0)
def _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, x): """Implements the Sturm sequence recurrence.""" n = alpha.shape[0] zeros = jnp.zeros(x.shape, dtype=jnp.int32) ones = jnp.ones(x.shape, dtype=jnp.int32) # The first step in the Sturm sequence recurrence # requires special care if x is equal to alpha[0]. def sturm_step0(): q = alpha[0] - x count = jnp.where(q < 0, ones, zeros) q = jnp.where(alpha[0] == x, alpha0_perturbation, q) return q, count # Subsequent steps all take this form: def sturm_step(i, q, count): q = alpha[i] - beta_sq[i - 1] / q - x count = jnp.where(q <= pivmin, count + 1, count) q = jnp.where(q <= pivmin, jnp.minimum(q, -pivmin), q) return q, count # The first step initializes q and count. q, count = sturm_step0() # Peel off ((n-1) % blocksize) steps from the main loop, so we can run # the bulk of the iterations unrolled by a factor of blocksize. blocksize = 16 i = 1 peel = (n - 1) % blocksize unroll_cnt = peel def unrolled_steps(args): start, q, count = args for j in range(unroll_cnt): q, count = sturm_step(start + j, q, count) return start + unroll_cnt, q, count i, q, count = unrolled_steps((i, q, count)) # Run the remaining steps of the Sturm sequence using a partially # unrolled while loop. unroll_cnt = blocksize def cond(iqc): i, q, count = iqc return jnp.less(i, n) _, _, count = lax.while_loop(cond, unrolled_steps, (i, q, count)) return count
def polydiv(u, v, *, trim_leading_zeros=False): _check_arraylike("polydiv", u, v) u, v = _promote_dtypes_inexact(u, v) m = len(u) - 1 n = len(v) - 1 scale = 1. / v[0] q = zeros(max(m - n + 1, 1), dtype=u.dtype) # force same dtype for k in range(0, m - n + 1): d = scale * u[k] q = q.at[k].set(d) u = u.at[k:k + n + 1].add(-d * v) if trim_leading_zeros: # use the square root of finfo(dtype) to approximate the absolute tolerance used in numpy return q, trim_zeros_tol(u, tol=sqrt(finfo(u.dtype).eps), trim='f') else: return q, u
def _gen_recurrence_mask( l_max: int, is_normalized: bool = True) -> Tuple[jnp.ndarray, jnp.ndarray]: """Generates mask for recurrence relation on the remaining entries. The remaining entries are with respect to the diagonal and offdiagonal entries. Args: l_max: see `gen_normalized_legendre`. is_normalized: True if the recurrence mask is used by normalized associated Legendre functions. Returns: Arrays representing the mask used by the recurrence relations. """ # Computes all coefficients. m_mat, l_mat = jnp.mgrid[:l_max + 1, :l_max + 1] if is_normalized: c0 = l_mat * l_mat c1 = m_mat * m_mat c2 = 2.0 * l_mat c3 = (l_mat - 1.0) * (l_mat - 1.0) d0 = jnp.sqrt((4.0 * c0 - 1.0) / (c0 - c1)) d1 = jnp.sqrt(((c2 + 1.0) * (c3 - c1)) / ((c2 - 3.0) * (c0 - c1))) else: d0 = (2.0 * l_mat - 1.0) / (l_mat - m_mat) d1 = (l_mat + m_mat - 1.0) / (l_mat - m_mat) d0_mask_indices = jnp.triu_indices(l_max + 1, 1) d1_mask_indices = jnp.triu_indices(l_max + 1, 2) d_zeros = jnp.zeros((l_max + 1, l_max + 1)) d0_mask = d_zeros.at[d0_mask_indices].set(d0[d0_mask_indices]) d1_mask = d_zeros.at[d1_mask_indices].set(d1[d1_mask_indices]) # Creates a 3D mask that contains 1s on the diagonal plane and 0s elsewhere. # i = jnp.arange(l_max + 1)[:, None, None] # j = jnp.arange(l_max + 1)[None, :, None] # k = jnp.arange(l_max + 1)[None, None, :] i, j, k = jnp.ogrid[:l_max + 1, :l_max + 1, :l_max + 1] mask = 1.0 * (i + j - k == 0) d0_mask_3d = jnp.einsum('jk,ijk->ijk', d0_mask, mask) d1_mask_3d = jnp.einsum('jk,ijk->ijk', d1_mask, mask) return (d0_mask_3d, d1_mask_3d)
def _find_indices(self, xi): # find relevant edges between which xi are situated indices = [] # compute distance to lower edge in unity units norm_distances = [] # check for out of bounds xi out_of_bounds = zeros((xi.shape[1], ), dtype=bool) # iterate through dimensions for x, g in zip(xi, self.grid): i = searchsorted(g, x) - 1 i = where(i < 0, 0, i) i = where(i > g.size - 2, g.size - 2, i) indices.append(i) norm_distances.append((x - g[i]) / (g[i + 1] - g[i])) if not self.bounds_error: out_of_bounds += x < g[0] out_of_bounds += x > g[-1] return indices, norm_distances, out_of_bounds
def block_diag(*arrs): if len(arrs) == 0: arrs = [jnp.zeros((1, 0))] arrs = jnp._promote_dtypes(*arrs) bad_shapes = [i for i, a in enumerate(arrs) if jnp.ndim(a) > 2] if bad_shapes: raise ValueError("Arguments to jax.scipy.linalg.block_diag must have at " "most 2 dimensions, got {} at argument {}." .format(arrs[bad_shapes[0]], bad_shapes[0])) arrs = [jnp.atleast_2d(a) for a in arrs] acc = arrs[0] dtype = lax.dtype(acc) for a in arrs[1:]: _, c = a.shape a = lax.pad(a, dtype.type(0), ((0, 0, 0), (acc.shape[-1], 0, 0))) acc = lax.pad(acc, dtype.type(0), ((0, 0, 0), (0, c, 0))) acc = lax.concatenate([acc, a], dimension=0) return acc
def _lu_unblocked(a): """Unblocked LU decomposition, as a rolled loop.""" m, n = a.shape def body(k, state): pivot, perm, a = state m_idx = jnp.arange(m) n_idx = jnp.arange(n) if jnp.issubdtype(a.dtype, jnp.complexfloating): t = a[:, k] magnitude = jnp.abs(jnp.real(t)) + jnp.abs(jnp.imag(t)) else: magnitude = jnp.abs(a[:, k]) i = jnp.argmax(jnp.where(m_idx >= k, magnitude, -jnp.inf)) pivot = ops.index_update(pivot, ops.index[k], i) a = ops.index_update(a, ops.index[[k, i], ], a[[i, k], ]) perm = ops.index_update(perm, ops.index[[i, k], ], perm[[k, i], ]) # a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes x = a[k, k] a = ops.index_update(a, ops.index[:, k], jnp.where(m_idx > k, a[:, k] / x, a[:, k])) # a[k+1:, k+1:] -= jnp.outer(a[k+1:, k], a[k, k+1:]) a = a - jnp.where( (m_idx[:, None] > k) & (n_idx > k), jnp.outer(a[:, k], a[k, :]), jnp.array(0, dtype=a.dtype)) return pivot, perm, a pivot = jnp.zeros((min(m, n), ), dtype=jnp.int32) perm = jnp.arange(m, dtype=jnp.int32) if m == 0 and n == 0: # If the array is empty, the loop body never executes but tracing it to a # jaxpr fails because the indexing cannot succeed. return (pivot, perm, a) return lax.fori_loop(0, min(m, n), body, (pivot, perm, a))
def _spectral_helper(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None, detrend_type='constant', return_onesided=True, scaling='density', axis=-1, mode='psd', boundary=None, padded=False): """LAX-backend implementation of `scipy.signal._spectral_helper`. Unlike the original helper function, `y` can be None for explicitly indicating auto-spectral (non cross-spectral) computation. In addition to this, `detrend` argument is renamed to `detrend_type` for avoiding internal name overlap. """ if mode not in ('psd', 'stft'): raise ValueError(f"Unknown value for mode {mode}, " "must be one of: ('psd', 'stft')") def make_pad(mode, **kwargs): def pad(x, n, axis=-1): pad_width = [(0, 0) for unused_n in range(x.ndim)] pad_width[axis] = (n, n) return jnp.pad(x, pad_width, mode, **kwargs) return pad boundary_funcs = { 'even': make_pad('reflect'), 'odd': odd_ext, 'constant': make_pad('edge'), 'zeros': make_pad('constant', constant_values=0.0), None: lambda x, *args, **kwargs: x } # Check/ normalize inputs if boundary not in boundary_funcs: raise ValueError(f"Unknown boundary option '{boundary}', " f"must be one of: {list(boundary_funcs.keys())}") axis = jax.core.concrete_or_error(operator.index, axis, "axis of windowed-FFT") axis = canonicalize_axis(axis, x.ndim) if nperseg is not None: # if specified by user nperseg = jax.core.concrete_or_error(int, nperseg, "nperseg of windowed-FFT") if nperseg < 1: raise ValueError('nperseg must be a positive integer') # parse window; if array like, then set nperseg = win.shape win, nperseg = signal_helper._triage_segments(window, nperseg, input_length=x.shape[axis]) if noverlap is None: noverlap = nperseg // 2 else: noverlap = jax.core.concrete_or_error(int, noverlap, "noverlap of windowed-FFT") if nfft is None: nfft = nperseg else: nfft = jax.core.concrete_or_error(int, nfft, "nfft of windowed-FFT") _check_arraylike("_spectral_helper", x) x = jnp.asarray(x) if y is None: outdtype = jax.dtypes.canonicalize_dtype( np.result_type(x, np.complex64)) else: _check_arraylike("_spectral_helper", y) y = jnp.asarray(y) outdtype = jax.dtypes.canonicalize_dtype( np.result_type(x, y, np.complex64)) if mode != 'psd': raise ValueError( "two-argument mode is available only when mode=='psd'") if x.ndim != y.ndim: raise ValueError( "two-arguments must have the same rank ({x.ndim} vs {y.ndim})." ) # Check if we can broadcast the outer axes together try: outershape = jnp.broadcast_shapes(tuple_delete(x.shape, axis), tuple_delete(y.shape, axis)) except ValueError as e: raise ValueError('x and y cannot be broadcast together.') from e # Special cases for size == 0 if y is None: if x.size == 0: return jnp.zeros(x.shape), jnp.zeros(x.shape), jnp.zeros(x.shape) else: if x.size == 0 or y.size == 0: outshape = tuple_insert(outershape, min([x.shape[axis], y.shape[axis]]), axis) emptyout = jnp.zeros(outshape) return emptyout, emptyout, emptyout # Move time-axis to the end if x.ndim > 1: if axis != -1: x = jnp.moveaxis(x, axis, -1) if y is not None and y.ndim > 1: y = jnp.moveaxis(y, axis, -1) # Check if x and y are the same length, zero-pad if necessary if y is not None: if x.shape[-1] != y.shape[-1]: if x.shape[-1] < y.shape[-1]: pad_shape = list(x.shape) pad_shape[-1] = y.shape[-1] - x.shape[-1] x = jnp.concatenate((x, jnp.zeros(pad_shape)), -1) else: pad_shape = list(y.shape) pad_shape[-1] = x.shape[-1] - y.shape[-1] y = jnp.concatenate((y, jnp.zeros(pad_shape)), -1) if nfft < nperseg: raise ValueError('nfft must be greater than or equal to nperseg.') if noverlap >= nperseg: raise ValueError('noverlap must be less than nperseg.') nstep = nperseg - noverlap # Apply paddings if boundary is not None: ext_func = boundary_funcs[boundary] x = ext_func(x, nperseg // 2, axis=-1) if y is not None: y = ext_func(y, nperseg // 2, axis=-1) if padded: # Pad to integer number of windowed segments # I.e make x.shape[-1] = nperseg + (nseg-1)*nstep, with integer nseg nadd = (-(x.shape[-1] - nperseg) % nstep) % nperseg zeros_shape = list(x.shape[:-1]) + [nadd] x = jnp.concatenate((x, jnp.zeros(zeros_shape)), axis=-1) if y is not None: zeros_shape = list(y.shape[:-1]) + [nadd] y = jnp.concatenate((y, jnp.zeros(zeros_shape)), axis=-1) # Handle detrending and window functions if not detrend_type: def detrend_func(d): return d elif not hasattr(detrend_type, '__call__'): def detrend_func(d): return detrend(d, type=detrend_type, axis=-1) elif axis != -1: # Wrap this function so that it receives a shape that it could # reasonably expect to receive. def detrend_func(d): d = jnp.moveaxis(d, axis, -1) d = detrend_type(d) return jnp.moveaxis(d, -1, axis) else: detrend_func = detrend_type if np.result_type(win, np.complex64) != outdtype: win = win.astype(outdtype) # Determine scale if scaling == 'density': scale = 1.0 / (fs * (win * win).sum()) elif scaling == 'spectrum': scale = 1.0 / win.sum()**2 else: raise ValueError(f'Unknown scaling: {scaling}') if mode == 'stft': scale = jnp.sqrt(scale) # Determine onesided/ two-sided if return_onesided: sides = 'onesided' if jnp.iscomplexobj(x) or jnp.iscomplexobj(y): sides = 'twosided' warnings.warn('Input data is complex, switching to ' 'return_onesided=False') else: sides = 'twosided' if sides == 'twosided': freqs = jax.numpy.fft.fftfreq(nfft, 1 / fs) elif sides == 'onesided': freqs = jax.numpy.fft.rfftfreq(nfft, 1 / fs) # Perform the windowed FFTs result = _fft_helper(x, win, detrend_func, nperseg, noverlap, nfft, sides) if y is not None: # All the same operations on the y data result_y = _fft_helper(y, win, detrend_func, nperseg, noverlap, nfft, sides) result = jnp.conjugate(result) * result_y elif mode == 'psd': result = jnp.conjugate(result) * result result *= scale if sides == 'onesided' and mode == 'psd': end = None if nfft % 2 else -1 result = result.at[..., 1:end].mul(2) time = jnp.arange(nperseg / 2, x.shape[-1] - nperseg / 2 + 1, nperseg - noverlap) / fs if boundary is not None: time -= (nperseg / 2) / fs result = result.astype(outdtype) # All imaginary parts are zero anyways if y is None and mode != 'stft': result = result.real # Move frequency axis back to axis where the data came from result = jnp.moveaxis(result, -1, axis) return freqs, time, result
def _gen_associated_legendre(l_max: int, x: jnp.ndarray, is_normalized: bool) -> jnp.ndarray: r"""Computes associated Legendre functions (ALFs) of the first kind. The ALFs of the first kind are used in spherical harmonics. The spherical harmonic of degree `l` and order `m` can be written as `Y_l^m(θ, φ) = N_l^m * P_l^m(cos(θ)) * exp(i m φ)`, where `N_l^m` is the normalization factor and θ and φ are the colatitude and longitude, repectively. `N_l^m` is chosen in the way that the spherical harmonics form a set of orthonormal basis function of L^2(S^2). For the computational efficiency of spherical harmonics transform, the normalization factor is used in the computation of the ALFs. In addition, normalizing `P_l^m` avoids overflow/underflow and achieves better numerical stability. Three recurrence relations are used in the computation. Args: l_max: The maximum degree of the associated Legendre function. Both the degrees and orders are `[0, 1, 2, ..., l_max]`. x: A vector of type `float32`, `float64` containing the sampled points in spherical coordinates, at which the ALFs are computed; `x` is essentially `cos(θ)`. For the numerical integration used by the spherical harmonics transforms, `x` contains the quadrature points in the interval of `[-1, 1]`. There are several approaches to provide the quadrature points: Gauss-Legendre method (`scipy.special.roots_legendre`), Gauss-Chebyshev method (`scipy.special.roots_chebyu`), and Driscoll & Healy method (Driscoll, James R., and Dennis M. Healy. "Computing Fourier transforms and convolutions on the 2-sphere." Advances in applied mathematics 15, no. 2 (1994): 202-250.). The Gauss-Legendre quadrature points are nearly equal-spaced along θ and provide exact discrete orthogonality, (P^m)^T W P_m = I, where `T` represents the transpose operation, `W` is a diagonal matrix containing the quadrature weights, and `I` is the identity matrix. The Gauss-Chebyshev points are equally spaced, which only provide approximate discrete orthogonality. The Driscoll & Healy qudarture points are equally spaced and provide the exact discrete orthogonality. The number of sampling points is required to be twice as the number of frequency points (modes) in the Driscoll & Healy approach, which enables FFT and achieves a fast spherical harmonics transform. is_normalized: True if the associated Legendre functions are normalized. With normalization, `N_l^m` is applied such that the spherical harmonics form a set of orthonormal basis functions of L^2(S^2). Returns: The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values of the ALFs at `x`; the dimensions in the sequence of order, degree, and evalution points. """ p = jnp.zeros((l_max + 1, l_max + 1, x.shape[0])) a_idx = jnp.arange(1, l_max + 1) b_idx = jnp.arange(l_max) if is_normalized: initial_value = 0.5 / jnp.sqrt(jnp.pi) # The initial value p(0,0). f_a = jnp.cumprod(-1 * jnp.sqrt(1.0 + 0.5 / a_idx)) f_b = jnp.sqrt(2.0 * b_idx + 3.0) else: initial_value = 1.0 # The initial value p(0,0). f_a = jnp.cumprod(1.0 - 2.0 * a_idx) f_b = 2.0 * b_idx + 1.0 p = p.at[(0, 0)].set(initial_value) # Compute the diagonal entries p(l,l) with recurrence. y = jnp.cumprod(jnp.broadcast_to(jnp.sqrt(1.0 - x * x), (l_max, x.shape[0])), axis=0) p_diag = initial_value * jnp.einsum('i,ij->ij', f_a, y) diag_indices = jnp.diag_indices(l_max + 1) p = p.at[(diag_indices[0][1:], diag_indices[1][1:])].set(p_diag) # Compute the off-diagonal entries with recurrence. p_offdiag = jnp.einsum('ij,ij->ij', jnp.einsum('i,j->ij', f_b, x), p[jnp.diag_indices(l_max)]) offdiag_indices = (diag_indices[0][:l_max], diag_indices[1][:l_max] + 1) p = p.at[offdiag_indices].set(p_offdiag) # Compute the remaining entries with recurrence. d0_mask_3d, d1_mask_3d = _gen_recurrence_mask(l_max, is_normalized=is_normalized) def body_fun(i, p_val): coeff_0 = d0_mask_3d[i] coeff_1 = d1_mask_3d[i] h = (jnp.einsum( 'ij,ijk->ijk', coeff_0, jnp.einsum('ijk,k->ijk', jnp.roll(p_val, shift=1, axis=1), x)) - jnp.einsum('ij,ijk->ijk', coeff_1, jnp.roll( p_val, shift=2, axis=1))) p_val = p_val + h return p_val if l_max > 1: p = lax.fori_loop(lower=2, upper=l_max + 1, body_fun=body_fun, init_val=p) return p
def _gen_derivatives(p: jnp.ndarray, x: jnp.ndarray, is_normalized: bool) -> jnp.ndarray: """Generates derivatives of associated Legendre functions of the first kind. Args: p: The 3D array containing the values of associated Legendre functions; the dimensions are in the sequence of order (m), degree (l), and evalution points. x: A vector of type `float32` or `float64` containing the sampled points. is_normalized: True if the associated Legendre functions are normalized. Returns: The 3D array representing the derivatives of associated Legendre functions of the first kind. """ num_m, num_l, num_x = p.shape # p_{l-1}^m. p_m_lm1 = jnp.pad(p, ((0, 0), (1, 0), (0, 0)))[:, :num_l, :] # p_{l-1}^{m+2}. p_mp2_lm1 = jnp.pad(p_m_lm1, ((0, 2), (0, 0), (0, 0)))[2:num_m + 2, :, :] # p_{l-1}^{m-2}. p_mm2_lm1 = jnp.pad(p_m_lm1, ((2, 0), (0, 0), (0, 0)))[:num_m, :, :] # Derivative computation requires negative orders. if is_normalized: raise NotImplementedError( 'Negative orders for normalization is not implemented yet.') else: if num_l > 1: l_vec = jnp.arange(1, num_l - 1) p_p1 = p[1, 1:num_l - 1, :] coeff = -1.0 / ((l_vec + 1) * l_vec) update_p_p1 = jnp.einsum('i,ij->ij', coeff, p_p1) p_mm2_lm1 = p_mm2_lm1.at[ops.index[1, 2:num_l, :]].set(update_p_p1) if num_l > 2: l_vec = jnp.arange(2, num_l - 1) p_p2 = p[2, 2:num_l - 1, :] coeff = 1.0 / ((l_vec + 2) * (l_vec + 1) * l_vec) update_p_p2 = jnp.einsum('i,ij->ij', coeff, p_p2) p_mm2_lm1 = p_mm2_lm1.at[ops.index[0, 3:num_l, :]].set(update_p_p2) m_mat, l_mat = jnp.mgrid[:num_m, :num_l] coeff_zeros = jnp.zeros((num_m, num_l)) upper_0_indices = jnp.triu_indices(num_m, 0, num_l) zero_vec = jnp.zeros((num_l, )) a0 = -0.5 / (m_mat - 1.0) a0_masked = coeff_zeros.at[upper_0_indices].set(a0[upper_0_indices]) a0_masked = a0_masked.at[1, :].set(zero_vec) b0 = l_mat + m_mat c0 = a0 * (b0 - 2.0) * (b0 - 1.0) c0_masked = coeff_zeros.at[upper_0_indices].set(c0[upper_0_indices]) c0_masked = c0_masked.at[1, :].set(zero_vec) # p_l^{m-1}. p_mm1_l = (jnp.einsum('ij,ijk->ijk', a0_masked, p_m_lm1) + jnp.einsum('ij,ijk->ijk', c0_masked, p_mm2_lm1)) d0 = -0.5 / (m_mat + 1.0) d0_masked = coeff_zeros.at[upper_0_indices].set(d0[upper_0_indices]) e0 = d0 * b0 * (b0 + 1.0) e0_masked = coeff_zeros.at[upper_0_indices].set(e0[upper_0_indices]) # p_l^{m+1}. p_mp1_l = (jnp.einsum('ij,ijk->ijk', d0_masked, p_mp2_lm1) + jnp.einsum('ij,ijk->ijk', e0_masked, p_m_lm1)) f0 = b0 * (l_mat - m_mat + 1.0) / 2.0 f0_masked = coeff_zeros.at[upper_0_indices].set(f0[upper_0_indices]) p_derivative = jnp.einsum('ij,ijk->ijk', f0_masked, p_mm1_l) - 0.5 * p_mp1_l # Special treatment of the singularity at m = 1. if num_m > 1: l_vec = jnp.arange(num_l) g0 = jnp.einsum('i,ij->ij', (l_vec + 1) * l_vec, p[0, :, :]) if num_l > 2: g0 = g0 - p[2, :, :] p_derivative_m0 = jnp.einsum('j,ij->ij', 0.5 / jnp.sqrt(1 - x * x), g0) p_derivative = p_derivative.at[1, :, :].set(p_derivative_m0) p_derivative = p_derivative.at[1, 0, :].set(jnp.zeros((num_x, ))) return p_derivative
def _unique(ar, axis, return_index=False, return_inverse=False, return_counts=False, size=None, fill_value=None, return_true_size=False): """ Find the unique elements of an array along a particular axis. """ if ar.shape[axis] == 0 and size and fill_value is None: raise ValueError( "jnp.unique: for zero-sized input with nonzero size argument, fill_value must be specified" ) aux, mask, perm = _unique_sorted_mask(ar, axis) if size is None: ind = core.concrete_or_error( None, mask, "The error arose in jnp.unique(). " + UNIQUE_SIZE_HINT) else: ind = nonzero(mask, size=size)[0] result = aux[ind] if aux.size else aux if fill_value is not None: fill_value = asarray(fill_value, dtype=result.dtype) if size is not None and fill_value is not None: if result.shape[0]: valid = lax.expand_dims( arange(size) < mask.sum(), tuple(range(1, result.ndim))) result = where(valid, result, fill_value) else: result = full_like(result, fill_value, shape=(size, *result.shape[1:])) result = moveaxis(result, 0, axis) ret = (result, ) if return_index: if aux.size: ret += (perm[ind], ) else: ret += (perm, ) if return_inverse: if aux.size: imask = cumsum(mask) - 1 inv_idx = zeros(mask.shape, dtype=dtypes.canonicalize_dtype(dtypes.int_)) inv_idx = inv_idx.at[perm].set(imask) else: inv_idx = zeros(ar.shape[axis], dtype=int) ret += (inv_idx, ) if return_counts: if aux.size: if size is None: idx = append(nonzero(mask)[0], mask.size) else: idx = nonzero(mask, size=size + 1)[0] idx = idx.at[1:].set(where(idx[1:], idx[1:], mask.size)) ret += (diff(idx), ) elif ar.shape[axis]: ret += (array([ar.shape[axis]], dtype=dtypes.canonicalize_dtype(dtypes.int_)), ) else: ret += (empty(0, dtype=int), ) if return_true_size: # Useful for internal uses of unique(). ret += (mask.sum(), ) return ret[0] if len(ret) == 1 else ret