def test_ssq_cwt(): x = np.random.randn(64) for wavelet in ('morlet', ('morlet', {'mu': 20}), 'bump'): Tx, *_ = ssq_cwt(x, wavelet) issq_cwt(Tx, wavelet) kw = dict(x=x, wavelet='morlet') params = dict( squeezing=('lebesgue', ), scales=('linear', 'log:minimal', 'linear:naive', np.power(2**(1 / 16), np.arange(1, 32))), difftype=('phase', 'numeric'), padtype=('zero', 'replicate'), mapkind=('energy', 'peak'), ) for name in params: for value in params[name]: try: _ = ssq_cwt(**kw, **{name: value}) except Exception as e: raise Exception(f"{name}={value} failed with:\n{e}") _ = ssq_cwt(x, wavelet, fs=2, difftype='numeric', difforder=2) _ = ssq_cwt(x, wavelet, fs=2, difftype='numeric', difforder=1)
def test_ssq_cwt(): os.environ['SSQ_GPU'] = '0' # in case concurrent tests set it to '1' np.random.seed(5) x = np.random.randn(64) for wavelet in ('morlet', ('morlet', {'mu': 20}), 'bump'): Tx, *_ = ssq_cwt(x, wavelet) issq_cwt(Tx, wavelet) kw = dict(x=x, wavelet='morlet') params = dict( squeezing=('lebesgue', ), scales=('linear', 'log:minimal', 'linear:naive', np.power(2**(1 / 8), np.arange(1, 32))), difftype=('phase', 'numeric'), padtype=('zero', 'replicate'), maprange=('maximal', 'energy', 'peak', (1, 32)), ) for name in params: for value in params[name]: errored = True try: if name == 'maprange' and value in ('maximal', (1, 32)): _ = ssq_cwt(**kw, **{name: value}, scales='log', get_w=1) else: _ = ssq_cwt(**kw, **{name: value}, get_w=1) errored = False finally: if errored: print(f"\n{name}={value} failed\n") _ = ssq_cwt(x, wavelet, fs=2, difftype='numeric', difforder=2, get_w=1) _ = ssq_cwt(x, wavelet, fs=2, difftype='numeric', difforder=1, get_w=1)
def test_ssq_cwt(): N = 256 tsigs = TestSignals(N=N) for dtype in ('float64', 'float32'): gpu_atol = 1e-8 if dtype == 'float64' else 2e-4 x = tsigs.par_lchirp()[0].astype(dtype) kw = dict(astensor=False) os.environ['SSQ_GPU'] = '0' Tx00 = ssq_cwt(x, _wavelet(dtype=dtype), **kw, get_w=1)[0] Tx01 = ssq_cwt(x, _wavelet(dtype=dtype), **kw, get_w=0)[0] if CAN_GPU: os.environ['SSQ_GPU'] = '1' Tx10 = ssq_cwt(x, _wavelet(dtype=dtype), **kw, get_w=1)[0] Tx11 = ssq_cwt(x, _wavelet(dtype=dtype), **kw, get_w=0)[0] adiff0001 = np.abs(Tx00 - Tx01).mean() assert np.allclose(Tx00, Tx01), (dtype, adiff0001) if CAN_GPU: adiff0010 = np.abs(Tx00 - Tx10).mean() adiff0011 = np.abs(Tx00 - Tx11).mean() assert np.allclose(Tx00, Tx10, atol=gpu_atol), (dtype, adiff0010) assert np.allclose(Tx00, Tx11, atol=gpu_atol), (dtype, adiff0011) os.environ['SSQ_GPU'] = '0'
def test_dtype(): """Ensure `cwt` and `ssq_cwt` compute at appropriate precision depending on `Wavelet.dtype`, returning float32 & complex64 arrays for single precision. """ os.environ['SSQ_GPU'] = '0' wav32, wav64 = Wavelet(dtype='float32'), Wavelet(dtype='float64') x = np.random.randn(256) outs32 = ssq_cwt(x, wav32) outs64 = ssq_cwt(x, wav64) outs32_o2 = ssq_cwt(x, wav32, order=2) names = ('Tx', 'Wx', 'ssq_freqs', 'scales', 'w', 'dWx') outs32 = {k: v for k, v in zip(names, outs32)} outs32_o2 = {k: v for k, v in zip(names, outs32_o2)} outs64 = {k: v for k, v in zip(names, outs64)} for k, v in outs32.items(): if k == 'ssq_freqs': assert v.dtype == np.float64, ("float32", k, v.dtype) continue assert v.dtype in (np.float32, np.complex64), ("float32", k, v.dtype) for k, v in outs32_o2.items(): if k == 'ssq_freqs': assert v.dtype == np.float64, ("float32", k, v.dtype) continue assert v.dtype in (np.float32, np.complex64), ("float32", k, v.dtype) for k, v in outs64.items(): if k == 'ssq_freqs': assert v.dtype == np.float64, ("float32", k, v.dtype) continue assert v.dtype in (np.float64, np.complex128), ("float64", k, v.dtype)
def time_ssq_cwt(x, dtype, scales, cache_wavelet, ssq_freqs): wavelet = Wavelet(dtype=dtype) kw = dict(wavelet=wavelet, scales=scales, ssq_freqs=ssq_freqs) if cache_wavelet: for _ in range(3): # warmup run _ = ssq_cwt(x, cache_wavelet=True, **kw) del _ gc.collect() return timeit(lambda: ssq_cwt(x, cache_wavelet=cache_wavelet, **kw))
def test_component_inversion(): def echirp(N): t = np.linspace(0, 10, N, False) return np.cos(2 * np.pi * np.exp(t / 3)), t N = 2048 noise_var = 6 x, ts = echirp(N) x *= (1 + .3 * cos_f([1], N)) # amplitude modulation xo = x.copy() np.random.seed(4) x += np.sqrt(noise_var) * np.random.randn(len(x)) wavelet = ('morlet', {'mu': 4.5}) Tx, *_ = ssq_cwt(x, wavelet, scales='log:maximal', nv=32, t=ts) # hand-coded, subject to failure bw, slope, offset = .035, .45, .45 Cs, freqband = lin_band(Tx, slope, offset, bw, norm=(0, 2e-1)) xrec = issq_cwt(Tx, wavelet, Cs, freqband)[0] axof = np.abs(np.fft.rfft(xo)) axrecf = np.abs(np.fft.rfft(xrec)) err_sig = mad_rms(xo, xrec) err_spc = mad_rms(axof, axrecf) print("signal MAD/RMS: %.6f" % err_sig) print("spectrum MAD/RMS: %.6f" % err_spc) assert err_sig <= .42, f"{err_sig} > .42" assert err_spc <= .14, f"{err_spc} > .14"
def tf_transforms(x, t, wavelet='morlet', window=None, padtype='wrap', penalty=.5, n_ridges=2, cwt_bw=15, stft_bw=15, ssq_cwt_bw=4, ssq_stft_bw=4): kw_cwt = dict(t=t, padtype=padtype) kw_stft = dict(fs=1 / (t[1] - t[0]), padtype=padtype, flipud=1) Twx, Wx, ssq_freqs_c, scales, *_ = ssq_cwt(x, wavelet, **kw_cwt) Tsx, Sx, ssq_freqs_s, Sfs, *_ = ssq_stft(x, window, **kw_stft) Sx, Sfs = Sx[::-1], Sfs[::-1] ckw = dict(penalty=penalty, n_ridges=n_ridges, transform='cwt') skw = dict(penalty=penalty, n_ridges=n_ridges, transform='stft') cwt_ridges = extract_ridges(Wx, scales, bw=cwt_bw, **ckw) ssq_cwt_ridges = extract_ridges(Twx, ssq_freqs_c, bw=ssq_cwt_bw, **ckw) stft_ridges = extract_ridges(Sx, Sfs, bw=stft_bw, **skw) ssq_stft_ridges = extract_ridges(Tsx, ssq_freqs_s, bw=ssq_stft_bw, **skw) viz(x, Wx, cwt_ridges, scales, ssq=0, transform='cwt', show_x=1) viz(x, Twx, ssq_cwt_ridges, ssq_freqs_c, ssq=1, transform='cwt', show_x=0) viz(x, Sx, stft_ridges, Sfs, ssq=0, transform='stft', show_x=0) viz(x, Tsx, ssq_stft_ridges, ssq_freqs_s, ssq=1, transform='stft', show_x=0)
def plot_tf_wsst(self, TRAI, save_path=None): i = self.data_tra[int(TRAI - 1)] if self.device == 'vallen': sig = np.multiply(array.array('h', bytes(i[-2])), i[-3] * 1000) time = np.linspace(0, pow(i[-5], -1) * (i[-4] - 1) * pow(10, 6), i[-4]) elif self.device == 'stream': sig = np.multiply(array.array('h', bytes(i[-1])), i[-2]) time = np.linspace(0, pow(i[3], -1) * (i[4] - 1) * pow(10, 6), i[4]) Twxo, Wxo, ssq_freqs, *_ = ssq_cwt(sig, wavelet='morlet', scales='log-piecewise', fs=i[3], t=time) fig = plt.figure(figsize=(5.12, 5.12)) # plt.imshow(np.abs(Twxo), aspect='auto', vmin=0, vmax=.2, cmap='jet') plt.contourf(time, ssq_freqs * 1000, abs(Twxo), cmap='jet') plt.ylim(min(ssq_freqs * 1000), 1000) plt.xlabel(r'Time (μs)') plt.ylabel(r'Frequency (kHz)') if save_path: plt.gca().xaxis.set_major_locator(plt.NullLocator()) plt.gca().yaxis.set_major_locator(plt.NullLocator()) plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) plt.margins(0, 0) plt.savefig(os.path.join(save_path, '%i.jpg' % TRAI), pad_inches=0)
def test_higher_order(): """`cwt` & `ssq_cwt` CPU & GPU outputs agreement.""" if not CAN_GPU: return tsigs = TestSignals(N=256) x = tsigs.par_lchirp()[0] x += x[::-1] kw = dict(order=range(3), astensor=False) for dtype in ('float32', 'float64'): os.environ['SSQ_GPU'] = '0' Tx0, Wx0, *_ = ssq_cwt(x, _wavelet(dtype=dtype), **kw) os.environ['SSQ_GPU'] = '1' Tx1, Wx1, *_ = ssq_cwt(x, _wavelet(dtype=dtype), **kw) adiff_Tx = np.abs(Tx0 - Tx1).mean() adiff_Wx = np.abs(Wx0 - Wx1).mean() th = 1e-6 # less should be possible for float64, but didn't investigate assert adiff_Tx < th, (dtype, th) assert adiff_Wx < th, (dtype, th) os.environ['SSQ_GPU'] = '0'
def test_ssq_cwt_batched(): """Ensure batched (2D `x`) inputs output same as if samples fed separately, and agreement between CPU & GPU. """ np.random.seed(0) x = np.random.randn(4, 256) kw = dict(astensor=False) for dtype in ('float64', 'float32'): os.environ['SSQ_GPU'] = '0' Tx0, Wx0, *_ = ssq_cwt(x, _wavelet(dtype=dtype), **kw) Tx00 = np.zeros(Tx0.shape, dtype=Tx0.dtype) Wx00 = Tx00.copy() for i, _x in enumerate(x): out = ssq_cwt(_x, _wavelet(dtype=dtype), **kw) Tx00[i], Wx00[i] = out[0], out[1] if CAN_GPU: os.environ['SSQ_GPU'] = '1' Tx1, Wx1, *_ = ssq_cwt(x, _wavelet(dtype=dtype), **kw) atol = 1e-12 if dtype == 'float64' else 1e-2 adiff_Tx000 = np.abs(Tx00 - Tx0).mean() adiff_Wx000 = np.abs(Wx00 - Wx0).mean() assert np.allclose(Wx00, Wx0), (dtype, adiff_Wx000) assert np.allclose(Tx00, Tx0), (dtype, adiff_Tx000) if CAN_GPU: adiff_Tx01 = np.abs(Tx0 - Tx1).mean() adiff_Wx01 = np.abs(Wx0 - Wx1).mean() assert np.allclose(Wx0, Wx1, atol=atol), (dtype, adiff_Wx01) assert np.allclose(Tx0, Tx1, atol=atol), (dtype, adiff_Tx01) # didn't investigate float32, and `allclose` threshold is pretty bad, # so check MAE if dtype == 'float32': assert adiff_Tx01 < 1e-6, (dtype, adiff_Tx01)
def test_ssq_cwt(): errs = [] for fn in test_fns: x, ts = fn(2048) for scales in ('log', 'linear'): # 'linear' default can't handle low frequencies for large N if scales == 'linear' and fn.__name__ == 'fast_transitions': continue Tx, *_ = ssq_cwt(x, wavelet, scales=scales, nv=32, t=ts) xrec = issq_cwt(Tx, wavelet) errs.append(round(mad_rms(x, xrec), 5)) assert errs[-1] < th, (errs[-1], fn.__name__, scales) print("\nssq_cwt PASSED\nerrs:", ', '.join(map(str, errs)))
def test_cwt_log_piecewise(): x, ts = echirp(1024) wavelet = 'gmw' Tx, Wx, ssq_freqs, scales, *_ = ssq_cwt(x, wavelet, scales='log-piecewise', t=ts, preserve_transform=True) xrec_ssq_cwt = issq_cwt(Tx, 'gmw') xrec_cwt = icwt(Wx, wavelet, scales=scales) err_ssq_cwt = round(mad_rms(x, xrec_ssq_cwt), 5) err_cwt = round(mad_rms(x, xrec_cwt), 5) assert err_ssq_cwt < .02, err_ssq_cwt assert err_cwt < .02, err_cwt
def test_ssq_cwt(): errs = [] for fn in test_fns: x, ts = fn(2048) for scales in ('log', 'log-piecewise', 'linear'): if fn.__name__ == 'low_freqs': if scales == 'linear': # 'linear' default can't handle low frequencies for large N # 'log-piecewise' maps it too sparsely continue else: scales = f'{scales}:maximal' Tx, *_ = ssq_cwt(x, wavelet, scales=scales, nv=32, t=ts) xrec = issq_cwt(Tx, wavelet) errs.append(round(mad_rms(x, xrec), 5)) title = "abs(SSQ_CWT) | {}, scales='{}'".format( fn.__qualname__, scales) _maybe_viz(Tx, x, xrec, title, errs[-1]) assert errs[-1] < th, (errs[-1], fn.__name__, scales) print("\nssq_cwt PASSED\nerrs:", ', '.join(map(str, errs)))
def swt(time, sig, sampleRate, wavelet='morlet', scales='log-piecewise'): ''' :param time: :param sig: :param sampleRate: :param wavelet: :param scales: :return: ''' Twxo, Wxo, ssq_freqs, *_ = ssq_cwt(sig, wavelet=wavelet, scales=scales, fs=sampleRate, t=time) fig = plt.figure(figsize=(5.12, 5.12)) plt.contourf(time, ssq_freqs * 1000, pow(abs(Twxo), 0.5), cmap='cubehelix_r') plt.ylim(min(ssq_freqs * 1000), 1000) plt.xlabel(r'Time (μs)') plt.ylabel(r'Frequency (kHz)')
x *= (1 + .3 * cos_f([1], N)) # amplitude modulation xo = x.copy() np.random.seed(4) x += np.sqrt(noise_var) * np.random.randn(len(x)) #### Show signal & its global spectrum ####################################### axf = np.abs(rfft(x)) plot(xo) scat(xo, s=8, show=1) plot(x) scat(x, s=8, show=1) plot(axf, show=1) #%%# Synchrosqueeze ########################################################## kw = dict(wavelet=('morlet', {'mu': 4.5}), nv=32, scales='log') Tx, *_ = ssq_cwt(x, t=ts, **kw) Wx, *_ = cwt(x, t=ts, **kw) #%%# Visualize ############################################################### pkw = dict(abs=1, w=.86, h=.9, aspect='auto', cmap='bone') _Tx = np.pad(Tx, [[4, 4]]) # improve display of top- & bottom-most freqs imshow(Wx, **pkw) imshow(np.flipud(_Tx), norm=(0, 2e-1), **pkw) #%%# Estimate inversion ridge ############################################### bw, slope, offset = .035, .45, .45 Cs, freqband = lin_band(Tx, slope, offset, bw, norm=(0, 2e-1)) #%%########################################################################### xrec = issq_cwt(Tx, kw['wavelet'], Cs, freqband)[0] plot(xo) plot(xrec, show=1) axof = np.abs(rfft(xo))
def plot_wave_TRAI(self, fig, k, data_pri, show_features=False, valid=False, cwt=False): # Waveform with specific TRAI try: if self.device == 'VALLEN': i = self.data_tra[k - 1] else: i = self.data_tra[k - self.data_tra[0][-1]] except IndexError: return str('Error: TRAI %d can not be found in data!' % k) if i[-1] != k: return str( 'Error: TRAI %d in data_tra is inconsistent with %d by input!' % (i[-1], k)) time, sig = self.cal_wave(i, valid=valid) for tmp_tail, s in enumerate(sig[::-1]): if s != 0: tail = -tmp_tail if tmp_tail > 0 else None break time, sig = time[:tail], sig[:tail] if cwt: fig.subplots_adjust(left=0.076, bottom=0.205, right=0.984, top=0.927, hspace=0.2, wspace=0.26) fig.text(0.47, 0.25, self.status, fontdict={ 'family': 'Arial', 'fontweight': 'bold', 'fontsize': 12 }, horizontalalignment="right") ax = fig.add_subplot(1, 2, 2) ax.cla() Twxo, Wxo, ssq_freqs, *_ = ssq_cwt(sig, wavelet='morlet', scales='log-piecewise', fs=i[3], t=time) ax.contourf(time, ssq_freqs * 1000, pow(abs(Twxo), 0.5), cmap='cubehelix_r') plot_norm(ax, 'Time (μs)', 'Frequency (kHz)', y_lim=[min(ssq_freqs * 1000), 1000], legend=False) ax = fig.add_subplot(1, 2, 1) ax.cla() ax.plot(time, sig, lw=1) else: fig.subplots_adjust(left=0.115, bottom=0.17, right=0.975, top=0.95) fig.text(0.96, 0.2, self.status, fontdict={ 'family': 'Arial', 'fontweight': 'bold', 'fontsize': 12 }, horizontalalignment="right") ax = fig.add_subplot() ax.cla() ax.plot(time, sig, lw=1) if self.device == 'vallen': if show_features: try: string = data_pri[np.where(data_pri[:, -1] == i[-1])][0] except IndexError: return str('Error: TRAI %d can not be found in data!' % k) print("=" * 23 + " Waveform information " + "=" * 23) for info, value, r in zip([ 'SetID', 'Time', 'Chan', 'Thr', 'Amp', 'RiseT', 'Dur', 'Eny', 'RMS', 'Counts', 'TRAI' ], [j for j in string], [0, 8, 0, 8, 8, 2, 2, 8, 8, 0, 0]): if r == 0: print('%s: %d' % (info, int(value))) else: print('%s: %s' % (info, round(value, r))) ax.axhline(abs(i[2]), 0, sig.shape[0], linewidth=1, color="black") ax.axhline(-abs(i[2]), 0, sig.shape[0], linewidth=1, color="black") elif self.device == 'pac': if show_features: # time, channel_num, sample_interval, points_num, dataset, hit_num # ID, Time(s), Chan, Thr(μV), Thr(dB), Amp(μV), Amp(dB), RiseT(s), Dur(s), Eny(aJ), RMS(μV), Frequency(Hz), Counts string = data_pri[np.where(data_pri[:, 0] == i[-1])][0] print("=" * 23 + " Waveform information " + "=" * 23) for info, value, r in zip([ 'Hit number', 'Time', 'Chan', 'Thr', 'Amp', 'RiseT', 'Dur', 'Eny', 'RMS', 'Counts' ], [ j for j in string[np.array( [0, 1, 2, 3, 5, 7, 8, 9, 10, 12])] ], [0, 7, 0, 8, 8, 7, 7, 8, 8, 0]): if r == 0: print('%s: %d' % (info, int(value))) else: print('%s: %s' % (info, round(value, r))) ax.axhline(abs(self.thr_μV), 0, sig.shape[0], linewidth=1, color="black") ax.axhline(-abs(self.thr_μV), 0, sig.shape[0], linewidth=1, color="black") plot_norm(ax, 'Time (μs)', 'Amplitude (μV)', legend=False, grid=True) # ================================================= 画图重绘与刷新 ================================================ fig.canvas.draw() fig.canvas.flush_events() with open( '/'.join([self.output, self.status]) + '-%d' % i[-1] + '.txt', 'w') as f: f.write('Time, Signal\n') for k in range(sig.shape[0]): f.write("{}, {}\n".format(time[k], sig[k]))
weight='bold', fontsize=14, xy=(.85, .93), xycoords='axes fraction') plt.show() else: plt.xlim(0, len(row) // 4) plt.show() #%% row_anim(Wxz, idxs, scales) #%%## Superimpose #### row_anim(Wxz, idxs, scales, superimposed=True) #%%## Synchrosqueeze Tx, fs, *_ = ssq_cwt(x, wavelet, t=_t(0, 1, N)) #%% imshow(Tx, abs=1, title="abs(SSWT)", yticks=fs, show=1) #%%# Damped pendulum example ################################################ N, w0 = 4096, 25 t = _t(0, 6, N) s = np.exp(-t) * np.cos(w0 * t) w = np.linspace(-40, 40, N) S = (1 + 1j * w) / ((1 + 1j * w)**2 + w0**2) #%%# Plot #### plot(s, title="s(t)", show=1) plot(w, np.abs(S), title="abs(FT(s(t)))", show=1)
def plot_filtering(self, TRAI, N, CutoffFreq, btype, valid=False, originWave=False, filteredWave=True): """ Signal Filtering :param TRAI: :param N: Order of filter :param CutoffFreq: Cutoff frequency, Wn = 2 * cutoff frequency / sampling frequency, len(Wn) = 2 if btype in ['bandpass', 'bandstop'] else 1 :param btype: Filter Types, {'lowpass', 'highpass', 'bandpass', 'bandstop'} :param originWave: Whether to display the original waveform :param filteredWave: Whether to display the filtered waveform :param valid: Whether to truncate the waveform according to the threshold :return: """ tmp = self.data_tra[int(TRAI - 1)] if TRAI != tmp[-1]: print('Error: TRAI is incorrect!') time, sig = self.cal_wave(tmp, valid=valid) b, a = butter(N, list(map(lambda x: 2 * x * 1e3 / tmp[3], CutoffFreq)), btype) sig_filter = filtfilt(b, a, sig) if originWave: fig = plt.figure(figsize=(9.2, 3)) ax = fig.add_subplot(1, 2, 1) ax.plot(time, sig, lw=1, color='blue') plot_norm(ax, 'Time (μs)', 'Amplitude (μV)', title='TRAI: %d' % TRAI, legend=False, grid=True) ax = fig.add_subplot(1, 2, 2) Twxo, Wxo, ssq_freqs, *_ = ssq_cwt(sig, wavelet='morlet', scales='log-piecewise', fs=tmp[3], t=time) plt.contourf(time, ssq_freqs * 1000, abs(Twxo), cmap='jet') plot_norm(ax, r'Time (μs)', r'Frequency (kHz)', y_lim=[min(ssq_freqs * 1000), 1000], legend=False) if filteredWave: if btype in ['lowpass', 'highpass']: label = 'Frequency %s %d kHz' % ('<' if btype == 'lowpass' else '>', CutoffFreq) elif btype == 'bandpass': label = '%d kHz < Frequency < %d kHz' % (CutoffFreq[0], CutoffFreq[1]) else: label = 'Frequency < %d kHz or > %d kHz' % (CutoffFreq[0], CutoffFreq[1]) fig = plt.figure(figsize=(9.2, 3)) ax = fig.add_subplot(1, 2, 1) ax.plot(time, sig_filter, lw=1, color='gray', label=label) plot_norm(ax, 'Time (μs)', 'Amplitude (μV)', title='TRAI: %d (%s)' % (TRAI, btype), grid=True, frameon=False, legend_loc='upper right') ax = fig.add_subplot(1, 2, 2) Twxo, Wxo, ssq_freqs, *_ = ssq_cwt(sig_filter, wavelet='morlet', scales='log-piecewise', fs=tmp[3], t=time) plt.contourf(time, ssq_freqs * 1000, abs(Twxo), cmap='jet') plot_norm(ax, r'Time (μs)', r'Frequency (kHz)', y_lim=[min(ssq_freqs * 1000), 1000], legend=False) return sig_filter
weight='bold', fontsize=14, xy=(.85, .93), xycoords='axes fraction') plt.show() else: plt.xlim(0, len(row) // 4) plt.show() #%% row_anim(Wxz, idxs, scales) #%%## Superimpose #### row_anim(Wxz, idxs, scales, superposed=True) #%%## Synchrosqueeze Tx, _, ssq_freqs, *_ = ssq_cwt(x, wavelet, t=_t(0, 1, N)) #%% imshow(Tx, abs=1, title="abs(SSWT)", yticks=ssq_freqs, show=1) #%%# Damped pendulum example ################################################ N, w0 = 4096, 25 t = _t(0, 6, N) s = np.exp(-t) * np.cos(w0 * t) w = np.linspace(-40, 40, N) S = (1 + 1j * w) / ((1 + 1j * w)**2 + w0**2) #%%# Plot #### plot(s, title="s(t)", show=1) plot(w, np.abs(S), title="abs(FT(s(t)))", show=1)
# downsampling factor for higher scales (used only if `scaletype='log-piecewise'`) downsample = 4 # show this many of lowest-frequency wavelets show_last = 20 #%%## Make scales ############################################################ # `cwt` uses `p2up`'d N internally M = p2up(N)[0] wavelet = Wavelet(wavelet, N=M) min_scale, max_scale = cwt_scalebounds(wavelet, N=len(x), preset=preset) scales = make_scales(N, min_scale, max_scale, nv=nv, scaletype=scaletype, wavelet=wavelet, downsample=downsample) #%%# Visualize scales ######################################################## viz(wavelet, scales, scaletype, show_last, nv) wavelet.viz('filterbank', scales=scales) #%%# Show applied ############################################################ Tx, Wx, ssq_freqs, scales, *_ = ssq_cwt(x, wavelet, scales=scales, padtype=padtype) imshow(Wx, abs=1, title="abs(CWT)") imshow(Tx, abs=1, title="abs(SSQ_CWT)")
# -*- coding: utf-8 -*- """Experimental feature example.""" if __name__ != '__main__': raise Exception("ran example file as non-main") import numpy as np from ssqueezepy import TestSignals, ssq_cwt, Wavelet from ssqueezepy.visuals import imshow from ssqueezepy.experimental import phase_ssqueeze #%% x = TestSignals(N=2048).par_lchirp()[0] x += x[::-1] wavelet = Wavelet() Tx0, Wx, _, scales, *_ = ssq_cwt(x, wavelet, get_dWx=1) Tx1, *_ = phase_ssqueeze(Wx, wavelet=wavelet, scales=scales, flipud=1) adiff = np.abs(Tx0 - Tx1) print(adiff.mean(), adiff.max(), adiff.sum()) #%% # main difference near boundaries; see `help(trigdiff)` w/ `rpadded=False` imshow(Tx1, abs=1)
return { num: '', f'{num}-cwt': time_cwt(x, dtype, scales, cache_wavelet), f'{num}-stft': time_stft(x, dtype, n_fft), f'{num}-ssq_cwt': time_ssq_cwt(x, dtype, scales, cache_wavelet, ssq_freqs), f'{num}-ssq_stft': time_ssq_stft(x, dtype, n_fft) } #%%# Setup ################################################################### # warmup x = np.random.randn(1000) for dtype in ('float32', 'float64'): wavelet = Wavelet(dtype=dtype) _ = ssq_cwt(x, wavelet, cache_wavelet=False) _ = ssq_stft(x, dtype=dtype) del _, wavelet #%%# Prepare reusable parameters such that STFT & CWT output shapes match #### N0, N1 = 10000, 160000 # selected such that CWT pad length ratios are same n_rows = 300 n_fft = n_rows * 2 - 2 wavelet = Wavelet() scales = process_scales('log-piecewise', N1, wavelet=wavelet)[:n_rows] ssq_freqs = _compute_associated_frequencies(scales, N1, wavelet, 'log-piecewise', maprange='peak',