def odd_ext(x, n, axis=-1): """Extends `x` along with `axis` by odd-extension. This function was previously a part of "scipy.signal.signaltools" but is no longer exposed. Args: x : input array n : the number of points to be added to the both end axis: the axis to be extended """ if n < 1: return x if n > x.shape[axis] - 1: raise ValueError( f"The extension length n ({n}) is too big. " f"It must not exceed x.shape[axis]-1, which is {x.shape[axis] - 1}." ) left_end = lax.slice_in_dim(x, 0, 1, axis=axis) left_ext = jnp.flip(lax.slice_in_dim(x, 1, n + 1, axis=axis), axis=axis) right_end = lax.slice_in_dim(x, -1, None, axis=axis) right_ext = jnp.flip(lax.slice_in_dim(x, -(n + 1), -1, axis=axis), axis=axis) ext = jnp.concatenate( (2 * left_end - left_ext, x, 2 * right_end - right_ext), axis=axis) return ext
def __getitem__(self, key): if not isinstance(key, tuple): key = (key,) params = [self.axis, self.ndmin, self.trans1d, -1] if isinstance(key[0], str): # split off the directive directive, *key = key # pytype: disable=bad-unpacking # check two special cases: matrix directives if directive == "r": params[-1] = 0 elif directive == "c": params[-1] = 1 else: vec = directive.split(",") k = len(vec) if k < 4: vec += params[k:] else: # ignore everything after the first three comma-separated ints vec = vec[:3] + params[-1] try: params = list(map(int, vec)) except ValueError as err: raise ValueError( f"could not understand directive {directive!r}" ) from err axis, ndmin, trans1d, matrix = params output = [] for item in key: if isinstance(item, slice): newobj = _make_1d_grid_from_slice(item, op_name=self.op_name) elif isinstance(item, str): raise ValueError("string directive must be placed at the beginning") else: newobj = item newobj = array(newobj, copy=False, ndmin=ndmin) if trans1d != -1 and ndmin - np.ndim(item) > 0: shape_obj = list(range(ndmin)) # Calculate number of left shifts, with overflow protection by mod num_lshifts = ndmin - abs(ndmin + trans1d + 1) % ndmin shape_obj = tuple(shape_obj[num_lshifts:] + shape_obj[:num_lshifts]) newobj = transpose(newobj, shape_obj) output.append(newobj) res = concatenate(tuple(output), axis=axis) if matrix != -1 and res.ndim == 1: # insert 2nd dim at axis 0 or 1 res = expand_dims(res, matrix) return res
def setxor1d(ar1, ar2, assume_unique=False): _check_arraylike("setxor1d", ar1, ar2) ar1 = core.concrete_or_error(None, ar1, "The error arose in setxor1d()") ar2 = core.concrete_or_error(None, ar2, "The error arose in setxor1d()") ar1 = ravel(ar1) ar2 = ravel(ar2) if not assume_unique: ar1 = unique(ar1) ar2 = unique(ar2) aux = concatenate((ar1, ar2)) if aux.size == 0: return aux aux = sort(aux) flag = concatenate((array([True]), aux[1:] != aux[:-1], array([True]))) return aux[flag[1:] & flag[:-1]]
def union1d(ar1, ar2, *, size=None, fill_value=None): _check_arraylike("union1d", ar1, ar2) if size is None: ar1 = core.concrete_or_error(None, ar1, "The error arose in union1d()") ar2 = core.concrete_or_error(None, ar2, "The error arose in union1d()") else: size = core.concrete_or_error(operator.index, size, "The error arose in union1d()") return unique(concatenate((ar1, ar2), axis=None), size=size, fill_value=fill_value)
def _intersect1d_sorted_mask(ar1, ar2, return_indices=False): """ Helper function for intersect1d which is jit-able """ ar = concatenate((ar1, ar2)) if return_indices: iota = lax.broadcasted_iota(np.int64, np.shape(ar), dimension=0) aux, indices = lax.sort_key_val(ar, iota) else: aux = sort(ar) mask = aux[1:] == aux[:-1] if return_indices: return aux, mask, indices else: return aux, mask
def polyint(p, m=1, k=None): m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint") k = 0 if k is None else k _check_arraylike("polyint", p, k) p, k = _promote_dtypes_inexact(p, k) if m < 0: raise ValueError("Order of integral must be positive (see polyder)") k = atleast_1d(k) if len(k) == 1: k = full((m,), k[0]) if k.shape != (m,): raise ValueError("k must be a scalar or a rank-1 array of length 1 or m.") if m == 0: return p else: coeff = maximum(1, arange(len(p) + m, 0, -1)[np.newaxis, :] - 1 - arange(m)[:, np.newaxis]).prod(0) return true_divide(concatenate((p, k)), coeff)
def logpdf(x, alpha): x, alpha = _promote_dtypes_inexact(x, alpha) if alpha.ndim != 1: raise ValueError( f"`alpha` must be one-dimensional; got alpha.shape={alpha.shape}" ) if x.shape[0] not in (alpha.shape[0], alpha.shape[0] - 1): raise ValueError( "`x` must have either the same number of entries as `alpha` " f"or one entry fewer; got x.shape={x.shape}, alpha.shape={alpha.shape}" ) one = lax._const(x, 1) if x.shape[0] != alpha.shape[0]: x = jnp.concatenate([x, lax.sub(one, x.sum(0, keepdims=True))], axis=0) normalize_term = jnp.sum(gammaln(alpha)) - gammaln(jnp.sum(alpha)) if x.ndim > 1: alpha = lax.broadcast_in_dim(alpha, alpha.shape + (1,) * (x.ndim - 1), (0,)) log_probs = lax.sub(jnp.sum(xlogy(lax.sub(alpha, one), x), axis=0), normalize_term) return jnp.where(_is_simplex(x), log_probs, -jnp.inf)
def _cofactor_solve(a, b): """Equivalent to det(a)*solve(a, b) for nonsingular mat. Intermediate function used for jvp and vjp of det. This function borrows heavily from jax.numpy.linalg.solve and jax.numpy.linalg.slogdet to compute the gradient of the determinant in a way that is well defined even for low rank matrices. This function handles two different cases: * rank(a) == n or n-1 * rank(a) < n-1 For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix. Rather than computing det(a)*solve(a, b), which would return NaN, we work directly with the LU decomposition. If a = p @ l @ u, then det(a)*solve(a, b) = prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b = prod(diag(u)) * triangular_solve(u, solve(p @ l, b)) If a is rank n-1, then the lower right corner of u will be zero and the triangular_solve will fail. Let x = solve(p @ l, b) and y = det(a)*solve(a, b). Then y_{n} x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) = x_{n} * prod_{i=1...n-1}(u_{ii}) So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1 we can avoid the triangular_solve failing. To correctly compute the rest of y_{i} for i != n, we simply multiply x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1. For the second case, a check is done on the matrix to see if `solve` returns NaN or Inf, and gives a matrix of zeros as a result, as the gradient of the determinant of a matrix with rank less than n-1 is 0. This will still return the correct value for rank n-1 matrices, as the check is applied *after* the lower right corner of u has been updated. Args: a: A square matrix or batch of matrices, possibly singular. b: A matrix, or batch of matrices of the same dimension as a. Returns: det(a) and cofactor(a)^T*b, aka adjugate(a)*b """ a = _promote_arg_dtypes(jnp.asarray(a)) b = _promote_arg_dtypes(jnp.asarray(b)) a_shape = jnp.shape(a) b_shape = jnp.shape(b) a_ndims = len(a_shape) if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2] and b_shape[-2:] == a_shape[-2:]): msg = ("The arguments to _cofactor_solve must have shapes " "a=[..., m, m] and b=[..., m, m]; got a={} and b={}") raise ValueError(msg.format(a_shape, b_shape)) if a_shape[-1] == 1: return a[..., 0, 0], b # lu contains u in the upper triangular matrix and l in the strict lower # triangular matrix. # The diagonal of l is set to ones without loss of generality. lu, pivots, permutation = lax_linalg.lu(a) dtype = lax.dtype(a) batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2]) x = jnp.broadcast_to(b, batch_dims + b.shape[-2:]) lu = jnp.broadcast_to(lu, batch_dims + lu.shape[-2:]) # Compute (partial) determinant, ignoring last diagonal of LU diag = jnp.diagonal(lu, axis1=-2, axis2=-1) parity = jnp.count_nonzero(pivots != jnp.arange(a_shape[-1]), axis=-1) sign = jnp.asarray(-2 * (parity % 2) + 1, dtype=dtype) # partial_det[:, -1] contains the full determinant and # partial_det[:, -2] contains det(u) / u_{nn}. partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None] lu = lu.at[..., -1, -1].set(1.0 / partial_det[..., -2]) permutation = jnp.broadcast_to(permutation, batch_dims + (a_shape[-1], )) iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1, ))) # filter out any matrices that are not full rank d = jnp.ones(x.shape[:-1], x.dtype) d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False) d = jnp.any(jnp.logical_or(jnp.isnan(d), jnp.isinf(d)), axis=-1) d = jnp.tile(d[..., None, None], d.ndim * (1, ) + x.shape[-2:]) x = jnp.where(d, jnp.zeros_like(x), x) # first filter x = x[iotas[:-1] + (permutation, slice(None))] x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=True, unit_diagonal=True) x = jnp.concatenate( (x[..., :-1, :] * partial_det[..., -1, None, None], x[..., -1:, :]), axis=-2) x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False) x = jnp.where(d, jnp.zeros_like(x), x) # second filter return partial_det[..., -1], x
def eigh_tridiagonal(d, e, *, eigvals_only=False, select='a', select_range=None, tol=None): if not eigvals_only: raise NotImplementedError( "Calculation of eigenvectors is not implemented") def _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, x): """Implements the Sturm sequence recurrence.""" n = alpha.shape[0] zeros = jnp.zeros(x.shape, dtype=jnp.int32) ones = jnp.ones(x.shape, dtype=jnp.int32) # The first step in the Sturm sequence recurrence # requires special care if x is equal to alpha[0]. def sturm_step0(): q = alpha[0] - x count = jnp.where(q < 0, ones, zeros) q = jnp.where(alpha[0] == x, alpha0_perturbation, q) return q, count # Subsequent steps all take this form: def sturm_step(i, q, count): q = alpha[i] - beta_sq[i - 1] / q - x count = jnp.where(q <= pivmin, count + 1, count) q = jnp.where(q <= pivmin, jnp.minimum(q, -pivmin), q) return q, count # The first step initializes q and count. q, count = sturm_step0() # Peel off ((n-1) % blocksize) steps from the main loop, so we can run # the bulk of the iterations unrolled by a factor of blocksize. blocksize = 16 i = 1 peel = (n - 1) % blocksize unroll_cnt = peel def unrolled_steps(args): start, q, count = args for j in range(unroll_cnt): q, count = sturm_step(start + j, q, count) return start + unroll_cnt, q, count i, q, count = unrolled_steps((i, q, count)) # Run the remaining steps of the Sturm sequence using a partially # unrolled while loop. unroll_cnt = blocksize def cond(iqc): i, q, count = iqc return jnp.less(i, n) _, _, count = lax.while_loop(cond, unrolled_steps, (i, q, count)) return count alpha = jnp.asarray(d) beta = jnp.asarray(e) supported_dtypes = (jnp.float32, jnp.float64, jnp.complex64, jnp.complex128) if alpha.dtype != beta.dtype: raise TypeError( "diagonal and off-diagonal values must have same dtype, " f"got {alpha.dtype} and {beta.dtype}") if alpha.dtype not in supported_dtypes or beta.dtype not in supported_dtypes: raise TypeError( "Only float32 and float64 inputs are supported as inputs " "to jax.scipy.linalg.eigh_tridiagonal, got " f"{alpha.dtype} and {beta.dtype}") n = alpha.shape[0] if n <= 1: return jnp.real(alpha) if jnp.issubdtype(alpha.dtype, jnp.complexfloating): alpha = jnp.real(alpha) beta_sq = jnp.real(beta * jnp.conj(beta)) beta_abs = jnp.sqrt(beta_sq) else: beta_abs = jnp.abs(beta) beta_sq = jnp.square(beta) # Estimate the largest and smallest eigenvalues of T using the Gershgorin # circle theorem. off_diag_abs_row_sum = jnp.concatenate( [beta_abs[:1], beta_abs[:-1] + beta_abs[1:], beta_abs[-1:]], axis=0) lambda_est_max = jnp.amax(alpha + off_diag_abs_row_sum) lambda_est_min = jnp.amin(alpha - off_diag_abs_row_sum) # Upper bound on 2-norm of T. t_norm = jnp.maximum(jnp.abs(lambda_est_min), jnp.abs(lambda_est_max)) # Compute the smallest allowed pivot in the Sturm sequence to avoid # overflow. finfo = np.finfo(alpha.dtype) one = np.ones([], dtype=alpha.dtype) safemin = np.maximum(one / finfo.max, (one + finfo.eps) * finfo.tiny) pivmin = safemin * jnp.maximum(1, jnp.amax(beta_sq)) alpha0_perturbation = jnp.square(finfo.eps * beta_abs[0]) abs_tol = finfo.eps * t_norm if tol is not None: abs_tol = jnp.maximum(tol, abs_tol) # In the worst case, when the absolute tolerance is eps*lambda_est_max and # lambda_est_max = -lambda_est_min, we have to take as many bisection steps # as there are bits in the mantissa plus 1. # The proof is left as an exercise to the reader. max_it = finfo.nmant + 1 # Determine the indices of the desired eigenvalues, based on select and # select_range. if select == 'a': target_counts = jnp.arange(n, dtype=jnp.int32) elif select == 'i': if select_range[0] > select_range[1]: raise ValueError('Got empty index range in select_range.') target_counts = jnp.arange(select_range[0], select_range[1] + 1, dtype=jnp.int32) elif select == 'v': # TODO(phawkins): requires dynamic shape support. raise NotImplementedError("eigh_tridiagonal(..., select='v') is not " "implemented") else: raise ValueError("'select must have a value in {'a', 'i', 'v'}.") # Run binary search for all desired eigenvalues in parallel, starting from # the interval lightly wider than the estimated # [lambda_est_min, lambda_est_max]. fudge = 2.1 # We widen starting interval the Gershgorin interval a bit. norm_slack = jnp.array(n, alpha.dtype) * fudge * finfo.eps * t_norm lower = lambda_est_min - norm_slack - 2 * fudge * pivmin upper = lambda_est_max + norm_slack + fudge * pivmin # Pre-broadcast the scalars used in the Sturm sequence for improved # performance. target_shape = jnp.shape(target_counts) lower = jnp.broadcast_to(lower, shape=target_shape) upper = jnp.broadcast_to(upper, shape=target_shape) mid = 0.5 * (upper + lower) pivmin = jnp.broadcast_to(pivmin, target_shape) alpha0_perturbation = jnp.broadcast_to(alpha0_perturbation, target_shape) # Start parallel binary searches. def cond(args): i, lower, _, upper = args return jnp.logical_and(jnp.less(i, max_it), jnp.less(abs_tol, jnp.amax(upper - lower))) def body(args): i, lower, mid, upper = args counts = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, mid) lower = jnp.where(counts <= target_counts, mid, lower) upper = jnp.where(counts > target_counts, mid, upper) mid = 0.5 * (lower + upper) return i + 1, lower, mid, upper _, _, mid, _ = lax.while_loop(cond, body, (0, lower, mid, upper)) return mid
def _spectral_helper(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None, detrend_type='constant', return_onesided=True, scaling='density', axis=-1, mode='psd', boundary=None, padded=False): """LAX-backend implementation of `scipy.signal._spectral_helper`. Unlike the original helper function, `y` can be None for explicitly indicating auto-spectral (non cross-spectral) computation. In addition to this, `detrend` argument is renamed to `detrend_type` for avoiding internal name overlap. """ if mode not in ('psd', 'stft'): raise ValueError(f"Unknown value for mode {mode}, " "must be one of: ('psd', 'stft')") def make_pad(mode, **kwargs): def pad(x, n, axis=-1): pad_width = [(0, 0) for unused_n in range(x.ndim)] pad_width[axis] = (n, n) return jnp.pad(x, pad_width, mode, **kwargs) return pad boundary_funcs = { 'even': make_pad('reflect'), 'odd': odd_ext, 'constant': make_pad('edge'), 'zeros': make_pad('constant', constant_values=0.0), None: lambda x, *args, **kwargs: x } # Check/ normalize inputs if boundary not in boundary_funcs: raise ValueError(f"Unknown boundary option '{boundary}', " f"must be one of: {list(boundary_funcs.keys())}") axis = jax.core.concrete_or_error(operator.index, axis, "axis of windowed-FFT") axis = canonicalize_axis(axis, x.ndim) if nperseg is not None: # if specified by user nperseg = jax.core.concrete_or_error(int, nperseg, "nperseg of windowed-FFT") if nperseg < 1: raise ValueError('nperseg must be a positive integer') # parse window; if array like, then set nperseg = win.shape win, nperseg = signal_helper._triage_segments(window, nperseg, input_length=x.shape[axis]) if noverlap is None: noverlap = nperseg // 2 else: noverlap = jax.core.concrete_or_error(int, noverlap, "noverlap of windowed-FFT") if nfft is None: nfft = nperseg else: nfft = jax.core.concrete_or_error(int, nfft, "nfft of windowed-FFT") _check_arraylike("_spectral_helper", x) x = jnp.asarray(x) if y is None: outdtype = jax.dtypes.canonicalize_dtype( np.result_type(x, np.complex64)) else: _check_arraylike("_spectral_helper", y) y = jnp.asarray(y) outdtype = jax.dtypes.canonicalize_dtype( np.result_type(x, y, np.complex64)) if mode != 'psd': raise ValueError( "two-argument mode is available only when mode=='psd'") if x.ndim != y.ndim: raise ValueError( "two-arguments must have the same rank ({x.ndim} vs {y.ndim})." ) # Check if we can broadcast the outer axes together try: outershape = jnp.broadcast_shapes(tuple_delete(x.shape, axis), tuple_delete(y.shape, axis)) except ValueError as e: raise ValueError('x and y cannot be broadcast together.') from e # Special cases for size == 0 if y is None: if x.size == 0: return jnp.zeros(x.shape), jnp.zeros(x.shape), jnp.zeros(x.shape) else: if x.size == 0 or y.size == 0: outshape = tuple_insert(outershape, min([x.shape[axis], y.shape[axis]]), axis) emptyout = jnp.zeros(outshape) return emptyout, emptyout, emptyout # Move time-axis to the end if x.ndim > 1: if axis != -1: x = jnp.moveaxis(x, axis, -1) if y is not None and y.ndim > 1: y = jnp.moveaxis(y, axis, -1) # Check if x and y are the same length, zero-pad if necessary if y is not None: if x.shape[-1] != y.shape[-1]: if x.shape[-1] < y.shape[-1]: pad_shape = list(x.shape) pad_shape[-1] = y.shape[-1] - x.shape[-1] x = jnp.concatenate((x, jnp.zeros(pad_shape)), -1) else: pad_shape = list(y.shape) pad_shape[-1] = x.shape[-1] - y.shape[-1] y = jnp.concatenate((y, jnp.zeros(pad_shape)), -1) if nfft < nperseg: raise ValueError('nfft must be greater than or equal to nperseg.') if noverlap >= nperseg: raise ValueError('noverlap must be less than nperseg.') nstep = nperseg - noverlap # Apply paddings if boundary is not None: ext_func = boundary_funcs[boundary] x = ext_func(x, nperseg // 2, axis=-1) if y is not None: y = ext_func(y, nperseg // 2, axis=-1) if padded: # Pad to integer number of windowed segments # I.e make x.shape[-1] = nperseg + (nseg-1)*nstep, with integer nseg nadd = (-(x.shape[-1] - nperseg) % nstep) % nperseg zeros_shape = list(x.shape[:-1]) + [nadd] x = jnp.concatenate((x, jnp.zeros(zeros_shape)), axis=-1) if y is not None: zeros_shape = list(y.shape[:-1]) + [nadd] y = jnp.concatenate((y, jnp.zeros(zeros_shape)), axis=-1) # Handle detrending and window functions if not detrend_type: def detrend_func(d): return d elif not hasattr(detrend_type, '__call__'): def detrend_func(d): return detrend(d, type=detrend_type, axis=-1) elif axis != -1: # Wrap this function so that it receives a shape that it could # reasonably expect to receive. def detrend_func(d): d = jnp.moveaxis(d, axis, -1) d = detrend_type(d) return jnp.moveaxis(d, -1, axis) else: detrend_func = detrend_type if np.result_type(win, np.complex64) != outdtype: win = win.astype(outdtype) # Determine scale if scaling == 'density': scale = 1.0 / (fs * (win * win).sum()) elif scaling == 'spectrum': scale = 1.0 / win.sum()**2 else: raise ValueError(f'Unknown scaling: {scaling}') if mode == 'stft': scale = jnp.sqrt(scale) # Determine onesided/ two-sided if return_onesided: sides = 'onesided' if jnp.iscomplexobj(x) or jnp.iscomplexobj(y): sides = 'twosided' warnings.warn('Input data is complex, switching to ' 'return_onesided=False') else: sides = 'twosided' if sides == 'twosided': freqs = jax.numpy.fft.fftfreq(nfft, 1 / fs) elif sides == 'onesided': freqs = jax.numpy.fft.rfftfreq(nfft, 1 / fs) # Perform the windowed FFTs result = _fft_helper(x, win, detrend_func, nperseg, noverlap, nfft, sides) if y is not None: # All the same operations on the y data result_y = _fft_helper(y, win, detrend_func, nperseg, noverlap, nfft, sides) result = jnp.conjugate(result) * result_y elif mode == 'psd': result = jnp.conjugate(result) * result result *= scale if sides == 'onesided' and mode == 'psd': end = None if nfft % 2 else -1 result = result.at[..., 1:end].mul(2) time = jnp.arange(nperseg / 2, x.shape[-1] - nperseg / 2 + 1, nperseg - noverlap) / fs if boundary is not None: time -= (nperseg / 2) / fs result = result.astype(outdtype) # All imaginary parts are zero anyways if y is None and mode != 'stft': result = result.real # Move frequency axis back to axis where the data came from result = jnp.moveaxis(result, -1, axis) return freqs, time, result