예제 #1
0
def cwt_1d(x, srate, freq_bands, freqs=np.logspace(np.log10(3), np.log10(90), 40),
           n_cycles=np.logspace(np.log10(1), np.log10(10), 40), scale_type=[], base_tstart=[], base_tend=[]):
    """ Compute the Continuous Wavelet Transform for a 1 dimension input array ``x``

    Parameters
    ----------
    x : array
        Input array - must have 1 dimension
    srate : int
        Sampling frequency
    freq_bands : array (size: n_freq_bands * 2)
        Frequency bands of interest. Mean value of both power and phase are calculated on these bands.
    freqs : array
        Pseudo frequency array to use for computing power and phase
    n_cycles : int | array
        Number of cycles for each wavelet
    scale_type : str
        Scaling type. Can be :

        * 'db_ratio' : :math:`P_{norm} = 10 \\cdot log10(\\frac{P}{mean(P_{baseline})})`
        * 'percent_change' : :math:`P_{norm} = 100 \\cdot \\frac{P - mean(P_{baseline})}{mean(P_{baseline})}`
        * 'z_transform' : :math:`P_{norm} = \\frac{P - mean(P_{baseline})}{std(P_{baseline})}`

    base_tstart : int
        Baseline starting point (in sample)
    base_tend : int
        Baseline ending point (in sample)

    Returns
    -------
    coeffs_power : array (size: n_freqs * n_pnts)
        Power extracted from the wavelet coefficients
    coeffs_power_bands : array (size: n_freq_bands * n_pnts)
        Mean of ``coeffs_power`` over the frequency bands defined in ``freq_bands``
    phase : array (size: n_freqs * n_pnts)
        Phase angle
    phase_bands : array (size: n_freq_bands * n_pnts)
        Mean of ``phase`` over the frequency bands defined in ``freq_bands``
    """
    freq_bands, base_tstart, base_tend = np.atleast_1d(freq_bands), np.atleast_1d(base_tstart), np.atleast_1d(base_tend)
    x_3d = np.array([[x.squeeze()]])
    cwt_complex = tfr_array_morlet(x_3d, srate, freqs, n_cycles, output='complex').squeeze()
    cwt_power, cwt_phase = np.abs(cwt_complex)**2, np.angle(cwt_complex)
    # Compute mean over frequency bands
    cwt_bandpower = compute_band_mean(cwt_power, freqs, freq_bands, interp_method='linear')
    cwt_bandphase = compute_band_mean(cwt_phase, freqs, freq_bands, interp_method='linear')
    # Normalization / Scaling
    if scale_type and base_tstart.size == 1 and base_tend.size == 1:
        cwt_power = tf_scaling(cwt_power, int(base_tstart * srate), int(base_tend * srate), scale_type)
        cwt_bandpower = tf_scaling(cwt_bandpower, int(base_tstart * srate), int(base_tend * srate), scale_type)
    return cwt_power, cwt_bandpower, cwt_phase, cwt_bandphase
예제 #2
0
def save_wavelet_complex(n_components):
    all_x_train_samples = []

    for sample in range(1, 22):
        print("sample {}".format(sample))
        epochs = get_epochs(sample, scale=False)
        freqs = np.logspace(*np.log10([2, 15]), num=15)
        n_cycles = freqs / 4.

        print("applying morlet wavelet")
        wavelet_output = tfr_array_morlet(epochs.get_data(),
                                          sfreq=epochs.info['sfreq'],
                                          freqs=freqs,
                                          n_cycles=n_cycles,
                                          output='complex')

        all_x_train_freqs = []

        for freq in range(wavelet_output.shape[2]):
            print("frequency: {}".format(freqs[freq]))

            wavelet_epochs = wavelet_output[:, :, freq, :]
            wavelet_epochs = np.append(wavelet_epochs.real,
                                       wavelet_epochs.imag,
                                       axis=1)

            wavelet_info = mne.create_info(ch_names=wavelet_epochs.shape[1],
                                           sfreq=epochs.info['sfreq'],
                                           ch_types='mag')
            wavelet_epochs = mne.EpochsArray(wavelet_epochs,
                                             info=wavelet_info,
                                             events=epochs.events)

            pca = UnsupervisedSpatialFilter(PCA(n_components=n_components),
                                            average=False)
            print('fitting pca')
            reduced = pca.fit_transform(wavelet_epochs.get_data())
            print('fitting done')

            x_train = reduced.transpose(0, 2, 1).reshape(-1, reduced.shape[1])
            all_x_train_freqs.append(x_train)

        all_x_train_samples.append(all_x_train_freqs)

    print('saving x_train for all samples')
    pickle.dump(
        all_x_train_samples,
        open(
            "DataTransformed/wavelet_complex/15hz/pca_{}/x_train_all_samples.pkl"
            .format(n_components), "wb"))
    print("x_train saved")
