def err_fix(x, wavelet, a0): # primitive code, doesn't work """Implements corrective term in Eq. 4.66 of [1]. 1. Mallat, S., Wavelet Tour of Signal Processing 3rd ed. """ # note x is *original* (padded), so this step must be done in forward CWT # to be passed to icwt N = len(x) xi = (2 * pi / N) * np.arange(1, N // 2 + 1) psihfn = Wavelet(wavelet) # integrate from 0 to w, w spanning same spectrum as psih # this can be sped up by nature of brick-wall behavior, stopping computing # after first zero, also computing fewer in total and linearly interpolating Cpsi_w = [ quadgk(lambda x: np.conj(psihfn(x)) * psihfn(x) / x, 0., w)[0] for w in a0 * xi ] Cpsi_w.insert(0, 0) # integral 0 to 0 = 0 Cpsi_w.extend([0] * (N // 2 - 1)) # analytic, right-half = 0 # integrate from 0 to inf Cpsi = adm_cwt(wavelet) # subtract from integration 0 to inf to obtain w to inf phi_w = Cpsi - np.array(Cpsi_w) # do convolution theorem with x, take care of padding etc corr = ifftshift(ifft(fft(x) * phi_w**2)) corr /= (a0 * Cpsi) # normalize return corr
def synsq_adm(wavelet_type, opts={}): """Calculate the synchrosqueezing admissibility constant, the term R_\psi in Eq. 3 of [1]. Note, here we multiply R_\psi by the inverse of log(2)/nv (found in Alg. 1 of [1]). Uses numerical intergration. # Arguments: wavelet_type: str. See `wfiltfn`. opts: dict. Options. See `wfiltfn`. # Returns: Css: proportional to 2 * integral(conj(f(w)) / w, w=0..inf) # References: 1. G. Thakur, E. Brevdo, N.-S. Fučkar, and H.-T. Wu, "The Synchrosqueezing algorithm for time-varying spectral analysis: robustness properties and new paleoclimate applications", Signal Processing, 93:1079-1094, 2013. """ psihfn = wfiltfn(wavelet_type, opts) Css = lambda x: quadgk(np.conj(psihfn(x)) / x, 0, np.inf) # Normalization constant, due to logarithmic scaling # in wavelet transform _Css = Css del Css Css = lambda x: _Css(x) / np.sqrt(2 * PI) * 2 * np.log(2) return Css
def stft_inv(Sx, opts={}): """Inverse short-time Fourier transform. Very closely based on Steven Schimel's stft.m and istft.m from his SPHSC 503: Speech Signal Processing course at Univ. Washington. Adapted for use with Synchrosqueeing Toolbox. # Arguments: Sx: np.ndarray. Wavelet transform of a signal (see `stft_fwd`). opts: dict. Options: 'type': str. Wavelet type. See `stft_fwd`, and `wfiltfn`. Others; see `stft_fwd` and source code. # Returns: x: the signal, as reconstructed from `Sx`. """ def _unbuffer(x, w, o): # Undo the effect of 'buffering' by overlap-add; # returns the signal A that is the unbuffered version of B y = [] skip = w - o N = np.ceil(w / skip) L = (x.shape[1] - 1) * skip + x.shape[0] # zero-pad columns to make length nearest integer multiple of `skip` if x.shape[0] < skip * N: x[skip * N - 1, -1] = 0 # TODO columns? # selectively reshape columns of input into 1d signals for i in range(N): t = x[:, range(i, len(x) - 1, N)].reshape(1, -1) l = len(t) y[i, l + (i - 1) * skip - 1] = 0 y[i, np.arange(l) + (i - 1) * skip] = t # overlap-add y = np.sum(y, axis=0) y = y[:L] return y def _process_opts(opts, Sx): # opts['window'] is window length; opts['type'] overrides # default hamming window opts['winlen'] = opts.get('winlen', int(np.round(Sx.shape[1] / 16))) opts['overlap'] = opts.get('overlap', opts['winlen'] - 1) opts['rpadded'] = opts.get('rpadded', False) if 'type' in opts: A = wfiltfn(opts['type'], opts) window = A(np.linspace(-1, 1, opts['winlen'])) else: window = np.hamming(opts['winlen']) return opts, window opts, window = _process_opts(opts, Sx) # window = window / norm(window, 2) --> Unit norm n_win = len(window) # find length of padding, similar to outputs of `padsignal` n = Sx.shape[1] if not opts['rpadded']: xLen = n else: xLen == n - n_win # n_up = xLen + 2 * n_win n1 = n_win - 1 # n2 = n_win new_n1 = np.floor((n1 - 1) / 2) # add STFT apdding if it doesn't exist if not opts['rpadded']: Sxp = np.zeros(Sx.shape) Sxp[:, range(new_n1, new_n1 + n + 1)] = Sx Sx = Sxp else: n = xLen # regenerate the full spectrum 0...2pi (minus zero Hz value) Sx = np.hstack( [Sx, np.conj(Sx[np.arange(np.floor((n_win + 1) / 2), 3, -1)])]) # take the inverse fft over the columns xbuf = np.real(np.fft.ifft(Sx, None, axis=0)) # apply the window to the columns xbuf *= np.matlib.repmat(window.flatten(), 1, xbuf.shape[1]) # overlap-add the columns x = _unbuffer(xbuf, n_win, opts['overlap']) # keep the unpadded part only x = x[n1:n1 + n + 1] # compute L2-norm of window to normalize STFT with windowfunc = wfiltfn(opts['type'], opts, derivative=False) C = lambda x: quadgk(windowfunc(x)**2, -np.inf, np.inf) # `quadgk` is a bit inaccurate with the 'bump' function, # this scales it correctly if opts['type'] == 'bump': C *= 0.8675 x *= 2 / (PI * C) return x
def synsq_stft_inv(Tx, fs, opts, Cs=None, freqband=None): """Inverse STFT synchrosqueezing transform of `Tx` with associated frequencies in `fs` and curve bands in time-frequency plane specified by `Cs` and `freqband`. This implements Eq. 5 of [1]. # Arguments: Tx: np.ndarray. Synchrosqueeze-transformed `x` (see `synsq_cwt`). fs: np.ndarray. Frequencies associated with rows of Tx. (see `synsq_cwt`). opts. dict. Options: 'type': type of wavelet used in `synsq_cwt` (required). other wavelet options ('mu', 's') should also match those used in `synsq_cwt` 'Cs': (optional) curve centerpoints 'freqs': (optional) curve bands # Returns: x: components of reconstructed signal, and residual error Example: Tx, fs = synsq_cwt(t, x, 32) # synchrosqueezing Txf = synsq_filter_pass(Tx, fs, -np.inf, 1) # pass band filter xf = synsq_cwt_inv(Txf, fs) # filtered signal reconstruction """ Cs = Cs or np.ones((Tx.shape[1], 1)) freqband = freqband or Tx.shape[0] windowfunc = Wavelet((opts['type'], opts)) inf_lim = 1000 # quadpy can't handle np.inf limits C = quadgk(lambda x: windowfunc(x)**2, -inf_lim, inf_lim) if opts['type'] == 'bump': C *= 0.8675 # Invert Tx around curve masks in the time-frequency plane to recover # individual components; last one is the remaining signal # Integration over all frequencies recovers original signal # Factor of 2 is because real parts contain half the energy x = np.zeros((Cs.shape[0], Cs.shape[1] + 1)) TxRemainder = Tx for n in range(Cs.shape[1]): TxMask = np.zeros(Tx.shape) UpperCs = min(max(Cs[:, n] + freqband[:, n], 1), len(fs)) LowerCs = min(max(Cs[:, n] - freqband[:, n], 1), len(fs)) # Cs==0 corresponds to no curve at that time, so this removes # such points from the inversion # NOTE: transposed + flattened to match MATLAB's 'linear indices' UpperCs[np.where(Cs[:, n].T.flatten() < 1)] = 1 LowerCs[np.where(Cs[:, n].T.flatten() < 1)] = 2 for m in range(Tx.shape[1]): idxs = slice(LowerCs[m] - 1, UpperCs[m]) TxMask[idxs, m] = Tx[idxs, m] TxRemainder[idxs, m] = 0 x[:, n] = 1 / (pi * C) * np.sum(np.real(TxMask), axis=0).T x[:, n + 1] = 1 / (pi * C) * np.sum(np.real(TxRemainder), axis=0).T x = x.T return x