Exemple #1
0
def test_ssq_cwt():
    x = np.random.randn(64)
    for wavelet in ('morlet', ('morlet', {'mu': 20}), 'bump'):
        Tx, *_ = ssq_cwt(x, wavelet)
        issq_cwt(Tx, wavelet)

    kw = dict(x=x, wavelet='morlet')
    params = dict(
        squeezing=('lebesgue', ),
        scales=('linear', 'log:minimal', 'linear:naive',
                np.power(2**(1 / 16), np.arange(1, 32))),
        difftype=('phase', 'numeric'),
        padtype=('zero', 'replicate'),
        mapkind=('energy', 'peak'),
    )

    for name in params:
        for value in params[name]:
            try:
                _ = ssq_cwt(**kw, **{name: value})
            except Exception as e:
                raise Exception(f"{name}={value} failed with:\n{e}")

    _ = ssq_cwt(x, wavelet, fs=2, difftype='numeric', difforder=2)
    _ = ssq_cwt(x, wavelet, fs=2, difftype='numeric', difforder=1)
Exemple #2
0
def test_ssq_cwt():
    os.environ['SSQ_GPU'] = '0'  # in case concurrent tests set it to '1'
    np.random.seed(5)
    x = np.random.randn(64)
    for wavelet in ('morlet', ('morlet', {'mu': 20}), 'bump'):
        Tx, *_ = ssq_cwt(x, wavelet)
        issq_cwt(Tx, wavelet)

    kw = dict(x=x, wavelet='morlet')
    params = dict(
        squeezing=('lebesgue', ),
        scales=('linear', 'log:minimal', 'linear:naive',
                np.power(2**(1 / 8), np.arange(1, 32))),
        difftype=('phase', 'numeric'),
        padtype=('zero', 'replicate'),
        maprange=('maximal', 'energy', 'peak', (1, 32)),
    )

    for name in params:
        for value in params[name]:
            errored = True
            try:
                if name == 'maprange' and value in ('maximal', (1, 32)):
                    _ = ssq_cwt(**kw, **{name: value}, scales='log', get_w=1)
                else:
                    _ = ssq_cwt(**kw, **{name: value}, get_w=1)
                errored = False
            finally:
                if errored:
                    print(f"\n{name}={value} failed\n")

    _ = ssq_cwt(x, wavelet, fs=2, difftype='numeric', difforder=2, get_w=1)
    _ = ssq_cwt(x, wavelet, fs=2, difftype='numeric', difforder=1, get_w=1)
Exemple #3
0
def test_ssq_cwt():
    N = 256
    tsigs = TestSignals(N=N)

    for dtype in ('float64', 'float32'):
      gpu_atol = 1e-8 if dtype == 'float64' else 2e-4
      x = tsigs.par_lchirp()[0].astype(dtype)
      kw = dict(astensor=False)

      os.environ['SSQ_GPU'] = '0'
      Tx00 = ssq_cwt(x, _wavelet(dtype=dtype), **kw, get_w=1)[0]
      Tx01 = ssq_cwt(x, _wavelet(dtype=dtype), **kw, get_w=0)[0]
      if CAN_GPU:
          os.environ['SSQ_GPU'] = '1'
          Tx10 = ssq_cwt(x, _wavelet(dtype=dtype), **kw, get_w=1)[0]
          Tx11 = ssq_cwt(x, _wavelet(dtype=dtype), **kw, get_w=0)[0]

      adiff0001 = np.abs(Tx00 - Tx01).mean()
      assert np.allclose(Tx00, Tx01), (dtype, adiff0001)
      if CAN_GPU:
          adiff0010 = np.abs(Tx00 - Tx10).mean()
          adiff0011 = np.abs(Tx00 - Tx11).mean()
          assert np.allclose(Tx00, Tx10, atol=gpu_atol), (dtype, adiff0010)
          assert np.allclose(Tx00, Tx11, atol=gpu_atol), (dtype, adiff0011)
    os.environ['SSQ_GPU'] = '0'