예제 #3
0
def time_frequency(data: pd.DataFrame, freqs: List[float], method: str = 'morlet', output: str = 'avg_power', **kwargs
                   ) -> np.ndarray:
    """Calculates time-frequency representation for each node.

    Parameters
    ----------
    data
        Simulation results.
    freqs
        Frequencies of interest.
    method
        Method to be used for TFR calculation. Can be `morlet` for `mne.time_frequency.tfr_array_morlet` or
        `multitaper` for `mne.time_frequency.tfr_array_multitaper`.
    output
        Type of the output variable to be calculated. For options, see `mne.time_frequency.tfr_array_morlet`.
    kwargs
        Additional keyword arguments to be passed to the function used for tfr calculation.

    Returns
    -------
    np.ndarray
        Time-frequency representation (n x f x t) for each node (n) at each frequency of interest (f) and time (t).

    """

    if 'time' in data.columns.values:
        idx = data.pop('time')
        data.index = idx

    if method == 'morlet':

        from mne.time_frequency import tfr_array_morlet
        return tfr_array_morlet(np.reshape(data.values.T, (1, data.shape[1], data.shape[0])),
                                sfreq=1./(data.index[1] - data.index[0]),
                                freqs=freqs, output=output, **kwargs)

    elif method == 'multitaper':

        from mne.time_frequency import tfr_array_multitaper
        return tfr_array_multitaper(np.reshape(data.values.T, (1, data.shape[1], data.shape[0])),
                                    sfreq=1. / (data.index[1] - data.index[0]),
                                    freqs=freqs, output=output, **kwargs)
예제 #4
0
def main():
    model_type = "lda"
    exp_name = "wavelet_class/lsqr/complex"

    save_dir = "Results/{}/{}".format(model_type, exp_name)

    sample_models = []

    for sample in range(1, 22):
        print("sample {}".format(sample))

        epochs = get_epochs(sample, scale=False)

        freqs = np.logspace(*np.log10([2, 25]), num=15)
        n_cycles = freqs / 4.

        print("applying morlet wavelet")
        # returns (n_epochs, n_channels, n_freqs, n_times)
        wavelet_output = tfr_array_morlet(epochs.get_data(),
                                          sfreq=epochs.info['sfreq'],
                                          freqs=freqs,
                                          n_cycles=n_cycles,
                                          output='complex')
        y_train = get_y_train(sample)

        freq_models = []

        for freq in range(wavelet_output.shape[2]):
            print("frequency: {}".format(freqs[freq]))

            wavelet_epochs = wavelet_output[:, :, freq, :]
            wavelet_epochs = np.append(wavelet_epochs.real,
                                       wavelet_epochs.imag,
                                       axis=1)

            wavelet_info = mne.create_info(ch_names=wavelet_epochs.shape[1],
                                           sfreq=epochs.info['sfreq'],
                                           ch_types='mag')
            wavelet_epochs = mne.EpochsArray(wavelet_epochs,
                                             info=wavelet_info,
                                             events=epochs.events)

            reduced = pca(80, wavelet_epochs, plot=False)
            x_train = reduced.transpose(0, 2, 1).reshape(-1, reduced.shape[1])

            time_models = []

            for time in range(50):
                print("time {}".format(time))
                intervals = np.arange(start=time,
                                      stop=x_train.shape[0],
                                      step=50)

                x_sample = x_train[intervals, :]
                y_sample = y_train[intervals]
                model = LinearDiscriminantAnalysis(solver='lsqr',
                                                   shrinkage='auto')
                model.fit(x_sample, y_sample)

                time_models.append(model)

            freq_models.append(time_models)

        sample_models.append(freq_models)
        print('saving models for sample {}'.format(sample))
        pickle.dump(
            freq_models,
            open("{}/sample_{}/all_freq_models.pkl".format(save_dir, sample),
                 "wb"))
        print("models saved")

    print('saving models for all samples')
    pickle.dump(sample_models, open("{}/all_models.pkl".format(save_dir),
                                    "wb"))
    print("models saved")
