def logpdf(self, x): _check_arraylike("logpdf", x) x = self._reshape_points(x) result = _gaussian_kernel_eval(True, self.dataset.T, self.weights[:, None], x.T, self.inv_cov) return result[:, 0]
def evaluate(self, points): _check_arraylike("evaluate", points) points = self._reshape_points(points) result = _gaussian_kernel_eval(False, self.dataset.T, self.weights[:, None], points.T, self.inv_cov) return result[:, 0]
def _segment_update(name: str, data: Array, segment_ids: Array, scatter_op: Callable, num_segments: Optional[int] = None, indices_are_sorted: bool = False, unique_indices: bool = False, bucket_size: Optional[int] = None, reducer: Optional[Callable] = None, mode: Optional[lax.GatherScatterMode] = None) -> Array: jnp._check_arraylike(name, data, segment_ids) mode = lax.GatherScatterMode.FILL_OR_DROP if mode is None else mode data = jnp.asarray(data) segment_ids = jnp.asarray(segment_ids) dtype = data.dtype if num_segments is None: num_segments = jnp.max(segment_ids) + 1 num_segments = core.concrete_or_error( int, num_segments, "segment_sum() `num_segments` argument.") if num_segments is not None and num_segments < 0: raise ValueError("num_segments must be non-negative.") num_buckets = 1 if bucket_size is None \ else util.ceil_of_ratio(segment_ids.size, bucket_size) if num_buckets == 1: out = jnp.full((num_segments, ) + data.shape[1:], _get_identity(scatter_op, dtype), dtype=dtype) return _scatter_update(out, segment_ids, data, scatter_op, indices_are_sorted, unique_indices, normalize_indices=False, mode=mode) # Bucketize indices and perform segment_update on each bucket to improve # numerical stability for operations like product and sum. assert reducer is not None out = jnp.full((num_buckets, num_segments) + data.shape[1:], _get_identity(scatter_op, dtype), dtype=dtype) out = _scatter_update( out, np.index_exp[lax.div(jnp.arange(segment_ids.shape[0]), bucket_size), segment_ids[None, :]], data, scatter_op, indices_are_sorted, unique_indices, normalize_indices=False, mode=mode) return reducer(out, axis=0).astype(dtype)
def __init__(self, dataset, bw_method=None, weights=None): _check_arraylike("gaussian_kde", dataset) dataset = jnp.atleast_2d(dataset) if jnp.issubdtype(lax.dtype(dataset), jnp.complexfloating): raise NotImplementedError( "gaussian_kde does not support complex data") if not dataset.size > 1: raise ValueError("`dataset` input should have multiple elements.") d, n = dataset.shape if weights is not None: _check_arraylike("gaussian_kde", weights) dataset, weights = _promote_dtypes_inexact(dataset, weights) weights = jnp.atleast_1d(weights) weights /= jnp.sum(weights) if weights.ndim != 1: raise ValueError("`weights` input should be one-dimensional.") if len(weights) != n: raise ValueError("`weights` input should be of length n") else: dataset, = _promote_dtypes_inexact(dataset) weights = jnp.full(n, 1.0 / n, dtype=dataset.dtype) self._setattr("dataset", dataset) self._setattr("weights", weights) neff = self._setattr("neff", 1 / jnp.sum(weights**2)) bw_method = "scott" if bw_method is None else bw_method if bw_method == "scott": factor = jnp.power(neff, -1. / (d + 4)) elif bw_method == "silverman": factor = jnp.power(neff * (d + 2) / 4.0, -1. / (d + 4)) elif jnp.isscalar(bw_method) and not isinstance(bw_method, str): factor = bw_method elif callable(bw_method): factor = bw_method(self) else: raise ValueError( "`bw_method` should be 'scott', 'silverman', a scalar, or a callable." ) data_covariance = jnp.atleast_2d( jnp.cov(dataset, rowvar=1, bias=False, aweights=weights)) data_inv_cov = jnp.linalg.inv(data_covariance) covariance = data_covariance * factor**2 inv_cov = data_inv_cov / factor**2 self._setattr("covariance", covariance) self._setattr("inv_cov", inv_cov)
def _segment_update(name: str, data: Array, segment_ids: Array, scatter_op: Callable, num_segments: Optional[int] = None, indices_are_sorted: bool = False, unique_indices: bool = False, bucket_size: Optional[int] = None, reducer: Optional[Callable] = None) -> Array: jnp._check_arraylike(name, data, segment_ids) data = jnp.asarray(data) segment_ids = jnp.asarray(segment_ids) dtype = data.dtype if num_segments is None: num_segments = jnp.max(segment_ids) + 1 num_segments = core.concrete_or_error( int, num_segments, "segment_sum() `num_segments` argument.") if num_segments is not None and num_segments < 0: raise ValueError("num_segments must be non-negative.") out = jnp.full((num_segments, ) + data.shape[1:], _get_identity(scatter_op, dtype), dtype=dtype) num_buckets = 1 if bucket_size is None \ else util.ceil_of_ratio(segment_ids.size, bucket_size) if num_buckets == 1: return _scatter_update(out, segment_ids, data, scatter_op, indices_are_sorted, unique_indices, normalize_indices=False) # Bucketize indices and perform segment_update on each bucket to improve # numerical stability for operations like product and sum. assert reducer is not None outs = [] for sub_data, sub_segment_ids in zip( jnp.array_split(data, num_buckets), jnp.array_split(segment_ids, num_buckets)): outs.append( _segment_update(name, sub_data, sub_segment_ids, scatter_op, num_segments, indices_are_sorted, unique_indices)) return reducer(jnp.stack(outs), axis=0).astype(dtype)
def _overlap_and_add(x, step_size): """Utility function compatible with tf.signal.overlap_and_add. Args: x: An array with `(..., frames, frame_length)`-shape. step_size: An integer denoting overlap offsets. Must be less than `frame_length`. Returns: An array with `(..., output_size)`-shape containing overlapped signal. """ _check_arraylike("_overlap_and_add", x) step_size = jax.core.concrete_or_error(int, step_size, "step_size for overlap_and_add") if x.ndim < 2: raise ValueError('Input must have (..., frames, frame_length) shape.') *batch_shape, nframes, segment_len = x.shape flat_batchsize = np.prod(batch_shape, dtype=np.int64) x = x.reshape((flat_batchsize, nframes, segment_len)) output_size = step_size * (nframes - 1) + segment_len nstep_per_segment = 1 + (segment_len - 1) // step_size # Here, we use shorter notation for axes. # B: batch_size, N: nframes, S: nstep_per_segment, # T: segment_len divided by S padded_segment_len = nstep_per_segment * step_size x = jnp.pad(x, ((0, 0), (0, 0), (0, padded_segment_len - segment_len))) x = x.reshape((flat_batchsize, nframes, nstep_per_segment, step_size)) # For obtaining shifted signals, this routine reinterprets flattened array # with a shrinked axis. With appropriate truncation/ padding, this operation # pushes the last padded elements of the previous row to the head of the # current row. # See implementation of `overlap_and_add` in Tensorflow for details. x = x.transpose((0, 2, 1, 3)) # x: (B, S, N, T) x = jnp.pad(x, ((0, 0), (0, 0), (0, nframes), (0, 0))) # x: (B, S, N*2, T) shrinked = x.shape[2] - 1 x = x.reshape((flat_batchsize, -1)) x = x[:, :(nstep_per_segment * shrinked * step_size)] x = x.reshape((flat_batchsize, nstep_per_segment, shrinked * step_size)) # Finally, sum shifted segments, and truncate results to the output_size. x = x.sum(axis=1)[:, :output_size] return x.reshape(tuple(batch_shape) + (-1, ))
def __init__(self, points, values, method="linear", bounds_error=False, fill_value=nan): if method not in ("linear", "nearest"): raise ValueError(f"method {method!r} is not defined") self.method = method self.bounds_error = bounds_error if self.bounds_error: raise NotImplementedError( "`bounds_error` takes no effect under JIT") _check_arraylike("RegularGridInterpolator", values) if len(points) > values.ndim: ve = f"there are {len(points)} point arrays, but values has {values.ndim} dimensions" raise ValueError(ve) values, = _promote_dtypes_inexact(values) if fill_value is not None: _check_arraylike("RegularGridInterpolator", fill_value) fill_value = asarray(fill_value) if not can_cast( fill_value.dtype, values.dtype, casting='same_kind'): ve = "fill_value must be either 'None' or of a type compatible with values" raise ValueError(ve) self.fill_value = fill_value # TODO: assert sanity of `points` similar to SciPy but in a JIT-able way _check_arraylike("RegularGridInterpolator", *points) self.grid = tuple(asarray(p) for p in points) self.values = values
def _ndim_coords_from_arrays(points, ndim=None): """Convert a tuple of coordinate arrays to a (..., ndim)-shaped array.""" if isinstance(points, tuple) and len(points) == 1: # handle argument tuple points = points[0] if isinstance(points, tuple): p = broadcast_arrays(*points) for p_other in p[1:]: if p_other.shape != p[0].shape: raise ValueError( "coordinate arrays do not have the same shape") points = empty(p[0].shape + (len(points), ), dtype=float) for j, item in enumerate(p): points = points.at[..., j].set(item) else: _check_arraylike("_ndim_coords_from_arrays", points) points = asarray(points) # SciPy: asanyarray(points) if points.ndim == 1: if ndim is None: points = points.reshape(-1, 1) else: points = points.reshape(-1, ndim) return points
def _spectral_helper(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None, detrend_type='constant', return_onesided=True, scaling='density', axis=-1, mode='psd', boundary=None, padded=False): """LAX-backend implementation of `scipy.signal._spectral_helper`. Unlike the original helper function, `y` can be None for explicitly indicating auto-spectral (non cross-spectral) computation. In addition to this, `detrend` argument is renamed to `detrend_type` for avoiding internal name overlap. """ if mode not in ('psd', 'stft'): raise ValueError(f"Unknown value for mode {mode}, " "must be one of: ('psd', 'stft')") def make_pad(mode, **kwargs): def pad(x, n, axis=-1): pad_width = [(0, 0) for unused_n in range(x.ndim)] pad_width[axis] = (n, n) return jnp.pad(x, pad_width, mode, **kwargs) return pad boundary_funcs = { 'even': make_pad('reflect'), 'odd': odd_ext, 'constant': make_pad('edge'), 'zeros': make_pad('constant', constant_values=0.0), None: lambda x, *args, **kwargs: x } # Check/ normalize inputs if boundary not in boundary_funcs: raise ValueError(f"Unknown boundary option '{boundary}', " f"must be one of: {list(boundary_funcs.keys())}") axis = jax.core.concrete_or_error(operator.index, axis, "axis of windowed-FFT") axis = canonicalize_axis(axis, x.ndim) if nperseg is not None: # if specified by user nperseg = jax.core.concrete_or_error(int, nperseg, "nperseg of windowed-FFT") if nperseg < 1: raise ValueError('nperseg must be a positive integer') # parse window; if array like, then set nperseg = win.shape win, nperseg = signal_helper._triage_segments(window, nperseg, input_length=x.shape[axis]) if noverlap is None: noverlap = nperseg // 2 else: noverlap = jax.core.concrete_or_error(int, noverlap, "noverlap of windowed-FFT") if nfft is None: nfft = nperseg else: nfft = jax.core.concrete_or_error(int, nfft, "nfft of windowed-FFT") _check_arraylike("_spectral_helper", x) x = jnp.asarray(x) if y is None: outdtype = jax.dtypes.canonicalize_dtype( np.result_type(x, np.complex64)) else: _check_arraylike("_spectral_helper", y) y = jnp.asarray(y) outdtype = jax.dtypes.canonicalize_dtype( np.result_type(x, y, np.complex64)) if mode != 'psd': raise ValueError( "two-argument mode is available only when mode=='psd'") if x.ndim != y.ndim: raise ValueError( "two-arguments must have the same rank ({x.ndim} vs {y.ndim})." ) # Check if we can broadcast the outer axes together try: outershape = jnp.broadcast_shapes(tuple_delete(x.shape, axis), tuple_delete(y.shape, axis)) except ValueError as e: raise ValueError('x and y cannot be broadcast together.') from e # Special cases for size == 0 if y is None: if x.size == 0: return jnp.zeros(x.shape), jnp.zeros(x.shape), jnp.zeros(x.shape) else: if x.size == 0 or y.size == 0: outshape = tuple_insert(outershape, min([x.shape[axis], y.shape[axis]]), axis) emptyout = jnp.zeros(outshape) return emptyout, emptyout, emptyout # Move time-axis to the end if x.ndim > 1: if axis != -1: x = jnp.moveaxis(x, axis, -1) if y is not None and y.ndim > 1: y = jnp.moveaxis(y, axis, -1) # Check if x and y are the same length, zero-pad if necessary if y is not None: if x.shape[-1] != y.shape[-1]: if x.shape[-1] < y.shape[-1]: pad_shape = list(x.shape) pad_shape[-1] = y.shape[-1] - x.shape[-1] x = jnp.concatenate((x, jnp.zeros(pad_shape)), -1) else: pad_shape = list(y.shape) pad_shape[-1] = x.shape[-1] - y.shape[-1] y = jnp.concatenate((y, jnp.zeros(pad_shape)), -1) if nfft < nperseg: raise ValueError('nfft must be greater than or equal to nperseg.') if noverlap >= nperseg: raise ValueError('noverlap must be less than nperseg.') nstep = nperseg - noverlap # Apply paddings if boundary is not None: ext_func = boundary_funcs[boundary] x = ext_func(x, nperseg // 2, axis=-1) if y is not None: y = ext_func(y, nperseg // 2, axis=-1) if padded: # Pad to integer number of windowed segments # I.e make x.shape[-1] = nperseg + (nseg-1)*nstep, with integer nseg nadd = (-(x.shape[-1] - nperseg) % nstep) % nperseg zeros_shape = list(x.shape[:-1]) + [nadd] x = jnp.concatenate((x, jnp.zeros(zeros_shape)), axis=-1) if y is not None: zeros_shape = list(y.shape[:-1]) + [nadd] y = jnp.concatenate((y, jnp.zeros(zeros_shape)), axis=-1) # Handle detrending and window functions if not detrend_type: def detrend_func(d): return d elif not hasattr(detrend_type, '__call__'): def detrend_func(d): return detrend(d, type=detrend_type, axis=-1) elif axis != -1: # Wrap this function so that it receives a shape that it could # reasonably expect to receive. def detrend_func(d): d = jnp.moveaxis(d, axis, -1) d = detrend_type(d) return jnp.moveaxis(d, -1, axis) else: detrend_func = detrend_type if np.result_type(win, np.complex64) != outdtype: win = win.astype(outdtype) # Determine scale if scaling == 'density': scale = 1.0 / (fs * (win * win).sum()) elif scaling == 'spectrum': scale = 1.0 / win.sum()**2 else: raise ValueError(f'Unknown scaling: {scaling}') if mode == 'stft': scale = jnp.sqrt(scale) # Determine onesided/ two-sided if return_onesided: sides = 'onesided' if jnp.iscomplexobj(x) or jnp.iscomplexobj(y): sides = 'twosided' warnings.warn('Input data is complex, switching to ' 'return_onesided=False') else: sides = 'twosided' if sides == 'twosided': freqs = jax.numpy.fft.fftfreq(nfft, 1 / fs) elif sides == 'onesided': freqs = jax.numpy.fft.rfftfreq(nfft, 1 / fs) # Perform the windowed FFTs result = _fft_helper(x, win, detrend_func, nperseg, noverlap, nfft, sides) if y is not None: # All the same operations on the y data result_y = _fft_helper(y, win, detrend_func, nperseg, noverlap, nfft, sides) result = jnp.conjugate(result) * result_y elif mode == 'psd': result = jnp.conjugate(result) * result result *= scale if sides == 'onesided' and mode == 'psd': end = None if nfft % 2 else -1 result = result.at[..., 1:end].mul(2) time = jnp.arange(nperseg / 2, x.shape[-1] - nperseg / 2 + 1, nperseg - noverlap) / fs if boundary is not None: time -= (nperseg / 2) / fs result = result.astype(outdtype) # All imaginary parts are zero anyways if y is None and mode != 'stft': result = result.real # Move frequency axis back to axis where the data came from result = jnp.moveaxis(result, -1, axis) return freqs, time, result
def istft(Zxx, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None, input_onesided=True, boundary=True, time_axis=-1, freq_axis=-2): # Input validation _check_arraylike("istft", Zxx) if Zxx.ndim < 2: raise ValueError('Input stft must be at least 2d!') freq_axis = canonicalize_axis(freq_axis, Zxx.ndim) time_axis = canonicalize_axis(time_axis, Zxx.ndim) if freq_axis == time_axis: raise ValueError('Must specify differing time and frequency axes!') Zxx = jnp.asarray(Zxx, dtype=jax.dtypes.canonicalize_dtype( np.result_type(Zxx, np.complex64))) n_default = (2 * (Zxx.shape[freq_axis] - 1) if input_onesided else Zxx.shape[freq_axis]) nperseg = jax.core.concrete_or_error(int, nperseg or n_default, "nperseg: segment length of STFT") if nperseg < 1: raise ValueError('nperseg must be a positive integer') if nfft is None: nfft = n_default if input_onesided and nperseg == n_default + 1: nfft += 1 # Odd nperseg, no FFT padding else: nfft = jax.core.concrete_or_error(int, nfft, "nfft of STFT") if nfft < nperseg: raise ValueError( f'FFT length ({nfft}) must be longer than nperseg ({nperseg}).') noverlap = jax.core.concrete_or_error(int, noverlap or nperseg // 2, "noverlap of STFT") if noverlap >= nperseg: raise ValueError('noverlap must be less than nperseg.') nstep = nperseg - noverlap # Rearrange axes if necessary if time_axis != Zxx.ndim - 1 or freq_axis != Zxx.ndim - 2: outer_idxs = tuple(idx for idx in range(Zxx.ndim) if idx not in {time_axis, freq_axis}) Zxx = jnp.transpose(Zxx, outer_idxs + (freq_axis, time_axis)) # Perform IFFT ifunc = jax.numpy.fft.irfft if input_onesided else jax.numpy.fft.ifft # xsubs: [..., T, N], N is the number of frames, T is the frame length. xsubs = ifunc(Zxx, axis=-2, n=nfft)[..., :nperseg, :] # Get window as array if isinstance(window, (str, tuple)): win = osp_signal.get_window(window, nperseg) win = jnp.asarray(win) else: win = jnp.asarray(window) if len(win.shape) != 1: raise ValueError('window must be 1-D') if win.shape[0] != nperseg: raise ValueError('window must have length of {0}'.format(nperseg)) win = win.astype(xsubs.dtype) xsubs *= win.sum() # This takes care of the 'spectrum' scaling # make win broadcastable over xsubs win = win.reshape((1, ) * (xsubs.ndim - 2) + win.shape + (1, )) x = _overlap_and_add((xsubs * win).swapaxes(-2, -1), nstep) win_squared = jnp.repeat((win * win), xsubs.shape[-1], axis=-1) norm = _overlap_and_add(win_squared.swapaxes(-2, -1), nstep) # Remove extension points if boundary: x = x[..., nperseg // 2:-(nperseg // 2)] norm = norm[..., nperseg // 2:-(nperseg // 2)] x /= jnp.where(norm > 1e-10, norm, 1.0) # Put axes back if x.ndim > 1: if time_axis != Zxx.ndim - 1: if freq_axis < time_axis: time_axis -= 1 x = jnp.moveaxis(x, -1, time_axis) time = jnp.arange(x.shape[0]) / fs return time, x