Exemple #4
0
def test_dtype():
    """Ensure `cwt` and `ssq_cwt` compute at appropriate precision depending
    on `Wavelet.dtype`, returning float32 & complex64 arrays for single precision.
    """
    os.environ['SSQ_GPU'] = '0'
    wav32, wav64 = Wavelet(dtype='float32'), Wavelet(dtype='float64')
    x = np.random.randn(256)
    outs32 = ssq_cwt(x, wav32)
    outs64 = ssq_cwt(x, wav64)
    outs32_o2 = ssq_cwt(x, wav32, order=2)

    names = ('Tx', 'Wx', 'ssq_freqs', 'scales', 'w', 'dWx')
    outs32 = {k: v for k, v in zip(names, outs32)}
    outs32_o2 = {k: v for k, v in zip(names, outs32_o2)}
    outs64 = {k: v for k, v in zip(names, outs64)}

    for k, v in outs32.items():
        if k == 'ssq_freqs':
            assert v.dtype == np.float64, ("float32", k, v.dtype)
            continue
        assert v.dtype in (np.float32, np.complex64), ("float32", k, v.dtype)
    for k, v in outs32_o2.items():
        if k == 'ssq_freqs':
            assert v.dtype == np.float64, ("float32", k, v.dtype)
            continue
        assert v.dtype in (np.float32, np.complex64), ("float32", k, v.dtype)
    for k, v in outs64.items():
        if k == 'ssq_freqs':
            assert v.dtype == np.float64, ("float32", k, v.dtype)
            continue
        assert v.dtype in (np.float64, np.complex128), ("float64", k, v.dtype)
Exemple #5
0
def time_ssq_cwt(x, dtype, scales, cache_wavelet, ssq_freqs):
    wavelet = Wavelet(dtype=dtype)
    kw = dict(wavelet=wavelet, scales=scales, ssq_freqs=ssq_freqs)
    if cache_wavelet:
        for _ in range(3):  # warmup run
            _ = ssq_cwt(x, cache_wavelet=True, **kw)
            del _
            gc.collect()
    return timeit(lambda: ssq_cwt(x, cache_wavelet=cache_wavelet, **kw))
def test_component_inversion():
    def echirp(N):
        t = np.linspace(0, 10, N, False)
        return np.cos(2 * np.pi * np.exp(t / 3)), t

    N = 2048
    noise_var = 6

    x, ts = echirp(N)
    x *= (1 + .3 * cos_f([1], N))  # amplitude modulation
    xo = x.copy()
    np.random.seed(4)
    x += np.sqrt(noise_var) * np.random.randn(len(x))

    wavelet = ('morlet', {'mu': 4.5})
    Tx, *_ = ssq_cwt(x, wavelet, scales='log:maximal', nv=32, t=ts)

    # hand-coded, subject to failure
    bw, slope, offset = .035, .45, .45
    Cs, freqband = lin_band(Tx, slope, offset, bw, norm=(0, 2e-1))

    xrec = issq_cwt(Tx, wavelet, Cs, freqband)[0]

    axof   = np.abs(np.fft.rfft(xo))
    axrecf = np.abs(np.fft.rfft(xrec))

    err_sig = mad_rms(xo, xrec)
    err_spc = mad_rms(axof, axrecf)
    print("signal   MAD/RMS: %.6f" % err_sig)
    print("spectrum MAD/RMS: %.6f" % err_spc)
    assert err_sig <= .42, f"{err_sig} > .42"
    assert err_spc <= .14, f"{err_spc} > .14"