def tfr_morlet(X, sfreq, freqs, n_cycles, mode):
    '''
    Basic library is mne
    Use Morlet wavelet to do time-frequency transform
    Default choice: preload=True
    :param X: input data array (n_events, n_epochs, n_chans, n_times)
    :param sfreq: sampling frequency
    :param freqs: list, define the frequencies used in time-frequency transform
    :param n_cycles: number of cycles in the Morlet wavelet; 
                    fixed number or one per frequency
    :param mode: complex, power, phase, avg_power, itc, avg_power_itc
        (1)complex: single trial complex (n_events, n_epochs, n_chans, n_freqs, n_times)
        (2)power: single trial power (n_events, n_epochs, n_chans, n_freqs, n_times)
        (3)phase: single trial phase (n_events, n_epochs, n_chans, n_freqs, n_times)
        (4)avg_power: average of single trial power (n_events, n_chans, n_freqs, n_times)
        (5)itc: inter-trial coherence (n_events, n_chans, n_freqs, n_times)
        (6)avg_power_itc: average of singel trial power and inter-trial coherence
            across trials :avg_power+i*itc (n_events, n_chans, n_freqs, n_times)
    Expand data array in channel's dimension to fit tfr_array_morlet if necessary
    '''
    if X.ndim < 4:
        data = np.zeros((X.shape[0], X.shape[1], 2, X.shape[2]))
        for i in range(2):
            data[:, :, i, :] = X
    else:
        data = X

    if mode == 'complex':
        C = np.zeros((data.shape[0], data.shape[1], data.shape[2],
                      freqs.shape[0], data.shape[3]))
        for i in range(data.shape[0]):
            C[i, :, :, :, :] = tfr_array_morlet(data[i, :, :, :],
                                                sfreq=sfreq,
                                                freqs=freqs,
                                                n_cycles=n_cycles,
                                                output='complex')
        return C

    elif mode == 'power':
        PO = np.zeros((data.shape[0], data.shape[1], data.shape[2],
                       freqs.shape[0], data.shape[3]))
        for i in range(data.shape[0]):
            PO[i, :, :, :, :] = tfr_array_morlet(data[i, :, :, :],
                                                 sfreq=sfreq,
                                                 freqs=freqs,
                                                 n_cycles=n_cycles,
                                                 output='power')
        return PO

    elif mode == 'phase':
        PH = np.zeros((data.shape[0], data.shape[1], data.shape[2],
                       freqs.shape[0], data.shape[3]))
        for i in range(data.shape[0]):
            PH[i, :, :, :, :] = tfr_array_morlet(data[i, :, :, :],
                                                 sfreq=sfreq,
                                                 freqs=freqs,
                                                 n_cycles=n_cycles,
                                                 output='phase')
        return PH

    elif mode == 'avg_power':
        AP = np.zeros(
            (data.shape[0], data.shape[2], freqs.shape[0], data.shape[3]))
        for i in range(data.shape[0]):
            AP[i, :, :, :] = tfr_array_morlet(data[i, :, :, :],
                                              sfreq=sfreq,
                                              freqs=freqs,
                                              n_cycles=n_cycles,
                                              output='avg_power')
        return AP

    elif mode == 'itc':
        ITC = np.zeros(
            (data.shape[0], data.shape[2], freqs.shape[0], data.shape[3]))
        for i in range(data.shape[0]):
            ITC[i, :, :, :] = tfr_array_morlet(data[i, :, :, :],
                                               sfreq=sfreq,
                                               freqs=freqs,
                                               n_cycles=n_cycles,
                                               output='itc')
        return ITC

    elif mode == 'avg_power_itc':
        API = np.zeros(
            (data.shape[0], data.shape[2], freqs.shape[0], data.shape[3]))
        for i in range(data.shape[0]):
            API[i, :, :, :] = tfr_array_morlet(data[i, :, :, :],
                                               sfreq=sfreq,
                                               freqs=freqs,
                                               n_cycles=n_cycles,
                                               output='avg_power_itc')
        return API
              vmax=vmax,
              title='Using Morlet wavelets and EpochsTFR',
              show=False)

