def test_1(): import time start = time.time() e_path = '/Users/m/data/events/RAM_FR1/R1060M_events.mat' from ptsa.data.readers import BaseEventReader base_e_reader = BaseEventReader(filename=e_path) base_events = base_e_reader.read() base_events = base_events[base_events.type == 'WORD'] # selecting only one session base_events = base_events[base_events.eegfile == base_events[0].eegfile] from ptsa.data.readers.TalReader import TalReader tal_path = '/Users/m/data/eeg/R1060M/tal/R1060M_talLocs_database_bipol.mat' tal_reader = TalReader(filename=tal_path) monopolar_channels = tal_reader.get_monopolar_channels() bipolar_pairs = tal_reader.get_bipolar_pairs() print 'bipolar_pairs=', bipolar_pairs from ptsa.data.readers.EEGReader import EEGReader time_series_reader = EEGReader(events=base_events, start_time=0.0,channels=monopolar_channels, end_time=1.6, buffer_time=1.0) base_eegs = time_series_reader.read() # base_eegs = base_eegs[:, 0:10, :] # bipolar_pairs = bipolar_pairs[0:10] from ptsa.data.filters import MorletWaveletFilter wf = MorletWaveletFilter(time_series=base_eegs, freqs=np.logspace(np.log10(3), np.log10(180), 8), # freqs=np.array([3.]), output='power', ) pow_wavelet, phase_wavelet = wf.filter() print 'total time = ', time.time() - start res_start = time.time() # from ptsa.data.filters.ResampleFilter import ResampleFilter # rsf = ResampleFilter (resamplerate=50.0) # rsf.set_input(pow_wavelet) # pow_wavelet = rsf.filter() print 'resample_time=', time.time() - res_start return pow_wavelet
def compute_phase_at_single_freq(eeg, freq, buffer_len): # compute phase phase_data = MorletWaveletFilter(eeg, np.array([freq]), output='phase', width=5, cpus=12, verbose=False).filter() # remove the buffer from each end phase_data = phase_data.remove_buffer(buffer_len) return phase_data
def compute_wavelet_at_single_freq(eeg, freq, buffer_len): # compute phase data = MorletWaveletFilter(eeg, np.array([freq]), output=['phase'], width=5, cpus=1, verbose=False).filter() # remove the buffer from each end data = data.remove_buffer(buffer_len) return data.squeeze()
def compute_wavelet_at_single_freq(eeg, freq, buffer_len): # compute phase data = MorletWaveletFilter(eeg, np.array([freq]), output=['phase'], width=5, cpus=1, verbose=False).filter() # remove the buffer from each end # data = data.remove_buffer(buffer_len) return data.squeeze()
def test_morlet(self, output_type): mwf = MorletWaveletFilter(timeseries=self.timeseries, freqs=np.array([10., 20., 40.]), width=4, output=output_type) result = mwf.filter() if len(mwf.output) == 1: if 'power' in mwf.output: assert result.data.shape == (3, 1000) if 'phase' in mwf.output: assert result.data.shape == (3, 1000) else: assert result.data.shape == (2, 3, 1000)
def _parallel_compute_power(arg_list): """ Returns a timeseries object of power values. Accepts the inputs of compute_power() as a single list. Probably don't really need to call this directly. """ events, freqs, wave_num, elec_scheme, rel_start_ms, rel_stop_ms, buf_ms, noise_freq, resample_freq, mean_over_time, \ log_power, use_mirror_buf, time_bins, eeg = arg_list # first load eeg if eeg is None: eeg = load_eeg(events, rel_start_ms, rel_stop_ms, buf_ms=buf_ms, elec_scheme=elec_scheme, noise_freq=noise_freq, resample_freq=resample_freq, use_mirror_buf=use_mirror_buf) # then compute power wave_pow = MorletWaveletFilter(eeg, freqs, output='power', width=wave_num, cpus=12, verbose=False).filter() # remove the buffer wave_pow = wave_pow.remove_buffer(buf_ms / 1000.) # are we taking the log? if log_power: data = wave_pow.data wave_pow.data = numexpr.evaluate('log10(data)') # mean over time if desired if mean_over_time: wave_pow = wave_pow.mean(dim='time') # or take the mean of each time bin, if given # create a new timeseries for each bin and the concat and add in new time dimension elif time_bins is not None: ts_list = [] time_list = [] for t in time_bins: t_inds = (wave_pow.time >= t[0]) & (wave_pow.time <= t[1]) ts_list.append(wave_pow.isel(time=t_inds).mean(dim='time')) time_list.append(wave_pow.time.data[t_inds].mean()) wave_pow = xr.concat(ts_list, dim='time') wave_pow.coords['time'] = time_list return wave_pow
def test_morlet(self): results0 = MorletWaveletFilter(self.timeseries, freqs=self.freqs, width=4, output=['power', 'phase']).filter() results1 = MorletWaveletFilter(self.timeseries.transpose(), freqs=self.freqs, width=4, output=['power', 'phase']).filter() print(results0) xr.testing.assert_allclose(results0.sel(output='power'), results1.sel(output='power')) xr.testing.assert_allclose(results0.sel(output='phase'), results1.sel(output='phase'))
def test_non_double(self): """Test that we can use a TimeSeries that starts out as a dtype other than double. """ lim = 10000 data = np.random.uniform(-lim, lim, (100, 1000)).astype(np.int16) ts = timeseries.TimeSeries(data=data, dims=("x", "time"), coords={ "x": np.linspace(0, data.shape[0], data.shape[0]), "time": np.arange(data.shape[1]), "samplerate": 1, }) mwf = MorletWaveletFilter(ts, np.array(range(70, 171, 10)), output="power") mwf.filter()
def test_wavelets_with_event_data_chopper(self): wf_session = MorletWaveletFilter( timeseries=self. session_eegs[:, :, :int(self.session_eegs.shape[2] / 4)], freqs=np.logspace(np.log10(3), np.log10(180), 8), output='power', verbose=True) pow_wavelet_session = wf_session.filter() sedc = DataChopper(events=self.base_events, timeseries=pow_wavelet_session, start_time=self.start_time, end_time=self.end_time, buffer_time=self.buffer_time) chopped_session_pow_wavelet = sedc.filter() # removing buffer chopped_session_pow_wavelet = chopped_session_pow_wavelet[:, :, :, 500:-500] wf = MorletWaveletFilter(timeseries=self.base_eegs, freqs=np.logspace(np.log10(3), np.log10(180), 8), output='power', verbose=True) pow_wavelet = wf.filter() pow_wavelet = pow_wavelet[:, :, :, 500:-500] assert_array_almost_equal( (chopped_session_pow_wavelet.data - pow_wavelet.data) / pow_wavelet.data, np.zeros_like(pow_wavelet), decimal=5)
def _parallel_compute_power(arg_list): """ Returns a timeseries object of power values. Accepts the inputs of compute_power() as a single list. Probably don't really need to call this directly. """ events, freqs, wave_num, elec_scheme, rel_start_ms, rel_stop_ms, buf_ms, noise_freq, resample_freq, mean_over_time, \ log_power, use_mirror_buf, time_bins, eeg = arg_list # first load eeg if eeg is None: eeg = load_eeg(events, rel_start_ms, rel_stop_ms, buf_ms=buf_ms, elec_scheme=elec_scheme, noise_freq=noise_freq, resample_freq=resample_freq, use_mirror_buf=use_mirror_buf) # then compute power wave_pow = MorletWaveletFilter(eeg, freqs, output='power', width=wave_num, cpus=12, verbose=False).filter() # remove the buffer wave_pow = wave_pow.remove_buffer(buf_ms / 1000.) # are we taking the log? if log_power: data = wave_pow.data wave_pow.data = numexpr.evaluate('log10(data)') # mean over time if desired if mean_over_time: wave_pow = wave_pow.mean(dim='time') # or take the mean of each time bin, if given # create a new timeseries for each bin and the concat and add in new time dimension elif time_bins is not None: # figure out window size based on sample rate window_size_s = time_bins[0, 1] - time_bins[0, 0] window_size = int(window_size_s * wave_pow.samplerate.data / 1000.) # compute moving average with window size that we want to average over (in samples) pow_move_mean = bn.move_mean(wave_pow.data, window=window_size, axis=3) # reduce to just windows that are centered on the times we want wave_pow.data = pow_move_mean wave_pow = wave_pow[:, :, :, np.searchsorted(wave_pow.time.data, time_bins[:, 1]) - 1] # set the times bins to be the new times bins (ie, the center of the bins) wave_pow['time'] = time_bins.mean(axis=1) # ts_list = [] # time_list = [] # for t in time_bins: # t_inds = (wave_pow.time >= t[0]) & (wave_pow.time <= t[1]) # ts_list.append(wave_pow.isel(time=t_inds).mean(dim='time')) # time_list.append(wave_pow.time.data[t_inds].mean()) # wave_pow = xr.concat(ts_list, dim='time') # wave_pow.coords['time'] = time_list return wave_pow
def __BOSC_tf(self): ''' Gets the time frequency matrix for events This function computes a continuous wavelet (Morlet) transform on the events of the BOSC object; this can be used to estimate the background spectrum (BOSC_bgfit) or to apply the BOSC method to detect oscillatory episodes in signal of interest (BOSC_detect). ''' from ptsa.data.filters import MorletWaveletFilter wf = MorletWaveletFilter(timeseries=self.eeg, freqs=self.freqs, width=self.width, output='power') pows = wf.filter().data # output is freqs, events, and time # inconsistent event labeling start_type, end_type = self.__get_event_keywords() list_events = self.events[np.logical_or(self.events.type == start_type, self.events.type == end_type)] while list_events.type.iloc[0] != start_type: list_events = list_events.iloc[1:] lists = list_events.list.unique() lists = lists[lists > 0] self.tfm = [] self.list_times = [] self.lists = [] for lst in lists: start = list_events[(list_events.type == start_type) & (list_events.list == lst)].eegoffset.values end = list_events[(list_events.type == end_type) & (list_events.list == lst)].eegoffset.values if (start.size != 1) | (end.size != 1): print('Bad start/end events for list {}'.format(lst)) continue # ex: list has practice or distractor but is interrupted # and never starts/ends # account for differences in samplerate - eegoffset is in samples, # so convert to time in ms (same as eeg.time) start = int(start * (1000 / self.sr)) end = int(end * (1000 / self.sr)) if start > self.eeg.time.values[-1]: print('No corresponding EEG data for list {}'.format(lst)) continue tfm = pows[:, (self.eeg.time >= start) & (self.eeg.time <= end)] if tfm == []: raise Exception('Empty powers') self.tfm.append(tfm) self.list_times.append( self.eeg[(self.eeg.time >= start) & (self.eeg.time <= end)].time.data) # only record successful lists self.lists.append(lst) self.interest_events = self.interest_events[np.isin( self.interest_events.list, self.lists)]
def raw_trace(self, freq_idx, list_idx=0, filtered=False, ax=None): """ Visualize the raw_eeg at a specific frequency, with significant oscillations highlighted. Parameters: freq_idx - index of 1D array in P_episode().freqs attribute list_idx - index of list to visualize """ plot_legend = False if ax is None: plot_legend = True ax = plt.subplot(1, 1, 1) bools = np.logical_and(self.eeg.time >= self.list_times[list_idx][0], self.eeg.time < self.list_times[list_idx][-1]) if filtered: complex_mat = MorletWaveletFilter(timeseries=self.eeg, freqs=self.freqs, width=self.width, output='complex').filter() target_signal = np.real(complex_mat)[freq_idx][bools] else: target_signal = self.eeg[bools] osc = np.copy(target_signal) # where the oscilations are # TODO: incompatible length with self.detected necessitates [:-1] osc[np.nonzero(self.detected[list_idx, freq_idx][:-1] == 0)[0]] = None time = range(len(target_signal)) / self.sr # plot the normal graph ax.plot(time, target_signal, 'k', linewidth=.25, label='EEG Time Series') # highlight the oscilations ax.plot(time, osc, 'r', linewidth=2, label='Detected Oscillations') # shade word presentation events if (self.event_type == 'WORD') & \ np.any(np.isin(self.events.type, ['WORD_OFF'])): local_events = self.events[np.logical_or( self.events.type == 'WORD', self.events.type == 'WORD_OFF')] local_events = local_events[np.logical_and( local_events.eegoffset * (1000 / self.sr) >= self.list_times[list_idx][0], local_events.eegoffset * (1000 / self.sr) <= self.list_times[list_idx][-1])] list_start = self.list_times[list_idx][0] for start, end in \ zip(local_events[local_events.type == 'WORD'].eegoffset * (1000 / self.sr), local_events[local_events.type == 'WORD_OFF'].eegoffset * (1000 / self.sr)): ax.axvspan((start - list_start) / 1000, (end - list_start) / 1000, alpha=0.2) else: local_events = self.interest_events[np.logical_and( self.interest_events.eegoffset * (1000 / self.sr) >= self.list_times[list_idx][0], self.interest_events.eegoffset * (1000 / self.sr) <= self.list_times[list_idx][-1])] list_start = self.list_times[list_idx][0] for start in local_events.eegoffset.values * (1000 / self.sr): ax.axvspan((start - list_start) / 1000, (start + self.relstop - list_start) / 1000, alpha=0.2) ax.set_ylabel(r'Voltage [$\mu V$]') ax.set_xlabel('Time [s]') ax.set_title('Frequency: {} Hz'.format(round(self.freqs[freq_idx], 2))) if plot_legend: ax.legend(['EEG Time Series', 'Detected Oscillations', 'Event'])
def power_spectra_from_spike_times(s_times, clust_nums, channel_file, rel_start_ms, rel_stop_ms, freqs, noise_freq=[58., 62.], downsample_freq=250, mean_over_spikes=True): """ Function to compute power relative to spike times. This computes power at given frequencies for the ENTIRE session and then bins it relative to spike times. You WILL run out of memory if you don't let it downsample first. Default downsample is to 250 Hz. Parameters ---------- s_times: np.ndarray Array (or list) of timestamps of when spikes occured. EEG will be loaded relative to these times. clust_nums: s_times: np.ndarray Array (or list) of cluster IDs, same size as s_times channel_file: str Path to Ncs file from which to load eeg. rel_start_ms: int Initial time (in ms), relative to the onset of each spike rel_stop_ms: int End time (in ms), relative to the onset of each spike freqs: np.ndarray array of frequencies at which to compute power noise_freq: list Stop filter will be applied to the given range. Default=[58. 62] downsample_freq: int or float Frequency to downsample the data. Use decimate, so we will likely not reach the exact frequency. mean_over_spikes: bool After computing the spike x frequency array, do we mean over spikes and return only the mean power spectra Returns ------- dict dict of either spike x frequency array of power values or just frequencies, if mean_over_spikes. Keys are cluster numbers """ # make a df with 'stTime' column for epoching events = pd.DataFrame(data=np.stack([s_times, clust_nums], -1), columns=['stTime', 'cluster_num']) # load channel data signals, timestamps, sr = load_ncs(channel_file) # downsample the session if downsample_freq is not None: signals, timestamps, sr = _my_downsample(signals, timestamps, sr, downsample_freq) else: print('I HIGHLY recommend you downsample the data before computing power across the whole session...') print('You will probably run out of memory.') # make into timeseries eeg = TimeSeries.create(signals, samplerate=sr, dims=['time'], coords={'time': timestamps / 1e6}) # filter line noise if noise_freq is not None: if isinstance(noise_freq[0], float): noise_freq = [noise_freq] for this_noise_freq in noise_freq: b_filter = ButterworthFilter(eeg, this_noise_freq, filt_type='stop', order=4) eeg = b_filter.filter() # compute power wave_pow = MorletWaveletFilter(eeg, freqs, output='power', width=5, cpus=12, verbose=False).filter() # log the power data = wave_pow.data wave_pow.data = numexpr.evaluate('log10(data)') # get start and stop relative to the spikes epochs = _compute_epochs(events, rel_start_ms, rel_stop_ms, timestamps, sr) bad_epochs = (np.any(epochs < 0, 1)) | (np.any(epochs > len(signals), 1)) epochs = epochs[~bad_epochs] events = events[~bad_epochs].reset_index(drop=True) # mean over time within epochs spikes_x_freqs = np.stack([np.mean(wave_pow.data[:, x[0]:x[1]], axis=1) for x in epochs]) # make dict with keys being cluster numbers. Mean over spikes if desired. pow_spect_dict = {} for this_cluster in events.cluster_num.unique(): if mean_over_spikes: pow_spect_dict[this_cluster] = spikes_x_freqs[events.cluster_num == this_cluster].mean(axis=0) else: pow_spect_dict[this_cluster] = spikes_x_freqs[events.cluster_num == this_cluster] return pow_spect_dict
def analysis(self): """ """ if self.subject_data is None: print('%s: compute or load data first with .load_data()!' % self.subject) # Get recalled or not labels if self.recall_filter_func is None: print('%s SME: please provide a .recall_filter_func function.' % self.subject) recalled = self.recall_filter_func(self.subject_data) # filter to electrodes in ROIs. First get broad electrode region labels region_df = self.bin_eloctrodes_into_rois() region_df['merged_col'] = region_df['hemi'] + '-' + region_df['region'] # make sure we have electrodes in each unique region for roi in self.roi_list: has_elecs = [] for label in roi: if np.any(region_df.merged_col == label): has_elecs.append(True) if ~np.any(has_elecs): print('{}: no {} electrodes, cannot compute synchrony.'.format(self.subject, roi)) return # then filter into just to ROIs defined above elecs_to_use = region_df.merged_col.isin([item for sublist in self.roi_list for item in sublist]) elec_scheme = self.elec_info.copy(deep=True) elec_scheme['ROI'] = region_df.merged_col[elecs_to_use] elec_scheme = elec_scheme[elecs_to_use].reset_index() if self.use_wavelets: phase_data = MorletWaveletFilter(self.subject_data[:, elecs_to_use], self.wavelet_freq, output='phase', width=5, cpus=12, verbose=False).filter() else: # band pass eeg phase_data = RAM_helpers.band_pass_eeg(self.subject_data[:, elecs_to_use], self.hilbert_band_pass_range) # get phase at each timepoint phase_data.data = np.angle(hilbert(phase_data.data, N=phase_data.shape[-1], axis=-1)) # remove the buffer phase_data = phase_data.remove_buffer(self.buf_ms / 1000.) # loop over each pair of ROIs for region_pair in combinations(self.roi_list, 2): elecs_region_1 = np.where(elec_scheme.ROI.isin(region_pair[0]))[0] elecs_region_2 = np.where(elec_scheme.ROI.isin(region_pair[1]))[0] elec_label_pairs = [] elec_pair_pvals = [] elec_pair_zs = [] elec_pair_rvls = [] elec_pair_pvals_rec = [] elec_pair_zs_rec = [] elec_pair_rvls_rec = [] elec_pair_pvals_nrec = [] elec_pair_zs_nrec = [] elec_pair_rvls_nrec = [] delta_mem_rayleigh_zscores = [] delta_mem_rvl_zscores = [] elec_pair_phase_diffs = [] # loop over all pairs of electrodes in the ROIs for elec_1 in elecs_region_1: for elec_2 in elecs_region_2: elec_label_pairs.append([elec_scheme.iloc[elec_1].label, elec_scheme.iloc[elec_2].label]) # and take the difference in phase values for this electrode pair elec_pair_phase_diff = pycircstat.cdiff(phase_data[:, elec_1], phase_data[:, elec_2]) if self.include_phase_diffs_in_res: elec_pair_phase_diffs.append(elec_pair_phase_diff) # compute the circular stats elec_pair_stats = calc_circ_stats(elec_pair_phase_diff, recalled, do_perm=False) elec_pair_pvals.append(elec_pair_stats['elec_pair_pval']) elec_pair_zs.append(elec_pair_stats['elec_pair_z']) elec_pair_rvls.append(elec_pair_stats['elec_pair_rvl']) elec_pair_pvals_rec.append(elec_pair_stats['elec_pair_pval_rec']) elec_pair_zs_rec.append(elec_pair_stats['elec_pair_z_rec']) elec_pair_pvals_nrec.append(elec_pair_stats['elec_pair_pval_nrec']) elec_pair_zs_nrec.append(elec_pair_stats['elec_pair_z_nrec']) elec_pair_rvls_rec.append(elec_pair_stats['elec_pair_rvl_rec']) elec_pair_rvls_nrec.append(elec_pair_stats['elec_pair_rvl_nrec']) # compute null distributions for the memory stats if self.do_perm_test: delta_mem_rayleigh_zscore, delta_mem_rvl_zscore = self.compute_null_stats(elec_pair_phase_diff, recalled, elec_pair_stats) delta_mem_rayleigh_zscores.append(delta_mem_rayleigh_zscore) delta_mem_rvl_zscores.append(delta_mem_rvl_zscore) region_pair_key = '+'.join(['-'.join(r) for r in region_pair]) self.res[region_pair_key] = {} self.res[region_pair_key]['elec_label_pairs'] = elec_label_pairs self.res[region_pair_key]['elec_pair_pvals'] = np.stack(elec_pair_pvals, 0) self.res[region_pair_key]['elec_pair_zs'] = np.stack(elec_pair_zs, 0) self.res[region_pair_key]['elec_pair_rvls'] = np.stack(elec_pair_rvls, 0) self.res[region_pair_key]['elec_pair_pvals_rec'] = np.stack(elec_pair_pvals_rec, 0) self.res[region_pair_key]['elec_pair_zs_rec'] = np.stack(elec_pair_zs_rec, 0) self.res[region_pair_key]['elec_pair_pvals_nrec'] = np.stack(elec_pair_pvals_nrec, 0) self.res[region_pair_key]['elec_pair_zs_nrec'] = np.stack(elec_pair_zs_nrec, 0) self.res[region_pair_key]['elec_pair_rvls_rec'] = np.stack(elec_pair_rvls_rec, 0) self.res[region_pair_key]['elec_pair_rvls_nrec'] = np.stack(elec_pair_rvls_nrec, 0) if self.do_perm_test: self.res[region_pair_key]['delta_mem_rayleigh_zscores'] = np.stack(delta_mem_rayleigh_zscores, 0) self.res[region_pair_key]['delta_mem_rvl_zscores'] = np.stack(delta_mem_rvl_zscores, 0) if self.include_phase_diffs_in_res: self.res[region_pair_key]['elec_pair_phase_diffs'] = np.stack(elec_pair_phase_diffs, -1) self.res[region_pair_key]['time'] = phase_data.time.data self.res[region_pair_key]['recalled'] = recalled
def analysis(self): """ Runs the phase synchrony analysis. """ if self.subject_data is None: print('%s: compute or load data first with .load_data()!' % self.subject) # Get recalled or not labels if self.recall_filter_func is None: print('%s SME: please provide a .recall_filter_func function.' % self.subject) recalled = self.recall_filter_func(self.subject_data) # filter to electrodes in ROIs. First get broad electrode region labels region_df = self.bin_eloctrodes_into_rois() region_df['merged_col'] = region_df['hemi'] + '-' + region_df['region'] # make sure we have electrodes in each unique region for roi in self.roi_list: has_elecs = [] for label in roi: if np.any(region_df.merged_col == label): has_elecs.append(True) if ~np.any(has_elecs): print('{}: no {} electrodes, cannot compute synchrony.'.format(self.subject, roi)) return # then filter into just to ROIs defined above elecs_to_use = region_df.merged_col.isin([item for sublist in self.roi_list for item in sublist]) elec_scheme = self.elec_info.copy(deep=True) elec_scheme['ROI'] = region_df.merged_col[elecs_to_use] elec_scheme = elec_scheme[elecs_to_use].reset_index() if self.use_wavelets: phase_data = MorletWaveletFilter(self.subject_data[:, elecs_to_use], self.wavelet_freq, output='phase', width=5, cpus=12, verbose=False).filter() else: # band pass eeg phase_data = ecog_helpers.band_pass_eeg(self.subject_data[:, elecs_to_use], self.hilbert_band_pass_range) # get phase at each timepoint phase_data.data = np.angle(hilbert(phase_data.data, N=phase_data.shape[-1], axis=-1)) # remove the buffer phase_data = phase_data.remove_buffer(self.buf_ms / 1000.) # loop over each pair of ROIs for region_pair in combinations(self.roi_list, 2): elecs_region_1 = np.where(elec_scheme.ROI.isin(region_pair[0]))[0] elecs_region_2 = np.where(elec_scheme.ROI.isin(region_pair[1]))[0] elec_label_pairs = [] elec_pair_pvals = [] elec_pair_zs = [] elec_pair_rvls = [] elec_pair_pvals_rec = [] elec_pair_zs_rec = [] elec_pair_rvls_rec = [] elec_pair_pvals_nrec = [] elec_pair_zs_nrec = [] elec_pair_rvls_nrec = [] delta_mem_rayleigh_zscores = [] delta_mem_rvl_zscores = [] elec_pair_phase_diffs = [] # loop over all pairs of electrodes in the ROIs for elec_1 in elecs_region_1: for elec_2 in elecs_region_2: elec_label_pairs.append([elec_scheme.iloc[elec_1].label, elec_scheme.iloc[elec_2].label]) # and take the difference in phase values for this electrode pair elec_pair_phase_diff = pycircstat.cdiff(phase_data[:, elec_1], phase_data[:, elec_2]) if self.include_phase_diffs_in_res: elec_pair_phase_diffs.append(elec_pair_phase_diff) # compute the circular stats elec_pair_stats = calc_circ_stats(elec_pair_phase_diff, recalled, do_perm=False) elec_pair_pvals.append(elec_pair_stats['elec_pair_pval']) elec_pair_zs.append(elec_pair_stats['elec_pair_z']) elec_pair_rvls.append(elec_pair_stats['elec_pair_rvl']) elec_pair_pvals_rec.append(elec_pair_stats['elec_pair_pval_rec']) elec_pair_zs_rec.append(elec_pair_stats['elec_pair_z_rec']) elec_pair_pvals_nrec.append(elec_pair_stats['elec_pair_pval_nrec']) elec_pair_zs_nrec.append(elec_pair_stats['elec_pair_z_nrec']) elec_pair_rvls_rec.append(elec_pair_stats['elec_pair_rvl_rec']) elec_pair_rvls_nrec.append(elec_pair_stats['elec_pair_rvl_nrec']) # compute null distributions for the memory stats if self.do_perm_test: delta_mem_rayleigh_zscore, delta_mem_rvl_zscore = self.compute_null_stats(elec_pair_phase_diff, recalled, elec_pair_stats) delta_mem_rayleigh_zscores.append(delta_mem_rayleigh_zscore) delta_mem_rvl_zscores.append(delta_mem_rvl_zscore) region_pair_key = '+'.join(['-'.join(r) for r in region_pair]) self.res[region_pair_key] = {} self.res[region_pair_key]['elec_label_pairs'] = elec_label_pairs self.res[region_pair_key]['elec_pair_pvals'] = np.stack(elec_pair_pvals, 0) self.res[region_pair_key]['elec_pair_zs'] = np.stack(elec_pair_zs, 0) self.res[region_pair_key]['elec_pair_rvls'] = np.stack(elec_pair_rvls, 0) self.res[region_pair_key]['elec_pair_pvals_rec'] = np.stack(elec_pair_pvals_rec, 0) self.res[region_pair_key]['elec_pair_zs_rec'] = np.stack(elec_pair_zs_rec, 0) self.res[region_pair_key]['elec_pair_pvals_nrec'] = np.stack(elec_pair_pvals_nrec, 0) self.res[region_pair_key]['elec_pair_zs_nrec'] = np.stack(elec_pair_zs_nrec, 0) self.res[region_pair_key]['elec_pair_rvls_rec'] = np.stack(elec_pair_rvls_rec, 0) self.res[region_pair_key]['elec_pair_rvls_nrec'] = np.stack(elec_pair_rvls_nrec, 0) if self.do_perm_test: self.res[region_pair_key]['delta_mem_rayleigh_zscores'] = np.stack(delta_mem_rayleigh_zscores, 0) self.res[region_pair_key]['delta_mem_rvl_zscores'] = np.stack(delta_mem_rvl_zscores, 0) if self.include_phase_diffs_in_res: self.res[region_pair_key]['elec_pair_phase_diffs'] = np.stack(elec_pair_phase_diffs, -1) self.res[region_pair_key]['time'] = phase_data.time.data self.res[region_pair_key]['recalled'] = recalled
def power_spectra_from_spike_times(s_times, clust_nums, channel_file, rel_start_ms, rel_stop_ms, freqs, noise_freq=[58., 62.], downsample_freq=250, mean_over_spikes=True): """ Function to compute power relative to spike times. This computes power at given frequencies for the ENTIRE session and then bins it relative to spike times. You WILL run out of memory if you don't let it downsample first. Default downsample is to 250 Hz. Parameters ---------- s_times: np.ndarray Array (or list) of timestamps of when spikes occured. EEG will be loaded relative to these times. clust_nums: s_times: np.ndarray Array (or list) of cluster IDs, same size as s_times channel_file: str Path to Ncs file from which to load eeg. rel_start_ms: int Initial time (in ms), relative to the onset of each spike rel_stop_ms: int End time (in ms), relative to the onset of each spike freqs: np.ndarray array of frequencies at which to compute power noise_freq: list Stop filter will be applied to the given range. Default=[58. 62] downsample_freq: int or float Frequency to downsample the data. Use decimate, so we will likely not reach the exact frequency. mean_over_spikes: bool After computing the spike x frequency array, do we mean over spikes and return only the mean power spectra Returns ------- dict dict of either spike x frequency array of power values or just frequencies, if mean_over_spikes. Keys are cluster numbers """ # make a df with 'stTime' column for epoching events = pd.DataFrame(data=np.stack([s_times, clust_nums], -1), columns=['stTime', 'cluster_num']) # load channel data signals, timestamps, sr = load_ncs(channel_file) # downsample the session if downsample_freq is not None: signals, timestamps, sr = _my_downsample(signals, timestamps, sr, downsample_freq) else: print( 'I HIGHLY recommend you downsample the data before computing power across the whole session...' ) print('You will probably run out of memory.') # make into timeseries eeg = TimeSeries.create(signals, samplerate=sr, dims=['time'], coords={'time': timestamps / 1e6}) # filter line noise if noise_freq is not None: if isinstance(noise_freq[0], float): noise_freq = [noise_freq] for this_noise_freq in noise_freq: b_filter = ButterworthFilter(eeg, this_noise_freq, filt_type='stop', order=4) eeg = b_filter.filter() # compute power wave_pow = MorletWaveletFilter(eeg, freqs, output='power', width=5, cpus=12, verbose=False).filter() # log the power data = wave_pow.data wave_pow.data = numexpr.evaluate('log10(data)') # get start and stop relative to the spikes epochs = _compute_epochs(events, rel_start_ms, rel_stop_ms, timestamps, sr) bad_epochs = (np.any(epochs < 0, 1)) | (np.any(epochs > len(signals), 1)) epochs = epochs[~bad_epochs] events = events[~bad_epochs].reset_index(drop=True) # mean over time within epochs spikes_x_freqs = np.stack( [np.mean(wave_pow.data[:, x[0]:x[1]], axis=1) for x in epochs]) # make dict with keys being cluster numbers. Mean over spikes if desired. pow_spect_dict = {} for this_cluster in events.cluster_num.unique(): if mean_over_spikes: pow_spect_dict[this_cluster] = spikes_x_freqs[ events.cluster_num == this_cluster].mean(axis=0) else: pow_spect_dict[this_cluster] = spikes_x_freqs[events.cluster_num == this_cluster] return pow_spect_dict