def tf_transforms(x,
                  t,
                  wavelet='morlet',
                  window=None,
                  padtype='wrap',
                  penalty=.5,
                  n_ridges=2,
                  cwt_bw=15,
                  stft_bw=15,
                  ssq_cwt_bw=4,
                  ssq_stft_bw=4):
    kw_cwt = dict(t=t, padtype=padtype)
    kw_stft = dict(fs=1 / (t[1] - t[0]), padtype=padtype, flipud=1)
    Twx, Wx, ssq_freqs_c, scales, *_ = ssq_cwt(x, wavelet, **kw_cwt)
    Tsx, Sx, ssq_freqs_s, Sfs, *_ = ssq_stft(x, window, **kw_stft)
    Sx, Sfs = Sx[::-1], Sfs[::-1]

    ckw = dict(penalty=penalty, n_ridges=n_ridges, transform='cwt')
    skw = dict(penalty=penalty, n_ridges=n_ridges, transform='stft')
    cwt_ridges = extract_ridges(Wx, scales, bw=cwt_bw, **ckw)
    ssq_cwt_ridges = extract_ridges(Twx, ssq_freqs_c, bw=ssq_cwt_bw, **ckw)
    stft_ridges = extract_ridges(Sx, Sfs, bw=stft_bw, **skw)
    ssq_stft_ridges = extract_ridges(Tsx, ssq_freqs_s, bw=ssq_stft_bw, **skw)

    viz(x, Wx, cwt_ridges, scales, ssq=0, transform='cwt', show_x=1)
    viz(x, Twx, ssq_cwt_ridges, ssq_freqs_c, ssq=1, transform='cwt', show_x=0)
    viz(x, Sx, stft_ridges, Sfs, ssq=0, transform='stft', show_x=0)
    viz(x,
        Tsx,
        ssq_stft_ridges,
        ssq_freqs_s,
        ssq=1,
        transform='stft',
        show_x=0)
Exemple #8
0
 def plot_tf_wsst(self, TRAI, save_path=None):
     i = self.data_tra[int(TRAI - 1)]
     if self.device == 'vallen':
         sig = np.multiply(array.array('h', bytes(i[-2])), i[-3] * 1000)
         time = np.linspace(0,
                            pow(i[-5], -1) * (i[-4] - 1) * pow(10, 6),
                            i[-4])
     elif self.device == 'stream':
         sig = np.multiply(array.array('h', bytes(i[-1])), i[-2])
         time = np.linspace(0,
                            pow(i[3], -1) * (i[4] - 1) * pow(10, 6), i[4])
     Twxo, Wxo, ssq_freqs, *_ = ssq_cwt(sig,
                                        wavelet='morlet',
                                        scales='log-piecewise',
                                        fs=i[3],
                                        t=time)
     fig = plt.figure(figsize=(5.12, 5.12))
     # plt.imshow(np.abs(Twxo), aspect='auto', vmin=0, vmax=.2, cmap='jet')
     plt.contourf(time, ssq_freqs * 1000, abs(Twxo), cmap='jet')
     plt.ylim(min(ssq_freqs * 1000), 1000)
     plt.xlabel(r'Time (μs)')
     plt.ylabel(r'Frequency (kHz)')
     if save_path:
         plt.gca().xaxis.set_major_locator(plt.NullLocator())
         plt.gca().yaxis.set_major_locator(plt.NullLocator())
         plt.subplots_adjust(top=1,
                             bottom=0,
                             right=1,
                             left=0,
                             hspace=0,
                             wspace=0)
         plt.margins(0, 0)
         plt.savefig(os.path.join(save_path, '%i.jpg' % TRAI), pad_inches=0)
Exemple #9
0
def test_higher_order():
    """`cwt` & `ssq_cwt` CPU & GPU outputs agreement."""
    if not CAN_GPU:
        return

    tsigs = TestSignals(N=256)
    x = tsigs.par_lchirp()[0]
    x += x[::-1]

    kw = dict(order=range(3), astensor=False)
    for dtype in ('float32', 'float64'):
        os.environ['SSQ_GPU'] = '0'
        Tx0, Wx0, *_ = ssq_cwt(x, _wavelet(dtype=dtype), **kw)
        os.environ['SSQ_GPU'] = '1'
        Tx1, Wx1, *_ = ssq_cwt(x, _wavelet(dtype=dtype), **kw)

        adiff_Tx = np.abs(Tx0 - Tx1).mean()
        adiff_Wx = np.abs(Wx0 - Wx1).mean()
        th = 1e-6  # less should be possible for float64, but didn't investigate
        assert adiff_Tx < th, (dtype, th)
        assert adiff_Wx < th, (dtype, th)
    os.environ['SSQ_GPU'] = '0'
