Example #1
0
def test_multi_dim():
    data = np.random.random((10, 100))
    wa = WaveletAnalysis(data, frequency=True)
    ns = len(wa.scales)
    assert (wa.wavelet_transform.shape == (ns, 10, 100))

    wan = WaveletAnalysis(data[0], frequency=True)
    assert (wan.wavelet_transform.shape == (ns, 100))

    npt.assert_array_almost_equal(wa.wavelet_transform[:, 0, :],
                                  wan.wavelet_transform[:, :],
                                  decimal=13)
Example #2
0
def test_reconstruction_freq():
    """In principle one can reconstruct the input data from the
    wavelet transform.

    Check within 10% when computing with frequency representation of
    wavelet.
    """
    wa = WaveletAnalysis(anomaly_sst, frequency=True)
    rdata = wa.reconstruction()

    err = wa.data - rdata
    assert(np.abs(err.mean()) < 0.02)
Example #3
0
def test_reconstruction_freq():
    """In principle one can reconstruct the input data from the
    wavelet transform.

    Check within 10% when computing with frequency representation of
    wavelet.
    """
    wa = WaveletAnalysis(anomaly_sst, frequency=True)
    rdata = wa.reconstruction()

    err = wa.data - rdata
    assert (np.abs(err.mean()) < 0.02)
Example #4
0
def test_multi_dim_axis_nd_time():
    data = np.random.random((3, 4, 100, 5))
    wa = WaveletAnalysis(data, frequency=False, axis=2)
    ns = len(wa.scales)
    print(wa.wavelet_transform.shape)
    print(ns)
    assert (wa.wavelet_transform.shape == (ns, 3, 4, 100, 5))

    wan = WaveletAnalysis(data[0, 0, :, 0], frequency=False)
    print(wan.wavelet_transform.shape)
    assert (wan.wavelet_transform.shape == (ns, 100))

    npt.assert_array_almost_equal(wa.wavelet_transform[:, 0, 0, :, 0],
                                  wan.wavelet_transform[:, :],
                                  decimal=13)
Example #5
0
def cwt(attr, old, new):
    calculateCWT.button_type = "warning"
    # Perform wavelet transform
    tStep = np.int(wvtTimestep.value)
    yWvt = np.array(ds_analysis.data['y'])[0::tStep]
    wa = WaveletAnalysis(yWvt,
                         wavelet=Morlet(w0=np.float(wvtWidth.value)),
                         dt=sampling_dt * tStep,
                         dj=np.float(wvtFreqstep.value))
    wvt = wa.time
    wvfreq = wa.fourier_frequencies
    wvpower = np.flipud(wa.wavelet_power)
    wvt_ds.data = {
        'image': [wvpower],
        'dw': [wvpower.shape[1]],
        'dh': [wvpower.shape[0]],
        'wv_time': [wvt],
        'wv_freq': [wvfreq]
    }
    p_wavelet.x_range.end = wvpower.shape[
        1]  # need to set ranges before creating image, and not using Range1d?
    p_wavelet.y_range.end = wvpower.shape[
        0]  # need to set ranges before creating image, and not using Range1d?

    calculateCWT.button_type = "success"
Example #6
0
def compare_morlet(N=2000):
    """Compare scipy morlet with my morlet (same, but correct
    argument order).
    """
    data = np.random.random(N)
    wave_anal = WaveletAnalysis(data, wavelet='ricker')
    scales = wave_anal.scales[::-1]

    cwt = wavelets.cwt
    cwt_sp = cwt(data, scipy.signal.morlet, scales)
    cwt_me = cwt(data, wavelets.Morlet(), scales)
    cwt_ri = cwt(data, scipy.signal.ricker, scales)

    t = np.indices(data.shape)
    T, S = np.meshgrid(t, scales)

    fig, ax = plt.subplots(nrows=3)

    ax[0].set_title('Scipy morlet')
    ax[0].contourf(T, S, cwt_sp, 100)

    ax[1].set_title('My morlet')
    ax[1].contourf(T, S, cwt_me, 100)

    ax[2].set_title('Scipy Ricker')
    ax[2].contourf(T, S, cwt_ri, 100)

    fig.tight_layout()

    return fig
Example #7
0
def analyse_song():
    """Compute the wavelet transform of a song."""
    fs, song = wavfile.read('alarma.wav')

    # select first part of one channel
    stride = 1
    # time step is inverse sample rate * stride
    dt = stride / fs
    # number of seconds of song to analyse
    t_s = 1
    n_s = fs * t_s

    # sub sample song on a single channel
    sub_song = song[:n_s:stride, 0]

    wa = WaveletAnalysis(sub_song, dt=dt)

    fig, ax = plt.subplots()
    T, F = np.meshgrid(wa.time, wa.fourier_periods)
    freqs = 1 / F
    ax.contourf(T, freqs, wa.wavelet_power, 100)
    ax.set_yscale('log')

    ax.set_ylabel('frequency (Hz)')
    ax.set_xlabel('time (s)')

    ax.set_ylim(100, 10000)

    fig.savefig('alarma_wavelet.png')
Example #8
0
    def _calculate_wavelet(self):
        '''
        Get a list of WaveletAnalyses, one for every node

        '''

        from wavelets import WaveletAnalysis
        wavelets = [
            WaveletAnalysis(force, time=self._time, dt=self.sampling_period)
            for force in self.force.T
        ]
        frequencies = wavelets[
            0].fourier_frequencies / self._normalized_frequency_by
        f_mask = numpy.multiply(frequencies - self.frequency_min,
                                frequencies - self.frequency_max) <= 0
        self._wavelet_frequency = frequencies[f_mask]
        wavelet_dominant_frequencies = []
        self._wavelet_power = []
        # for i in range(len(self.wavelet)):
        for wavelet_ in wavelets:
            power = wavelet_.wavelet_power
            power = power[f_mask]
            wavelet_power_spectral_density_max_index = numpy.argmax(power,
                                                                    axis=0)
            self._wavelet_power.append(power)
            wavelet_dominant_frequencies.append(
                self.
                _wavelet_frequency[wavelet_power_spectral_density_max_index])
        # _wavelet_power[node, time, frequency]
        # self._wavelet_power=numpy.stack(wavelet_power)
        # wavelet_dominant_frequencies[node, time]
        self._wavelet_dominant_frequencies = numpy.stack(
            wavelet_dominant_frequencies)
def paul_wav(time, flux):
    """Using Aaron O'Leary's wavelet package to compute the paul wavelet.
       Paul wavelet is used in computing the Gradient of the Power Spectrum.

    Args:
        time (List): Time values from processed data file.
        flux (List): Flux values from processed data file.
    """
    dt = time[1] - time[0]
    # Package implementation
    wa = WaveletAnalysis(data=flux, time=time, wavelet=Paul(), dt=dt)
    power = wa.wavelet_power
    scales = wa.scales
    periods = wa.fourier_periods
    frequencies = wa.fourier_frequencies

    t = wa.time

    GPS(time, frequencies, periods, np.sum(power, axis=1))
    #Attempting to plot period values on a 1-D grid.
    plt.plot(scales, np.sum(power, axis=1))
    plt.show()

    # Plotting wavelet results on 2D map.
    fig, ax = plt.subplots()
    T, S = np.meshgrid(t, scales)

    ax.contourf(T, S, power, 100)
    ax.set_yscale('log')
    plt.show()
Example #10
0
def test_fourier_frequencies():
    # Just some signal, no special meaning
    dt = .1
    x = np.arange(5000) * dt
    signal = np.cos(1. * x) + np.cos(2. * x) + np.cos(3. * x)

    wa = WaveletAnalysis(signal, dt=dt,
                         wavelet=wavelets.Morlet(), unbias=False)
    # Set frequencies and check if they match when retrieving them again
    frequencies = np.linspace(1., 100., 100)
    wa.fourier_frequencies = frequencies
    npt.assert_array_almost_equal(wa.fourier_frequencies, frequencies)
    # Check periods
    npt.assert_array_almost_equal(wa.fourier_periods, 1. / frequencies)

    # Set periods and re-check
    wa.fourier_periods = 1. / frequencies
    npt.assert_array_almost_equal(wa.fourier_frequencies, frequencies)
    npt.assert_array_almost_equal(wa.fourier_periods, 1. / frequencies)
