def tf_transforms(x,
                  t,
                  wavelet='morlet',
                  window=None,
                  padtype='wrap',
                  penalty=.5,
                  n_ridges=2,
                  cwt_bw=15,
                  stft_bw=15,
                  ssq_cwt_bw=4,
                  ssq_stft_bw=4):
    kw_cwt = dict(t=t, padtype=padtype)
    kw_stft = dict(fs=1 / (t[1] - t[0]), padtype=padtype, flipud=1)
    Twx, Wx, ssq_freqs_c, scales, *_ = ssq_cwt(x, wavelet, **kw_cwt)
    Tsx, Sx, ssq_freqs_s, Sfs, *_ = ssq_stft(x, window, **kw_stft)
    Sx, Sfs = Sx[::-1], Sfs[::-1]

    ckw = dict(penalty=penalty, n_ridges=n_ridges, transform='cwt')
    skw = dict(penalty=penalty, n_ridges=n_ridges, transform='stft')
    cwt_ridges = extract_ridges(Wx, scales, bw=cwt_bw, **ckw)
    ssq_cwt_ridges = extract_ridges(Twx, ssq_freqs_c, bw=ssq_cwt_bw, **ckw)
    stft_ridges = extract_ridges(Sx, Sfs, bw=stft_bw, **skw)
    ssq_stft_ridges = extract_ridges(Tsx, ssq_freqs_s, bw=ssq_stft_bw, **skw)

    viz(x, Wx, cwt_ridges, scales, ssq=0, transform='cwt', show_x=1)
    viz(x, Twx, ssq_cwt_ridges, ssq_freqs_c, ssq=1, transform='cwt', show_x=0)
    viz(x, Sx, stft_ridges, Sfs, ssq=0, transform='stft', show_x=0)
    viz(x,
        Tsx,
        ssq_stft_ridges,
        ssq_freqs_s,
        ssq=1,
        transform='stft',
        show_x=0)
Esempio n. 2
0
def test_parallel():
    """Ensure `parallel=True` output matches that of `=False`."""
    for N in (255, 512):
        x = np.random.randn(N)
        Wx, scales = cwt(x)

        out0 = extract_ridges(Wx, scales, parallel=False)
        out1 = extract_ridges(Wx, scales, parallel=True)

        adiff = np.abs(out0 - out1)
        assert np.allclose(out0, out1), "N=%s, Max err: %s" % (N, adiff.max())
Esempio n. 3
0
def test_basic():
    """Example ridge from similar example as can be found at MATLAB:
    https://www.mathworks.com/help/wavelet/ref/wsstridge.html#bu6we25-penalty
    """
    test_matrix = np.array([[1, 4, 4], [2, 2, 2], [5, 5, 4]])
    fs_test = np.exp([1, 2, 3])

    ridge_idxs, *_ = extract_ridges(test_matrix, fs_test, penalty=2.0,
                                    get_params=True)
    assert np.allclose(ridge_idxs, np.array([[2, 2, 2]]))
        Tsx,
        ssq_stft_ridges,
        ssq_freqs_s,
        ssq=1,
        transform='stft',
        show_x=0)


#%%# Basic example ###########################################################
# Example ridge from similar example as can be found at MATLAB:
# https://www.mathworks.com/help/wavelet/ref/wsstridge.html#bu6we25-penalty
test_matrix = np.array([[1, 4, 4], [2, 2, 2], [5, 5, 4]])
fs_test = np.exp([1, 2, 3])

ridge_idxs, *_ = extract_ridges(test_matrix,
                                fs_test,
                                penalty=2.0,
                                get_params=True)
print("Ridge follows indexes:", ridge_idxs)
assert np.allclose(ridge_idxs, np.array([[2, 2, 2]]))

#%%# sin + cos ###############################################################
N, f1, f2 = 513, 5, 20
padtype = 'wrap'
penalty = 20

t = np.linspace(0, 1, N, endpoint=True)
x1 = np.sin(2 * np.pi * f1 * t)
x2 = np.cos(2 * np.pi * f2 * t)
x = x1 + x2

tf_transforms(x, t, padtype=padtype, penalty=penalty)
Esempio n. 5
0
def test_ridge_extraction():
    """For @jit coverage."""
    Wx, scales = cwt(np.random.randn(128))
    _ = extract_ridges(Wx, scales, transform='cwt', parallel=False)
    _ = extract_ridges(Wx, scales, transform='cwt', parallel=True)