Exemple #10
0
def test_ssq_cwt_batched():
    """Ensure batched (2D `x`) inputs output same as if samples fed separately,
    and agreement between CPU & GPU.
    """
    np.random.seed(0)
    x = np.random.randn(4, 256)
    kw = dict(astensor=False)

    for dtype in ('float64', 'float32'):
        os.environ['SSQ_GPU'] = '0'
        Tx0, Wx0, *_ = ssq_cwt(x, _wavelet(dtype=dtype), **kw)

        Tx00 = np.zeros(Tx0.shape, dtype=Tx0.dtype)
        Wx00 = Tx00.copy()
        for i, _x in enumerate(x):
            out = ssq_cwt(_x, _wavelet(dtype=dtype), **kw)
            Tx00[i], Wx00[i] = out[0], out[1]

        if CAN_GPU:
            os.environ['SSQ_GPU'] = '1'
            Tx1, Wx1, *_ = ssq_cwt(x, _wavelet(dtype=dtype), **kw)

        atol = 1e-12 if dtype == 'float64' else 1e-2
        adiff_Tx000 = np.abs(Tx00 - Tx0).mean()
        adiff_Wx000 = np.abs(Wx00 - Wx0).mean()
        assert np.allclose(Wx00, Wx0), (dtype, adiff_Wx000)
        assert np.allclose(Tx00, Tx0), (dtype, adiff_Tx000)
        if CAN_GPU:
            adiff_Tx01  = np.abs(Tx0 - Tx1).mean()
            adiff_Wx01  = np.abs(Wx0 - Wx1).mean()
            assert np.allclose(Wx0, Wx1, atol=atol), (dtype, adiff_Wx01)
            assert np.allclose(Tx0, Tx1, atol=atol), (dtype, adiff_Tx01)

            # didn't investigate float32, and `allclose` threshold is pretty bad,
            # so check MAE
            if dtype == 'float32':
                assert adiff_Tx01 < 1e-6, (dtype, adiff_Tx01)
def test_ssq_cwt():
    errs = []
    for fn in test_fns:
        x, ts = fn(2048)
        for scales in ('log', 'linear'):
            # 'linear' default can't handle low frequencies for large N
            if scales == 'linear' and fn.__name__ == 'fast_transitions':
                continue

            Tx, *_ = ssq_cwt(x, wavelet, scales=scales, nv=32, t=ts)
            xrec = issq_cwt(Tx, wavelet)

            errs.append(round(mad_rms(x, xrec), 5))
            assert errs[-1] < th, (errs[-1], fn.__name__, scales)
    print("\nssq_cwt PASSED\nerrs:", ', '.join(map(str, errs)))
def test_cwt_log_piecewise():
    x, ts = echirp(1024)

    wavelet = 'gmw'
    Tx, Wx, ssq_freqs, scales, *_ = ssq_cwt(x,
                                            wavelet,
                                            scales='log-piecewise',
                                            t=ts,
                                            preserve_transform=True)
    xrec_ssq_cwt = issq_cwt(Tx, 'gmw')
    xrec_cwt = icwt(Wx, wavelet, scales=scales)

    err_ssq_cwt = round(mad_rms(x, xrec_ssq_cwt), 5)
    err_cwt = round(mad_rms(x, xrec_cwt), 5)
    assert err_ssq_cwt < .02, err_ssq_cwt
    assert err_cwt < .02, err_cwt
def test_ssq_cwt():
    errs = []
    for fn in test_fns:
        x, ts = fn(2048)
        for scales in ('log', 'log-piecewise', 'linear'):
            if fn.__name__ == 'low_freqs':
                if scales == 'linear':
                    # 'linear' default can't handle low frequencies for large N
                    # 'log-piecewise' maps it too sparsely
                    continue
                else:
                    scales = f'{scales}:maximal'

            Tx, *_ = ssq_cwt(x, wavelet, scales=scales, nv=32, t=ts)
            xrec = issq_cwt(Tx, wavelet)

            errs.append(round(mad_rms(x, xrec), 5))
            title = "abs(SSQ_CWT) | {}, scales='{}'".format(
                fn.__qualname__, scales)
            _maybe_viz(Tx, x, xrec, title, errs[-1])
            assert errs[-1] < th, (errs[-1], fn.__name__, scales)
    print("\nssq_cwt PASSED\nerrs:", ', '.join(map(str, errs)))
