def tf_transforms(x, t, wavelet='morlet', window=None, padtype='wrap', penalty=.5, n_ridges=2, cwt_bw=15, stft_bw=15, ssq_cwt_bw=4, ssq_stft_bw=4): kw_cwt = dict(t=t, padtype=padtype) kw_stft = dict(fs=1 / (t[1] - t[0]), padtype=padtype, flipud=1) Twx, Wx, ssq_freqs_c, scales, *_ = ssq_cwt(x, wavelet, **kw_cwt) Tsx, Sx, ssq_freqs_s, Sfs, *_ = ssq_stft(x, window, **kw_stft) Sx, Sfs = Sx[::-1], Sfs[::-1] ckw = dict(penalty=penalty, n_ridges=n_ridges, transform='cwt') skw = dict(penalty=penalty, n_ridges=n_ridges, transform='stft') cwt_ridges = extract_ridges(Wx, scales, bw=cwt_bw, **ckw) ssq_cwt_ridges = extract_ridges(Twx, ssq_freqs_c, bw=ssq_cwt_bw, **ckw) stft_ridges = extract_ridges(Sx, Sfs, bw=stft_bw, **skw) ssq_stft_ridges = extract_ridges(Tsx, ssq_freqs_s, bw=ssq_stft_bw, **skw) viz(x, Wx, cwt_ridges, scales, ssq=0, transform='cwt', show_x=1) viz(x, Twx, ssq_cwt_ridges, ssq_freqs_c, ssq=1, transform='cwt', show_x=0) viz(x, Sx, stft_ridges, Sfs, ssq=0, transform='stft', show_x=0) viz(x, Tsx, ssq_stft_ridges, ssq_freqs_s, ssq=1, transform='stft', show_x=0)
def test_parallel(): """Ensure `parallel=True` output matches that of `=False`.""" for N in (255, 512): x = np.random.randn(N) Wx, scales = cwt(x) out0 = extract_ridges(Wx, scales, parallel=False) out1 = extract_ridges(Wx, scales, parallel=True) adiff = np.abs(out0 - out1) assert np.allclose(out0, out1), "N=%s, Max err: %s" % (N, adiff.max())
def test_basic(): """Example ridge from similar example as can be found at MATLAB: https://www.mathworks.com/help/wavelet/ref/wsstridge.html#bu6we25-penalty """ test_matrix = np.array([[1, 4, 4], [2, 2, 2], [5, 5, 4]]) fs_test = np.exp([1, 2, 3]) ridge_idxs, *_ = extract_ridges(test_matrix, fs_test, penalty=2.0, get_params=True) assert np.allclose(ridge_idxs, np.array([[2, 2, 2]]))
Tsx, ssq_stft_ridges, ssq_freqs_s, ssq=1, transform='stft', show_x=0) #%%# Basic example ########################################################### # Example ridge from similar example as can be found at MATLAB: # https://www.mathworks.com/help/wavelet/ref/wsstridge.html#bu6we25-penalty test_matrix = np.array([[1, 4, 4], [2, 2, 2], [5, 5, 4]]) fs_test = np.exp([1, 2, 3]) ridge_idxs, *_ = extract_ridges(test_matrix, fs_test, penalty=2.0, get_params=True) print("Ridge follows indexes:", ridge_idxs) assert np.allclose(ridge_idxs, np.array([[2, 2, 2]])) #%%# sin + cos ############################################################### N, f1, f2 = 513, 5, 20 padtype = 'wrap' penalty = 20 t = np.linspace(0, 1, N, endpoint=True) x1 = np.sin(2 * np.pi * f1 * t) x2 = np.cos(2 * np.pi * f2 * t) x = x1 + x2 tf_transforms(x, t, padtype=padtype, penalty=penalty)
def test_ridge_extraction(): """For @jit coverage.""" Wx, scales = cwt(np.random.randn(128)) _ = extract_ridges(Wx, scales, transform='cwt', parallel=False) _ = extract_ridges(Wx, scales, transform='cwt', parallel=True)