###############################################################################
# Operating on arrays
# -------------------
#
# MNE also has versions of the functions above which operate on numpy arrays
# instead of MNE objects. They expect inputs of the shape
# ``(n_epochs, n_channels, n_times)``. They will also return a numpy array
# of shape ``(n_epochs, n_channels, n_freqs, n_times)``.

power = tfr_array_morlet(epochs.get_data(),
                         sfreq=epochs.info['sfreq'],
                         freqs=freqs,
                         n_cycles=n_cycles,
                         output='avg_power')
# Baseline the output
rescale(power, epochs.times, (0., 0.1), mode='mean', copy=False)
fig, ax = plt.subplots()
mesh = ax.pcolormesh(epochs.times * 1000,
                     freqs,
                     power[0],
                     cmap='RdBu_r',
                     vmin=vmin,
                     vmax=vmax)
ax.set_title('TFR calculated on a numpy array')
ax.set(ylim=freqs[[0, -1]], xlabel='Time (ms)')
fig.colorbar(mesh)
plt.tight_layout()
power = tfr_morlet(epochs, freqs=freqs,
                   n_cycles=n_cycles, return_itc=False, average=False)
print(type(power))
avgpower = power.average()
avgpower.plot([0], baseline=(0., 0.1), mode='mean', vmin=vmin, vmax=vmax,
              title='Using Morlet wavelets and EpochsTFR', show=False)

###############################################################################
# Operating on arrays
# -------------------
#
# MNE also has versions of the functions above which operate on numpy arrays
# instead of MNE objects. They expect inputs of the shape
# ``(n_epochs, n_channels, n_times)``. They will also return a numpy array
# of shape ``(n_epochs, n_channels, n_frequencies, n_times)``.

power = tfr_array_morlet(epochs.get_data(), sfreq=epochs.info['sfreq'],
                         frequencies=freqs, n_cycles=n_cycles,
                         output='avg_power')
# Baseline the output
rescale(power, epochs.times, (0., 0.1), mode='mean', copy=False)
fig, ax = plt.subplots()
mesh = ax.pcolormesh(epochs.times * 1000, freqs, power[0],
                     cmap='RdBu_r', vmin=vmin, vmax=vmax)
ax.set_title('TFR calculated on a numpy array')
ax.set(ylim=freqs[[0, -1]], xlabel='Time (ms)')
fig.colorbar(mesh)
plt.tight_layout()

