def __getitem__(self, key): if not isinstance(key, tuple): key = (key,) params = [self.axis, self.ndmin, self.trans1d, -1] if isinstance(key[0], str): # split off the directive directive, *key = key # pytype: disable=bad-unpacking # check two special cases: matrix directives if directive == "r": params[-1] = 0 elif directive == "c": params[-1] = 1 else: vec = directive.split(",") k = len(vec) if k < 4: vec += params[k:] else: # ignore everything after the first three comma-separated ints vec = vec[:3] + params[-1] try: params = list(map(int, vec)) except ValueError as err: raise ValueError( f"could not understand directive {directive!r}" ) from err axis, ndmin, trans1d, matrix = params output = [] for item in key: if isinstance(item, slice): newobj = _make_1d_grid_from_slice(item, op_name=self.op_name) elif isinstance(item, str): raise ValueError("string directive must be placed at the beginning") else: newobj = item newobj = array(newobj, copy=False, ndmin=ndmin) if trans1d != -1 and ndmin - np.ndim(item) > 0: shape_obj = list(range(ndmin)) # Calculate number of left shifts, with overflow protection by mod num_lshifts = ndmin - abs(ndmin + trans1d + 1) % ndmin shape_obj = tuple(shape_obj[num_lshifts:] + shape_obj[:num_lshifts]) newobj = transpose(newobj, shape_obj) output.append(newobj) res = concatenate(tuple(output), axis=axis) if matrix != -1 and res.ndim == 1: # insert 2nd dim at axis 0 or 1 res = expand_dims(res, matrix) return res
def istft(Zxx, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None, input_onesided=True, boundary=True, time_axis=-1, freq_axis=-2): # Input validation _check_arraylike("istft", Zxx) if Zxx.ndim < 2: raise ValueError('Input stft must be at least 2d!') freq_axis = canonicalize_axis(freq_axis, Zxx.ndim) time_axis = canonicalize_axis(time_axis, Zxx.ndim) if freq_axis == time_axis: raise ValueError('Must specify differing time and frequency axes!') Zxx = jnp.asarray(Zxx, dtype=jax.dtypes.canonicalize_dtype( np.result_type(Zxx, np.complex64))) n_default = (2 * (Zxx.shape[freq_axis] - 1) if input_onesided else Zxx.shape[freq_axis]) nperseg = jax.core.concrete_or_error(int, nperseg or n_default, "nperseg: segment length of STFT") if nperseg < 1: raise ValueError('nperseg must be a positive integer') if nfft is None: nfft = n_default if input_onesided and nperseg == n_default + 1: nfft += 1 # Odd nperseg, no FFT padding else: nfft = jax.core.concrete_or_error(int, nfft, "nfft of STFT") if nfft < nperseg: raise ValueError( f'FFT length ({nfft}) must be longer than nperseg ({nperseg}).') noverlap = jax.core.concrete_or_error(int, noverlap or nperseg // 2, "noverlap of STFT") if noverlap >= nperseg: raise ValueError('noverlap must be less than nperseg.') nstep = nperseg - noverlap # Rearrange axes if necessary if time_axis != Zxx.ndim - 1 or freq_axis != Zxx.ndim - 2: outer_idxs = tuple(idx for idx in range(Zxx.ndim) if idx not in {time_axis, freq_axis}) Zxx = jnp.transpose(Zxx, outer_idxs + (freq_axis, time_axis)) # Perform IFFT ifunc = jax.numpy.fft.irfft if input_onesided else jax.numpy.fft.ifft # xsubs: [..., T, N], N is the number of frames, T is the frame length. xsubs = ifunc(Zxx, axis=-2, n=nfft)[..., :nperseg, :] # Get window as array if isinstance(window, (str, tuple)): win = osp_signal.get_window(window, nperseg) win = jnp.asarray(win) else: win = jnp.asarray(window) if len(win.shape) != 1: raise ValueError('window must be 1-D') if win.shape[0] != nperseg: raise ValueError('window must have length of {0}'.format(nperseg)) win = win.astype(xsubs.dtype) xsubs *= win.sum() # This takes care of the 'spectrum' scaling # make win broadcastable over xsubs win = win.reshape((1, ) * (xsubs.ndim - 2) + win.shape + (1, )) x = _overlap_and_add((xsubs * win).swapaxes(-2, -1), nstep) win_squared = jnp.repeat((win * win), xsubs.shape[-1], axis=-1) norm = _overlap_and_add(win_squared.swapaxes(-2, -1), nstep) # Remove extension points if boundary: x = x[..., nperseg // 2:-(nperseg // 2)] norm = norm[..., nperseg // 2:-(nperseg // 2)] x /= jnp.where(norm > 1e-10, norm, 1.0) # Put axes back if x.ndim > 1: if time_axis != Zxx.ndim - 1: if freq_axis < time_axis: time_axis -= 1 x = jnp.moveaxis(x, -1, time_axis) time = jnp.arange(x.shape[0]) / fs return time, x