def get_wavelet(channel, time, im_size=[1024, 1024]):
    wa = WaveletAnalysis(channel, time=time)
    # power_im = np.log(wa.wavelet_power)
    # power_norm = power_im - power_im.mean()
    # power_norm = power_norm / power_norm.ptp()
    # power_norm = np.flip(scipy.ndimage.interpolation.zoom(power_im, [im_size[0]/power_im.shape[0], im_size[1]/power_im.shape[1]], order=5, prefilter=False), axis=0)

    df = pd.DataFrame(data=wa.wavelet_power, columns=time / 1000)
    df.insert(0, 'frequency', wa.fourier_frequencies)
    df = df.set_index('frequency')
    return df
Example #12
0
def MyWavelets(data,MyWidths):
    Widths = MyWidths

    ''' 将int型data转为float型sig '''
    sig = np.ones(len(data),np.float)  #产生空的float型sig
    for i in range(0,len(data)): 
        sig[i] = float(data[i])

    sig = np.array( sig )
    wa = WaveletAnalysis(sig, wavelet=Morlet() )
    # wavelet power spectrum
    power = wa.wavelet_power

    # scales 
    scales = wa.scales

    # associated time vector
    t = wa.time

    # reconstruction of the original data
    rx = wa.reconstruction()
    

    ########################################
    # 数据逐帧 0-1标准化
    # print(power.shape)
    # power = np.transpose(power) #转置
    power = power.T
    # print(power.shape)

    # # power_out = np.array([])
    power_out = []
    for i in power:
    #     # np.append(power_out, minmax_scale(i), axis = 0)
        # power_out.append( minmax_scale(i).tolist() )
        power_out.append( minmax_scale(i))
        # print(max( minmax_scale(i) ))
    # power_out = np.array(power_out)


    return power_out
Example #13
0
def test_fourier_frequencies():
    # Just some signal, no special meaning
    dt = .1
    x = np.arange(5000) * dt
    signal = np.cos(1. * x) + np.cos(2. * x) + np.cos(3. * x)

    wa = WaveletAnalysis(signal,
                         dt=dt,
                         wavelet=wavelets.Morlet(),
                         unbias=False)
    # Set frequencies and check if they match when retrieving them again
    frequencies = np.linspace(1., 100., 100)
    wa.fourier_frequencies = frequencies
    npt.assert_array_almost_equal(wa.fourier_frequencies, frequencies)
    # Check periods
    npt.assert_array_almost_equal(wa.fourier_periods, 1. / frequencies)

    # Set periods and re-check
    wa.fourier_periods = 1. / frequencies
    npt.assert_array_almost_equal(wa.fourier_frequencies, frequencies)
    npt.assert_array_almost_equal(wa.fourier_periods, 1. / frequencies)
Example #14
0
def test_var_freq():
    """The wavelet transform conserves total energy, i.e. variance.

    The variance of the data should be the same as the variance of
    the wavelet.

    Check that they are within 1%% for the frequency representation.

    N.B. the performance of this test does depend on the input data.
    If e.g. np.random.random is used for the input, the variance
    difference is larger.
    """
    wa = WaveletAnalysis(anomaly_sst, frequency=True)
    rdiff = 1 - wa.data_variance / wa.wavelet_variance
    assert_less(rdiff, 0.01)
def aren_wavelet(x,filename,fs):
    '''
    Does the same continuous wavelet transmform as the mlpy_cwt
    '''
    from wavelets import WaveletAnalysis
    dt=1/fs
    wa = WaveletAnalysis(x, dt=dt)
    # wavelet power spectrum
    power = wa.wavelet_power
    # scales
    scales = wa.scales
    import mlpy.wavelet as wave
    freqs=[math.pow(i,-1) for i in wave.fourier_from_scales(scales, 'morlet',2)]
    # associated time vector
    t = wa.time
    # reconstruction of the original data
    rx = wa.reconstruction()

    fig, ax = plt.subplots()
    T, S = np.meshgrid(t, scales)
    ax.contourf(T, S, power)
    #ax.set_ylim(freqs[-1],freqs[0])
    #ax.set_yscale('log')
    fig.savefig(filename+'.pdf')
Example #16
0
def compare_cwt():
    """Compare the output of Scipy's cwt (using direct convolution)
    and my cwt (using fft convolution).
    """
    cwt = scipy.signal.cwt
    fft_cwt = wavelets.cwt

    data = np.random.random(2000)
    wave_anal = WaveletAnalysis(data, wavelet=wavelets.Ricker())
    widths = wave_anal.scales[::-1]

    morlet = scipy.signal.morlet

    cwt = cwt(data, morlet, widths)
    fft_cwt = fft_cwt(data, morlet, widths)

    npt.assert_array_almost_equal(cwt, fft_cwt, decimal=13)
Example #17
0
def Plot_cwt():
    calculateCWT.button_type = "warning"
    # Perform initial wavelet transform
    tStep = np.int(wvtTimestep.value)
    yWvt = np.array(ds_analysis.data['y'])[0::tStep]
    wa = WaveletAnalysis(yWvt,
                         wavelet=Morlet(w0=np.float(wvtWidth.value)),
                         dt=sampling_dt * tStep,
                         dj=np.float(wvtFreqstep.value))
    wvt = wa.time
    wvfreq = wa.fourier_frequencies
    wvpower = np.flipud(wa.wavelet_power)

    # Plot wavelet transform as image
    wvt_ds.data = {
        'image': [wvpower],
        'dw': [wvpower.shape[1]],
        'dh': [wvpower.shape[0]],
        'wv_time': [wvt],
        'wv_freq': [wvfreq]
    }
    p_wavelet.x_range.end = wvpower.shape[
        1]  # need to set ranges before creating image, and not using Range1d?
    p_wavelet.y_range.end = wvpower.shape[
        0]  # need to set ranges before creating image, and not using Range1d?

    newLabels = {}
    for i in range(p_wavelet.x_range.start, p_wavelet.x_range.end):
        newLabels[i] = str(round_sig(wvt[i], 2))
    p_wavelet.xaxis.major_label_overrides = newLabels
    #ipdb.set_trace()

    newLabelsY = {}
    for i in range(p_wavelet.y_range.start, p_wavelet.y_range.end):
        newLabelsY[i] = str(
            np.round(np.array(np.flipud(wvfreq))[i] * 1550.0e-9 / 2.0, 1))
    p_wavelet.yaxis.major_label_overrides = newLabelsY

    p_wavelet.image(image='image',
                    source=wvt_ds,
                    x=0,
                    y=0,
                    dw='dw',
                    dh='dh',
                    palette=viridis(200))
    calculateCWT.button_type = "success"
def simple_wavelet_transform(signal,
                             sampling_rate,
                             scaling_factor=0.25,
                             wave_lowpass=None,
                             wave_highpass=None):
    """
    Simple wavelet transformation of signal

    Parameters
    ----------
    signal : (N,1) array_like
        Signal to be transformed
    sampling_rate : int
        Sampling rate of signal
    scaling_factor : float, optional
        Determines amount of log-space frequencies M in output, by default 0.25
    wave_highpass : int, optional
        Cut of frequencies below, by default 2
    wave_lowpass : int, optional
        Cut of frequencies above, by default 30000

    Returns
    -------
    wavelet_power : (N, M) array_like
        Wavelet transformed signal
    wavelet_frequencies : (M, 1) array_like
        Corresponding frequencies to wavelet_power
    wavelet_obj : object
        WaveletTransform Object
    """
    wavelet_obj = WaveletAnalysis(signal,
                                  dt=1 / sampling_rate,
                                  dj=scaling_factor)
    wavelet_power = wavelet_obj.wavelet_power
    wavelet_frequencies = wavelet_obj.fourier_frequencies

    if wave_lowpass or wave_highpass:
        wavelet_power = wavelet_power[(wavelet_frequencies < wave_lowpass) &
                                      (wavelet_frequencies > wave_highpass), :]
        wavelet_frequencies = wavelet_frequencies[
            (wavelet_frequencies < wave_lowpass)
            & (wavelet_frequencies > wave_highpass)]

    return (wavelet_power, wavelet_frequencies, wavelet_obj)
