def plot_wave_coherence(wave1,
                        wave2,
                        sample_times,
                        min_freq=1,
                        max_freq=256,
                        sig=False,
                        ax=None,
                        title="Wavelet Coherence",
                        plot_arrows=True,
                        plot_coi=True,
                        plot_period=False,
                        resolution=12,
                        all_arrows=True,
                        quiv_x=5,
                        quiv_y=24,
                        block=None):
    """
    Calculate wavelet coherence between wave1 and wave2 using pycwt.

    TODO fix min_freq, max_freq and add parameters to control arrows.
    TODO also test out sig on a large dataset

    Parameters
    ----------
    wave1 : np.ndarray
        The values of the first waveform.
    wave2 : np.ndarray
        The values of the second waveform.
    sample_times : np.ndarray
        The times at which waveform samples occur.
    min_freq : float
        Supposed to be minimum frequency, but not quite working.
    max_freq : float
        Supposed to be max frequency, but not quite working.
    sig : bool, default False
        Optional Should significance of waveform coherence be calculated.
    ax : plt.axe, default None
        Optional ax object to plot into.
    title : str, default "Wavelet Coherence"
        Optional title for the graph
    plot_arrows : bool, default True
        Should phase arrows be plotted.
    plot_coi : bool, default True
        Should the cone of influence be plotted
    plot_period : bool
        Should the y-axis be in period or in frequency (Hz)
    resolution : int
        How many wavelets should be at each level of the graph
    all_arrows : bool
        Should phase arrows be plotted uniformly or only at high coherence
    quiv_x : float
        sets quiver window in time domain in seconds
    quiv_y : float
        sets number of quivers evenly distributed across freq limits
    block : [int, int]
        Plots only points between ints.

    Returns
    -------
    tuple : (fig, result)
        Where fig is a matplotlib Figure
        and result is a tuple consisting of WCT, aWCT, coi, freq, sig
        WCT - 2D numpy array with coherence values
        aWCT - 2D numpy array with same shape as aWCT indicating phase angles
        coi - 1D numpy array with a frequency value for each time
        freq - 1D numpy array with the frequencies wavelets were calculated at
        sig - 2D numpy array indicating where data is significant by monte carlo

    """
    t = np.asarray(sample_times)
    dt = np.mean(np.diff(t))
    # Set up the scales to match min max input frequencies
    dj = resolution
    s0 = min_freq * dt
    if s0 < 2 * dt:
        s0 = 2 * dt
    max_J = max_freq * dt
    J = dj * np.int(np.round(np.log2(max_J / np.abs(s0))))
    # freqs = np.geomspace(max_freq, min_freq, num=50)
    freqs = None

    # Do the actual calculation
    print("Calculating coherence...")
    start_time = time.time()
    WCT, aWCT, coi, freq, sig = wavelet.wct(
        wave1,
        wave2,
        dt,  # Fixed params
        dj=(1.0 / dj),
        s0=s0,
        J=J,
        sig=sig,
        normalize=True,
        freqs=freqs,
    )
    print("Time Taken: %s s" % (time.time() - start_time))
    if np.max(WCT) > 1 or np.min(WCT) < 0:
        print('WCT was out of range: min {},max {}'.format(
            np.min(WCT), np.max(WCT)))
        WCT = np.clip(WCT, 0, 1)

    # Convert frequency to period if necessary
    if plot_period:
        y_vals = np.log2(1 / freq)
    if not plot_period:
        y_vals = np.log2(freq)

    # Calculates the phase between both time series. The phase arrows in the
    # cross wavelet power spectrum rotate clockwise with 'north' origin.
    # The relative phase relationship convention is the same as adopted
    # by Torrence and Webster (1999), where in phase signals point
    # upwards (N), anti-phase signals point downwards (S). If X leads Y,
    # arrows point to the right (E) and if X lags Y, arrow points to the
    # left (W).
    angle = 0.5 * np.pi - aWCT
    u, v = np.cos(angle), np.sin(angle)

    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = None

    # Set the x and y axes of the plot
    extent_corr = [t.min(), t.max(), 0, max(y_vals)]

    # Fill the plot with the magnitude squared coherence values
    # That is, MSC = abs(Pxy) ^ 2 / (Pxx * Pyy)
    # TODO I think this might be the wrong way to plot this
    # It assumes that the samples are linearly spaced
    im = NonUniformImage(ax, interpolation='bilinear', extent=extent_corr)
    if plot_period:
        im.set_data(t, y_vals, WCT)
    else:
        im.set_data(t, y_vals[::-1], WCT[::-1, :])
    ax.images.append(im)
    # pcm = ax.pcolormesh(WCT)

    # Plot the cone of influence - Periods greater than
    # those are subject to edge effects.
    if plot_coi:
        # Performed by plotting a polygon
        x_positions = np.zeros(shape=(len(t), ))
        x_positions = t

        y_positions = np.zeros(shape=(len(t), ))
        if plot_period:
            y_positions = np.log2(coi)
        else:
            y_positions = np.log2(1 / coi)

        ax.plot(x_positions, y_positions, 'w--', linewidth=2, c="w")

    # Plot the significance level contour plot
    if sig:
        ax.contour(t,
                   y_vals,
                   sig, [-99, 1],
                   colors='k',
                   linewidths=2,
                   extent=extent_corr)

    # Add limits, titles, etc.
    ax.set_ylim(min(y_vals), max(y_vals))
    if block:
        ax.set_xlim(t[block[0]], t[int(block[1] * 1 / dt)])
    else:
        ax.set_xlim(t.min(), t.max())

    # TODO split graph into smaller time chunks
    # Test for smaller timescale
    # quiv_x = 1

    # Plot the arrows on the plot
    if plot_arrows:
        # TODO currently this is a uniform grid, could be changed to WCT > 0.5

        x_res = int(1 / dt * quiv_x)
        y_res = int(np.floor(len(y_vals) / quiv_y))
        if all_arrows:
            ax.quiver(
                t[::x_res],
                y_vals[::y_res],
                u[::y_res, ::x_res],
                v[::y_res, ::x_res],
                units='height',
                angles='uv',
                pivot='mid',
                linewidth=1,
                edgecolor='k',
                scale=30,
                headwidth=10,
                headlength=10,
                headaxislength=5,
                minshaft=2,
            )
        else:
            # t[::x_res], y_vals[::y_res],
            # u[::y_res, ::x_res], v[::y_res, ::x_res]
            high_points = np.nonzero(WCT[::y_res, ::x_res] > 0.5)
            sub_t = t[::x_res][high_points[1]]
            sub_y = y_vals[::y_res][high_points[0]]
            sub_u = u[::y_res, ::x_res][np.array(high_points[0]),
                                        np.array(high_points[1])]
            sub_v = v[::y_res, ::x_res][high_points[0], high_points[1]]
            res = 1
            ax.quiver(
                sub_t[::res],
                sub_y[::res],
                sub_u[::res],
                sub_v[::res],
                units='height',
                angles='uv',
                pivot='mid',
                linewidth=1,
                edgecolor='k',
                scale=30,
                headwidth=10,
                headlength=10,
                headaxislength=5,
                minshaft=2,
            )
    # splits = [0, 60, 120 ...]
    # Add the colorbar to the figure
    if fig is not None:
        fig.colorbar(im)
    else:
        plt.colorbar(im, ax=ax, use_gridspec=True)

    if plot_period:
        y_ticks = np.linspace(min(y_vals), max(y_vals), 8)
        # TODO improve ticks
        y_ticks = [
            np.log2(x)
            for x in [0.004, 0.008, 0.016, 0.032, 0.064, 0.125, 0.25, 0.5, 1]
        ]
        y_labels = [str(x) for x in (np.round(np.exp2(y_ticks), 3))]
        ax.set_ylabel("Period")
    else:
        y_ticks = np.linspace(min(y_vals), max(y_vals), 8)
        # TODO improve ticks
        # y_ticks = [np.log2(x) for x in [256, 128, 64, 32, 16, 8, 4, 2, 1]]
        y_ticks = [np.log2(x) for x in [64, 32, 16, 8, 4, 2, 1]]
        y_labels = [str(x) for x in (np.round(np.exp2(y_ticks), 3))]
        ax.set_ylabel("Frequency (Hz)")
    plt.yticks(y_ticks, y_labels)
    ax.set_title(title)
    ax.set_xlabel("Time (s)")

    return (fig, [WCT, aWCT, coi, freq, sig])