plt.show()
예제 #8
0
파일: TF.py 프로젝트: seapsy/DvM
    def TFanalysisMNE(self,
                      sj,
                      cnds,
                      cnd_header,
                      base_period,
                      time_period,
                      method='hilbert',
                      flip=None,
                      base_type='conspec',
                      downsample=1,
                      min_freq=5,
                      max_freq=40,
                      num_frex=25,
                      cycle_range=(3, 12),
                      freq_scaling='log'):
        '''
		Time frequency analysis using either morlet waveforms or filter-hilbertmethod for time frequency decomposition

		Add option to subtract ERP to get evoked power
		Add option to match trial number

		Arguments
		- - - - - 
		sj (int): subject number
		cnds (list): list of conditions as stored in behavior file
		cnd_header (str): key in behavior file that contains condition info
		base_period (tuple | list): time window used for baseline correction
		time_period (tuple | list): time window of interest
		method (str): specifies whether hilbert or wavelet convolution is used for time-frequency decomposition
		flip (dict): flips a subset of trials. Key of dictionary specifies header in beh that contains flip info 
		List in dict contains variables that need to be flipped. Note: flipping is done from right to left hemifield
		base_type (str): specifies whether DB conversion is condition specific ('conspec') or averaged across conditions ('conavg')
		downsample (int): factor used for downsampling (aplied after filtering). Default is no downsampling
		min_freq (int): minimum frequency for TF analysis
		max_freq (int): maximum frequency for TF analysis
		num_frex (int): number of frequencies in TF analysis
		cycle_range (tuple): number of cycles increases in the same number of steps used for scaling
		freq_scaling (str): specify whether frequencies are linearly or logarithmically spaced. 
							If main results are expected in lower frequency bands logarithmic scale 
							is adviced, whereas linear scale is advised for expected results in higher
							frequency bands
		Returns
		- - - 
		
		wavelets(array): 


	
		'''

        # read in data
        eegs, beh, times, s_freq, ch_names = self.selectTFData(sj)

        # flip subset of trials (allows for lateralization indices)
        if flip != None:
            key = flip.keys()[0]
            eegs = self.topoFlip(eegs,
                                 beh[key],
                                 ch_names,
                                 left=[flip.get(key)])

        # get parameters
        nr_time = eegs.shape[-1]
        nr_chan = eegs.shape[1]

        freqs = np.logspace(np.log10(min_freq), np.log10(max_freq), num_frex)
        nr_cycles = np.logspace(np.log10(cycle_range[0]),
                                np.log10(cycle_range[1]), num_frex)

        base_s, base_e = [np.argmin(abs(times - b)) for b in base_period]
        idx_time = np.where(
            (times >= time_period[0]) * (times <= time_period[1]))[0]
        idx_2_save = np.array(
            [idx for i, idx in enumerate(idx_time) if i % downsample == 0])

        # initiate dict
        tf = {}
        base = {}

        # loop over conditions
        for c, cnd in enumerate(cnds):
            tf.update({cnd: {}})
            base.update({cnd: np.zeros((num_frex, nr_chan))})

            cnd_idx = np.where(beh['block_type'] == cnd)[0]

            power = tfr_array_morlet(eegs[cnd_idx],
                                     sfreq=s_freq,
                                     freqs=freqs,
                                     n_cycles=nr_cycles,
                                     output='avg_power')

            # update cnd dict with power values
            tf[cnd]['power'] = np.swapaxes(power, 0, 1)
            tf[cnd]['base_power'] = rescale(np.swapaxes(power, 0, 1),
                                            times,
                                            base_period,
                                            mode='logratio')
            tf[cnd]['phase'] = '?'

        # save TF matrices
        with open(
                self.FolderTracker(['tf', method],
                                   '{}-tf-mne.pickle'.format(sj)),
                'wb') as handle:
            pickle.dump(tf, handle)

        # store dictionary with variables for plotting
        plot_dict = {
            'ch_names': ch_names,
            'times': times[idx_2_save],
            'frex': freqs
        }

        with open(
                self.FolderTracker(['tf', method],
                                   filename='plot_dict.pickle'),
                'wb') as handle:
            pickle.dump(plot_dict, handle)