Exemple #14
0
def swt(time, sig, sampleRate, wavelet='morlet', scales='log-piecewise'):
    '''
    :param time:
    :param sig:
    :param sampleRate:
    :param wavelet:
    :param scales:
    :return:
    '''
    Twxo, Wxo, ssq_freqs, *_ = ssq_cwt(sig,
                                       wavelet=wavelet,
                                       scales=scales,
                                       fs=sampleRate,
                                       t=time)
    fig = plt.figure(figsize=(5.12, 5.12))
    plt.contourf(time,
                 ssq_freqs * 1000,
                 pow(abs(Twxo), 0.5),
                 cmap='cubehelix_r')
    plt.ylim(min(ssq_freqs * 1000), 1000)
    plt.xlabel(r'Time (μs)')
    plt.ylabel(r'Frequency (kHz)')
Exemple #15
0
x *= (1 + .3 * cos_f([1], N))  # amplitude modulation
xo = x.copy()
np.random.seed(4)
x += np.sqrt(noise_var) * np.random.randn(len(x))

#### Show signal & its global spectrum #######################################
axf = np.abs(rfft(x))

plot(xo)
scat(xo, s=8, show=1)
plot(x)
scat(x, s=8, show=1)
plot(axf, show=1)
#%%# Synchrosqueeze ##########################################################
kw = dict(wavelet=('morlet', {'mu': 4.5}), nv=32, scales='log')
Tx, *_ = ssq_cwt(x, t=ts, **kw)
Wx, *_ = cwt(x, t=ts, **kw)
#%%# Visualize ###############################################################
pkw = dict(abs=1, w=.86, h=.9, aspect='auto', cmap='bone')
_Tx = np.pad(Tx, [[4, 4]])  # improve display of top- & bottom-most freqs
imshow(Wx, **pkw)
imshow(np.flipud(_Tx), norm=(0, 2e-1), **pkw)
#%%# Estimate inversion ridge ###############################################
bw, slope, offset = .035, .45, .45
Cs, freqband = lin_band(Tx, slope, offset, bw, norm=(0, 2e-1))
#%%###########################################################################
xrec = issq_cwt(Tx, kw['wavelet'], Cs, freqband)[0]
plot(xo)
plot(xrec, show=1)

