def __init__(self, dataset, bw_method=None, weights=None): _check_arraylike("gaussian_kde", dataset) dataset = jnp.atleast_2d(dataset) if jnp.issubdtype(lax.dtype(dataset), jnp.complexfloating): raise NotImplementedError( "gaussian_kde does not support complex data") if not dataset.size > 1: raise ValueError("`dataset` input should have multiple elements.") d, n = dataset.shape if weights is not None: _check_arraylike("gaussian_kde", weights) dataset, weights = _promote_dtypes_inexact(dataset, weights) weights = jnp.atleast_1d(weights) weights /= jnp.sum(weights) if weights.ndim != 1: raise ValueError("`weights` input should be one-dimensional.") if len(weights) != n: raise ValueError("`weights` input should be of length n") else: dataset, = _promote_dtypes_inexact(dataset) weights = jnp.full(n, 1.0 / n, dtype=dataset.dtype) self._setattr("dataset", dataset) self._setattr("weights", weights) neff = self._setattr("neff", 1 / jnp.sum(weights**2)) bw_method = "scott" if bw_method is None else bw_method if bw_method == "scott": factor = jnp.power(neff, -1. / (d + 4)) elif bw_method == "silverman": factor = jnp.power(neff * (d + 2) / 4.0, -1. / (d + 4)) elif jnp.isscalar(bw_method) and not isinstance(bw_method, str): factor = bw_method elif callable(bw_method): factor = bw_method(self) else: raise ValueError( "`bw_method` should be 'scott', 'silverman', a scalar, or a callable." ) data_covariance = jnp.atleast_2d( jnp.cov(dataset, rowvar=1, bias=False, aweights=weights)) data_inv_cov = jnp.linalg.inv(data_covariance) covariance = data_covariance * factor**2 inv_cov = data_inv_cov / factor**2 self._setattr("covariance", covariance) self._setattr("inv_cov", inv_cov)
def detrend(data, axis=-1, type='linear', bp=0, overwrite_data=None): if overwrite_data is not None: raise NotImplementedError("overwrite_data argument not implemented.") if type not in ['constant', 'linear']: raise ValueError("Trend type must be 'linear' or 'constant'.") data, = _promote_dtypes_inexact(jnp.asarray(data)) if type == 'constant': return data - data.mean(axis, keepdims=True) else: N = data.shape[axis] # bp is static, so we use np operations to avoid pushing to device. bp = np.sort(np.unique(np.r_[0, bp, N])) if bp[0] < 0 or bp[-1] > N: raise ValueError( "Breakpoints must be non-negative and less than length of data along given axis." ) data = jnp.moveaxis(data, axis, 0) shape = data.shape data = data.reshape(N, -1) for m in range(len(bp) - 1): Npts = bp[m + 1] - bp[m] A = jnp.vstack([ jnp.ones(Npts, dtype=data.dtype), jnp.arange(1, Npts + 1, dtype=data.dtype) / Npts ]).T sl = slice(bp[m], bp[m + 1]) coef, *_ = linalg.lstsq(A, data[sl]) data = data.at[sl].add( -jnp.matmul(A, coef, precision=lax.Precision.HIGHEST)) return jnp.moveaxis(data.reshape(shape), 0, axis)
def _gaussian_kernel_eval(in_log, points, values, xi, precision): points, values, xi, precision = _promote_dtypes_inexact( points, values, xi, precision) d = points.shape[1] if xi.shape[1] != d: raise ValueError("points and xi must have same trailing dim") if precision.shape != (d, d): raise ValueError("precision matrix must match data dims") whitening = linalg.cholesky(precision, lower=True) points = jnp.dot(points, whitening) xi = jnp.dot(xi, whitening) log_norm = jnp.sum(jnp.log( jnp.diag(whitening))) - 0.5 * d * jnp.log(2 * np.pi) def kernel(x_test, x_train, y_train): arg = log_norm - 0.5 * jnp.sum(jnp.square(x_train - x_test)) if in_log: return jnp.log(y_train) + arg else: return y_train * jnp.exp(arg) reduce = special.logsumexp if in_log else jnp.sum reduced_kernel = lambda x: reduce(vmap(kernel, in_axes=(None, 0, 0)) (x, points, values), axis=0) mapped_kernel = vmap(reduced_kernel) return mapped_kernel(xi)
def logpdf(x, mean, cov, allow_singular=None): if allow_singular is not None: raise NotImplementedError( "allow_singular argument of multivariate_normal.logpdf") x, mean, cov = _promote_dtypes_inexact(x, mean, cov) if not mean.shape: return (-1 / 2 * jnp.square(x - mean) / cov - 1 / 2 * (np.log(2 * np.pi) + jnp.log(cov))) else: n = mean.shape[-1] if not np.shape(cov): y = x - mean return (-1 / 2 * jnp.einsum('...i,...i->...', y, y) / cov - n / 2 * (np.log(2 * np.pi) + jnp.log(cov))) else: if cov.ndim < 2 or cov.shape[-2:] != (n, n): raise ValueError( "multivariate_normal.logpdf got incompatible shapes") L = lax.linalg.cholesky(cov) y = lax.linalg.triangular_solve(L, x - mean, lower=True, transpose_a=True) return (-1 / 2 * jnp.einsum('...i,...i->...', y, y) - n / 2 * np.log(2 * np.pi) - jnp.log(L.diagonal(axis1=-1, axis2=-2)).sum(-1))
def __init__(self, points, values, method="linear", bounds_error=False, fill_value=nan): if method not in ("linear", "nearest"): raise ValueError(f"method {method!r} is not defined") self.method = method self.bounds_error = bounds_error if self.bounds_error: raise NotImplementedError( "`bounds_error` takes no effect under JIT") _check_arraylike("RegularGridInterpolator", values) if len(points) > values.ndim: ve = f"there are {len(points)} point arrays, but values has {values.ndim} dimensions" raise ValueError(ve) values, = _promote_dtypes_inexact(values) if fill_value is not None: _check_arraylike("RegularGridInterpolator", fill_value) fill_value = asarray(fill_value) if not can_cast( fill_value.dtype, values.dtype, casting='same_kind'): ve = "fill_value must be either 'None' or of a type compatible with values" raise ValueError(ve) self.fill_value = fill_value # TODO: assert sanity of `points` similar to SciPy but in a JIT-able way _check_arraylike("RegularGridInterpolator", *points) self.grid = tuple(asarray(p) for p in points) self.values = values
def _convolve_nd(in1, in2, mode, *, precision): if mode not in ["full", "same", "valid"]: raise ValueError("mode must be one of ['full', 'same', 'valid']") if in1.ndim != in2.ndim: raise ValueError("in1 and in2 must have the same number of dimensions") if in1.size == 0 or in2.size == 0: raise ValueError(f"zero-size arrays not supported in convolutions, got shapes {in1.shape} and {in2.shape}.") in1, in2 = _promote_dtypes_inexact(in1, in2) no_swap = all(s1 >= s2 for s1, s2 in zip(in1.shape, in2.shape)) swap = all(s1 <= s2 for s1, s2 in zip(in1.shape, in2.shape)) if not (no_swap or swap): raise ValueError("One input must be smaller than the other in every dimension.") shape_o = in2.shape if swap: in1, in2 = in2, in1 shape = in2.shape in2 = jnp.flip(in2) if mode == 'valid': padding = [(0, 0) for s in shape] elif mode == 'same': padding = [(s - 1 - (s_o - 1) // 2, s - s_o + (s_o - 1) // 2) for (s, s_o) in zip(shape, shape_o)] elif mode == 'full': padding = [(s - 1, s - 1) for s in shape] strides = tuple(1 for s in shape) result = lax.conv_general_dilated(in1[None, None], in2[None, None], strides, padding, precision=precision) return result[0, 0]
def logpdf(x, mean, cov): x, mean, cov = _promote_dtypes_inexact(x, mean, cov) if not mean.shape: return (-1 / 2 * jnp.square(x - mean) / cov - 1 / 2 * (np.log(2 * np.pi) + jnp.log(cov))) else: n = mean.shape[-1] if not np.shape(cov): y = x - mean return (-1 / 2 * jnp.einsum('...i,...i->...', y, y) / cov - n / 2 * (np.log(2 * np.pi) + jnp.log(cov))) else: if cov.ndim < 2 or cov.shape[-2:] != (n, n): raise ValueError( "multivariate_normal.logpdf got incompatible shapes") L = cholesky(cov) y = triangular_solve(L, x - mean, lower=True, transpose_a=True) return (-1 / 2 * jnp.einsum('...i,...i->...', y, y) - n / 2 * np.log(2 * np.pi) - jnp.log(L.diagonal()).sum())
def logpdf(x, alpha): x, alpha = _promote_dtypes_inexact(x, alpha) if alpha.ndim != 1: raise ValueError( f"`alpha` must be one-dimensional; got alpha.shape={alpha.shape}" ) if x.shape[0] not in (alpha.shape[0], alpha.shape[0] - 1): raise ValueError( "`x` must have either the same number of entries as `alpha` " f"or one entry fewer; got x.shape={x.shape}, alpha.shape={alpha.shape}" ) one = lax._const(x, 1) if x.shape[0] != alpha.shape[0]: x = jnp.concatenate([x, lax.sub(one, x.sum(0, keepdims=True))], axis=0) normalize_term = jnp.sum(gammaln(alpha)) - gammaln(jnp.sum(alpha)) if x.ndim > 1: alpha = lax.broadcast_in_dim(alpha, alpha.shape + (1,) * (x.ndim - 1), (0,)) log_probs = lax.sub(jnp.sum(xlogy(lax.sub(alpha, one), x), axis=0), normalize_term) return jnp.where(_is_simplex(x), log_probs, -jnp.inf)
def _spectral_helper(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None, detrend_type='constant', return_onesided=True, scaling='density', axis=-1, mode='psd', boundary=None, padded=False): """LAX-backend implementation of `scipy.signal._spectral_helper`. Unlike the original helper function, `y` can be None for explicitly indicating auto-spectral (non cross-spectral) computation. In addition to this, `detrend` argument is renamed to `detrend_type` for avoiding internal name overlap. """ if mode not in ('psd', 'stft'): raise ValueError(f"Unknown value for mode {mode}, " "must be one of: ('psd', 'stft')") def make_pad(mode, **kwargs): def pad(x, n, axis=-1): pad_width = [(0, 0) for unused_n in range(x.ndim)] pad_width[axis] = (n, n) return jnp.pad(x, pad_width, mode, **kwargs) return pad boundary_funcs = { 'even': make_pad('reflect'), 'odd': odd_ext, 'constant': make_pad('edge'), 'zeros': make_pad('constant', constant_values=0.0), None: lambda x, *args, **kwargs: x } # Check/ normalize inputs if boundary not in boundary_funcs: raise ValueError(f"Unknown boundary option '{boundary}', " f"must be one of: {list(boundary_funcs.keys())}") axis = jax.core.concrete_or_error(operator.index, axis, "axis of windowed-FFT") axis = canonicalize_axis(axis, x.ndim) if y is None: _check_arraylike('spectral_helper', x) x, = _promote_dtypes_inexact(x) outershape = tuple_delete(x.shape, axis) else: if mode != 'psd': raise ValueError( "two-argument mode is available only when mode=='psd'") _check_arraylike('spectral_helper', x, y) x, y = _promote_dtypes_inexact(x, y) if x.ndim != y.ndim: raise ValueError( "two-arguments must have the same rank ({x.ndim} vs {y.ndim})." ) # Check if we can broadcast the outer axes together try: outershape = jnp.broadcast_shapes(tuple_delete(x.shape, axis), tuple_delete(y.shape, axis)) except ValueError as err: raise ValueError('x and y cannot be broadcast together.') from err result_dtype = dtypes._to_complex_dtype(x.dtype) freq_dtype = np.finfo(result_dtype).dtype if nperseg is not None: # if specified by user nperseg = jax.core.concrete_or_error(int, nperseg, "nperseg of windowed-FFT") if nperseg < 1: raise ValueError('nperseg must be a positive integer') # parse window; if array like, then set nperseg = win.shape win, nperseg = signal_helper._triage_segments(window, nperseg, input_length=x.shape[axis], dtype=result_dtype) if noverlap is None: noverlap = nperseg // 2 else: noverlap = jax.core.concrete_or_error(int, noverlap, "noverlap of windowed-FFT") if nfft is None: nfft = nperseg else: nfft = jax.core.concrete_or_error(int, nfft, "nfft of windowed-FFT") # Special cases for size == 0 if y is None: if x.size == 0: return jnp.zeros(x.shape, freq_dtype), jnp.zeros( x.shape, freq_dtype), jnp.zeros(x.shape, result_dtype) else: if x.size == 0 or y.size == 0: shape = tuple_insert(outershape, min([x.shape[axis], y.shape[axis]]), axis) return jnp.zeros(shape, freq_dtype), jnp.zeros( shape, freq_dtype), jnp.zeros(shape, result_dtype) # Move time-axis to the end x = jnp.moveaxis(x, axis, -1) if y is not None and y.ndim > 1: y = jnp.moveaxis(y, axis, -1) # Check if x and y are the same length, zero-pad if necessary if y is not None and x.shape[-1] != y.shape[-1]: if x.shape[-1] < y.shape[-1]: pad_shape = list(x.shape) pad_shape[-1] = y.shape[-1] - x.shape[-1] x = jnp.concatenate((x, jnp.zeros_like(x, shape=pad_shape)), -1) else: pad_shape = list(y.shape) pad_shape[-1] = x.shape[-1] - y.shape[-1] y = jnp.concatenate((y, jnp.zeros_like(x, shape=pad_shape)), -1) if nfft < nperseg: raise ValueError('nfft must be greater than or equal to nperseg.') if noverlap >= nperseg: raise ValueError('noverlap must be less than nperseg.') nstep = nperseg - noverlap # Apply paddings if boundary is not None: ext_func = boundary_funcs[boundary] x = ext_func(x, nperseg // 2, axis=-1) if y is not None: y = ext_func(y, nperseg // 2, axis=-1) if padded: # Pad to integer number of windowed segments # I.e make x.shape[-1] = nperseg + (nseg-1)*nstep, with integer nseg nadd = (-(x.shape[-1] - nperseg) % nstep) % nperseg x = jnp.concatenate( (x, jnp.zeros_like(x, shape=(*x.shape[:-1], nadd))), axis=-1) if y is not None: y = jnp.concatenate( (y, jnp.zeros_like(x, shape=(*y.shape[:-1], nadd))), axis=-1) # Handle detrending and window functions if not detrend_type: detrend_func = lambda d: d elif not callable(detrend_type): detrend_func = partial(detrend, type=detrend_type, axis=-1) elif axis != -1: # Wrap this function so that it receives a shape that it could # reasonably expect to receive. def detrend_func(d): d = jnp.moveaxis(d, axis, -1) d = detrend_type(d) return jnp.moveaxis(d, -1, axis) else: detrend_func = detrend_type # Determine scale if scaling == 'density': scale = 1.0 / (fs * (win * win).sum()) elif scaling == 'spectrum': scale = 1.0 / win.sum()**2 else: raise ValueError(f'Unknown scaling: {scaling}') if mode == 'stft': scale = jnp.sqrt(scale) # Determine onesided/ two-sided if return_onesided: sides = 'onesided' if jnp.iscomplexobj(x) or jnp.iscomplexobj(y): sides = 'twosided' warnings.warn('Input data is complex, switching to ' 'return_onesided=False') else: sides = 'twosided' if sides == 'twosided': freqs = jax.numpy.fft.fftfreq(nfft, 1 / fs).astype(freq_dtype) elif sides == 'onesided': freqs = jax.numpy.fft.rfftfreq(nfft, 1 / fs).astype(freq_dtype) # Perform the windowed FFTs result = _fft_helper(x.astype(result_dtype), win, detrend_func, nperseg, noverlap, nfft, sides) if y is not None: # All the same operations on the y data result_y = _fft_helper(y.astype(result_dtype), win, detrend_func, nperseg, noverlap, nfft, sides) result = jnp.conjugate(result) * result_y elif mode == 'psd': result = jnp.conjugate(result) * result result *= scale if sides == 'onesided' and mode == 'psd': end = None if nfft % 2 else -1 result = result.at[..., 1:end].mul(2) time = jnp.arange(nperseg / 2, x.shape[-1] - nperseg / 2 + 1, nperseg - noverlap, dtype=freq_dtype) / fs if boundary is not None: time -= (nperseg / 2) / fs result = result.astype(result_dtype) # All imaginary parts are zero anyways if y is None and mode != 'stft': result = result.real # Move frequency axis back to axis where the data came from result = jnp.moveaxis(result, -1, axis) return freqs, time, result