Example #19
0
 def clean_tranform(self, s_data_subject):
     wa = WaveletAnalysis(data=s_data_subject, wavelet=Ricker(), dt=1 / 128)
     max_freq = self.get_max_freq(wa.wavelet_power)
     return wa, max_freq
Example #20
0
def defringeflat(flat_file,
                 wbin=10,
                 start_col=10,
                 end_col=980,
                 clip1=0,
                 diagnostic=True,
                 save_to_path=None,
                 filename=None):
    """
	This function is to remove the fringe pattern using
	the method described in Rojo and Harrington (2006).

	Use a fifth order polynomial to remove the continuum.

	Parameters
	----------
	flat_file 		: 	fits
						original flat file

	Optional Parameters
	-------------------
	wbin 			:	int
						the bin width to calculate each 
						enhance row
						Default is 32

	start_col 		: 	int
						starting column number for the
						wavelet analysis
						Default is 10

	end_col 		: 	int
						ending column number for the
						wavelet analysis
						Default is 980

	diagnostic 		: 	boolean
						output the diagnostic plots
						Default is True

	Returns
	-------
	defringe file 	: 	fits
						defringed flat file

	"""
    # the path to save the diagnostic plots
    #save_to_path = 'defringeflat/allflat/'

    #print(flat_file)

    data = fits.open(flat_file, ignore_missing_end=True)

    # Use the data to figure out the values to mask through the image (low counts/order edges)
    hist, bins = np.histogram(data[0].data.flatten(),
                              bins=int(np.sqrt(len(data[0].data.flatten()))))
    bins = bins[0:-1]
    index1 = np.where((bins > np.percentile(data[0].data.flatten(), 10))
                      & (bins < np.percentile(data[0].data.flatten(), 30)))
    try:
        lowval = bins[index1][np.where(hist[index1] == np.min(hist[index1]))]
        #print(lowval, len(lowval))
        if len(lowval) >= 2: lowval = np.min(lowval)
    except:
        lowval = 0  #if no values for index1

    flat = data

    # initial flat plot
    if diagnostic is True:

        # Save the images to a separate folder
        save_to_image_path = save_to_path + '/images/'
        if not os.path.exists(save_to_image_path):
            os.makedirs(save_to_image_path)

        fig = plt.figure(figsize=(8, 8))
        fig.suptitle("original flat", fontsize=12)
        gs = gridspec.GridSpec(2, 1, height_ratios=[6, 1])
        ax0 = plt.subplot(gs[0])
        # Create an ImageNormalize object
        norm = ImageNormalize(flat[0].data, interval=ZScaleInterval())
        ax0.imshow(flat[0].data,
                   cmap='gray',
                   norm=norm,
                   origin='lower',
                   aspect='auto')
        ax0.set_ylabel("Row number")
        ax1 = plt.subplot(gs[1], sharex=ax0)
        ax1.plot(flat[0].data[60, :],
                 'k-',
                 alpha=0.5,
                 label='60th row profile')
        ax1.set_ylabel("Amp (DN)")
        ax1.set_xlabel("Column number")
        plt.legend()
        plt.savefig(save_to_image_path + "defringeflat_{}_0_original_flat.png"\
                 .format(filename), bbox_inches='tight')
        plt.close()

    defringeflat_img = data
    defringe_data = np.array(defringeflat_img[0].data, dtype=float)

    for k in np.arange(0, 1024 - wbin, wbin):
        #print(k)
        #if k != 310: continue
        """
		# Use the data to figure out the values to mask through the image (low counts/order edges)
		hist, bins = np.histogram(flat[0].data[k:k+wbin+1, 0:1024-clip1].flatten(), 
			                      bins=int(np.sqrt(len(flat[0].data[k:k+wbin+1, 0:1024-clip1].flatten()))))
		bins       = bins[0:-1]
		index1     = np.where( (bins > np.percentile(flat[0].data[k:k+wbin+1, 0:1024-clip1].flatten(), 10)) & 
			                   (bins < np.percentile(flat[0].data[k:k+wbin+1, 0:1024-clip1].flatten(), 30)) )
		lowval     = bins[index1][np.where(hist[index1] == np.min(hist[index1]))]
		
		#print(lowval, len(lowval))
		if len(lowval) >= 2: lowval = np.min(lowval)
		"""
        # Find the mask
        mask = np.zeros(flat[0].data[k:k + wbin + 1, 0:1024 - clip1].shape)
        baddata = np.where(
            flat[0].data[k:k + wbin + 1, 0:1024 - clip1] <= lowval)
        mask[baddata] = 1

        # extract the patch from the fits file
        #flat_patch = np.ma.array(flat[0].data[k:k+wbin,:], mask=mask)
        flat_patch = np.array(flat[0].data[k:k + wbin + 1, 0:1024 - clip1])

        # median average the selected region in the order
        flat_patch_median = np.ma.median(flat_patch, axis=0)

        # continuum fit
        # smooth the continuum (Chris's method)
        smoothed = sp.ndimage.uniform_filter1d(flat_patch_median, 30)
        splinefit = sp.interpolate.interp1d(np.arange(len(smoothed)),
                                            smoothed,
                                            kind='cubic')
        cont_fit = splinefit(np.arange(0, 1024 - clip1))  #smoothed

        # Now fit a polynomial
        #pcont     = np.ma.polyfit(np.arange(0, 1024-clip1),
        #	                      cont_fit, 10)
        #cont_fit2 = np.polyval(pcont, np.arange(0,1024))

        #plt.plot(flat_patch_median, c='r')
        #plt.plot(smoothed, c='b')
        #plt.savefig(save_to_image_path + "TEST.png", bbox_inches='tight')
        #plt.close()
        #plt.show()
        #sys.exit()

        #pcont    = np.ma.polyfit(np.arange(start_col,end_col),
        #	                     flat_patch_median[start_col:end_col],10)
        #cont_fit = np.polyval(pcont, np.arange(0,1024))

        # use wavelets package: WaveletAnalysis
        enhance_row = flat_patch_median - cont_fit

        dt = 0.1
        wa = WaveletAnalysis(enhance_row[start_col:end_col], dt=dt)
        # wavelet power spectrum
        power = wa.wavelet_power
        # scales
        cales = wa.scales
        # associated time vector
        t = wa.time
        # reconstruction of the original data
        rx = wa.reconstruction()

        # reconstruct the fringe image
        reconstruct_image = np.zeros(defringe_data[k:k + wbin + 1,
                                                   0:1024 - clip1].shape)
        for i in range(wbin + 1):
            for j in np.arange(start_col, end_col):
                reconstruct_image[i, j] = rx[j - start_col]

        defringe_data[k:k + wbin + 1,
                      0:1024 - clip1] -= reconstruct_image[0:1024 - clip1]

        # Add in something for the edges/masked out data in the reconstructed image
        defringe_data[k:k + wbin + 1, 0:1024 -
                      clip1][baddata] = flat[0].data[k:k + wbin + 1,
                                                     0:1024 - clip1][baddata]

        #print("{} row starting {} is done".format(filename,k))

        # diagnostic plots
        if diagnostic is True:
            print("Generating diagnostic plots")
            # middle cut plot
            fig = plt.figure(figsize=(10, 6))
            fig.suptitle("middle cut at row {}".format(k + wbin // 2),
                         fontsize=12)
            ax1 = fig.add_subplot(2, 1, 1)

            norm = ImageNormalize(flat_patch, interval=ZScaleInterval())
            ax1.imshow(flat_patch,
                       cmap='gray',
                       norm=norm,
                       origin='lower',
                       aspect='auto')
            ax1.set_ylabel("Row number")
            ax2 = fig.add_subplot(2, 1, 2, sharex=ax1)
            ax2.plot(flat_patch[wbin // 2, :], 'k-', alpha=0.5)
            ax2.set_ylabel("Amp (DN)")
            ax2.set_xlabel("Column number")

            plt.tight_layout()
            plt.subplots_adjust(top=0.85, hspace=0.5)
            plt.savefig(save_to_image_path + \
             'defringeflat_{}_flat_start_row_{}_middle_profile.png'\
             .format(filename,k), bbox_inches='tight')
            plt.close()

            # continuum fit plot
            fig = plt.figure(figsize=(10, 6))
            fig.suptitle("continuum fit row {}-{}".format(k, k + wbin),
                         fontsize=12)
            gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1])
            ax0 = plt.subplot(gs[0])
            ax0.plot(flat_patch_median,
                     'k-',
                     alpha=0.5,
                     label='mean average patch')
            ax0.plot(cont_fit, 'r-', alpha=0.5, label='continuum fit')
            #ax0.plot(cont_fit2,'m-', alpha=0.5, label='continuum fit poly')
            ax0.set_ylabel("Amp (DN)")
            plt.legend()
            ax1 = plt.subplot(gs[1])
            ax1.plot(flat_patch_median - cont_fit,
                     'k-',
                     alpha=0.5,
                     label='residual')
            ax1.set_ylabel("Amp (DN)")
            ax1.set_xlabel("Column number")
            plt.legend()

            plt.tight_layout()
            plt.subplots_adjust(top=0.85, hspace=0.5)
            plt.savefig(save_to_image_path + \
                       "defringeflat_{}_start_row_{}_continuum_fit.png".\
                     format(filename,k), bbox_inches='tight')
            #plt.show()
            #sys.exit()
            plt.close()

            # enhance row vs. reconstructed wavelet plot
            try:
                fig = plt.figure(figsize=(10, 6))
                fig.suptitle("reconstruct fringe comparison row {}-{}".\
                          format(k,k+wbin), fontsize=10)
                ax1 = fig.add_subplot(3, 1, 1)

                ax1.set_title('enhance_row start row')
                ax1.plot(enhance_row,
                         'k-',
                         alpha=0.5,
                         label="enhance_row start row {}".format(k))
                ax1.set_ylabel("Amp (DN)")
                #plt.legend()

                ax2 = fig.add_subplot(3, 1, 2, sharex=ax1)
                ax2.set_title('reconstructed fringe pattern')
                ax2.plot(rx,
                         'k-',
                         alpha=0.5,
                         label='reconstructed fringe pattern')
                ax2.set_ylabel("Amp (DN)")
                #plt.legend()

                ax3 = fig.add_subplot(3, 1, 3, sharex=ax1)
                ax3.set_title('residual')
                ax3.plot(enhance_row[start_col:end_col] - rx,
                         'k-',
                         alpha=0.5,
                         label='residual')
                ax3.set_ylabel("Amp (DN)")
                ax3.set_xlabel("Column number")
                #plt.legend()
                plt.tight_layout()
                plt.subplots_adjust(top=0.85, hspace=0.5)
                plt.savefig(save_to_image_path + \
                         "defringeflat_{}_start_row_{}_reconstruct_profile.png".\
                         format(filename,k), bbox_inches='tight')
                plt.close()
            except RuntimeError:
                print("CANNOT GENERATE THE PLOT defringeflat\
					_{}_start_row_{}_reconstruct_profile.png"            \
                 .format(filename,k))
                pass

            # reconstruct image comparison plot
            fig = plt.figure(figsize=(10, 6))
            fig.suptitle("reconstructed image row {}-{}".\
                      format(k,k+wbin), fontsize=12)

            ax1 = fig.add_subplot(3, 1, 1)
            ax1.set_title('raw flat image')
            norm = ImageNormalize(flat_patch, interval=ZScaleInterval())
            ax1.imshow(flat_patch,
                       cmap='gray',
                       norm=norm,
                       origin='lower',
                       label='raw flat image',
                       aspect='auto')
            ax1.set_ylabel("Row number")
            #plt.legend()

            ax2 = fig.add_subplot(3, 1, 2, sharex=ax1)
            ax2.set_title('reconstructed fringe image')
            norm = ImageNormalize(reconstruct_image, interval=ZScaleInterval())
            ax2.imshow(reconstruct_image,
                       cmap='gray',
                       norm=norm,
                       origin='lower',
                       label='reconstructed fringe image',
                       aspect='auto')
            ax2.set_ylabel("Row number")
            #plt.legend()

            ax3 = fig.add_subplot(3, 1, 3, sharex=ax1)
            ax3.set_title('residual')
            norm = ImageNormalize(flat_patch - reconstruct_image,
                                  interval=ZScaleInterval())
            ax3.imshow(flat_patch - reconstruct_image,
                       norm=norm,
                       origin='lower',
                       cmap='gray',
                       label='residual',
                       aspect='auto')
            ax3.set_ylabel("Row number")
            ax3.set_xlabel("Column number")
            #plt.legend()
            plt.tight_layout()
            plt.subplots_adjust(top=0.85, hspace=0.5)
            plt.savefig(save_to_image_path + \
                     "defringeflat_{}_start_row_{}_reconstruct_image.png".\
                     format(filename,k), bbox_inches='tight')
            plt.close()

            # middle residual comparison plot
            fig = plt.figure(figsize=(10, 6))
            fig.suptitle("middle row comparison row {}-{}".\
                      format(k,k+wbin), fontsize=12)

            ax1 = fig.add_subplot(3, 1, 1)
            ax1.plot(flat_patch[wbin // 2, :],
                     'k-',
                     alpha=0.5,
                     label='original flat row {}'.format(k + wbin / 2))
            ax1.set_ylabel("Amp (DN)")
            plt.legend()

            ax2 = fig.add_subplot(3, 1, 2, sharex=ax1)
            ax2.plot(flat_patch[wbin//2,:]-\
                  reconstruct_image[wbin//2,:],'k-',
                  alpha=0.5, label='defringed flat row {}'.format(k+wbin/2))
            ax2.set_ylabel("Amp (DN)")
            plt.legend()

            ax3 = fig.add_subplot(3, 1, 3, sharex=ax1)
            ax3.plot(reconstruct_image[wbin // 2, :],
                     'k-',
                     alpha=0.5,
                     label='difference')
            ax3.set_ylabel("Amp (DN)")
            ax3.set_xlabel("Column number")
            plt.legend()

            plt.tight_layout()
            plt.subplots_adjust(top=0.85, hspace=0.5)
            plt.savefig(save_to_image_path + \
                     "defringeflat_{}_start_row_{}_defringe_middle_profile.png".\
                     format(filename,k), bbox_inches='tight')
            plt.close()

        #if k > 30: sys.exit() # for testing purposes

    # final diagnostic plot
    if diagnostic is True:
        fig = plt.figure(figsize=(8, 8))
        fig.suptitle("defringed flat", fontsize=12)
        gs = gridspec.GridSpec(2, 1, height_ratios=[6, 1])
        ax0 = plt.subplot(gs[0])
        norm = ImageNormalize(defringe_data, interval=ZScaleInterval())
        ax0.imshow(defringe_data,
                   cmap='gray',
                   norm=norm,
                   origin='lower',
                   aspect='auto')
        ax0.set_ylabel("Row number")
        ax1 = plt.subplot(gs[1], sharex=ax0)
        ax1.plot(defringe_data[60, :],
                 'k-',
                 alpha=0.5,
                 label='60th row profile')
        ax1.set_ylabel("Amp (DN)")
        ax1.set_xlabel("Column number")
        plt.legend()

        plt.tight_layout()
        plt.subplots_adjust(top=0.85, hspace=0.5)
        plt.savefig(save_to_image_path + "defringeflat_{}_0_defringe_flat.png"\
         .format(filename), bbox_inches='tight')
        plt.close()

    hdu = fits.PrimaryHDU(data=defringe_data)
    hdu.header = flat[0].header
    return hdu
Example #21
0
def find_peaks_cwt(bins,
                   width,
                   snr,
                   min_length_percentage=0.4,
                   peak_range=(0.5, np.inf),
                   peak_separation=0,
                   gap_scale=0.05,
                   gap_thresh=2,
                   noise_window=1.0):
    """
    Find peaks in APT spectra through CWT, ridge lines and filtering

    Args:
        (REQUIRED)
        bins: mass histogram bins (need linear scaling!) (1D list)
        width: the width of the bins
        snr: required signal-to-noise ratio for significant peaks

        (OPTIONAL)
        min_length_percentage: minimum percentage of scales where the peak should exist
        peak_range: looks for peaks in this range (now from 0.5 to inf -> so peaks before 0.5 are removed)
        peak_separation: an optional distance between peaks (peaks below this threshold are merged)
        gap_scale: width between scales given in units of 2^width having width as linear scale with gap_scale separation
        gap_thresh: threshold of allowed gap for levels where peaks are not present before saving it as new ridge line (depends on the amount of scales used)
        noise_window: window on real scale to compute the noise

    Returns:
        Tuple with
        - list of masses
        - list of scales
        - 2D array of ridges (see find_ridge_lines)
        - 1D list of information on peaks (see find_ridge_lines)
    """

    # select the wavelet object to use - the ricker (mexican hat) wavelet gives good results
    wavelet = Ricker()

    # select the minimum and maximum applicable scale levels
    min_scale = np.log2(4 * width)
    max_scale = np.log2(2)  #this are 2 amu at the moment

    # compute the logarithm of the height as the CWT performs better here on typical spectra
    height_log = bins['height'].astype(np.float)
    height_log += 1  # add one to circumvent the problem of zero height
    height_log = np.log(height_log)

    # initialize the wavelet library
    wav = WaveletAnalysis(height_log, wavelet=wavelet, dt=width)
    wav.time = bins['edge'] + width / 2  # use mass to charge as the time scale
    scales = wav.scales = 2**(np.arange(min_scale, max_scale,
                                        gap_scale))  # build a set of scales
    # apply the continous wavelet transform and keep the real part (in theory there should be no imaginary part)
    cwt = np.real(wav.wavelet_transform)

    #ridge is 2D array of tuples where each tuple contains (see function above)
    #(cwt-intensity, local-max-yes-no, previous-point-on-line, scale-gap-to-previous-point, start-of-ridge, current-length-of-ridge)
    #peak_info 1D list of tuples where each tuple contains (see function above)
    #(scale-of-ridge-start(row), mass-of-ridge-start(col), start-of-ridge, max-length-of-ridge, max-intensity, scale-of-max-intensity)
    #called with: cwt-object (3D array scale, mass, intensities), bin-width, left and right window for ridge lines
    #scale levels that it skips before ridge line ends, window to check noise at lowest scale level)
    ridge, peak_info = find_ridge_lines(cwt, width, scales / 2, scales / 2,
                                        gap_thresh, noise_window)
    #TODO: limiting what ridge saves is saving a LOT OF MEMORY

    # correct maxima for snr and push them down deleting any that does not work
    for ind, (row, col, ridge_max, _, loc, length,
              noise) in np.ndenumerate(peak_info):
        # delete peaks that:
        # - are not in parameter peak_range or
        # - dont appear as a single line for enough scale or
        # - too low snr (noise comes from wavelet estimate)
        if not (peak_range[0] <= wav.time[loc] <=
                peak_range[1]) or length < min_length_percentage * len(
                    scales) or ridge_max / noise < snr:
            delete_ridge(ridge, row, col)
            peak_info['row'][ind] = -1
            continue

    peak_info = peak_info[peak_info['row'] != -1]

    # find a guess of the maximum row that does not include the asymmetric behavior
    scale_row = np.zeros(shape=peak_info.shape, dtype=np.int32)
    scale_row_strength = np.zeros(shape=peak_info.shape, dtype=np.int32)
    for ind, (row, col, ridge_max, _, loc, length,
              noise) in np.ndenumerate(peak_info):
        #traversing down the ridge
        while ridge['from'][row, col] != -1:
            # stop at first peak that is near the start location and not above the original maximum
            if abs(wav.time[loc] -
                   wav.time[col]) < 10 * width and not np.isclose(
                       ridge['max'][row, col], ridge_max):
                scale_row[ind] = row
                scale_row_strength[ind] = ridge['cwt'][row, col]
                break

            col = ridge['from'][row, col]
            row = row - 1

    # delete all that do not have expected range by estimating their uncertainty behaviour
    # coeff is tuple of 2 (steepness of a line fit, const. line fit=0)
    coeff = nppoly.polyfit(np.sqrt(wav.time[peak_info['loc']]),
                           scales[scale_row], [1],
                           w=peak_info['max'])

    #resulting linear fit
    #print(coeff)
    y_scale_fit = coeff[1] * np.sqrt(wav.time) + coeff[0]

    # find the cwt coefficient at the nearest scale level and compare to snr
    for ind, (row, col, ridge_max, _, loc, length,
              noise) in np.ndenumerate(peak_info):
        exp_scale = y_scale_fit[loc]

        strength = 0
        trow = row
        tcol = col
        #if the expected row is never reached (peak only exists on higher scales) -> strength stays 0
        #after this loop peaks that dont exist at the expected strength are 0 and peaks that exist have their intensity in strength
        while ridge['from'][trow, tcol] != -1:
            if scales[trow] < exp_scale:
                #if the first row on the traverse down is below the scale -> strength stays 0
                #if exp_scale > max_scale than the fit is bad and we just keep everything
                if trow != row or exp_scale > max_scale:
                    strength = cwt[trow, tcol]
                break

            tcol = ridge['from'][trow, tcol]
            trow = trow - 1

        #print(wav.time[loc], exp_scale, ridge_max, strength, noise) #, length, scales[trow], trow)

        #kicks out peaks that dont have enough signal at the expected scale level at this point in the mass spectrum
        if strength / noise < snr:
            # delete if not significant at the expected scale level
            delete_ridge(ridge, row, col)
            peak_info['row'][ind] = -1
        elif strength > scale_row_strength[ind]:
            # optimize if previous scale estimation step was too aggressive
            peak_info[ind]['max_row'] = trow
            peak_info[ind]['loc'] = tcol
        else:
            #else just accept the previous estimate as actual max row
            peak_info[ind]['max_row'] = scale_row[ind]

    # remove all filtered peaks
    peak_info = peak_info[peak_info['row'] != -1]

    # delete not well separated peaks (optional step)
    max_ridges = np.sort(peak_info, order="max")[::-1]
    for rdg in max_ridges:
        # find peaks near others
        bef, idx, aft = np.searchsorted(peak_info['loc'], [
            rdg['loc'] - peak_separation / width, rdg['loc'],
            rdg['loc'] + peak_separation / width
        ])

        # remove the peak if larger is inside the separation range
        if bef == aft or idx == peak_info.shape[
                0] or peak_info[idx]['loc'] != rdg['loc']:
            continue

        peak_info['row'][bef:idx] = -1
        peak_info['row'][idx + 1:aft] = -1
        peak_info = peak_info[peak_info['row'] != -1]

    # return the list of masses (center of the bins), scales (defined above), ridge information and the peak information including their related ridge (from the find_ridge_lines function)
    return wav.time, wav.scales, ridge, peak_info
Example #22
0
def defringetelluric():

if not os.path.exists(save_to_path):
    os.makedirs(save_to_path)

############################################
print(tell_data_name)

tell_sp = nspf.Spectrum(name=tell_data_name, order=order, path=tell_path)

clickpoints = []
def onclick(event):
	print(event)
	global clickpoints
	clickpoints.append([event.xdata])
	print(clickpoints)
	plt.axvline(event.xdata, c='r', ls='--')
	plt.draw()
	if len(clickpoints) == 2:
		print('Closing Figure')
		plt.axvspan(clickpoints[0][0], clickpoints[1][0], color='0.5', alpha=0.5, zorder=-100)
		plt.draw()
		plt.pause(1)
		plt.close('all')


### Draw the figure with the power spectrum
cal1        = tell_sp.flux#[pixel_range_start:pixel_range_end]
xdim        = len(cal1)#[pixel_range_start:pixel_range_end])
nfil        = xdim//2 + 1
# Smooth the continuum
cal1smooth  = sp.ndimage.median_filter(cal1, size=30)
# Do the FFT
cal1fft     = fftpack.rfft(cal1-cal1smooth)
yp          = abs(cal1fft[0:nfil])**2
yp          = yp / np.max(yp)

fig, ax1 = plt.subplots(figsize=(12,6))
cid = fig.canvas.mpl_connect('button_press_event', onclick)
freq        = np.arange(nfil)
yp[0:3]     = 0 # Fix for very low order noise
ax1.plot(freq, yp)
#ax1.axvline(f_high*2, c='r', ls='--')
#ax1.axvline(f_low*2, c='r', ls='--')
ax1.set_ylabel('Power Spectrum')
ax1.set_xlabel('1 / (1024 pix)')
ax1.set_title('Select the range you would like to filter out')
ax1.set_xlim(0, np.max(freq))
plt.show()
plt.close('all')
f_high   = np.max(clickpoints)/2
f_low    = np.min(clickpoints)/2


if method == 'wavelet':
	#### Wavelets
	from wavelets import WaveletAnalysis

	xdim      = len(tell_sp.flux)#[pixel_range_start:pixel_range_end])
	cal1      = tell_sp.flux#[pixel_range_start:pixel_range_end]

	# Smooth the continuum
	smoothed    = sp.ndimage.uniform_filter1d(cal1, 30)
	splinefit   = sp.interpolate.interp1d(np.arange(len(smoothed)), smoothed, kind='cubic')
	cal1smooth  = splinefit(np.arange(0, len(cal1))) #smoothed

	# use wavelets package: WaveletAnalysis
	enhance_row = cal1 - cal1smooth
	#print(enhance_row)

	dt     = 0.1
	wa     = WaveletAnalysis(enhance_row, dt=dt, axis=0)
	# wavelet power spectrum
	power  = wa.wavelet_power

	# scales
	scales = wa.scales
	# associated time vector
	t      = wa.time
	# reconstruction of the original data
	rx     = wa.reconstruction()

	defringe_data    = np.array(cal1.data, dtype=float)

	# reconstruct the fringe image
	#reconstruct_image      = np.zeros(defringe_data.shape)
	reconstruct_image      = np.real(rx)

	defringe_data -= reconstruct_image
	newSpectrum   = defringe_data

	if PLOT:

		fig = plt.figure(figsize=(12,6))
		ax1 = plt.subplot2grid((3, 1), (0, 0))
		ax2 = plt.subplot2grid((3, 1), (1, 0), rowspan=2)
		
		#freq        = np.arange(nfil)
		ax1.plot(power**2)
		ax1.axvline(f_high*2, c='r', ls='--')
		ax1.axvline(f_low*2, c='r', ls='--')
		ax1.set_ylabel('Power Spectrum')
		ax1.set_xlabel('1 / (1024 pix)')
		#ax1.set_xlim(0, np.max(freq))
	
		ax2.plot(cal1[0:-23], label='original', alpha=0.5, lw=1, c='b')
		ax2.plot(newSpectrum[0:-23]+0.5*np.median(newSpectrum[0:-23]), label='defringed', alpha=0.8, lw=1, c='r')
		ax2.legend()
		ax2.set_ylabel('Flux')
		ax2.set_xlabel('Pixel')
		ax2.set_xlim(0, len(cal1[0:-23]))
		plt.tight_layout()
		plt.savefig(save_to_path+"defringed_spectrum.png", bbox_inches='tight')
		plt.show()



if method == 'hanningnotch':

	## REDSPEC version
	cal1        = tell_sp.flux#[pixel_range_start:pixel_range_end]
	xdim        = len(cal1)#[pixel_range_start:pixel_range_end])
	nfil        = xdim//2 + 1

	#print(xdim, nfil//2+1)
	freq        = np.arange(nfil//2+1) / (nfil / float(xdim))
	fil         = np.zeros(len(freq), dtype=np.float) 
	fil[np.where((freq < f_low) | (freq > f_high))] = 1.
	fil         = np.append(fil, np.flip(fil[1:],axis=0))
	fil         = np.real(np.fft.ifft(fil))
	fil         = np.roll(fil, nfil//2)
	fil         = fil*np.hanning(nfil)

	# Smooth the continuum
	#smoothed    = sp.ndimage.uniform_filter1d(cal1, 30)
	#splinefit   = sp.interpolate.interp1d(np.arange(len(smoothed)), smoothed, kind='cubic')
	#cal1smooth  = splinefit(np.arange(0, len(cal1))) #smoothed
	cal1smooth  = sp.ndimage.median_filter(cal1, size=30)

	"""
	plt.figure()
	plt.plot(abs(np.real(fftpack.fft(cal1orig-cal1smooth)))**2, c='k', lw=0.5)
	plt.plot(abs(np.real(fftpack.fft( sp.ndimage.convolve(cal1orig-cal1smooth, fil, mode='wrap') ) ))**2, c='r', lw=0.5)
	plt.ylim(0,25000)
	#plt.xlim(0,800)
	plt.show()
	#sys.exit()

	plt.figure(figsize=(10,6))
	plt.plot(cal1-cal1smooth, lw=0.5, c='k')
	plt.plot(sp.ndimage.convolve(cal1-cal1smooth, fil, mode='wrap'), lw=0.5, c='r')
	plt.plot(sp.ndimage.median_filter(cal1-cal1smooth, 10), lw=0.5, c='m')
	plt.show(block=False)
	#sys.exit()

	plt.figure()
	plt.plot( (cal1-cal1smooth)-sp.ndimage.convolve(cal1-cal1smooth, fil, mode='wrap'), lw=0.5, c='k')
	plt.show(block=True)
	#sys.exit()
	"""
	newSpectrum       = sp.ndimage.convolve(cal1-cal1smooth, fil, mode='wrap') + cal1smooth

	if PLOT:

		# Do the FFT
		cal1fft     = fftpack.rfft(cal1-cal1smooth)
		yp          = abs(cal1fft[0:nfil])**2
		yp          = yp / np.max(yp)
		yp[0:3]     = 0 # Fix for very low order noise

		fig = plt.figure(figsize=(12,6))
		ax1 = plt.subplot2grid((3, 1), (0, 0))
		ax2 = plt.subplot2grid((3, 1), (1, 0), rowspan=2)
		
		freq        = np.arange(nfil)
		ax1.plot(freq, yp)
		ax1.axvline(f_high*2, c='r', ls='--')
		ax1.axvline(f_low*2, c='r', ls='--')
		ax1.set_ylabel('Power Spectrum')
		ax1.set_xlabel('1 / (1024 pix)')
		ax1.set_xlim(0, np.max(freq))
	
		ax2.plot(cal1[0:-23], label='original', alpha=0.5, lw=1, c='b')
		ax2.plot(newSpectrum[0:-23]+0.5*np.median(newSpectrum[0:-23]), label='defringed', alpha=0.8, lw=1, c='r')
		ax2.legend()
		ax2.set_ylabel('Flux')
		ax2.set_xlabel('Pixel')
		ax2.set_xlim(0, len(cal1[0:-23]))
		plt.tight_layout()
		plt.savefig(save_to_path+"defringed_spectrum.png", bbox_inches='tight')
		plt.show()



if method == 'flatfilter':

	cal1     = tell_sp.flux#[pixel_range_start:pixel_range_end]

	W        = fftpack.fftfreq(cal1.size, d=1./1024)
	fftval   = fftpack.rfft(cal1.astype(float))
	fftval[np.where((W > f_low) & (W < f_high))] = 0

	newSpectrum   = fftpack.irfft(fftval) 
	
	if PLOT: 

		xdim       = len(cal1)#[pixel_range_start:pixel_range_end])
		nfil       = xdim//2 + 1

		freq       = np.arange(nfil)

		# Smooth the continuum
		smoothed   = sp.ndimage.uniform_filter1d(cal1, 30)
		splinefit  = sp.interpolate.interp1d(np.arange(len(smoothed)), smoothed, kind='cubic')
		cal1smooth = splinefit(np.arange(0, len(cal1))) #smoothed

		# Do the FFT
		cal1fft    = fftpack.rfft(cal1-cal1smooth)
		yp         = abs(cal1fft[0:nfil])**2 # Power
		yp         = yp / np.max(yp)

		fig = plt.figure(figsize=(12,6))
		ax1 = plt.subplot2grid((3, 1), (0, 0))
		ax2 = plt.subplot2grid((3, 1), (1, 0), rowspan=2)

		ax1.plot(freq, yp)
		ax1.axvline(f_high, c='r', ls='--')
		ax1.axvline(f_low, c='r', ls='--')
		ax1.set_ylabel('Power Spectrum')
		ax1.set_xlabel('1 / (1024 pix)')
		ax1.set_xlim(0, np.max(freq))
	
		ax2.plot(cal1[0:-23], label='original', alpha=0.5, lw=1, c='b')
		ax2.plot(newSpectrum[0:-23]+0.5*np.median(newSpectrum[0:-23]), label='defringed', alpha=0.8, lw=1, c='r')
		ax2.legend()
		ax2.set_ylabel('Flux')
		ax2.set_xlabel('Pixel')
		ax2.set_xlim(0, len(cal1[0:-23]))
		plt.tight_layout()
		plt.savefig(save_to_path+"defringed_spectrum.png", bbox_inches='tight')
		#plt.show()


fullpath  = tell_path + '/' + tell_data_name + '_' + str(order) + '_all.fits'
save_name = save_to_path + '%s_defringe_%s_all.fits'%(tell_data_name, order)

hdulist = fits.open(fullpath)
hdulist.append(fits.PrimaryHDU())

hdulist[-1].data = tell_sp.flux
hdulist[1].data  = newSpectrum

hdulist[-1].header['COMMENT']  = 'Raw Extracted Spectrum'
hdulist[1].header['COMMENT']   = 'Defringed Spectrum'
try:
	hdulist.writeto(save_name, overwrite=True)
except FileNotFoundError:
	hdulist.writeto(save_name)
Example #23
0
    fs = 240//decimation

    cardSig = 10
    numElec = 10
    elec = slice(lenSig*numElec,lenSig*numElec+lenSig)

    frameSize = 0.2

    p300 = np.where(y==1)[0]
    nonp300 = np.where(y==-1)[0][:cardSig]

    for numSig in range(cardSig):
        signal = X[p300[numSig],elec]

        dt = 1.0/fs
        wa = WaveletAnalysis(signal, dt=dt)

        # wavelet power spectrum
        power = wa.wavelet_power
        # scales
        scales = wa.fourier_frequencies
        # associated time vector
        t = wa.time
        print wa.fourier_frequencies
        # reconstruction of the original data
        rx = wa.reconstruction()

        fig, ax = plt.subplots()
        T, S = np.meshgrid(t, scales)
        ax.contourf(T, S, power, 100)
        print(S)
Example #24
0
def test_power_bias():
    """See if the global wavelet spectrum is biased or not.

    Wavelet transform a signal of 3 distinct Fourier frequencies.

    The power spectrum should contain peaks at the frequencies, all
    of which should be the same height.
    """
    dt = 0.1
    x = np.arange(5000) * dt

    T1 = 20 * dt
    T2 = 100 * dt
    T3 = 500 * dt

    w1 = 2 * np.pi / T1
    w2 = 2 * np.pi / T2
    w3 = 2 * np.pi / T3

    signal = np.cos(w1 * x) + np.cos(w2 * x) + np.cos(w3 * x)

    wa = WaveletAnalysis(signal, dt=dt,
                         wavelet=wavelets.Morlet(), unbias=False)

    power_biased = wa.global_wavelet_spectrum
    wa.unbias = True
    power = wa.global_wavelet_spectrum
    wa.mask_coi = True
    power_coi = wa.global_wavelet_spectrum

    freqs = wa.fourier_periods

    fig, ax = plt.subplots(nrows=2)

    ax_transform = ax[0]
    fig_info = (r"Wavelet transform of "
                r"$cos(2 \pi / {T1}) + cos(2 \pi / {T2}) + cos(2 \pi / {T3})$")
    ax_transform.set_title(fig_info.format(T1=T1, T2=T2, T3=T3))
    X, Y = np.meshgrid(wa.time, wa.fourier_periods)
    ax_transform.set_xlabel('time')
    ax_transform.set_ylabel('fourier period')
    ax_transform.set_ylim(10 * dt, 1000 * dt)
    ax_transform.set_yscale('log')
    ax_transform.contourf(X, Y, wa.wavelet_power, 100)

    # shade the region between the edge and coi
    C, S = wa.coi
    F = wa.fourier_period(S)
    f_max = F.max()
    ax_transform.fill_between(x=C, y1=F, y2=f_max, color='gray', alpha=0.3)

    ax_power = ax[1]
    ax_power.set_title('Global wavelet spectrum '
                       '(estimator for power spectrum)')
    ax_power.plot(freqs, power, 'k', label=r'unbiased all domain')
    ax_power.plot(freqs, power_coi, 'g', label=r'unbiased coi only')
    ax_power.set_xscale('log')
    ax_power.set_xlim(10 * dt, wa.time.max())
    ax_power.set_xlabel('fourier period')
    ax_power.set_ylabel(r'power / $\sigma^2$  (bias corrected)')

    ax_power_bi = ax_power.twinx()
    ax_power_bi.plot(freqs, power_biased, 'r', label='biased all domain')
    ax_power_bi.set_xlim(10 * dt, wa.time.max())
    ax_power_bi.set_ylabel(r'power / $\sigma^2$  (bias uncorrected)')
    ax_power_bi.set_yticklabels(ax_power_bi.get_yticks(), color='r')

    label = "T={0}"
    for T in (T1, T2, T3):
        ax_power.axvline(T)
        ax_power.annotate(label.format(T), (T, 1))

    ax_power.legend(fontsize='x-small', loc='lower right')
    ax_power_bi.legend(fontsize='x-small', loc='upper right')

    fig.tight_layout()
    fig.savefig('tests/test_power_bias.png')

    return fig
Example #25
0
    Value at frequency = 0 should be 0.
    """
    npt.assert_almost_equal(wavelets.DOG(m=2)(0), 0.867, 3)
    npt.assert_almost_equal(wavelets.DOG(m=6)(0), 0.884, 3)
    npt.assert_almost_equal(wavelets.DOG(m=2).frequency(0), 0, 6)
    npt.assert_almost_equal(wavelets.DOG(m=6).frequency(0), 0, 6)


test_data = np.loadtxt('tests/nino3data.asc', skiprows=3)

nino_time = test_data[:, 0]
nino_dt = np.diff(nino_time).mean()
anomaly_sst = test_data[:, 2]

wa = WaveletAnalysis(anomaly_sst, time=nino_time, dt=nino_dt)


def test_N():
    assert_equal(anomaly_sst.size, wa.N)


def compare_cwt():
    """Compare the output of Scipy's cwt (using direct convolution)
    and my cwt (using fft convolution).
    """
    cwt = scipy.signal.cwt
    fft_cwt = wavelets.cwt

    data = np.random.random(2000)
    wave_anal = WaveletAnalysis(data, wavelet=wavelets.Ricker())
Example #26
0
def generate_datasets(data_source):
    print "Preparing to load data from %s" % data_source

    data_var = 'data'
    if "Patient" in data_source:
        patient_data = []
        patient_data.insert(0, load_for_patient(data_source))
    else:
        start_time_load = time.time()
        patient_data = load_all_patients(data_source)
        print "Time taken to load %d patients: %f" % (
            len(patient_data), time.time() - start_time_load)

    wav_ictal_data = []
    wav_interictal_data = []

    pnum = 1
    start_time_patient_process = time.time()
    for patient in patient_data:
        print "Processing data for patient %d" % pnum
        freq = patient[0]
        ictal_files = patient[1]
        interictal_files = patient[2]

        num_files = min(len(ictal_files), len(interictal_files))
        datasize = int(round(freq / 10))

        for index in range(0, num_files):
            ictal_channels = ictal_files[index].get(data_var)
            for channel in ictal_channels:
                channel = ss.decimate(channel, datasize)
                ictal_wavelets = np.transpose(
                    np.real(WaveletAnalysis(channel).wavelet_transform))
                wav_ictal_data.extend(ictal_wavelets)

            interictal_channels = interictal_files[index].get(data_var)
            for channel in interictal_channels:
                channel = ss.decimate(channel, datasize)
                interictal_wavelets = np.transpose(
                    np.real(WaveletAnalysis(channel).wavelet_transform))
                wav_interictal_data.extend(interictal_wavelets)

        pnum += 1

    pca = PCA(n_components=16)
    ictal_data = pca.fit_transform(wav_ictal_data)
    interictal_data = pca.fit_transform(wav_interictal_data)

    print "Generating datasets"
    num_datapts = len(ictal_data)

    np.random.shuffle(ictal_data)
    np.random.shuffle(interictal_data)

    num_test = int(round(num_datapts / 5))
    num_valid = int((num_datapts - num_test) / 4)
    num_train = num_datapts - num_valid - num_test

    train_data = np.array(
        np.vstack([ictal_data[:num_train], interictal_data[:num_train]]))
    train_labels = np.concatenate([np.ones(num_train), np.zeros(num_train)])

    seed = np.random.get_state()
    np.random.shuffle(train_data)
    np.random.set_state(seed)
    np.random.shuffle(train_labels)

    valid_data = np.array(
        np.vstack([
            ictal_data[num_train:num_train + num_valid],
            interictal_data[num_train:num_train + num_valid]
        ]))
    valid_labels = np.concatenate([np.ones(num_valid), np.zeros(num_valid)])

    seed = np.random.get_state()
    np.random.shuffle(valid_data)
    np.random.set_state(seed)
    np.random.shuffle(valid_labels)

    test_data = np.array(
        np.vstack([
            ictal_data[num_train + num_valid:num_datapts],
            interictal_data[num_train + num_valid:num_datapts]
        ]))
    test_labels = np.concatenate([np.ones(num_test), np.zeros(num_test)])

    print "Time taken to process files: %f" % (time.time() -
                                               start_time_patient_process)
    return train_data, train_labels, valid_data, valid_labels, test_data, test_labels
Example #27
0
                    metavar='depth_level',
                    type=int,
                    help='number of depths in file')
args = parser.parse_args()

for level in range(0, args.depth_level, 1):
    print level
    mooring = args.mooring
    (raw_data, x, dt, time, variance, time_base, depth) = \
    ADCP_2D(args.DataFile+mooring+'/'+mooring+'_ABS.txt', \
    args.DataFile+'/'+mooring+'/'+mooring+'_dates.txt', \
    args.DataFile+'/'+mooring+'/'+mooring+'_depth.txt', level=level)
    fig_name_base = 'images/' + mooring + '_ADCP'
    """-----------------------------wavelet analysis           ---------------------------"""

    wa = WaveletAnalysis(x, time=time, dt=dt, dj=0.125)

    # wavelet power spectrum
    power = wa.wavelet_power
    transform = wa.wavelet_transform

    # scales
    scales = wa.scales

    # associated time vector
    t = wa.time / 24.

    # reconstruction of the original data
    rx = wa.reconstruction()

    # determine acor factor for red noise
Example #28
0
# print(t1)
sig  = np.sin(4 *2 * np.pi * t) 
sig = sig.tolist()
sig1 = np.sin(6 *2 * np.pi * t1) 
sig1 = sig1.tolist()
sig2 =  np.sin(10 *2 * np.pi * t2)
sig2 = sig2.tolist()


sig_data = sig + sig1 + sig2
print(type(sig_data))
sig_data = np.array( sig_data )

widths = np.arange(1, 64)
# cwtmatr = cwt(sig_data, wavelet=Morlet() , widths=widths)
wa = WaveletAnalysis(sig_data, wavelet=Morlet() )
# wavelet power spectrum
power = wa.wavelet_power

# scales 
scales = wa.scales
plt.plot(scales)
plt.show()

# associated time vector
t = wa.time
plt.plot(t)
plt.show()

# reconstruction of the original data
rx = wa.reconstruction()
Example #29
0
def test_power_bias():
    """See if the global wavelet spectrum is biased or not.

    Wavelet transform a signal of 3 distinct Fourier frequencies.

    The power spectrum should contain peaks at the frequencies, all
    of which should be the same height.
    """
    dt = 0.1
    x = np.arange(5000) * dt

    T1 = 20 * dt
    T2 = 100 * dt
    T3 = 500 * dt

    w1 = 2 * np.pi / T1
    w2 = 2 * np.pi / T2
    w3 = 2 * np.pi / T3

    signal = np.cos(w1 * x) + np.cos(w2 * x) + np.cos(w3 * x)

    wa = WaveletAnalysis(signal,
                         dt=dt,
                         wavelet=wavelets.Morlet(),
                         unbias=False)

    power_biased = wa.global_wavelet_spectrum
    wa.unbias = True
    power = wa.global_wavelet_spectrum
    wa.mask_coi = True
    power_coi = wa.global_wavelet_spectrum

    freqs = wa.fourier_periods

    fig, ax = plt.subplots(nrows=2)

    ax_transform = ax[0]
    fig_info = (r"Wavelet transform of "
                r"$cos(2 \pi / {T1}) + cos(2 \pi / {T2}) + cos(2 \pi / {T3})$")
    ax_transform.set_title(fig_info.format(T1=T1, T2=T2, T3=T3))
    X, Y = np.meshgrid(wa.time, wa.fourier_periods)
    ax_transform.set_xlabel('time')
    ax_transform.set_ylabel('fourier period')
    ax_transform.set_ylim(10 * dt, 1000 * dt)
    ax_transform.set_yscale('log')
    ax_transform.contourf(X, Y, wa.wavelet_power, 100)

    # shade the region between the edge and coi
    C, S = wa.coi
    F = wa.fourier_period(S)
    f_max = F.max()
    ax_transform.fill_between(x=C, y1=F, y2=f_max, color='gray', alpha=0.3)

    ax_power = ax[1]
    ax_power.set_title('Global wavelet spectrum '
                       '(estimator for power spectrum)')
    ax_power.plot(freqs, power, 'k', label=r'unbiased all domain')
    ax_power.plot(freqs, power_coi, 'g', label=r'unbiased coi only')
    ax_power.set_xscale('log')
    ax_power.set_xlim(10 * dt, wa.time.max())
    ax_power.set_xlabel('fourier period')
    ax_power.set_ylabel(r'power / $\sigma^2$  (bias corrected)')

    ax_power_bi = ax_power.twinx()
    ax_power_bi.plot(freqs, power_biased, 'r', label='biased all domain')
    ax_power_bi.set_xlim(10 * dt, wa.time.max())
    ax_power_bi.set_ylabel(r'power / $\sigma^2$  (bias uncorrected)')
    ax_power_bi.set_yticklabels(ax_power_bi.get_yticks(), color='r')

    label = "T={0}"
    for T in (T1, T2, T3):
        ax_power.axvline(T)
        ax_power.annotate(label.format(T), (T, 1))

    ax_power.legend(fontsize='x-small', loc='lower right')
    ax_power_bi.legend(fontsize='x-small', loc='upper right')

    fig.tight_layout()
    fig.savefig('tests/test_power_bias.png')

    return fig
Example #30
0
Created on Fri Jul  7 19:59:05 2017

@author: tdong
"""

'Wavelet transformation'
import numpy as np
from wavelets import WaveletAnalysis
from matplotlib import pyplot as plt

# given a signal x(t)
x = np.random.randn(1000)
# and a sample spacing
dt = 1.0 / 512

wa = WaveletAnalysis(x, dt=dt)
#wa = WaveletAnalysis(last1second, dt=dt)

# wavelet power spectrum
power = wa.wavelet_power
re = wa.wavelet_transform.imag
im = wa.wavelet_transform.real
phase = np.arctan(im / re)

# scales
scales = wa.scales

# associated time vector
t = wa.time

# reconstruction of the original data
Example #31
0
# -*- coding: utf-8 -*-
"""
Created on Mon Jul  9 13:46:53 2018

@author: amoody
"""

import pywt
from wavelets import WaveletAnalysis
wa = WaveletAnalysis((s-s.mean()).values, dt=1/365)
# wavelet power spectrum
power = wa.wavelet_power
# scales 
scales = wa.scales
# associated time vector
t = wa.time
# reconstruction of the original data
rx = wa.reconstruction()

fig, ax = plt.subplots()
T, S = np.meshgrid(t, scales)
c1=ax.contourf(T, S, power, 100,cmap='jet')
plt.colorbar(c1)
ax.set_yscale('log')


chisquare(power,axis=1)

#%%
plt.style.use('bmh')
data=DataNull.loc[DataNull.WellNumber == '01N 05E 17BCA1','WaterLevelBelowLSD']