def test_wct(lfp1, lfp2, sig=True):  # python CWT
    import pycwt as wavelet
    dt = 1 / lfp1.get_sampling_rate()
    WCT, aWCT, coi, freq, sig = wavelet.wct(
        lfp1.get_samples(), lfp2.get_samples(), dt, sig=sig)
    _, ax = plt.subplots()
    t = lfp1.get_timestamp()
    ax.contourf(t, freq, WCT, 6, extend='both', cmap="viridis")
    extent = [t.min(), t.max(), 0, max(freq)]
    N = lfp1.get_total_samples()
    sig95 = np.ones([1, N]) * sig[:, None]
    sig95 = WCT / sig95
    ax.contour(t, freq, sig95, [-99, 1], colors='k', linewidths=2,
               extent=extent)
    ax.fill(np.concatenate([t, t[-1:] + dt, t[-1:] + dt,
                            t[:1] - dt, t[:1] - dt]),
            np.concatenate([coi, [1e-9], freq[-1:],
                            freq[-1:], [1e-9]]),
            'k', alpha=0.3, hatch='x')

    ax.set_title('Wavelet Power Spectrum')
    ax.set_ylabel('Freq (Hz)')
    ax.set_xlabel('Time (s)')

    plt.show()
    exit(-1)
示例#3
0
def cross_wavelet(signal_1, signal_2, period, mother='morlet'):

    signal_1 = (signal_1 - signal_1.mean()) / signal_1.std()  # Normalizing
    signal_2 = (signal_2 - signal_2.mean()) / signal_2.std()  # Normalizing

    W12, cross_coi, freq, signif = wavelet.xwt(signal_1,
                                               signal_2,
                                               period,
                                               dj=1 / 100,
                                               s0=-1,
                                               J=-1,
                                               significance_level=0.95,
                                               wavelet=mother,
                                               normalize=True)

    cross_power = np.abs(W12)**2
    cross_sig = np.ones([1, signal_1.size]) * signif[:, None]
    cross_sig = cross_power / cross_sig
    cross_period = 1 / freq

    WCT, aWCT, corr_coi, freq, sig = wavelet.wct(signal_1,
                                                 signal_2,
                                                 period,
                                                 dj=1 / 100,
                                                 s0=-1,
                                                 J=-1,
                                                 sig=False,
                                                 significance_level=0.95,
                                                 wavelet=mother,
                                                 normalize=True)

    cor_sig = np.ones([1, signal_1.size]) * sig[:, None]
    cor_sig = np.abs(WCT) / cor_sig
    cor_period = 1 / freq
    t1 = np.linspace(0, period * signal_1.size, signal_1.size)
    idx = find_closest(cor_period, corr_coi.max())

    t1 /= 60
    cross_period /= 60
    cor_period /= 60
    cross_coi /= 60
    corr_coi /= 60

    return W12, WCT, aWCT, cor_period, corr_coi, cor_sig, idx, t1
                                           wavelet='morlet',
                                           normalize=True)

