def plot_wave_coherence(wave1, wave2, sample_times, min_freq=1, max_freq=256, sig=False, ax=None, title="Wavelet Coherence", plot_arrows=True, plot_coi=True, plot_period=False, resolution=12, all_arrows=True, quiv_x=5, quiv_y=24, block=None): """ Calculate wavelet coherence between wave1 and wave2 using pycwt. TODO fix min_freq, max_freq and add parameters to control arrows. TODO also test out sig on a large dataset Parameters ---------- wave1 : np.ndarray The values of the first waveform. wave2 : np.ndarray The values of the second waveform. sample_times : np.ndarray The times at which waveform samples occur. min_freq : float Supposed to be minimum frequency, but not quite working. max_freq : float Supposed to be max frequency, but not quite working. sig : bool, default False Optional Should significance of waveform coherence be calculated. ax : plt.axe, default None Optional ax object to plot into. title : str, default "Wavelet Coherence" Optional title for the graph plot_arrows : bool, default True Should phase arrows be plotted. plot_coi : bool, default True Should the cone of influence be plotted plot_period : bool Should the y-axis be in period or in frequency (Hz) resolution : int How many wavelets should be at each level of the graph all_arrows : bool Should phase arrows be plotted uniformly or only at high coherence quiv_x : float sets quiver window in time domain in seconds quiv_y : float sets number of quivers evenly distributed across freq limits block : [int, int] Plots only points between ints. Returns ------- tuple : (fig, result) Where fig is a matplotlib Figure and result is a tuple consisting of WCT, aWCT, coi, freq, sig WCT - 2D numpy array with coherence values aWCT - 2D numpy array with same shape as aWCT indicating phase angles coi - 1D numpy array with a frequency value for each time freq - 1D numpy array with the frequencies wavelets were calculated at sig - 2D numpy array indicating where data is significant by monte carlo """ t = np.asarray(sample_times) dt = np.mean(np.diff(t)) # Set up the scales to match min max input frequencies dj = resolution s0 = min_freq * dt if s0 < 2 * dt: s0 = 2 * dt max_J = max_freq * dt J = dj * np.int(np.round(np.log2(max_J / np.abs(s0)))) # freqs = np.geomspace(max_freq, min_freq, num=50) freqs = None # Do the actual calculation print("Calculating coherence...") start_time = time.time() WCT, aWCT, coi, freq, sig = wavelet.wct( wave1, wave2, dt, # Fixed params dj=(1.0 / dj), s0=s0, J=J, sig=sig, normalize=True, freqs=freqs, ) print("Time Taken: %s s" % (time.time() - start_time)) if np.max(WCT) > 1 or np.min(WCT) < 0: print('WCT was out of range: min {},max {}'.format( np.min(WCT), np.max(WCT))) WCT = np.clip(WCT, 0, 1) # Convert frequency to period if necessary if plot_period: y_vals = np.log2(1 / freq) if not plot_period: y_vals = np.log2(freq) # Calculates the phase between both time series. The phase arrows in the # cross wavelet power spectrum rotate clockwise with 'north' origin. # The relative phase relationship convention is the same as adopted # by Torrence and Webster (1999), where in phase signals point # upwards (N), anti-phase signals point downwards (S). If X leads Y, # arrows point to the right (E) and if X lags Y, arrow points to the # left (W). angle = 0.5 * np.pi - aWCT u, v = np.cos(angle), np.sin(angle) if ax is None: fig, ax = plt.subplots() else: fig = None # Set the x and y axes of the plot extent_corr = [t.min(), t.max(), 0, max(y_vals)] # Fill the plot with the magnitude squared coherence values # That is, MSC = abs(Pxy) ^ 2 / (Pxx * Pyy) # TODO I think this might be the wrong way to plot this # It assumes that the samples are linearly spaced im = NonUniformImage(ax, interpolation='bilinear', extent=extent_corr) if plot_period: im.set_data(t, y_vals, WCT) else: im.set_data(t, y_vals[::-1], WCT[::-1, :]) ax.images.append(im) # pcm = ax.pcolormesh(WCT) # Plot the cone of influence - Periods greater than # those are subject to edge effects. if plot_coi: # Performed by plotting a polygon x_positions = np.zeros(shape=(len(t), )) x_positions = t y_positions = np.zeros(shape=(len(t), )) if plot_period: y_positions = np.log2(coi) else: y_positions = np.log2(1 / coi) ax.plot(x_positions, y_positions, 'w--', linewidth=2, c="w") # Plot the significance level contour plot if sig: ax.contour(t, y_vals, sig, [-99, 1], colors='k', linewidths=2, extent=extent_corr) # Add limits, titles, etc. ax.set_ylim(min(y_vals), max(y_vals)) if block: ax.set_xlim(t[block[0]], t[int(block[1] * 1 / dt)]) else: ax.set_xlim(t.min(), t.max()) # TODO split graph into smaller time chunks # Test for smaller timescale # quiv_x = 1 # Plot the arrows on the plot if plot_arrows: # TODO currently this is a uniform grid, could be changed to WCT > 0.5 x_res = int(1 / dt * quiv_x) y_res = int(np.floor(len(y_vals) / quiv_y)) if all_arrows: ax.quiver( t[::x_res], y_vals[::y_res], u[::y_res, ::x_res], v[::y_res, ::x_res], units='height', angles='uv', pivot='mid', linewidth=1, edgecolor='k', scale=30, headwidth=10, headlength=10, headaxislength=5, minshaft=2, ) else: # t[::x_res], y_vals[::y_res], # u[::y_res, ::x_res], v[::y_res, ::x_res] high_points = np.nonzero(WCT[::y_res, ::x_res] > 0.5) sub_t = t[::x_res][high_points[1]] sub_y = y_vals[::y_res][high_points[0]] sub_u = u[::y_res, ::x_res][np.array(high_points[0]), np.array(high_points[1])] sub_v = v[::y_res, ::x_res][high_points[0], high_points[1]] res = 1 ax.quiver( sub_t[::res], sub_y[::res], sub_u[::res], sub_v[::res], units='height', angles='uv', pivot='mid', linewidth=1, edgecolor='k', scale=30, headwidth=10, headlength=10, headaxislength=5, minshaft=2, ) # splits = [0, 60, 120 ...] # Add the colorbar to the figure if fig is not None: fig.colorbar(im) else: plt.colorbar(im, ax=ax, use_gridspec=True) if plot_period: y_ticks = np.linspace(min(y_vals), max(y_vals), 8) # TODO improve ticks y_ticks = [ np.log2(x) for x in [0.004, 0.008, 0.016, 0.032, 0.064, 0.125, 0.25, 0.5, 1] ] y_labels = [str(x) for x in (np.round(np.exp2(y_ticks), 3))] ax.set_ylabel("Period") else: y_ticks = np.linspace(min(y_vals), max(y_vals), 8) # TODO improve ticks # y_ticks = [np.log2(x) for x in [256, 128, 64, 32, 16, 8, 4, 2, 1]] y_ticks = [np.log2(x) for x in [64, 32, 16, 8, 4, 2, 1]] y_labels = [str(x) for x in (np.round(np.exp2(y_ticks), 3))] ax.set_ylabel("Frequency (Hz)") plt.yticks(y_ticks, y_labels) ax.set_title(title) ax.set_xlabel("Time (s)") return (fig, [WCT, aWCT, coi, freq, sig])
def test_wct(lfp1, lfp2, sig=True): # python CWT import pycwt as wavelet dt = 1 / lfp1.get_sampling_rate() WCT, aWCT, coi, freq, sig = wavelet.wct( lfp1.get_samples(), lfp2.get_samples(), dt, sig=sig) _, ax = plt.subplots() t = lfp1.get_timestamp() ax.contourf(t, freq, WCT, 6, extend='both', cmap="viridis") extent = [t.min(), t.max(), 0, max(freq)] N = lfp1.get_total_samples() sig95 = np.ones([1, N]) * sig[:, None] sig95 = WCT / sig95 ax.contour(t, freq, sig95, [-99, 1], colors='k', linewidths=2, extent=extent) ax.fill(np.concatenate([t, t[-1:] + dt, t[-1:] + dt, t[:1] - dt, t[:1] - dt]), np.concatenate([coi, [1e-9], freq[-1:], freq[-1:], [1e-9]]), 'k', alpha=0.3, hatch='x') ax.set_title('Wavelet Power Spectrum') ax.set_ylabel('Freq (Hz)') ax.set_xlabel('Time (s)') plt.show() exit(-1)
def cross_wavelet(signal_1, signal_2, period, mother='morlet'): signal_1 = (signal_1 - signal_1.mean()) / signal_1.std() # Normalizing signal_2 = (signal_2 - signal_2.mean()) / signal_2.std() # Normalizing W12, cross_coi, freq, signif = wavelet.xwt(signal_1, signal_2, period, dj=1 / 100, s0=-1, J=-1, significance_level=0.95, wavelet=mother, normalize=True) cross_power = np.abs(W12)**2 cross_sig = np.ones([1, signal_1.size]) * signif[:, None] cross_sig = cross_power / cross_sig cross_period = 1 / freq WCT, aWCT, corr_coi, freq, sig = wavelet.wct(signal_1, signal_2, period, dj=1 / 100, s0=-1, J=-1, sig=False, significance_level=0.95, wavelet=mother, normalize=True) cor_sig = np.ones([1, signal_1.size]) * sig[:, None] cor_sig = np.abs(WCT) / cor_sig cor_period = 1 / freq t1 = np.linspace(0, period * signal_1.size, signal_1.size) idx = find_closest(cor_period, corr_coi.max()) t1 /= 60 cross_period /= 60 cor_period /= 60 cross_coi /= 60 corr_coi /= 60 return W12, WCT, aWCT, cor_period, corr_coi, cor_sig, idx, t1
wavelet='morlet', normalize=True) cross_power = np.abs(W12)**2 cross_sig = np.ones([1, n]) * signif[:, None] cross_sig = cross_power / cross_sig # Power is significant where ratio > 1 cross_period = 1/freq '''Calculate the wavelet coherence (WTC). The WTC finds regions in time frequency space where the two time seris co-vary, but do not necessarily have high power.''' WCT, aWCT, corr_coi, freq, sig = wavelet.wct(s1, s2, dt, dj=1/12, s0=-1, J=-1, significance_level=0.8646, wavelet='morlet', normalize=True, cache=True) cor_sig = np.ones([1, n]) * sig[:, None] cor_sig = np.abs(WCT) / cor_sig # Power is significant where ratio > 1 cor_period = 1 / freq ''' Calculates the phase between both time series. The phase arrows in the cross wavelet power spectrum rotate clockwise with 'north' origin. The relative phase relationship convention is the same as adopted by Torrence and Webster (1999), where in phase signals point upwards (N), anti-phase signals point downwards (S). If X leads Y,
################################ # basic parameters fs = 1 / dt dj = 1 / 12 s0 = -1 J = -1 sig = False wvn = 'morlet' # wavelet cross spectrm WCT, aWCT, coi, freq, sig = pycwt.wct(ndata, data, 1 / fs, dj=dj, s0=s0, J=J, sig=sig, wavelet=wvn, normalize=True) # unwrap the phase phase = np.unwrap( aWCT, axis=-1) # axis=0, upwrap along time; axis=-1, unwrap along frequency delta_t = phase / (2 * np.pi * freq[:, None] ) # normalize phase by (2*pi*frequency) plot_wxt = True if plot_wxt: plt.imshow(delta_t, cmap='jet',
def wxs_allfreq(cur, ref, allfreq, para, dj=1 / 12, s0=-1, J=-1, sig=False, wvn='morlet', unwrapflag=False): """ Compute dt or dv/v in time and frequency domain from wavelet cross spectrum (wxs). for all frequecies in an interest range Parameters -------------- :type cur: :class:`~numpy.ndarray` :param cur: 1d array. Cross-correlation measurements. :type ref: :class:`~numpy.ndarray` :param ref: 1d array. The reference trace. :type t: :class:`~numpy.ndarray` :param t: 1d array. Cross-correlation measurements. :param twindow: 1d array. [earlist time, latest time] time window limit :param fwindow: 1d array. [lowest frequncy, highest frequency] frequency window limit :params, dj, s0, J, sig, wvn, refer to function 'wavelet.wct' :unwrapflag: True - unwrap phase delays. Default is False :nwindow: the times of current period/frequency, which will be time window if windowflag is False :windowflag: if True, the given window 'twindow' will be used, otherwise, the current period*nwindow will be used as time window Originally written by Tim Clements (1 March, 2019) Modified by Congcong Yuan (30 June, 2019) based on (Mao et al. 2019). """ # common variables twin = para['twin'] freq = para['freq'] dt = para['dt'] tmin = np.min(twin) tmax = np.max(twin) fmin = np.min(freq) fmax = np.max(freq) tvec = np.arange(tmin, tmax, dt) # perform cross coherent analysis, modified from function 'wavelet.cwt' WCT, aWCT, coi, freq, sig = pycwt.wct(cur, ref, dt, dj=dj, s0=s0, J=J, sig=sig, wavelet=wvn, normalize=True) if unwrapflag: phase = np.unwrap( aWCT, axis=-1 ) # axis=0, upwrap along time; axis=-1, unwrap along frequency else: phase = aWCT # convert phase delay to time delay delta_t = phase / (2 * np.pi * freq[:, None] ) # normalize phase by (2*pi*frequency) # zero out data outside frequency band if (fmax > np.max(freq)) | (fmax <= fmin): raise ValueError('Abort: input frequency out of limits!') else: freq_indin = np.where((freq >= fmin) & (freq <= fmax))[0] # initialize arrays for dv/v measurements dvv, err = np.zeros(freq_indin.shape), np.zeros(freq_indin.shape) # loop through freq for linear regression for ii, ifreq in enumerate(freq_indin): if len(tvec) > 2: if not np.any(delta_t[ifreq]): continue #---- use WXA as weight for regression---- # w = 1.0 / (1.0 / (WCT[ifreq,:] ** 2) - 1.0) # w[WCT[ifreq,time_ind] >= 0.99] = 1.0 / (1.0 / 0.9801 - 1.0) # w = np.sqrt(w * np.sqrt(WXA[ifreq,time_ind])) # w = np.real(w) w = 1 / WCT[ifreq] w[~np.isfinite(w)] = 1.0 #m, a, em, ea = linear_regression(time_axis[indx], delta_t[indx], w, intercept_origin=False) m, em = linear_regression(tvec, delta_t[ifreq], w, intercept_origin=True) dvv[ii], err[ii] = -m, em else: print('not enough points to estimate dv/v') dvv[ii], err[ii] = np.nan, np.nan del WCT, aWCT, coi, sig, phase, delta_t del tvec, w, m, em if not allfreq: return np.mean(dvv) * 100, np.mean(err) * 100 else: return freq[freq_indin], dvv * 100, err * 100
# the PPF using chi2.ppf gives Z2(95%)=5.991. To ensure similar significance # intervals as in Grinsted et al. (2004), one has to use confidence of 86.46%. W12, cross_coi, freq, signif = wavelet.xwt(s1, s2, dt, dj=1/12, s0=-1, J=-1, significance_level=0.8646, wavelet='morlet', normalize=True) cross_power = np.abs(W12)**2 cross_sig = np.ones([1, n]) * signif[:, None] cross_sig = cross_power / cross_sig # Power is significant where ratio > 1 cross_period = 1/freq # Calculate the wavelet coherence (WTC). The WTC finds regions in time # frequency space where the two time seris co-vary, but do not necessarily have # high power. WCT, aWCT, corr_coi, freq, sig = wavelet.wct(s1, s2, dt, dj=1/12, s0=-1, J=-1, significance_level=0.8646, wavelet='morlet', normalize=True, cache=True) cor_sig = np.ones([1, n]) * sig[:, None] cor_sig = np.abs(WCT) / cor_sig # Power is significant where ratio > 1 cor_period = 1 / freq # Calculates the phase between both time series. The phase arrows in the # cross wavelet power spectrum rotate clockwise with 'north' origin. # The relative phase relationship convention is the same as adopted # by Torrence and Webster (1999), where in phase signals point # upwards (N), anti-phase signals point downwards (S). If X leads Y, # arrows point to the right (E) and if X lags Y, arrow points to the # left (W). angle = 0.5 * np.pi - aWCT u, v = np.cos(angle), np.sin(angle)
def calc_wave_coherence(wave1, wave2, sample_times, min_freq=1, max_freq=128, sig=False, resolution=12): """ Calculate wavelet coherence between wave1 and wave2 using pycwt. Parameters ---------- wave1 : np.ndarray The values of the first waveform. wave2 : np.ndarray The values of the second waveform. sample_times : np.ndarray The times at which waveform samples occur. min_freq : float Supposed to be minimum frequency, but not quite working. max_freq : float Supposed to be max frequency, but not quite working. sig : bool, default False Optional Should significance of waveform coherence be calculated. resolution : int How many wavelets should be at each level Returns ------- WCT, t, freq, coi, sig, aWCT WCT - 2D numpy array with coherence values t - 2D numpy array with sample_times freq - 1D numpy array with the frequencies wavelets were calculated at coi - 1D numpy array with a frequency value for each time sig - 2D numpy array indicating where data is significant by monte carlo aWCT - 2D numpy array with same shape as aWCT indicating phase angles """ t = np.asarray(sample_times) dt = np.mean(np.diff(t)) # dt = 0.004 dj = resolution s0 = 1 / max_freq if s0 < (2 * dt): s0 = 2 * dt max_J = 1 / min_freq J = dj * np.int(np.round(np.log2(max_J / np.abs(s0)))) # # Original by Sean # s0 = min_freq * dt # if s0 < (2 * dt): # s0 = 2 * dt # max_J = max_freq * dt # J = dj * np.int(np.round(np.log2(max_J / np.abs(s0)))) # Do the actual calculation print("Calculating coherence...") start_time = time.time() WCT, aWCT, coi, freq, sig = wavelet.wct( wave1, wave2, dt, # Fixed params dj=(1.0 / dj), s0=s0, J=J, sig=sig, normalize=True) print("Time Taken: %s s" % (time.time() - start_time)) if np.max(WCT) > 1 or np.min(WCT) < 0: print('WCT was out of range: min {},max {}'.format( np.min(WCT), np.max(WCT))) WCT = np.clip(WCT, 0, 1) return WCT, t, freq, coi, sig, aWCT
def cross_wavelet(self, signal_1, signal_2, mother='morlet', plot=True): signal_1 = (signal_1 - signal_1.mean()) / signal_1.std() # Normalizing signal_2 = (signal_2 - signal_2.mean()) / signal_2.std() # Normalizing W12, cross_coi, freq, signif = wavelet.xwt(signal_1, signal_2, self.period, dj=1/100, s0=-1, J=-1, significance_level=0.95, wavelet=mother, normalize=True) cross_power = np.abs(W12)**2 cross_sig = np.ones([1, signal_1.size]) * signif[:, None] cross_sig = cross_power / cross_sig cross_period = 1/freq WCT, aWCT, corr_coi, freq, sig = wavelet.wct(signal_1, signal_2, self.period, dj=1/100, s0=-1, J=-1, sig=False,significance_level=0.95, wavelet=mother, normalize=True) cor_sig = np.ones([1, signal_1.size]) * sig[:, None] cor_sig = np.abs(WCT) / cor_sig cor_period = 1/freq angle = 0.5 * np.pi - aWCT u, v = np.cos(angle), np.sin(angle) t1 = np.linspace(0,self.period*signal_1.size,signal_1.size) ## indices for stuff idx = self.find_closest(cor_period,corr_coi.max()) ## Into minutes t1 /= 60 cross_period /= 60 cor_period /= 60 cross_coi /= 60 corr_coi /= 60 fig1, ax1 = plt.subplots(nrows=1,ncols=1, sharex=True, sharey=True, figsize=(12,12)) extent_cross = [t1.min(),t1.max(),0,max(cross_period)] extent_corr = [t1.min(),t1.max(),0,max(cor_period)] im1 = NonUniformImage(ax1, interpolation='nearest', extent=extent_cross) im1.set_cmap('cubehelix') im1.set_data(t1, cross_period[:idx], cross_power[:idx,:]) ax1.images.append(im1) ax1.contour(t1, cross_period[:idx], cross_sig[:idx,:], [-99, 1], colors='k', linewidths=2, extent=extent_cross) ax1.fill(np.concatenate([t1, t1[-1:]+self.period, t1[-1:]+self.period,t1[:1]-self.period, t1[:1]-self.period]), (np.concatenate([cross_coi,[1e-9], cross_period[-1:], cross_period[-1:], [1e-9]])), 'k', alpha=0.3,hatch='x') ax1.set_title('Cross-Wavelet') # ax1.quiver(t1[::3], cross_period[::3], u[::3, ::3], # v[::3, ::3], units='width', angles='uv', pivot='mid', # linewidth=1.5, edgecolor='k', headwidth=10, headlength=10, # headaxislength=5, minshaft=2, minlength=5) ax1.set_ylim(([min(cross_period), cross_period[idx]])) ax1.set_xlim(t1.min(),t1.max()) fig2, ax2 = plt.subplots(nrows=1,ncols=1, sharex=True, sharey=True, figsize=(12,12)) fig2.subplots_adjust(right=0.8) cbar_ax_1 = fig2.add_axes([0.85, 0.05, 0.05, 0.35]) im2 = NonUniformImage(ax2, interpolation='nearest', extent=extent_corr) im2.set_cmap('cubehelix') im2.set_data(t1, cor_period[:idx], np.log10(WCT[:idx,:])) ax2.images.append(im2) ax2.contour(t1, cor_period[:idx], cor_sig[:idx,:], [-99, 1], colors='k', linewidths=2, extent=extent_corr) ax2.fill(np.concatenate([t1, t1[-1:]+self.period, t1[-1:]+self.period,t1[:1]-self.period, t1[:1]-self.period]), (np.concatenate([corr_coi,[1e-9], cor_period[-1:], cor_period[-1:], [1e-9]])), 'k', alpha=0.3,hatch='x') ax2.set_title('Cross-Correlation') # ax2.quiver(t1[::3], cor_period[::3], u[::3,::3], v[::3,::3], # units='height', angles='uv', pivot='mid',linewidth=1.5, edgecolor='k', # headwidth=10, headlength=10, headaxislength=5, minshaft=2, minlength=5) ax2.set_ylim(([min(cor_period), cor_period[idx]])) ax2.set_xlim(t1.min(),t1.max()) fig2.colorbar(im2, cax=cbar_ax_1) plt.show() plt.figure(figsize=(12,12)) im3= plt.imshow(np.rad2deg(aWCT), origin='lower',interpolation='nearest', cmap='seismic', extent=extent_corr) plt.fill(np.concatenate([t1, t1[-1:]+self.period, t1[-1:]+self.period,t1[:1]-self.period, t1[:1]-self.period]), (np.concatenate([corr_coi,[1e-9], cor_period[-1:], cor_period[-1:], [1e-9]])), 'k', alpha=0.3,hatch='x') plt.ylim(([min(cor_period), cor_period[idx]])) plt.xlim(t1.min(),t1.max()) plt.colorbar(im3) plt.show() return