def detrend(data, axis=-1, type='linear', bp=0, overwrite_data=None): if overwrite_data is not None: raise NotImplementedError("overwrite_data argument not implemented.") if type not in ['constant', 'linear']: raise ValueError("Trend type must be 'linear' or 'constant'.") data, = _promote_dtypes_inexact(jnp.asarray(data)) if type == 'constant': return data - data.mean(axis, keepdims=True) else: N = data.shape[axis] # bp is static, so we use np operations to avoid pushing to device. bp = np.sort(np.unique(np.r_[0, bp, N])) if bp[0] < 0 or bp[-1] > N: raise ValueError( "Breakpoints must be non-negative and less than length of data along given axis." ) data = jnp.moveaxis(data, axis, 0) shape = data.shape data = data.reshape(N, -1) for m in range(len(bp) - 1): Npts = bp[m + 1] - bp[m] A = jnp.vstack([ jnp.ones(Npts, dtype=data.dtype), jnp.arange(1, Npts + 1, dtype=data.dtype) / Npts ]).T sl = slice(bp[m], bp[m + 1]) coef, *_ = linalg.lstsq(A, data[sl]) data = data.at[sl].add( -jnp.matmul(A, coef, precision=lax.Precision.HIGHEST)) return jnp.moveaxis(data.reshape(shape), 0, axis)
def _fft_core(func_name, fft_type, a, s, axes, norm): full_name = "jax.numpy.fft." + func_name if s is not None: s = tuple(map(operator.index, s)) if np.any(np.less(s, 0)): raise ValueError("Shape should be non-negative.") if s is not None and axes is not None and len(s) != len(axes): # Same error as numpy. raise ValueError("Shape and axes have different lengths.") orig_axes = axes if axes is None: if s is None: axes = range(a.ndim) else: axes = range(a.ndim - len(s), a.ndim) if len(axes) != len(set(axes)): raise ValueError( f"{full_name} does not support repeated axes. Got axes {axes}.") if len(axes) > 3: # XLA does not support FFTs over more than 3 dimensions raise ValueError("%s only supports 1D, 2D, and 3D FFTs. " "Got axes %s with input rank %s." % (full_name, orig_axes, a.ndim)) # XLA only supports FFTs over the innermost axes, so rearrange if necessary. if orig_axes is not None: axes = tuple(range(a.ndim - len(axes), a.ndim)) a = jnp.moveaxis(a, orig_axes, axes) if s is not None: a = jnp.asarray(a) in_s = list(a.shape) for axis, x in safe_zip(axes, s): in_s[axis] = x if fft_type == xla_client.FftType.IRFFT: in_s[-1] = (in_s[-1] // 2 + 1) # Cropping a = a[tuple(map(slice, in_s))] # Padding a = jnp.pad(a, [(0, x - y) for x, y in zip(in_s, a.shape)]) else: if fft_type == xla_client.FftType.IRFFT: s = [a.shape[axis] for axis in axes[:-1]] if axes: s += [max(0, 2 * (a.shape[axes[-1]] - 1))] else: s = [a.shape[axis] for axis in axes] transformed = lax.fft(a, fft_type, tuple(s)) transformed *= _fft_norm(jnp.array(s, dtype=transformed.dtype), func_name, norm) if orig_axes is not None: transformed = jnp.moveaxis(transformed, axes, orig_axes) return transformed
def _unique_sorted_mask(ar, axis): aux = moveaxis(ar, axis, 0) if np.issubdtype(aux.dtype, np.complexfloating): # Work around issue in sorting of complex numbers with Nan only in the # imaginary component. This can be removed if sorting in this situation # is fixed to match numpy. aux = where(isnan(aux), _lax_const(aux, np.nan), aux) size, *out_shape = aux.shape if _prod(out_shape) == 0: size = 1 perm = zeros(1, dtype=int) else: perm = lexsort(aux.reshape(size, _prod(out_shape)).T[::-1]) aux = aux[perm] if aux.size: if dtypes.issubdtype(aux.dtype, np.inexact): # This is appropriate for both float and complex due to the documented behavior of np.unique: # See https://github.com/numpy/numpy/blob/v1.22.0/numpy/lib/arraysetops.py#L212-L220 neq = lambda x, y: lax.ne(x, y) & ~(isnan(x) & isnan(y)) else: neq = lax.ne mask = ones(size, dtype=bool).at[1:].set( any(neq(aux[1:], aux[:-1]), tuple(range(1, aux.ndim)))) else: mask = zeros(size, dtype=bool) return aux, mask, perm
def norm(x, ord=None, axis: Union[None, Tuple[int, ...], int] = None, keepdims=False): x = _promote_arg_dtypes(jnp.asarray(x)) x_shape = jnp.shape(x) ndim = len(x_shape) if axis is None: # NumPy has an undocumented behavior that admits arbitrary rank inputs if # `ord` is None: https://github.com/numpy/numpy/issues/14215 if ord is None: return jnp.sqrt( jnp.sum(jnp.real(x * jnp.conj(x)), keepdims=keepdims)) axis = tuple(range(ndim)) elif isinstance(axis, tuple): axis = tuple(canonicalize_axis(x, ndim) for x in axis) else: axis = (canonicalize_axis(axis, ndim), ) num_axes = len(axis) if num_axes == 1: if ord is None or ord == 2: return jnp.sqrt( jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis, keepdims=keepdims)) elif ord == jnp.inf: return jnp.amax(jnp.abs(x), axis=axis, keepdims=keepdims) elif ord == -jnp.inf: return jnp.amin(jnp.abs(x), axis=axis, keepdims=keepdims) elif ord == 0: return jnp.sum(x != 0, dtype=jnp.finfo(lax.dtype(x)).dtype, axis=axis, keepdims=keepdims) elif ord == 1: # Numpy has a special case for ord == 1 as an optimization. We don't # really need the optimization (XLA could do it for us), but the Numpy # code has slightly different type promotion semantics, so we need a # special case too. return jnp.sum(jnp.abs(x), axis=axis, keepdims=keepdims) else: abs_x = jnp.abs(x) ord = lax._const(abs_x, ord) out = jnp.sum(abs_x**ord, axis=axis, keepdims=keepdims) return jnp.power(out, 1. / ord) elif num_axes == 2: row_axis, col_axis = cast(Tuple[int, ...], axis) if ord is None or ord in ('f', 'fro'): return jnp.sqrt( jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis, keepdims=keepdims)) elif ord == 1: if not keepdims and col_axis > row_axis: col_axis -= 1 return jnp.amax(jnp.sum(jnp.abs(x), axis=row_axis, keepdims=keepdims), axis=col_axis, keepdims=keepdims) elif ord == -1: if not keepdims and col_axis > row_axis: col_axis -= 1 return jnp.amin(jnp.sum(jnp.abs(x), axis=row_axis, keepdims=keepdims), axis=col_axis, keepdims=keepdims) elif ord == jnp.inf: if not keepdims and row_axis > col_axis: row_axis -= 1 return jnp.amax(jnp.sum(jnp.abs(x), axis=col_axis, keepdims=keepdims), axis=row_axis, keepdims=keepdims) elif ord == -jnp.inf: if not keepdims and row_axis > col_axis: row_axis -= 1 return jnp.amin(jnp.sum(jnp.abs(x), axis=col_axis, keepdims=keepdims), axis=row_axis, keepdims=keepdims) elif ord in ('nuc', 2, -2): x = jnp.moveaxis(x, axis, (-2, -1)) if ord == 2: reducer = jnp.amax elif ord == -2: reducer = jnp.amin else: reducer = jnp.sum y = reducer(svd(x, compute_uv=False), axis=-1) if keepdims: result_shape = list(x_shape) result_shape[axis[0]] = 1 result_shape[axis[1]] = 1 y = jnp.reshape(y, result_shape) return y else: raise ValueError("Invalid order '{}' for matrix norm.".format(ord)) else: raise ValueError( "Invalid axis values ({}) for jnp.linalg.norm.".format(axis))
def conv_general_dilated_patches( lhs: lax.Array, filter_shape: Sequence[int], window_strides: Sequence[int], padding: Union[str, Sequence[Tuple[int, int]]], lhs_dilation: Optional[Sequence[int]] = None, rhs_dilation: Optional[Sequence[int]] = None, dimension_numbers: Optional[lax.ConvGeneralDilatedDimensionNumbers] = None, precision: Optional[lax.PrecisionType] = None, preferred_element_type: Optional[DType] = None, ) -> lax.Array: """Extract patches subject to the receptive field of `conv_general_dilated`. Runs the input through a convolution with given parameters. The kernel of the convolution is constructed such that the output channel dimension `"C"` contains flattened image patches, so instead a single `"C"` dimension represents, for example, three dimensions `"chw"` collapsed. The order of these dimensions is `"c" + ''.join(c for c in rhs_spec if c not in 'OI')`, where `rhs_spec == dimension_numbers[1]`, and the size of this `"C"` dimension is therefore the size of each patch, i.e. `np.prod(filter_shape) * lhs.shape[lhs_spec.index('C')]`, where `lhs_spec == dimension_numbers[0]`. Docstring below adapted from `jax.lax.conv_general_dilated`. See Also: https://www.tensorflow.org/xla/operation_semantics#conv_convolution Args: lhs: a rank `n+2` dimensional input array. filter_shape: a sequence of `n` integers, representing the receptive window spatial shape in the order as specified in `rhs_spec = dimension_numbers[1]`. window_strides: a sequence of `n` integers, representing the inter-window strides. padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of `n` `(low, high)` integer pairs that give the padding to apply before and after each spatial dimension. lhs_dilation: `None`, or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of `lhs`. LHS dilation is also known as transposed convolution. rhs_dilation: `None`, or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of `rhs`. RHS dilation is also known as atrous convolution. dimension_numbers: either `None`, or a 3-tuple `(lhs_spec, rhs_spec, out_spec)`, where each element is a string of length `n+2`. `None` defaults to `("NCHWD..., OIHWD..., NCHWD...")`. precision: Optional. Either ``None``, which means the default precision for the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``). preferred_element_type: Optional. Either ``None``, which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype. Returns: A rank `n+2` array containing the flattened image patches in the output channel (`"C"`) dimension. For example if `dimension_numbers = ("NcHW", "OIwh", "CNHW")`, the output has dimension numbers `"CNHW" = "{cwh}NHW"`, with the size of dimension `"C"` equal to the size of each patch (`np.prod(filter_shape) * lhs.shape[lhs_spec.index('C')]`). """ filter_shape = tuple(filter_shape) dimension_numbers = lax.conv_dimension_numbers(lhs.shape, (1, 1) + filter_shape, dimension_numbers) lhs_spec, rhs_spec, out_spec = dimension_numbers spatial_size = prod(filter_shape) n_channels = lhs.shape[lhs_spec[1]] # Move separate `lhs` spatial locations into separate `rhs` channels. rhs = jnp.eye(spatial_size, dtype=lhs.dtype).reshape(filter_shape * 2) rhs = rhs.reshape((spatial_size, 1) + filter_shape) rhs = jnp.tile(rhs, (n_channels, ) + (1, ) * (rhs.ndim - 1)) rhs = jnp.moveaxis(rhs, (0, 1), (rhs_spec[0], rhs_spec[1])) out = lax.conv_general_dilated( lhs=lhs, rhs=rhs, window_strides=window_strides, padding=padding, lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation, dimension_numbers=dimension_numbers, precision=None if precision is None else (precision, lax.Precision.DEFAULT), feature_group_count=n_channels, preferred_element_type=preferred_element_type) return out
def conv_general_dilated_local( lhs: jnp.ndarray, rhs: jnp.ndarray, window_strides: Sequence[int], padding: Union[str, Sequence[Tuple[int, int]]], filter_shape: Sequence[int], lhs_dilation: Optional[Sequence[int]] = None, rhs_dilation: Optional[Sequence[int]] = None, dimension_numbers: Optional[ convolution.ConvGeneralDilatedDimensionNumbers] = None, precision: lax.PrecisionLike = None) -> jnp.ndarray: """General n-dimensional unshared convolution operator with optional dilation. Also known as locally connected layer, the operation is equivalent to convolution with a separate (unshared) `rhs` kernel used at each output spatial location. Docstring below adapted from `jax.lax.conv_general_dilated`. See Also: https://www.tensorflow.org/xla/operation_semantics#conv_convolution Args: lhs: a rank `n+2` dimensional input array. rhs: a rank `n+2` dimensional array of kernel weights. Unlike in regular CNNs, its spatial coordinates (`H`, `W`, ...) correspond to output spatial locations, while input spatial locations are fused with the input channel locations in the single `I` dimension, in the order of `"C" + ''.join(c for c in rhs_spec if c not in 'OI')`, where `rhs_spec = dimension_numbers[1]`. For example, if `rhs_spec == "WHIO", the unfolded kernel shape is `"[output W][output H]{I[receptive window W][receptive window H]}O"`. window_strides: a sequence of `n` integers, representing the inter-window strides. padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of `n` `(low, high)` integer pairs that give the padding to apply before and after each spatial dimension. filter_shape: a sequence of `n` integers, representing the receptive window spatial shape in the order as specified in `rhs_spec = dimension_numbers[1]`. lhs_dilation: `None`, or a sequence of `n` integers, giving the dilation factor to apply in each spatial dimension of `lhs`. LHS dilation is also known as transposed convolution. rhs_dilation: `None`, or a sequence of `n` integers, giving the dilation factor to apply in each input spatial dimension of `rhs`. RHS dilation is also known as atrous convolution. dimension_numbers: either `None`, a `ConvDimensionNumbers` object, or a 3-tuple `(lhs_spec, rhs_spec, out_spec)`, where each element is a string of length `n+2`. precision: Optional. Either ``None``, which means the default precision for the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two ``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``. Returns: An array containing the unshared convolution result. In the string case of `dimension_numbers`, each character identifies by position: - the batch dimensions in `lhs`, `rhs`, and the output with the character 'N', - the feature dimensions in `lhs` and the output with the character 'C', - the input and output feature dimensions in rhs with the characters 'I' and 'O' respectively, and - spatial dimension correspondences between `lhs`, `rhs`, and the output using any distinct characters. For example, to indicate dimension numbers consistent with the `conv` function with two spatial dimensions, one could use `('NCHW', 'OIHW', 'NCHW')`. As another example, to indicate dimension numbers consistent with the TensorFlow Conv2D operation, one could use `('NHWC', 'HWIO', 'NHWC')`. When using the latter form of convolution dimension specification, window strides are associated with spatial dimension character labels according to the order in which the labels appear in the `rhs_spec` string, so that `window_strides[0]` is matched with the dimension corresponding to the first character appearing in rhs_spec that is not `'I'` or `'O'`. If `dimension_numbers` is `None`, the default is `('NCHW', 'OIHW', 'NCHW')` (for a 2D convolution). """ c_precision = lax.canonicalize_precision(precision) lhs_precision = type_cast(Optional[lax.PrecisionType], (c_precision[0] if (isinstance(c_precision, tuple) and len(c_precision) == 2) else c_precision)) patches = conv_general_dilated_patches(lhs=lhs, filter_shape=filter_shape, window_strides=window_strides, padding=padding, lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation, dimension_numbers=dimension_numbers, precision=lhs_precision) lhs_spec, rhs_spec, out_spec = convolution.conv_dimension_numbers( lhs.shape, (1, 1) + tuple(filter_shape), dimension_numbers) lhs_c_dims, rhs_c_dims = [out_spec[1]], [rhs_spec[1]] lhs_b_dims = out_spec[2:] rhs_b_dims = rhs_spec[2:] rhs_b_dims = [ rhs_b_dims[i] for i in sorted(range(len(rhs_b_dims)), key=lambda k: lhs_b_dims[k]) ] lhs_b_dims = sorted(lhs_b_dims) dn = ((lhs_c_dims, rhs_c_dims), (lhs_b_dims, rhs_b_dims)) out = lax.dot_general(patches, rhs, dimension_numbers=dn, precision=precision) out = jnp.moveaxis(out, (-2, -1), (out_spec[0], out_spec[1])) return out
def detrend_func(d): d = jnp.moveaxis(d, axis, -1) d = detrend_type(d) return jnp.moveaxis(d, -1, axis)
def _spectral_helper(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None, detrend_type='constant', return_onesided=True, scaling='density', axis=-1, mode='psd', boundary=None, padded=False): """LAX-backend implementation of `scipy.signal._spectral_helper`. Unlike the original helper function, `y` can be None for explicitly indicating auto-spectral (non cross-spectral) computation. In addition to this, `detrend` argument is renamed to `detrend_type` for avoiding internal name overlap. """ if mode not in ('psd', 'stft'): raise ValueError(f"Unknown value for mode {mode}, " "must be one of: ('psd', 'stft')") def make_pad(mode, **kwargs): def pad(x, n, axis=-1): pad_width = [(0, 0) for unused_n in range(x.ndim)] pad_width[axis] = (n, n) return jnp.pad(x, pad_width, mode, **kwargs) return pad boundary_funcs = { 'even': make_pad('reflect'), 'odd': odd_ext, 'constant': make_pad('edge'), 'zeros': make_pad('constant', constant_values=0.0), None: lambda x, *args, **kwargs: x } # Check/ normalize inputs if boundary not in boundary_funcs: raise ValueError(f"Unknown boundary option '{boundary}', " f"must be one of: {list(boundary_funcs.keys())}") axis = jax.core.concrete_or_error(operator.index, axis, "axis of windowed-FFT") axis = canonicalize_axis(axis, x.ndim) if nperseg is not None: # if specified by user nperseg = jax.core.concrete_or_error(int, nperseg, "nperseg of windowed-FFT") if nperseg < 1: raise ValueError('nperseg must be a positive integer') # parse window; if array like, then set nperseg = win.shape win, nperseg = signal_helper._triage_segments(window, nperseg, input_length=x.shape[axis]) if noverlap is None: noverlap = nperseg // 2 else: noverlap = jax.core.concrete_or_error(int, noverlap, "noverlap of windowed-FFT") if nfft is None: nfft = nperseg else: nfft = jax.core.concrete_or_error(int, nfft, "nfft of windowed-FFT") _check_arraylike("_spectral_helper", x) x = jnp.asarray(x) if y is None: outdtype = jax.dtypes.canonicalize_dtype( np.result_type(x, np.complex64)) else: _check_arraylike("_spectral_helper", y) y = jnp.asarray(y) outdtype = jax.dtypes.canonicalize_dtype( np.result_type(x, y, np.complex64)) if mode != 'psd': raise ValueError( "two-argument mode is available only when mode=='psd'") if x.ndim != y.ndim: raise ValueError( "two-arguments must have the same rank ({x.ndim} vs {y.ndim})." ) # Check if we can broadcast the outer axes together try: outershape = jnp.broadcast_shapes(tuple_delete(x.shape, axis), tuple_delete(y.shape, axis)) except ValueError as e: raise ValueError('x and y cannot be broadcast together.') from e # Special cases for size == 0 if y is None: if x.size == 0: return jnp.zeros(x.shape), jnp.zeros(x.shape), jnp.zeros(x.shape) else: if x.size == 0 or y.size == 0: outshape = tuple_insert(outershape, min([x.shape[axis], y.shape[axis]]), axis) emptyout = jnp.zeros(outshape) return emptyout, emptyout, emptyout # Move time-axis to the end if x.ndim > 1: if axis != -1: x = jnp.moveaxis(x, axis, -1) if y is not None and y.ndim > 1: y = jnp.moveaxis(y, axis, -1) # Check if x and y are the same length, zero-pad if necessary if y is not None: if x.shape[-1] != y.shape[-1]: if x.shape[-1] < y.shape[-1]: pad_shape = list(x.shape) pad_shape[-1] = y.shape[-1] - x.shape[-1] x = jnp.concatenate((x, jnp.zeros(pad_shape)), -1) else: pad_shape = list(y.shape) pad_shape[-1] = x.shape[-1] - y.shape[-1] y = jnp.concatenate((y, jnp.zeros(pad_shape)), -1) if nfft < nperseg: raise ValueError('nfft must be greater than or equal to nperseg.') if noverlap >= nperseg: raise ValueError('noverlap must be less than nperseg.') nstep = nperseg - noverlap # Apply paddings if boundary is not None: ext_func = boundary_funcs[boundary] x = ext_func(x, nperseg // 2, axis=-1) if y is not None: y = ext_func(y, nperseg // 2, axis=-1) if padded: # Pad to integer number of windowed segments # I.e make x.shape[-1] = nperseg + (nseg-1)*nstep, with integer nseg nadd = (-(x.shape[-1] - nperseg) % nstep) % nperseg zeros_shape = list(x.shape[:-1]) + [nadd] x = jnp.concatenate((x, jnp.zeros(zeros_shape)), axis=-1) if y is not None: zeros_shape = list(y.shape[:-1]) + [nadd] y = jnp.concatenate((y, jnp.zeros(zeros_shape)), axis=-1) # Handle detrending and window functions if not detrend_type: def detrend_func(d): return d elif not hasattr(detrend_type, '__call__'): def detrend_func(d): return detrend(d, type=detrend_type, axis=-1) elif axis != -1: # Wrap this function so that it receives a shape that it could # reasonably expect to receive. def detrend_func(d): d = jnp.moveaxis(d, axis, -1) d = detrend_type(d) return jnp.moveaxis(d, -1, axis) else: detrend_func = detrend_type if np.result_type(win, np.complex64) != outdtype: win = win.astype(outdtype) # Determine scale if scaling == 'density': scale = 1.0 / (fs * (win * win).sum()) elif scaling == 'spectrum': scale = 1.0 / win.sum()**2 else: raise ValueError(f'Unknown scaling: {scaling}') if mode == 'stft': scale = jnp.sqrt(scale) # Determine onesided/ two-sided if return_onesided: sides = 'onesided' if jnp.iscomplexobj(x) or jnp.iscomplexobj(y): sides = 'twosided' warnings.warn('Input data is complex, switching to ' 'return_onesided=False') else: sides = 'twosided' if sides == 'twosided': freqs = jax.numpy.fft.fftfreq(nfft, 1 / fs) elif sides == 'onesided': freqs = jax.numpy.fft.rfftfreq(nfft, 1 / fs) # Perform the windowed FFTs result = _fft_helper(x, win, detrend_func, nperseg, noverlap, nfft, sides) if y is not None: # All the same operations on the y data result_y = _fft_helper(y, win, detrend_func, nperseg, noverlap, nfft, sides) result = jnp.conjugate(result) * result_y elif mode == 'psd': result = jnp.conjugate(result) * result result *= scale if sides == 'onesided' and mode == 'psd': end = None if nfft % 2 else -1 result = result.at[..., 1:end].mul(2) time = jnp.arange(nperseg / 2, x.shape[-1] - nperseg / 2 + 1, nperseg - noverlap) / fs if boundary is not None: time -= (nperseg / 2) / fs result = result.astype(outdtype) # All imaginary parts are zero anyways if y is None and mode != 'stft': result = result.real # Move frequency axis back to axis where the data came from result = jnp.moveaxis(result, -1, axis) return freqs, time, result
def istft(Zxx, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None, input_onesided=True, boundary=True, time_axis=-1, freq_axis=-2): # Input validation _check_arraylike("istft", Zxx) if Zxx.ndim < 2: raise ValueError('Input stft must be at least 2d!') freq_axis = canonicalize_axis(freq_axis, Zxx.ndim) time_axis = canonicalize_axis(time_axis, Zxx.ndim) if freq_axis == time_axis: raise ValueError('Must specify differing time and frequency axes!') Zxx = jnp.asarray(Zxx, dtype=jax.dtypes.canonicalize_dtype( np.result_type(Zxx, np.complex64))) n_default = (2 * (Zxx.shape[freq_axis] - 1) if input_onesided else Zxx.shape[freq_axis]) nperseg = jax.core.concrete_or_error(int, nperseg or n_default, "nperseg: segment length of STFT") if nperseg < 1: raise ValueError('nperseg must be a positive integer') if nfft is None: nfft = n_default if input_onesided and nperseg == n_default + 1: nfft += 1 # Odd nperseg, no FFT padding else: nfft = jax.core.concrete_or_error(int, nfft, "nfft of STFT") if nfft < nperseg: raise ValueError( f'FFT length ({nfft}) must be longer than nperseg ({nperseg}).') noverlap = jax.core.concrete_or_error(int, noverlap or nperseg // 2, "noverlap of STFT") if noverlap >= nperseg: raise ValueError('noverlap must be less than nperseg.') nstep = nperseg - noverlap # Rearrange axes if necessary if time_axis != Zxx.ndim - 1 or freq_axis != Zxx.ndim - 2: outer_idxs = tuple(idx for idx in range(Zxx.ndim) if idx not in {time_axis, freq_axis}) Zxx = jnp.transpose(Zxx, outer_idxs + (freq_axis, time_axis)) # Perform IFFT ifunc = jax.numpy.fft.irfft if input_onesided else jax.numpy.fft.ifft # xsubs: [..., T, N], N is the number of frames, T is the frame length. xsubs = ifunc(Zxx, axis=-2, n=nfft)[..., :nperseg, :] # Get window as array if isinstance(window, (str, tuple)): win = osp_signal.get_window(window, nperseg) win = jnp.asarray(win) else: win = jnp.asarray(window) if len(win.shape) != 1: raise ValueError('window must be 1-D') if win.shape[0] != nperseg: raise ValueError('window must have length of {0}'.format(nperseg)) win = win.astype(xsubs.dtype) xsubs *= win.sum() # This takes care of the 'spectrum' scaling # make win broadcastable over xsubs win = win.reshape((1, ) * (xsubs.ndim - 2) + win.shape + (1, )) x = _overlap_and_add((xsubs * win).swapaxes(-2, -1), nstep) win_squared = jnp.repeat((win * win), xsubs.shape[-1], axis=-1) norm = _overlap_and_add(win_squared.swapaxes(-2, -1), nstep) # Remove extension points if boundary: x = x[..., nperseg // 2:-(nperseg // 2)] norm = norm[..., nperseg // 2:-(nperseg // 2)] x /= jnp.where(norm > 1e-10, norm, 1.0) # Put axes back if x.ndim > 1: if time_axis != Zxx.ndim - 1: if freq_axis < time_axis: time_axis -= 1 x = jnp.moveaxis(x, -1, time_axis) time = jnp.arange(x.shape[0]) / fs return time, x
def _unique(ar, axis, return_index=False, return_inverse=False, return_counts=False, size=None, fill_value=None, return_true_size=False): """ Find the unique elements of an array along a particular axis. """ if ar.shape[axis] == 0 and size and fill_value is None: raise ValueError( "jnp.unique: for zero-sized input with nonzero size argument, fill_value must be specified" ) aux, mask, perm = _unique_sorted_mask(ar, axis) if size is None: ind = core.concrete_or_error( None, mask, "The error arose in jnp.unique(). " + UNIQUE_SIZE_HINT) else: ind = nonzero(mask, size=size)[0] result = aux[ind] if aux.size else aux if fill_value is not None: fill_value = asarray(fill_value, dtype=result.dtype) if size is not None and fill_value is not None: if result.shape[0]: valid = lax.expand_dims( arange(size) < mask.sum(), tuple(range(1, result.ndim))) result = where(valid, result, fill_value) else: result = full_like(result, fill_value, shape=(size, *result.shape[1:])) result = moveaxis(result, 0, axis) ret = (result, ) if return_index: if aux.size: ret += (perm[ind], ) else: ret += (perm, ) if return_inverse: if aux.size: imask = cumsum(mask) - 1 inv_idx = zeros(mask.shape, dtype=dtypes.canonicalize_dtype(dtypes.int_)) inv_idx = inv_idx.at[perm].set(imask) else: inv_idx = zeros(ar.shape[axis], dtype=int) ret += (inv_idx, ) if return_counts: if aux.size: if size is None: idx = append(nonzero(mask)[0], mask.size) else: idx = nonzero(mask, size=size + 1)[0] idx = idx.at[1:].set(where(idx[1:], idx[1:], mask.size)) ret += (diff(idx), ) elif ar.shape[axis]: ret += (array([ar.shape[axis]], dtype=dtypes.canonicalize_dtype(dtypes.int_)), ) else: ret += (empty(0, dtype=int), ) if return_true_size: # Useful for internal uses of unique(). ret += (mask.sum(), ) return ret[0] if len(ret) == 1 else ret