예제 #9
0
def main():
    model_type = "lda"
    exp_name = "freq_gen_matrix/"

    for i, sample in enumerate(range(1, 22)):
        print("sample {}".format(sample))

        if not os.path.isdir("Results/{}/{}/sample_{}".format(
                model_type, exp_name, sample)):
            os.mkdir("Results/{}/{}/sample_{}".format(model_type, exp_name,
                                                      sample))

        epochs = get_epochs(sample, scale=False)
        y_train = epochs.events[:, 2]

        freqs = np.logspace(*np.log10([2, 25]), num=15)
        n_cycles = freqs / 4.
        string_freqs = [round(x, 2) for x in freqs]

        print("applying morlet wavelet")

        wavelet_output = tfr_array_morlet(epochs.get_data(),
                                          sfreq=epochs.info['sfreq'],
                                          freqs=freqs,
                                          n_cycles=n_cycles,
                                          output='complex')

        time_results = np.zeros(
            (wavelet_output.shape[3], len(freqs), len(freqs)))

        for time in range(wavelet_output.shape[3]):
            print("time: {}".format(time))

            wavelet_epochs = wavelet_output[:, :, :, time]
            wavelet_epochs = np.append(wavelet_epochs.real,
                                       wavelet_epochs.imag,
                                       axis=1)

            wavelet_info = mne.create_info(ch_names=wavelet_epochs.shape[1],
                                           sfreq=epochs.info['sfreq'],
                                           ch_types='mag')
            wavelet_epochs = mne.EpochsArray(wavelet_epochs,
                                             info=wavelet_info,
                                             events=epochs.events)

            x_train = pca(80, wavelet_epochs, plot=False)

            model = LinearModel(
                LinearDiscriminantAnalysis(solver='lsqr', shrinkage='auto'))
            freq_gen = GeneralizingEstimator(model,
                                             n_jobs=1,
                                             scoring='accuracy',
                                             verbose=True)
            scores = cross_val_multiscore(freq_gen,
                                          x_train,
                                          y_train,
                                          cv=5,
                                          n_jobs=1)
            scores = np.mean(scores, axis=0)
            time_results[time] = scores

            sns.set()
            ax = sns.barplot(
                np.sort(string_freqs),
                np.diag(scores),
            )
            ax.set(ylim=(0, 0.8),
                   xlabel='Frequencies',
                   ylabel='Accuracy',
                   title='Cross Val Accuracy {} for Subject {} for Time {}'.
                   format(model_type, sample, time))
            ax.axhline(0.12, color='k', linestyle='--')
            ax.figure.set_size_inches(8, 6)
            ax.figure.savefig(
                "Results/{}/{}/sample_{}/time_{}_accuracy.png".format(
                    model_type, exp_name, sample, time),
                dpi=300)
            plt.close('all')
            # plt.show()

            fig, ax = plt.subplots(1, 1)
            im = ax.imshow(scores,
                           interpolation='lanczos',
                           origin='lower',
                           cmap='RdBu_r',
                           extent=[2, 25, 2, 25],
                           vmin=0.,
                           vmax=0.8)
            ax.set_xlabel('Testing Frequency (hz)')
            ax.set_ylabel('Training Frequency (hz)')
            ax.set_title(
                'Frequency generalization for Subject {} at Time {}'.format(
                    sample, time))
            plt.colorbar(im, ax=ax)
            ax.grid(False)
            ax.figure.savefig(
                "Results/{}/{}/sample_{}/time_{}_matrix.png".format(
                    model_type, exp_name, sample, time),
                dpi=300)
            plt.close('all')
            # plt.show()

        time_results = time_results.reshape(time_results.shape[0], -1)
        all_results_df = pd.DataFrame(time_results)
        all_results_df.to_csv(
            "Results/{}/{}/sample_{}/all_time_matrix_results.csv".format(
                model_type, exp_name, sample))
예제 #10
0
        method,
        pick_ori=None)

    for label in labels_sel:
        label_ts = []
        for j in range(len(stcs)):
            ts = mne.extract_label_time_course(
                stcs[j], labels=label, src=src, mode="mean_flip")
            ts = np.squeeze(ts)
            # ts *= np.sign(ts[np.argmax(np.abs(ts))])
            label_ts.append(ts)

        label_ts = np.asarray(label_ts)
        label_ts = label_ts[:, np.newaxis, :]
        tfr = tfr_array_morlet(
            label_ts,
            epochs.info["sfreq"],
            freqs,
            # use_fft=True,
            n_cycles=n_cycle)

        np.save(tf_folder + "%s_%s_%s_%s_%s_pca_snr_3-tfr" %
                (subject, condition[:3], condition[4:], label.name, method),
                tfr)
        np.save(tf_folder + "%s_%s_%s_%s_%s_pca_snr_3-ts" %
                (subject, condition[:3], condition[4:], label.name, method),
                label_ts)

    del stcs
    del tfr
예제 #11
0
import mne
import numpy as np
from my_settings import (epochs_folder, tf_folder)
from mne.time_frequency import tfr_array_morlet
import sys

subject = sys.argv[1]

freqs = np.arange(8, 13, 1)  # define frequencies of interest
n_cycles = 4.  # freqs / 2.  # different number of cycle per frequency

sides = ["left", "right"]
conditions = ["ctl", "ent"]

epochs = mne.read_epochs(
    epochs_folder + "%s_trial_start-epo.fif" % subject, preload=True)
epochs.resample(250)

for cond in conditions:
    for side in sides:
        power = tfr_array_morlet(
            epochs[cond + "/" + side],
            sfreq=epochs.info["sfreq"],
            frequencies=freqs,
            n_cycles=n_cycles,
            use_fft=True,
            output="complex",
            n_jobs=1)
        np.save(tf_folder + "%s_%s_%s-4-complex-tfr.npy" %
                (subject, cond, side), power)