axof = np.abs(rfft(xo))
Exemple #16
0
    def plot_wave_TRAI(self,
                       fig,
                       k,
                       data_pri,
                       show_features=False,
                       valid=False,
                       cwt=False):
        # Waveform with specific TRAI
        try:
            if self.device == 'VALLEN':
                i = self.data_tra[k - 1]
            else:
                i = self.data_tra[k - self.data_tra[0][-1]]
        except IndexError:
            return str('Error: TRAI %d can not be found in data!' % k)
        if i[-1] != k:
            return str(
                'Error: TRAI %d in data_tra is inconsistent with %d by input!'
                % (i[-1], k))
        time, sig = self.cal_wave(i, valid=valid)
        for tmp_tail, s in enumerate(sig[::-1]):
            if s != 0:
                tail = -tmp_tail if tmp_tail > 0 else None
                break
        time, sig = time[:tail], sig[:tail]

        if cwt:
            fig.subplots_adjust(left=0.076,
                                bottom=0.205,
                                right=0.984,
                                top=0.927,
                                hspace=0.2,
                                wspace=0.26)
            fig.text(0.47,
                     0.25,
                     self.status,
                     fontdict={
                         'family': 'Arial',
                         'fontweight': 'bold',
                         'fontsize': 12
                     },
                     horizontalalignment="right")
            ax = fig.add_subplot(1, 2, 2)
            ax.cla()
            Twxo, Wxo, ssq_freqs, *_ = ssq_cwt(sig,
                                               wavelet='morlet',
                                               scales='log-piecewise',
                                               fs=i[3],
                                               t=time)
            ax.contourf(time,
                        ssq_freqs * 1000,
                        pow(abs(Twxo), 0.5),
                        cmap='cubehelix_r')
            plot_norm(ax,
                      'Time (μs)',
                      'Frequency (kHz)',
                      y_lim=[min(ssq_freqs * 1000), 1000],
                      legend=False)
            ax = fig.add_subplot(1, 2, 1)
            ax.cla()
            ax.plot(time, sig, lw=1)
        else:
            fig.subplots_adjust(left=0.115, bottom=0.17, right=0.975, top=0.95)
            fig.text(0.96,
                     0.2,
                     self.status,
                     fontdict={
                         'family': 'Arial',
                         'fontweight': 'bold',
                         'fontsize': 12
                     },
                     horizontalalignment="right")
            ax = fig.add_subplot()
            ax.cla()
            ax.plot(time, sig, lw=1)

        if self.device == 'vallen':
            if show_features:
                try:
                    string = data_pri[np.where(data_pri[:, -1] == i[-1])][0]
                except IndexError:
                    return str('Error: TRAI %d can not be found in data!' % k)
                print("=" * 23 + " Waveform information " + "=" * 23)
                for info, value, r in zip([
                        'SetID', 'Time', 'Chan', 'Thr', 'Amp', 'RiseT', 'Dur',
                        'Eny', 'RMS', 'Counts', 'TRAI'
                ], [j for j in string], [0, 8, 0, 8, 8, 2, 2, 8, 8, 0, 0]):
                    if r == 0:
                        print('%s: %d' % (info, int(value)))
                    else:
                        print('%s: %s' % (info, round(value, r)))
            ax.axhline(abs(i[2]), 0, sig.shape[0], linewidth=1, color="black")
            ax.axhline(-abs(i[2]), 0, sig.shape[0], linewidth=1, color="black")
        elif self.device == 'pac':
            if show_features:
                # time, channel_num, sample_interval, points_num, dataset, hit_num
                # ID, Time(s), Chan, Thr(μV), Thr(dB), Amp(μV), Amp(dB), RiseT(s), Dur(s), Eny(aJ), RMS(μV), Frequency(Hz), Counts
                string = data_pri[np.where(data_pri[:, 0] == i[-1])][0]
                print("=" * 23 + " Waveform information " + "=" * 23)
                for info, value, r in zip([
                        'Hit number', 'Time', 'Chan', 'Thr', 'Amp', 'RiseT',
                        'Dur', 'Eny', 'RMS', 'Counts'
                ], [
                        j for j in string[np.array(
                            [0, 1, 2, 3, 5, 7, 8, 9, 10, 12])]
                ], [0, 7, 0, 8, 8, 7, 7, 8, 8, 0]):
                    if r == 0:
                        print('%s: %d' % (info, int(value)))
                    else:
                        print('%s: %s' % (info, round(value, r)))
            ax.axhline(abs(self.thr_μV),
                       0,
                       sig.shape[0],
                       linewidth=1,
                       color="black")
            ax.axhline(-abs(self.thr_μV),
                       0,
                       sig.shape[0],
                       linewidth=1,
                       color="black")
        plot_norm(ax, 'Time (μs)', 'Amplitude (μV)', legend=False, grid=True)

        # ================================================= 画图重绘与刷新 ================================================
        fig.canvas.draw()
        fig.canvas.flush_events()

        with open(
                '/'.join([self.output, self.status]) + '-%d' % i[-1] + '.txt',
                'w') as f:
            f.write('Time, Signal\n')
            for k in range(sig.shape[0]):
                f.write("{}, {}\n".format(time[k], sig[k]))
