def cwt_1d(x, srate, freq_bands, freqs=np.logspace(np.log10(3), np.log10(90), 40), n_cycles=np.logspace(np.log10(1), np.log10(10), 40), scale_type=[], base_tstart=[], base_tend=[]): """ Compute the Continuous Wavelet Transform for a 1 dimension input array ``x`` Parameters ---------- x : array Input array - must have 1 dimension srate : int Sampling frequency freq_bands : array (size: n_freq_bands * 2) Frequency bands of interest. Mean value of both power and phase are calculated on these bands. freqs : array Pseudo frequency array to use for computing power and phase n_cycles : int | array Number of cycles for each wavelet scale_type : str Scaling type. Can be : * 'db_ratio' : :math:`P_{norm} = 10 \\cdot log10(\\frac{P}{mean(P_{baseline})})` * 'percent_change' : :math:`P_{norm} = 100 \\cdot \\frac{P - mean(P_{baseline})}{mean(P_{baseline})}` * 'z_transform' : :math:`P_{norm} = \\frac{P - mean(P_{baseline})}{std(P_{baseline})}` base_tstart : int Baseline starting point (in sample) base_tend : int Baseline ending point (in sample) Returns ------- coeffs_power : array (size: n_freqs * n_pnts) Power extracted from the wavelet coefficients coeffs_power_bands : array (size: n_freq_bands * n_pnts) Mean of ``coeffs_power`` over the frequency bands defined in ``freq_bands`` phase : array (size: n_freqs * n_pnts) Phase angle phase_bands : array (size: n_freq_bands * n_pnts) Mean of ``phase`` over the frequency bands defined in ``freq_bands`` """ freq_bands, base_tstart, base_tend = np.atleast_1d(freq_bands), np.atleast_1d(base_tstart), np.atleast_1d(base_tend) x_3d = np.array([[x.squeeze()]]) cwt_complex = tfr_array_morlet(x_3d, srate, freqs, n_cycles, output='complex').squeeze() cwt_power, cwt_phase = np.abs(cwt_complex)**2, np.angle(cwt_complex) # Compute mean over frequency bands cwt_bandpower = compute_band_mean(cwt_power, freqs, freq_bands, interp_method='linear') cwt_bandphase = compute_band_mean(cwt_phase, freqs, freq_bands, interp_method='linear') # Normalization / Scaling if scale_type and base_tstart.size == 1 and base_tend.size == 1: cwt_power = tf_scaling(cwt_power, int(base_tstart * srate), int(base_tend * srate), scale_type) cwt_bandpower = tf_scaling(cwt_bandpower, int(base_tstart * srate), int(base_tend * srate), scale_type) return cwt_power, cwt_bandpower, cwt_phase, cwt_bandphase
def save_wavelet_complex(n_components): all_x_train_samples = [] for sample in range(1, 22): print("sample {}".format(sample)) epochs = get_epochs(sample, scale=False) freqs = np.logspace(*np.log10([2, 15]), num=15) n_cycles = freqs / 4. print("applying morlet wavelet") wavelet_output = tfr_array_morlet(epochs.get_data(), sfreq=epochs.info['sfreq'], freqs=freqs, n_cycles=n_cycles, output='complex') all_x_train_freqs = [] for freq in range(wavelet_output.shape[2]): print("frequency: {}".format(freqs[freq])) wavelet_epochs = wavelet_output[:, :, freq, :] wavelet_epochs = np.append(wavelet_epochs.real, wavelet_epochs.imag, axis=1) wavelet_info = mne.create_info(ch_names=wavelet_epochs.shape[1], sfreq=epochs.info['sfreq'], ch_types='mag') wavelet_epochs = mne.EpochsArray(wavelet_epochs, info=wavelet_info, events=epochs.events) pca = UnsupervisedSpatialFilter(PCA(n_components=n_components), average=False) print('fitting pca') reduced = pca.fit_transform(wavelet_epochs.get_data()) print('fitting done') x_train = reduced.transpose(0, 2, 1).reshape(-1, reduced.shape[1]) all_x_train_freqs.append(x_train) all_x_train_samples.append(all_x_train_freqs) print('saving x_train for all samples') pickle.dump( all_x_train_samples, open( "DataTransformed/wavelet_complex/15hz/pca_{}/x_train_all_samples.pkl" .format(n_components), "wb")) print("x_train saved")
def time_frequency(data: pd.DataFrame, freqs: List[float], method: str = 'morlet', output: str = 'avg_power', **kwargs ) -> np.ndarray: """Calculates time-frequency representation for each node. Parameters ---------- data Simulation results. freqs Frequencies of interest. method Method to be used for TFR calculation. Can be `morlet` for `mne.time_frequency.tfr_array_morlet` or `multitaper` for `mne.time_frequency.tfr_array_multitaper`. output Type of the output variable to be calculated. For options, see `mne.time_frequency.tfr_array_morlet`. kwargs Additional keyword arguments to be passed to the function used for tfr calculation. Returns ------- np.ndarray Time-frequency representation (n x f x t) for each node (n) at each frequency of interest (f) and time (t). """ if 'time' in data.columns.values: idx = data.pop('time') data.index = idx if method == 'morlet': from mne.time_frequency import tfr_array_morlet return tfr_array_morlet(np.reshape(data.values.T, (1, data.shape[1], data.shape[0])), sfreq=1./(data.index[1] - data.index[0]), freqs=freqs, output=output, **kwargs) elif method == 'multitaper': from mne.time_frequency import tfr_array_multitaper return tfr_array_multitaper(np.reshape(data.values.T, (1, data.shape[1], data.shape[0])), sfreq=1. / (data.index[1] - data.index[0]), freqs=freqs, output=output, **kwargs)
def main(): model_type = "lda" exp_name = "wavelet_class/lsqr/complex" save_dir = "Results/{}/{}".format(model_type, exp_name) sample_models = [] for sample in range(1, 22): print("sample {}".format(sample)) epochs = get_epochs(sample, scale=False) freqs = np.logspace(*np.log10([2, 25]), num=15) n_cycles = freqs / 4. print("applying morlet wavelet") # returns (n_epochs, n_channels, n_freqs, n_times) wavelet_output = tfr_array_morlet(epochs.get_data(), sfreq=epochs.info['sfreq'], freqs=freqs, n_cycles=n_cycles, output='complex') y_train = get_y_train(sample) freq_models = [] for freq in range(wavelet_output.shape[2]): print("frequency: {}".format(freqs[freq])) wavelet_epochs = wavelet_output[:, :, freq, :] wavelet_epochs = np.append(wavelet_epochs.real, wavelet_epochs.imag, axis=1) wavelet_info = mne.create_info(ch_names=wavelet_epochs.shape[1], sfreq=epochs.info['sfreq'], ch_types='mag') wavelet_epochs = mne.EpochsArray(wavelet_epochs, info=wavelet_info, events=epochs.events) reduced = pca(80, wavelet_epochs, plot=False) x_train = reduced.transpose(0, 2, 1).reshape(-1, reduced.shape[1]) time_models = [] for time in range(50): print("time {}".format(time)) intervals = np.arange(start=time, stop=x_train.shape[0], step=50) x_sample = x_train[intervals, :] y_sample = y_train[intervals] model = LinearDiscriminantAnalysis(solver='lsqr', shrinkage='auto') model.fit(x_sample, y_sample) time_models.append(model) freq_models.append(time_models) sample_models.append(freq_models) print('saving models for sample {}'.format(sample)) pickle.dump( freq_models, open("{}/sample_{}/all_freq_models.pkl".format(save_dir, sample), "wb")) print("models saved") print('saving models for all samples') pickle.dump(sample_models, open("{}/all_models.pkl".format(save_dir), "wb")) print("models saved")
def tfr_morlet(X, sfreq, freqs, n_cycles, mode): ''' Basic library is mne Use Morlet wavelet to do time-frequency transform Default choice: preload=True :param X: input data array (n_events, n_epochs, n_chans, n_times) :param sfreq: sampling frequency :param freqs: list, define the frequencies used in time-frequency transform :param n_cycles: number of cycles in the Morlet wavelet; fixed number or one per frequency :param mode: complex, power, phase, avg_power, itc, avg_power_itc (1)complex: single trial complex (n_events, n_epochs, n_chans, n_freqs, n_times) (2)power: single trial power (n_events, n_epochs, n_chans, n_freqs, n_times) (3)phase: single trial phase (n_events, n_epochs, n_chans, n_freqs, n_times) (4)avg_power: average of single trial power (n_events, n_chans, n_freqs, n_times) (5)itc: inter-trial coherence (n_events, n_chans, n_freqs, n_times) (6)avg_power_itc: average of singel trial power and inter-trial coherence across trials :avg_power+i*itc (n_events, n_chans, n_freqs, n_times) Expand data array in channel's dimension to fit tfr_array_morlet if necessary ''' if X.ndim < 4: data = np.zeros((X.shape[0], X.shape[1], 2, X.shape[2])) for i in range(2): data[:, :, i, :] = X else: data = X if mode == 'complex': C = np.zeros((data.shape[0], data.shape[1], data.shape[2], freqs.shape[0], data.shape[3])) for i in range(data.shape[0]): C[i, :, :, :, :] = tfr_array_morlet(data[i, :, :, :], sfreq=sfreq, freqs=freqs, n_cycles=n_cycles, output='complex') return C elif mode == 'power': PO = np.zeros((data.shape[0], data.shape[1], data.shape[2], freqs.shape[0], data.shape[3])) for i in range(data.shape[0]): PO[i, :, :, :, :] = tfr_array_morlet(data[i, :, :, :], sfreq=sfreq, freqs=freqs, n_cycles=n_cycles, output='power') return PO elif mode == 'phase': PH = np.zeros((data.shape[0], data.shape[1], data.shape[2], freqs.shape[0], data.shape[3])) for i in range(data.shape[0]): PH[i, :, :, :, :] = tfr_array_morlet(data[i, :, :, :], sfreq=sfreq, freqs=freqs, n_cycles=n_cycles, output='phase') return PH elif mode == 'avg_power': AP = np.zeros( (data.shape[0], data.shape[2], freqs.shape[0], data.shape[3])) for i in range(data.shape[0]): AP[i, :, :, :] = tfr_array_morlet(data[i, :, :, :], sfreq=sfreq, freqs=freqs, n_cycles=n_cycles, output='avg_power') return AP elif mode == 'itc': ITC = np.zeros( (data.shape[0], data.shape[2], freqs.shape[0], data.shape[3])) for i in range(data.shape[0]): ITC[i, :, :, :] = tfr_array_morlet(data[i, :, :, :], sfreq=sfreq, freqs=freqs, n_cycles=n_cycles, output='itc') return ITC elif mode == 'avg_power_itc': API = np.zeros( (data.shape[0], data.shape[2], freqs.shape[0], data.shape[3])) for i in range(data.shape[0]): API[i, :, :, :] = tfr_array_morlet(data[i, :, :, :], sfreq=sfreq, freqs=freqs, n_cycles=n_cycles, output='avg_power_itc') return API
vmax=vmax, title='Using Morlet wavelets and EpochsTFR', show=False) ############################################################################### # Operating on arrays # ------------------- # # MNE also has versions of the functions above which operate on numpy arrays # instead of MNE objects. They expect inputs of the shape # ``(n_epochs, n_channels, n_times)``. They will also return a numpy array # of shape ``(n_epochs, n_channels, n_freqs, n_times)``. power = tfr_array_morlet(epochs.get_data(), sfreq=epochs.info['sfreq'], freqs=freqs, n_cycles=n_cycles, output='avg_power') # Baseline the output rescale(power, epochs.times, (0., 0.1), mode='mean', copy=False) fig, ax = plt.subplots() mesh = ax.pcolormesh(epochs.times * 1000, freqs, power[0], cmap='RdBu_r', vmin=vmin, vmax=vmax) ax.set_title('TFR calculated on a numpy array') ax.set(ylim=freqs[[0, -1]], xlabel='Time (ms)') fig.colorbar(mesh) plt.tight_layout()
power = tfr_morlet(epochs, freqs=freqs, n_cycles=n_cycles, return_itc=False, average=False) print(type(power)) avgpower = power.average() avgpower.plot([0], baseline=(0., 0.1), mode='mean', vmin=vmin, vmax=vmax, title='Using Morlet wavelets and EpochsTFR', show=False) ############################################################################### # Operating on arrays # ------------------- # # MNE also has versions of the functions above which operate on numpy arrays # instead of MNE objects. They expect inputs of the shape # ``(n_epochs, n_channels, n_times)``. They will also return a numpy array # of shape ``(n_epochs, n_channels, n_frequencies, n_times)``. power = tfr_array_morlet(epochs.get_data(), sfreq=epochs.info['sfreq'], frequencies=freqs, n_cycles=n_cycles, output='avg_power') # Baseline the output rescale(power, epochs.times, (0., 0.1), mode='mean', copy=False) fig, ax = plt.subplots() mesh = ax.pcolormesh(epochs.times * 1000, freqs, power[0], cmap='RdBu_r', vmin=vmin, vmax=vmax) ax.set_title('TFR calculated on a numpy array') ax.set(ylim=freqs[[0, -1]], xlabel='Time (ms)') fig.colorbar(mesh) plt.tight_layout() plt.show()
def TFanalysisMNE(self, sj, cnds, cnd_header, base_period, time_period, method='hilbert', flip=None, base_type='conspec', downsample=1, min_freq=5, max_freq=40, num_frex=25, cycle_range=(3, 12), freq_scaling='log'): ''' Time frequency analysis using either morlet waveforms or filter-hilbertmethod for time frequency decomposition Add option to subtract ERP to get evoked power Add option to match trial number Arguments - - - - - sj (int): subject number cnds (list): list of conditions as stored in behavior file cnd_header (str): key in behavior file that contains condition info base_period (tuple | list): time window used for baseline correction time_period (tuple | list): time window of interest method (str): specifies whether hilbert or wavelet convolution is used for time-frequency decomposition flip (dict): flips a subset of trials. Key of dictionary specifies header in beh that contains flip info List in dict contains variables that need to be flipped. Note: flipping is done from right to left hemifield base_type (str): specifies whether DB conversion is condition specific ('conspec') or averaged across conditions ('conavg') downsample (int): factor used for downsampling (aplied after filtering). Default is no downsampling min_freq (int): minimum frequency for TF analysis max_freq (int): maximum frequency for TF analysis num_frex (int): number of frequencies in TF analysis cycle_range (tuple): number of cycles increases in the same number of steps used for scaling freq_scaling (str): specify whether frequencies are linearly or logarithmically spaced. If main results are expected in lower frequency bands logarithmic scale is adviced, whereas linear scale is advised for expected results in higher frequency bands Returns - - - wavelets(array): ''' # read in data eegs, beh, times, s_freq, ch_names = self.selectTFData(sj) # flip subset of trials (allows for lateralization indices) if flip != None: key = flip.keys()[0] eegs = self.topoFlip(eegs, beh[key], ch_names, left=[flip.get(key)]) # get parameters nr_time = eegs.shape[-1] nr_chan = eegs.shape[1] freqs = np.logspace(np.log10(min_freq), np.log10(max_freq), num_frex) nr_cycles = np.logspace(np.log10(cycle_range[0]), np.log10(cycle_range[1]), num_frex) base_s, base_e = [np.argmin(abs(times - b)) for b in base_period] idx_time = np.where( (times >= time_period[0]) * (times <= time_period[1]))[0] idx_2_save = np.array( [idx for i, idx in enumerate(idx_time) if i % downsample == 0]) # initiate dict tf = {} base = {} # loop over conditions for c, cnd in enumerate(cnds): tf.update({cnd: {}}) base.update({cnd: np.zeros((num_frex, nr_chan))}) cnd_idx = np.where(beh['block_type'] == cnd)[0] power = tfr_array_morlet(eegs[cnd_idx], sfreq=s_freq, freqs=freqs, n_cycles=nr_cycles, output='avg_power') # update cnd dict with power values tf[cnd]['power'] = np.swapaxes(power, 0, 1) tf[cnd]['base_power'] = rescale(np.swapaxes(power, 0, 1), times, base_period, mode='logratio') tf[cnd]['phase'] = '?' # save TF matrices with open( self.FolderTracker(['tf', method], '{}-tf-mne.pickle'.format(sj)), 'wb') as handle: pickle.dump(tf, handle) # store dictionary with variables for plotting plot_dict = { 'ch_names': ch_names, 'times': times[idx_2_save], 'frex': freqs } with open( self.FolderTracker(['tf', method], filename='plot_dict.pickle'), 'wb') as handle: pickle.dump(plot_dict, handle)
def main(): model_type = "lda" exp_name = "freq_gen_matrix/" for i, sample in enumerate(range(1, 22)): print("sample {}".format(sample)) if not os.path.isdir("Results/{}/{}/sample_{}".format( model_type, exp_name, sample)): os.mkdir("Results/{}/{}/sample_{}".format(model_type, exp_name, sample)) epochs = get_epochs(sample, scale=False) y_train = epochs.events[:, 2] freqs = np.logspace(*np.log10([2, 25]), num=15) n_cycles = freqs / 4. string_freqs = [round(x, 2) for x in freqs] print("applying morlet wavelet") wavelet_output = tfr_array_morlet(epochs.get_data(), sfreq=epochs.info['sfreq'], freqs=freqs, n_cycles=n_cycles, output='complex') time_results = np.zeros( (wavelet_output.shape[3], len(freqs), len(freqs))) for time in range(wavelet_output.shape[3]): print("time: {}".format(time)) wavelet_epochs = wavelet_output[:, :, :, time] wavelet_epochs = np.append(wavelet_epochs.real, wavelet_epochs.imag, axis=1) wavelet_info = mne.create_info(ch_names=wavelet_epochs.shape[1], sfreq=epochs.info['sfreq'], ch_types='mag') wavelet_epochs = mne.EpochsArray(wavelet_epochs, info=wavelet_info, events=epochs.events) x_train = pca(80, wavelet_epochs, plot=False) model = LinearModel( LinearDiscriminantAnalysis(solver='lsqr', shrinkage='auto')) freq_gen = GeneralizingEstimator(model, n_jobs=1, scoring='accuracy', verbose=True) scores = cross_val_multiscore(freq_gen, x_train, y_train, cv=5, n_jobs=1) scores = np.mean(scores, axis=0) time_results[time] = scores sns.set() ax = sns.barplot( np.sort(string_freqs), np.diag(scores), ) ax.set(ylim=(0, 0.8), xlabel='Frequencies', ylabel='Accuracy', title='Cross Val Accuracy {} for Subject {} for Time {}'. format(model_type, sample, time)) ax.axhline(0.12, color='k', linestyle='--') ax.figure.set_size_inches(8, 6) ax.figure.savefig( "Results/{}/{}/sample_{}/time_{}_accuracy.png".format( model_type, exp_name, sample, time), dpi=300) plt.close('all') # plt.show() fig, ax = plt.subplots(1, 1) im = ax.imshow(scores, interpolation='lanczos', origin='lower', cmap='RdBu_r', extent=[2, 25, 2, 25], vmin=0., vmax=0.8) ax.set_xlabel('Testing Frequency (hz)') ax.set_ylabel('Training Frequency (hz)') ax.set_title( 'Frequency generalization for Subject {} at Time {}'.format( sample, time)) plt.colorbar(im, ax=ax) ax.grid(False) ax.figure.savefig( "Results/{}/{}/sample_{}/time_{}_matrix.png".format( model_type, exp_name, sample, time), dpi=300) plt.close('all') # plt.show() time_results = time_results.reshape(time_results.shape[0], -1) all_results_df = pd.DataFrame(time_results) all_results_df.to_csv( "Results/{}/{}/sample_{}/all_time_matrix_results.csv".format( model_type, exp_name, sample))
method, pick_ori=None) for label in labels_sel: label_ts = [] for j in range(len(stcs)): ts = mne.extract_label_time_course( stcs[j], labels=label, src=src, mode="mean_flip") ts = np.squeeze(ts) # ts *= np.sign(ts[np.argmax(np.abs(ts))]) label_ts.append(ts) label_ts = np.asarray(label_ts) label_ts = label_ts[:, np.newaxis, :] tfr = tfr_array_morlet( label_ts, epochs.info["sfreq"], freqs, # use_fft=True, n_cycles=n_cycle) np.save(tf_folder + "%s_%s_%s_%s_%s_pca_snr_3-tfr" % (subject, condition[:3], condition[4:], label.name, method), tfr) np.save(tf_folder + "%s_%s_%s_%s_%s_pca_snr_3-ts" % (subject, condition[:3], condition[4:], label.name, method), label_ts) del stcs del tfr
import mne import numpy as np from my_settings import (epochs_folder, tf_folder) from mne.time_frequency import tfr_array_morlet import sys subject = sys.argv[1] freqs = np.arange(8, 13, 1) # define frequencies of interest n_cycles = 4. # freqs / 2. # different number of cycle per frequency sides = ["left", "right"] conditions = ["ctl", "ent"] epochs = mne.read_epochs( epochs_folder + "%s_trial_start-epo.fif" % subject, preload=True) epochs.resample(250) for cond in conditions: for side in sides: power = tfr_array_morlet( epochs[cond + "/" + side], sfreq=epochs.info["sfreq"], frequencies=freqs, n_cycles=n_cycles, use_fft=True, output="complex", n_jobs=1) np.save(tf_folder + "%s_%s_%s-4-complex-tfr.npy" % (subject, cond, side), power)
def main(): model_type = "lda" exp_name = "wavelet_class/lsqr/complex/15hz" for sample in range(1, 22): print("sample {}".format(sample)) if not os.path.isdir("Results/{}/{}/sample_{}".format( model_type, exp_name, sample)): os.mkdir("Results/{}/{}/sample_{}".format(model_type, exp_name, sample)) epochs = get_epochs(sample, scale=False) freqs = np.logspace(*np.log10([2, 15]), num=15) n_cycles = freqs / 4. print("applying morlet wavelet") # returns (n_epochs, n_channels, n_freqs, n_times) if exp_name.split("/")[-2] == "real" or exp_name.split( "/")[-2] == "complex": wavelet_output = tfr_array_morlet(epochs.get_data(), sfreq=epochs.info['sfreq'], freqs=freqs, n_cycles=n_cycles, output='complex') elif exp_name.split("/")[-2] == "power": wavelet_output = tfr_array_morlet(epochs.get_data(), sfreq=epochs.info['sfreq'], freqs=freqs, n_cycles=n_cycles, output='power') elif exp_name.split("/")[-2] == "phase": wavelet_output = tfr_array_morlet(epochs.get_data(), sfreq=epochs.info['sfreq'], freqs=freqs, n_cycles=n_cycles, output='phase') else: raise ValueError("{} not an output of wavelet function".format( exp_name.split("/")[-2])) y_train = get_y_train(sample) freq_results = np.zeros((wavelet_output.shape[2], 50)) for freq in range(wavelet_output.shape[2]): print("frequency: {}".format(freqs[freq])) wavelet_epochs = wavelet_output[:, :, freq, :] if exp_name.split("/")[-2] == "real": wavelet_epochs = wavelet_epochs.real if exp_name.split("/")[-2] == "complex": wavelet_epochs = np.append(wavelet_epochs.real, wavelet_epochs.imag, axis=1) wavelet_info = mne.create_info(ch_names=wavelet_epochs.shape[1], sfreq=epochs.info['sfreq'], ch_types='mag') wavelet_epochs = mne.EpochsArray(wavelet_epochs, info=wavelet_info, events=epochs.events) reduced = pca(80, wavelet_epochs, plot=False) x_train = reduced.transpose(0, 2, 1).reshape(-1, reduced.shape[1]) results = linear_models(x_train, y_train, model_type=model_type) freq_results[freq] = results curr_freq = str(round(freqs[freq], 2)) sns.set() ax = sns.lineplot(data=results, dashes=False) ax.set(ylim=(0, 1), xlabel='Time', ylabel='Accuracy', title='Cross Val Accuracy {} for Subject {} for Freq {}'. format(model_type, sample, curr_freq)) plt.axvline(x=15, color='b', linestyle='--') ax.figure.savefig("Results/{}/{}/sample_{}/freq_{}.png".format( model_type, exp_name, sample, curr_freq), dpi=300) plt.clf() all_results_df = pd.DataFrame(freq_results) all_results_df.to_csv( "Results/{}/{}/sample_{}/all_freq_results.csv".format( model_type, exp_name, sample))
fmax = 150; bandwidth = 3 tmin_crop = -0.5 tmax_crop = 1.75 fname = 'visual_channels_BP_montage.csv' ieeg_path = cf.cifar_ieeg_path(home='~') visual_chan_table_path = ieeg_path.joinpath(fname) df = pd.read_csv(visual_chan_table_path) df_sorted = pd.DataFrame(columns=df.columns) ts = [0]*len(subjects) for s in range(len(subjects)): sub_id = subjects[s] subject = cf.Subject(name=sub_id) datadir = subject.processing_stage_path(proc=proc) visual_chan = subject.pick_visual_chan() HFB = hf.visually_responsive_HFB(sub_id = sub_id) ts[s], time = fun.ts_all_categories(HFB, sfreq=sfreq, tmin_crop=tmin_crop, tmax_crop=tmax_crop) df_sorted = df_sorted.append(df.loc[df['subject_id']==sub_id].sort_values(by='latency')) df_sorted = df_sorted.reset_index(drop=True) X = np.concatenate(tuple(ts), axis=0) X = np.transpose(X, (3,2,0,1)) #%% nstate = X.shape[-1] time_freq = [0]*nstate freqs = np.arange(5., 100., 3.) for i in range(nstate): epoch_data = X[i,...] time_freq[i] = tfr_array_morlet(epoch_data, sfreq, freqs, n_cycles=freqs/2., output='complex')
def get_normalized_cwt_data(self, channels=(7, 9, 11)): ''' Applies Morlet Continuous Wavelet Transform for filter extraction on the Epoch data from a certain channel range e.g. (C3, CPz, C4), normalizes it and then returns the data in a tuple in the form of (Trial, Freq Sample, Time Sample, Channel) alongside its respective labels (i.e. (Left_MEpoch, Left Labels)). This is done to reduce the computational load during training time. NOTE: the absolute value is taken to remove imaginary components :return: (Normalized, CWT Task Data, Task Label Array) ''' if channels is None: channels = [7, 9, 11] left_filtered, right_filtered, bimanual_filtered = self.get_crop_filtered_data_from_fif( ) left_filtered, left_label = left_filtered[0], left_filtered[1] left_filtered = left_filtered[:, channels, :] right_filtered, right_label = right_filtered[0], right_filtered[1] right_filtered = right_filtered[:, channels, :] bimanual_filtered, bimanual_label = bimanual_filtered[ 0], bimanual_filtered[1] bimanual_filtered = bimanual_filtered[:, channels, :] freqs = np.logspace(*np.log10([5, 30]), num=25) n_cycles = freqs / 2. sfreq = 256 # Perform a Morlet CWT on each epoch for feature extraction Left_MEpoch = time_frequency.tfr_array_morlet(left_filtered, sfreq=sfreq, freqs=freqs, n_cycles=n_cycles, use_fft=True, decim=3, output='complex', n_jobs=1) Right_MEpoch = time_frequency.tfr_array_morlet(right_filtered, sfreq=sfreq, freqs=freqs, n_cycles=n_cycles, use_fft=True, decim=3, output='complex', n_jobs=1) Biman_MEpoch = time_frequency.tfr_array_morlet(bimanual_filtered, sfreq=sfreq, freqs=freqs, n_cycles=n_cycles, use_fft=True, decim=3, output='complex', n_jobs=1) Left_MEpoch, Right_MEpoch, Biman_MEpoch = np.abs(Left_MEpoch), np.abs( Right_MEpoch), np.abs(Biman_MEpoch) # Swap the axes to feed into GAN models later Left_MEpoch = np.swapaxes(np.swapaxes(Left_MEpoch, 1, 3), 1, 2) Right_MEpoch = np.swapaxes(np.swapaxes(Right_MEpoch, 1, 3), 1, 2) Biman_MEpoch = np.swapaxes(np.swapaxes(Biman_MEpoch, 1, 3), 1, 2) # ... then normalise the data for faster training Norm_Left_MEpoch = 2 * (Left_MEpoch - np.min(Left_MEpoch, axis=0)) / \ (np.max(Left_MEpoch, axis=0) - np.min(Left_MEpoch, axis=0)) - 1 Norm_Right_MEpoch = 2 * (Right_MEpoch - np.min(Right_MEpoch, axis=0)) / \ (np.max(Right_MEpoch, axis=0) - np.min(Right_MEpoch, axis=0)) - 1 Norm_Biman_MEpoch = 2 * (Biman_MEpoch - np.min(Biman_MEpoch, axis=0)) / \ (np.max(Biman_MEpoch, axis=0) - np.min(Biman_MEpoch, axis=0)) - 1 return (Norm_Left_MEpoch, left_label), (Norm_Right_MEpoch, right_label), (Norm_Biman_MEpoch, bimanual_label)