예제 #12
0
def main():
    model_type = "lda"
    exp_name = "wavelet_class/lsqr/complex/15hz"

    for sample in range(1, 22):
        print("sample {}".format(sample))

        if not os.path.isdir("Results/{}/{}/sample_{}".format(
                model_type, exp_name, sample)):
            os.mkdir("Results/{}/{}/sample_{}".format(model_type, exp_name,
                                                      sample))

        epochs = get_epochs(sample, scale=False)

        freqs = np.logspace(*np.log10([2, 15]), num=15)
        n_cycles = freqs / 4.

        print("applying morlet wavelet")

        # returns (n_epochs, n_channels, n_freqs, n_times)
        if exp_name.split("/")[-2] == "real" or exp_name.split(
                "/")[-2] == "complex":
            wavelet_output = tfr_array_morlet(epochs.get_data(),
                                              sfreq=epochs.info['sfreq'],
                                              freqs=freqs,
                                              n_cycles=n_cycles,
                                              output='complex')
        elif exp_name.split("/")[-2] == "power":
            wavelet_output = tfr_array_morlet(epochs.get_data(),
                                              sfreq=epochs.info['sfreq'],
                                              freqs=freqs,
                                              n_cycles=n_cycles,
                                              output='power')
        elif exp_name.split("/")[-2] == "phase":
            wavelet_output = tfr_array_morlet(epochs.get_data(),
                                              sfreq=epochs.info['sfreq'],
                                              freqs=freqs,
                                              n_cycles=n_cycles,
                                              output='phase')
        else:
            raise ValueError("{} not an output of wavelet function".format(
                exp_name.split("/")[-2]))

        y_train = get_y_train(sample)

        freq_results = np.zeros((wavelet_output.shape[2], 50))

        for freq in range(wavelet_output.shape[2]):
            print("frequency: {}".format(freqs[freq]))

            wavelet_epochs = wavelet_output[:, :, freq, :]

            if exp_name.split("/")[-2] == "real":
                wavelet_epochs = wavelet_epochs.real
            if exp_name.split("/")[-2] == "complex":
                wavelet_epochs = np.append(wavelet_epochs.real,
                                           wavelet_epochs.imag,
                                           axis=1)

            wavelet_info = mne.create_info(ch_names=wavelet_epochs.shape[1],
                                           sfreq=epochs.info['sfreq'],
                                           ch_types='mag')
            wavelet_epochs = mne.EpochsArray(wavelet_epochs,
                                             info=wavelet_info,
                                             events=epochs.events)

            reduced = pca(80, wavelet_epochs, plot=False)
            x_train = reduced.transpose(0, 2, 1).reshape(-1, reduced.shape[1])

            results = linear_models(x_train, y_train, model_type=model_type)
            freq_results[freq] = results

            curr_freq = str(round(freqs[freq], 2))

            sns.set()
            ax = sns.lineplot(data=results, dashes=False)
            ax.set(ylim=(0, 1),
                   xlabel='Time',
                   ylabel='Accuracy',
                   title='Cross Val Accuracy {} for Subject {} for Freq {}'.
                   format(model_type, sample, curr_freq))
            plt.axvline(x=15, color='b', linestyle='--')
            ax.figure.savefig("Results/{}/{}/sample_{}/freq_{}.png".format(
                model_type, exp_name, sample, curr_freq),
                              dpi=300)
            plt.clf()

        all_results_df = pd.DataFrame(freq_results)
        all_results_df.to_csv(
            "Results/{}/{}/sample_{}/all_freq_results.csv".format(
                model_type, exp_name, sample))
fmax = 150;
bandwidth = 3
tmin_crop = -0.5
tmax_crop = 1.75
fname = 'visual_channels_BP_montage.csv'
ieeg_path = cf.cifar_ieeg_path(home='~')
visual_chan_table_path = ieeg_path.joinpath(fname)
df = pd.read_csv(visual_chan_table_path)
df_sorted = pd.DataFrame(columns=df.columns)

ts = [0]*len(subjects)
for s in range(len(subjects)):
    sub_id = subjects[s]  
    subject = cf.Subject(name=sub_id)
    datadir = subject.processing_stage_path(proc=proc)
    visual_chan = subject.pick_visual_chan()
    HFB = hf.visually_responsive_HFB(sub_id = sub_id)
    ts[s], time = fun.ts_all_categories(HFB, sfreq=sfreq, tmin_crop=tmin_crop, tmax_crop=tmax_crop)
    df_sorted = df_sorted.append(df.loc[df['subject_id']==sub_id].sort_values(by='latency'))