Exemple #17
0
                         weight='bold',
                         fontsize=14,
                         xy=(.85, .93),
                         xycoords='axes fraction')
            plt.show()
        else:
            plt.xlim(0, len(row) // 4)
    plt.show()


#%%
row_anim(Wxz, idxs, scales)
#%%## Superimpose ####
row_anim(Wxz, idxs, scales, superimposed=True)
#%%## Synchrosqueeze
Tx, fs, *_ = ssq_cwt(x, wavelet, t=_t(0, 1, N))
#%%
imshow(Tx, abs=1, title="abs(SSWT)", yticks=fs, show=1)

#%%# Damped pendulum example ################################################
N, w0 = 4096, 25
t = _t(0, 6, N)
s = np.exp(-t) * np.cos(w0 * t)

w = np.linspace(-40, 40, N)
S = (1 + 1j * w) / ((1 + 1j * w)**2 + w0**2)

#%%# Plot ####
plot(s, title="s(t)", show=1)
plot(w, np.abs(S), title="abs(FT(s(t)))", show=1)
Exemple #18
0
    def plot_filtering(self,
                       TRAI,
                       N,
                       CutoffFreq,
                       btype,
                       valid=False,
                       originWave=False,
                       filteredWave=True):
        """
        Signal Filtering
        :param TRAI:
        :param N: Order of filter
        :param CutoffFreq: Cutoff frequency, Wn = 2 * cutoff frequency / sampling frequency, len(Wn) = 2 if btype in ['bandpass', 'bandstop'] else 1
        :param btype: Filter Types, {'lowpass', 'highpass', 'bandpass', 'bandstop'}
        :param originWave: Whether to display the original waveform
        :param filteredWave: Whether to display the filtered waveform
        :param valid: Whether to truncate the waveform according to the threshold
        :return:
        """
        tmp = self.data_tra[int(TRAI - 1)]
        if TRAI != tmp[-1]:
            print('Error: TRAI is incorrect!')
        time, sig = self.cal_wave(tmp, valid=valid)
        b, a = butter(N, list(map(lambda x: 2 * x * 1e3 / tmp[3], CutoffFreq)),
                      btype)
        sig_filter = filtfilt(b, a, sig)

        if originWave:
            fig = plt.figure(figsize=(9.2, 3))
            ax = fig.add_subplot(1, 2, 1)
            ax.plot(time, sig, lw=1, color='blue')
            plot_norm(ax,
                      'Time (μs)',
                      'Amplitude (μV)',
                      title='TRAI: %d' % TRAI,
                      legend=False,
                      grid=True)
            ax = fig.add_subplot(1, 2, 2)
            Twxo, Wxo, ssq_freqs, *_ = ssq_cwt(sig,
                                               wavelet='morlet',
                                               scales='log-piecewise',
                                               fs=tmp[3],
                                               t=time)
            plt.contourf(time, ssq_freqs * 1000, abs(Twxo), cmap='jet')
            plot_norm(ax,
                      r'Time (μs)',
                      r'Frequency (kHz)',
                      y_lim=[min(ssq_freqs * 1000), 1000],
                      legend=False)

        if filteredWave:
            if btype in ['lowpass', 'highpass']:
                label = 'Frequency %s %d kHz' % ('<' if btype == 'lowpass' else
                                                 '>', CutoffFreq)
            elif btype == 'bandpass':
                label = '%d kHz < Frequency < %d kHz' % (CutoffFreq[0],
                                                         CutoffFreq[1])
            else:
                label = 'Frequency < %d kHz or > %d kHz' % (CutoffFreq[0],
                                                            CutoffFreq[1])
            fig = plt.figure(figsize=(9.2, 3))
            ax = fig.add_subplot(1, 2, 1)
            ax.plot(time, sig_filter, lw=1, color='gray', label=label)
            plot_norm(ax,
                      'Time (μs)',
                      'Amplitude (μV)',
                      title='TRAI: %d (%s)' % (TRAI, btype),
                      grid=True,
                      frameon=False,
                      legend_loc='upper right')
            ax = fig.add_subplot(1, 2, 2)
            Twxo, Wxo, ssq_freqs, *_ = ssq_cwt(sig_filter,
                                               wavelet='morlet',
                                               scales='log-piecewise',
                                               fs=tmp[3],
                                               t=time)
            plt.contourf(time, ssq_freqs * 1000, abs(Twxo), cmap='jet')
            plot_norm(ax,
                      r'Time (μs)',
                      r'Frequency (kHz)',
                      y_lim=[min(ssq_freqs * 1000), 1000],
                      legend=False)

        return sig_filter
                         weight='bold',
                         fontsize=14,
                         xy=(.85, .93),
                         xycoords='axes fraction')
            plt.show()
        else:
            plt.xlim(0, len(row) // 4)
    plt.show()


#%%
row_anim(Wxz, idxs, scales)
#%%## Superimpose ####
row_anim(Wxz, idxs, scales, superposed=True)
#%%## Synchrosqueeze
Tx, _, ssq_freqs, *_ = ssq_cwt(x, wavelet, t=_t(0, 1, N))
#%%
imshow(Tx, abs=1, title="abs(SSWT)", yticks=ssq_freqs, show=1)

#%%# Damped pendulum example ################################################
N, w0 = 4096, 25
t = _t(0, 6, N)
s = np.exp(-t) * np.cos(w0 * t)

w = np.linspace(-40, 40, N)
S = (1 + 1j * w) / ((1 + 1j * w)**2 + w0**2)

#%%# Plot ####
plot(s, title="s(t)", show=1)
plot(w, np.abs(S), title="abs(FT(s(t)))", show=1)
Exemple #20
0
# downsampling factor for higher scales (used only if `scaletype='log-piecewise'`)
downsample = 4
# show this many of lowest-frequency wavelets
show_last = 20

#%%## Make scales ############################################################
# `cwt` uses `p2up`'d N internally
M = p2up(N)[0]
wavelet = Wavelet(wavelet, N=M)

min_scale, max_scale = cwt_scalebounds(wavelet, N=len(x), preset=preset)
scales = make_scales(N,
                     min_scale,
                     max_scale,
                     nv=nv,
                     scaletype=scaletype,
                     wavelet=wavelet,
                     downsample=downsample)

#%%# Visualize scales ########################################################
viz(wavelet, scales, scaletype, show_last, nv)
wavelet.viz('filterbank', scales=scales)

#%%# Show applied ############################################################
Tx, Wx, ssq_freqs, scales, *_ = ssq_cwt(x,
                                        wavelet,
                                        scales=scales,
                                        padtype=padtype)
imshow(Wx, abs=1, title="abs(CWT)")
imshow(Tx, abs=1, title="abs(SSQ_CWT)")
Exemple #21
0
# -*- coding: utf-8 -*-
"""Experimental feature example."""
if __name__ != '__main__':
    raise Exception("ran example file as non-main")

import numpy as np
from ssqueezepy import TestSignals, ssq_cwt, Wavelet
from ssqueezepy.visuals import imshow
from ssqueezepy.experimental import phase_ssqueeze

#%%
x = TestSignals(N=2048).par_lchirp()[0]
x += x[::-1]
wavelet = Wavelet()

Tx0, Wx, _, scales, *_ = ssq_cwt(x, wavelet, get_dWx=1)
Tx1, *_ = phase_ssqueeze(Wx, wavelet=wavelet, scales=scales, flipud=1)

adiff = np.abs(Tx0 - Tx1)
print(adiff.mean(), adiff.max(), adiff.sum())
#%%
# main difference near boundaries; see `help(trigdiff)` w/ `rpadded=False`
imshow(Tx1, abs=1)
Exemple #22
0
    return {
        num: '',
        f'{num}-cwt': time_cwt(x, dtype, scales, cache_wavelet),
        f'{num}-stft': time_stft(x, dtype, n_fft),
        f'{num}-ssq_cwt': time_ssq_cwt(x, dtype, scales, cache_wavelet,
                                       ssq_freqs),
        f'{num}-ssq_stft': time_ssq_stft(x, dtype, n_fft)
    }


#%%# Setup ###################################################################
# warmup
x = np.random.randn(1000)
for dtype in ('float32', 'float64'):
    wavelet = Wavelet(dtype=dtype)
    _ = ssq_cwt(x, wavelet, cache_wavelet=False)
    _ = ssq_stft(x, dtype=dtype)
del _, wavelet

#%%# Prepare reusable parameters such that STFT & CWT output shapes match ####
N0, N1 = 10000, 160000  # selected such that CWT pad length ratios are same
n_rows = 300
n_fft = n_rows * 2 - 2

wavelet = Wavelet()
scales = process_scales('log-piecewise', N1, wavelet=wavelet)[:n_rows]
ssq_freqs = _compute_associated_frequencies(scales,
                                            N1,
                                            wavelet,
                                            'log-piecewise',
                                            maprange='peak',