cross_power = np.abs(W12)**2
cross_sig = np.ones([1, n]) * signif[:, None]
cross_sig = cross_power / cross_sig  # Power is significant where ratio > 1
cross_period = 1/freq



'''Calculate the wavelet coherence (WTC).
The WTC finds regions in time frequency space where the
two time seris co-vary, but do not necessarily have high power.'''
WCT, aWCT, corr_coi, freq, sig = wavelet.wct(s1, s2, dt,
                                             dj=1/12, s0=-1, J=-1,
                                             significance_level=0.8646,
                                             wavelet='morlet', normalize=True,
                                             cache=True)

cor_sig = np.ones([1, n]) * sig[:, None]
cor_sig = np.abs(WCT) / cor_sig  # Power is significant where ratio > 1
cor_period = 1 / freq



''' Calculates the phase between both time series. 
The phase arrows in the cross wavelet power spectrum rotate clockwise
with 'north' origin.
The relative phase relationship convention is the same as adopted
by Torrence and Webster (1999), where in phase signals point
upwards (N), anti-phase signals point downwards (S). If X leads Y,
示例#5
0
################################

# basic parameters
fs = 1 / dt
dj = 1 / 12
s0 = -1
J = -1
sig = False
wvn = 'morlet'

# wavelet cross spectrm
WCT, aWCT, coi, freq, sig = pycwt.wct(ndata,
                                      data,
                                      1 / fs,
                                      dj=dj,
                                      s0=s0,
                                      J=J,
                                      sig=sig,
                                      wavelet=wvn,
                                      normalize=True)

# unwrap the phase
phase = np.unwrap(
    aWCT,
    axis=-1)  # axis=0, upwrap along time; axis=-1, unwrap along frequency
delta_t = phase / (2 * np.pi * freq[:, None]
                   )  # normalize phase by (2*pi*frequency)
plot_wxt = True
if plot_wxt:
    plt.imshow(delta_t,
               cmap='jet',
示例#6
0
def wxs_allfreq(cur,
                ref,
                allfreq,
                para,
                dj=1 / 12,
                s0=-1,
                J=-1,
                sig=False,
                wvn='morlet',
                unwrapflag=False):
    """
    Compute dt or dv/v in time and frequency domain from wavelet cross spectrum (wxs).
    for all frequecies in an interest range
    
    Parameters
    --------------
    :type cur: :class:`~numpy.ndarray`
    :param cur: 1d array. Cross-correlation measurements.
    :type ref: :class:`~numpy.ndarray`
    :param ref: 1d array. The reference trace.
    :type t: :class:`~numpy.ndarray`
    :param t: 1d array. Cross-correlation measurements.
    :param twindow: 1d array. [earlist time, latest time] time window limit
    :param fwindow: 1d array. [lowest frequncy, highest frequency] frequency window limit
    :params, dj, s0, J, sig, wvn, refer to function 'wavelet.wct'
    :unwrapflag: True - unwrap phase delays. Default is False
    :nwindow: the times of current period/frequency, which will be time window if windowflag is False 
    :windowflag: if True, the given window 'twindow' will be used, 
                 otherwise, the current period*nwindow will be used as time window
    
    Originally written by Tim Clements (1 March, 2019)
    Modified by Congcong Yuan (30 June, 2019) based on (Mao et al. 2019).
    """
    # common variables
    twin = para['twin']
    freq = para['freq']
    dt = para['dt']
    tmin = np.min(twin)
    tmax = np.max(twin)
    fmin = np.min(freq)
    fmax = np.max(freq)
    tvec = np.arange(tmin, tmax, dt)

    # perform cross coherent analysis, modified from function 'wavelet.cwt'
    WCT, aWCT, coi, freq, sig = pycwt.wct(cur,
                                          ref,
                                          dt,
                                          dj=dj,
                                          s0=s0,
                                          J=J,
                                          sig=sig,
                                          wavelet=wvn,
                                          normalize=True)

    if unwrapflag:
        phase = np.unwrap(
            aWCT, axis=-1
        )  # axis=0, upwrap along time; axis=-1, unwrap along frequency
    else:
        phase = aWCT

    # convert phase delay to time delay
    delta_t = phase / (2 * np.pi * freq[:, None]
                       )  # normalize phase by (2*pi*frequency)

    # zero out data outside frequency band
    if (fmax > np.max(freq)) | (fmax <= fmin):
        raise ValueError('Abort: input frequency out of limits!')
    else:
        freq_indin = np.where((freq >= fmin) & (freq <= fmax))[0]

    # initialize arrays for dv/v measurements
    dvv, err = np.zeros(freq_indin.shape), np.zeros(freq_indin.shape)

    # loop through freq for linear regression
    for ii, ifreq in enumerate(freq_indin):
        if len(tvec) > 2:
            if not np.any(delta_t[ifreq]):
                continue
            #---- use WXA as weight for regression----
            # w = 1.0 / (1.0 / (WCT[ifreq,:] ** 2) - 1.0)
            # w[WCT[ifreq,time_ind] >= 0.99] = 1.0 / (1.0 / 0.9801 - 1.0)
            # w = np.sqrt(w * np.sqrt(WXA[ifreq,time_ind]))
            # w = np.real(w)
            w = 1 / WCT[ifreq]
            w[~np.isfinite(w)] = 1.0

            #m, a, em, ea = linear_regression(time_axis[indx], delta_t[indx], w, intercept_origin=False)
            m, em = linear_regression(tvec,
                                      delta_t[ifreq],
                                      w,
                                      intercept_origin=True)
            dvv[ii], err[ii] = -m, em
        else:
            print('not enough points to estimate dv/v')
            dvv[ii], err[ii] = np.nan, np.nan

    del WCT, aWCT, coi, sig, phase, delta_t
    del tvec, w, m, em

    if not allfreq:
        return np.mean(dvv) * 100, np.mean(err) * 100
    else:
        return freq[freq_indin], dvv * 100, err * 100
示例#7
0
# the PPF using chi2.ppf gives Z2(95%)=5.991. To ensure similar significance
# intervals as in Grinsted et al. (2004), one has to use confidence of 86.46%.
W12, cross_coi, freq, signif = wavelet.xwt(s1, s2, dt, dj=1/12, s0=-1, J=-1,
                                           significance_level=0.8646,
                                           wavelet='morlet', normalize=True)

cross_power = np.abs(W12)**2
cross_sig = np.ones([1, n]) * signif[:, None]
cross_sig = cross_power / cross_sig  # Power is significant where ratio > 1
cross_period = 1/freq

# Calculate the wavelet coherence (WTC). The WTC finds regions in time
# frequency space where the two time seris co-vary, but do not necessarily have
# high power.
WCT, aWCT, corr_coi, freq, sig = wavelet.wct(s1, s2, dt, dj=1/12, s0=-1, J=-1,
                                             significance_level=0.8646,
                                             wavelet='morlet', normalize=True,
                                             cache=True)

cor_sig = np.ones([1, n]) * sig[:, None]
cor_sig = np.abs(WCT) / cor_sig  # Power is significant where ratio > 1
cor_period = 1 / freq

# Calculates the phase between both time series. The phase arrows in the
# cross wavelet power spectrum rotate clockwise with 'north' origin.
# The relative phase relationship convention is the same as adopted
# by Torrence and Webster (1999), where in phase signals point
# upwards (N), anti-phase signals point downwards (S). If X leads Y,
# arrows point to the right (E) and if X lags Y, arrow points to the
# left (W).
angle = 0.5 * np.pi - aWCT
u, v = np.cos(angle), np.sin(angle)
def calc_wave_coherence(wave1,
                        wave2,
                        sample_times,
                        min_freq=1,
                        max_freq=128,
                        sig=False,
                        resolution=12):
    """
    Calculate wavelet coherence between wave1 and wave2 using pycwt.

    Parameters
    ----------
    wave1 : np.ndarray
        The values of the first waveform.
    wave2 : np.ndarray
        The values of the second waveform.
    sample_times : np.ndarray
        The times at which waveform samples occur.
    min_freq : float
        Supposed to be minimum frequency, but not quite working.
    max_freq : float
        Supposed to be max frequency, but not quite working.
    sig : bool, default False
        Optional Should significance of waveform coherence be calculated.
    resolution : int
        How many wavelets should be at each level

    Returns
    -------
    WCT, t, freq, coi, sig, aWCT
        WCT - 2D numpy array with coherence values
        t - 2D numpy array with sample_times
        freq - 1D numpy array with the frequencies wavelets were calculated at
        coi - 1D numpy array with a frequency value for each time
        sig - 2D numpy array indicating where data is significant by monte carlo
        aWCT - 2D numpy array with same shape as aWCT indicating phase angles
    """

    t = np.asarray(sample_times)
    dt = np.mean(np.diff(t))  # dt = 0.004

    dj = resolution
    s0 = 1 / max_freq
    if s0 < (2 * dt):
        s0 = 2 * dt
    max_J = 1 / min_freq
    J = dj * np.int(np.round(np.log2(max_J / np.abs(s0))))

    # # Original by Sean
    # s0 = min_freq * dt
    # if s0 < (2 * dt):
    #     s0 = 2 * dt
    # max_J = max_freq * dt
    # J = dj * np.int(np.round(np.log2(max_J / np.abs(s0))))

    # Do the actual calculation
    print("Calculating coherence...")
    start_time = time.time()
    WCT, aWCT, coi, freq, sig = wavelet.wct(
        wave1,
        wave2,
        dt,  # Fixed params
        dj=(1.0 / dj),
        s0=s0,
        J=J,
        sig=sig,
        normalize=True)
    print("Time Taken: %s s" % (time.time() - start_time))
    if np.max(WCT) > 1 or np.min(WCT) < 0:
        print('WCT was out of range: min {},max {}'.format(
            np.min(WCT), np.max(WCT)))
        WCT = np.clip(WCT, 0, 1)

    return WCT, t, freq, coi, sig, aWCT
示例#9
0
    def cross_wavelet(self, signal_1, signal_2, mother='morlet', plot=True):

        signal_1 = (signal_1 - signal_1.mean()) / signal_1.std()    # Normalizing
        signal_2 = (signal_2 - signal_2.mean()) / signal_2.std()    # Normalizing

        W12, cross_coi, freq, signif = wavelet.xwt(signal_1, signal_2, self.period, dj=1/100, s0=-1, J=-1,
                                             significance_level=0.95, wavelet=mother,
                                             normalize=True)

        cross_power = np.abs(W12)**2
        cross_sig = np.ones([1, signal_1.size]) * signif[:, None]
        cross_sig = cross_power / cross_sig
        cross_period = 1/freq

        WCT, aWCT, corr_coi, freq, sig = wavelet.wct(signal_1, signal_2, self.period, dj=1/100, s0=-1, J=-1,
                                                sig=False,significance_level=0.95, wavelet=mother,
                                                normalize=True)

        cor_sig = np.ones([1, signal_1.size]) * sig[:, None]
        cor_sig = np.abs(WCT) / cor_sig
        cor_period = 1/freq

        angle = 0.5 * np.pi - aWCT
        u, v = np.cos(angle), np.sin(angle)


        t1 = np.linspace(0,self.period*signal_1.size,signal_1.size)

        ## indices for stuff
        idx = self.find_closest(cor_period,corr_coi.max())

        ## Into minutes
        t1 /= 60
        cross_period /= 60
        cor_period /= 60
        cross_coi /= 60
        corr_coi /= 60

        fig1, ax1 = plt.subplots(nrows=1,ncols=1, sharex=True, sharey=True, figsize=(12,12))
        extent_cross = [t1.min(),t1.max(),0,max(cross_period)]
        extent_corr =  [t1.min(),t1.max(),0,max(cor_period)]
        im1 = NonUniformImage(ax1, interpolation='nearest', extent=extent_cross)
        im1.set_cmap('cubehelix')
        im1.set_data(t1, cross_period[:idx], cross_power[:idx,:])
        ax1.images.append(im1)
        ax1.contour(t1, cross_period[:idx], cross_sig[:idx,:], [-99, 1], colors='k', linewidths=2, extent=extent_cross)
        ax1.fill(np.concatenate([t1, t1[-1:]+self.period, t1[-1:]+self.period,t1[:1]-self.period, t1[:1]-self.period]),
                (np.concatenate([cross_coi,[1e-9], cross_period[-1:], cross_period[-1:], [1e-9]])),
                'k', alpha=0.3,hatch='x')
        ax1.set_title('Cross-Wavelet')
#        ax1.quiver(t1[::3], cross_period[::3], u[::3, ::3],
#                  v[::3, ::3], units='width', angles='uv', pivot='mid',
#                  linewidth=1.5, edgecolor='k', headwidth=10, headlength=10,
#                  headaxislength=5, minshaft=2, minlength=5)
        ax1.set_ylim(([min(cross_period), cross_period[idx]]))
        ax1.set_xlim(t1.min(),t1.max())

        fig2, ax2 = plt.subplots(nrows=1,ncols=1, sharex=True, sharey=True, figsize=(12,12))
        fig2.subplots_adjust(right=0.8)
        cbar_ax_1 = fig2.add_axes([0.85, 0.05, 0.05, 0.35])
        im2 = NonUniformImage(ax2, interpolation='nearest', extent=extent_corr)
        im2.set_cmap('cubehelix')
        im2.set_data(t1, cor_period[:idx], np.log10(WCT[:idx,:]))
        ax2.images.append(im2)
        ax2.contour(t1, cor_period[:idx], cor_sig[:idx,:], [-99, 1], colors='k', linewidths=2, extent=extent_corr)
        ax2.fill(np.concatenate([t1, t1[-1:]+self.period, t1[-1:]+self.period,t1[:1]-self.period, t1[:1]-self.period]),
                (np.concatenate([corr_coi,[1e-9], cor_period[-1:], cor_period[-1:], [1e-9]])),
                'k', alpha=0.3,hatch='x')
        ax2.set_title('Cross-Correlation')
#        ax2.quiver(t1[::3], cor_period[::3], u[::3,::3], v[::3,::3],
#                   units='height', angles='uv', pivot='mid',linewidth=1.5, edgecolor='k',
#                   headwidth=10, headlength=10, headaxislength=5, minshaft=2, minlength=5)
        ax2.set_ylim(([min(cor_period), cor_period[idx]]))
        ax2.set_xlim(t1.min(),t1.max())
        fig2.colorbar(im2, cax=cbar_ax_1)

        plt.show()

        plt.figure(figsize=(12,12))
        im3= plt.imshow(np.rad2deg(aWCT), origin='lower',interpolation='nearest', cmap='seismic', extent=extent_corr)
        plt.fill(np.concatenate([t1, t1[-1:]+self.period, t1[-1:]+self.period,t1[:1]-self.period, t1[:1]-self.period]),
                (np.concatenate([corr_coi,[1e-9], cor_period[-1:], cor_period[-1:], [1e-9]])),
                'k', alpha=0.3,hatch='x')
        plt.ylim(([min(cor_period), cor_period[idx]]))
        plt.xlim(t1.min(),t1.max())
        plt.colorbar(im3)
        plt.show()


        return