df_sorted = df_sorted.reset_index(drop=True)
X = np.concatenate(tuple(ts), axis=0)
X = np.transpose(X, (3,2,0,1))
#%%
nstate = X.shape[-1]
time_freq = [0]*nstate
freqs = np.arange(5., 100., 3.)
for i in range(nstate):
    epoch_data = X[i,...]
    time_freq[i] = tfr_array_morlet(epoch_data, sfreq, freqs, n_cycles=freqs/2., 
                                        output='complex')
예제 #14
0
    def get_normalized_cwt_data(self, channels=(7, 9, 11)):
        '''
        Applies Morlet Continuous Wavelet Transform for filter extraction on the Epoch data from a
        certain channel range e.g. (C3, CPz, C4), normalizes it and then returns the data in a tuple
        in the form of (Trial, Freq Sample, Time Sample, Channel)
        alongside its respective labels (i.e. (Left_MEpoch, Left Labels)). This is done to reduce the
        computational load during training time.
        NOTE: the absolute value is taken to remove imaginary components
        :return: (Normalized, CWT Task Data, Task Label Array)
        '''

        if channels is None:
            channels = [7, 9, 11]

        left_filtered, right_filtered, bimanual_filtered = self.get_crop_filtered_data_from_fif(
        )

        left_filtered, left_label = left_filtered[0], left_filtered[1]
        left_filtered = left_filtered[:, channels, :]

        right_filtered, right_label = right_filtered[0], right_filtered[1]
        right_filtered = right_filtered[:, channels, :]

        bimanual_filtered, bimanual_label = bimanual_filtered[
            0], bimanual_filtered[1]
        bimanual_filtered = bimanual_filtered[:, channels, :]

        freqs = np.logspace(*np.log10([5, 30]), num=25)
        n_cycles = freqs / 2.
        sfreq = 256

        # Perform a Morlet CWT on each epoch for feature extraction
        Left_MEpoch = time_frequency.tfr_array_morlet(left_filtered,
                                                      sfreq=sfreq,
                                                      freqs=freqs,
                                                      n_cycles=n_cycles,
                                                      use_fft=True,
                                                      decim=3,
                                                      output='complex',
                                                      n_jobs=1)

        Right_MEpoch = time_frequency.tfr_array_morlet(right_filtered,
                                                       sfreq=sfreq,
                                                       freqs=freqs,
                                                       n_cycles=n_cycles,
                                                       use_fft=True,
                                                       decim=3,
                                                       output='complex',
                                                       n_jobs=1)

        Biman_MEpoch = time_frequency.tfr_array_morlet(bimanual_filtered,
                                                       sfreq=sfreq,
                                                       freqs=freqs,
                                                       n_cycles=n_cycles,
                                                       use_fft=True,
                                                       decim=3,
                                                       output='complex',
                                                       n_jobs=1)

        Left_MEpoch, Right_MEpoch, Biman_MEpoch = np.abs(Left_MEpoch), np.abs(
            Right_MEpoch), np.abs(Biman_MEpoch)

        # Swap the axes to feed into GAN models later
        Left_MEpoch = np.swapaxes(np.swapaxes(Left_MEpoch, 1, 3), 1, 2)
        Right_MEpoch = np.swapaxes(np.swapaxes(Right_MEpoch, 1, 3), 1, 2)
        Biman_MEpoch = np.swapaxes(np.swapaxes(Biman_MEpoch, 1, 3), 1, 2)

        # ... then normalise the data for faster training
        Norm_Left_MEpoch = 2 * (Left_MEpoch - np.min(Left_MEpoch, axis=0)) / \
                           (np.max(Left_MEpoch, axis=0) - np.min(Left_MEpoch, axis=0)) - 1
        Norm_Right_MEpoch = 2 * (Right_MEpoch - np.min(Right_MEpoch, axis=0)) / \
                            (np.max(Right_MEpoch, axis=0) - np.min(Right_MEpoch, axis=0)) - 1
        Norm_Biman_MEpoch = 2 * (Biman_MEpoch - np.min(Biman_MEpoch, axis=0)) / \
                            (np.max(Biman_MEpoch, axis=0) - np.min(Biman_MEpoch, axis=0)) - 1

        return (Norm_Left_MEpoch,
                left_label), (Norm_Right_MEpoch,
                              right_label), (Norm_Biman_MEpoch, bimanual_label)