コード例 #1
0
def test_1():
    import time
    start = time.time()

    e_path = '/Users/m/data/events/RAM_FR1/R1060M_events.mat'

    from ptsa.data.readers import BaseEventReader

    base_e_reader = BaseEventReader(filename=e_path)

    base_events = base_e_reader.read()

    base_events = base_events[base_events.type == 'WORD']

    # selecting only one session
    base_events = base_events[base_events.eegfile == base_events[0].eegfile]

    from ptsa.data.readers.TalReader import TalReader
    tal_path = '/Users/m/data/eeg/R1060M/tal/R1060M_talLocs_database_bipol.mat'
    tal_reader = TalReader(filename=tal_path)
    monopolar_channels = tal_reader.get_monopolar_channels()
    bipolar_pairs = tal_reader.get_bipolar_pairs()

    print 'bipolar_pairs=', bipolar_pairs

    from ptsa.data.readers.EEGReader import EEGReader

    time_series_reader = EEGReader(events=base_events, start_time=0.0,channels=monopolar_channels, end_time=1.6, buffer_time=1.0)

    base_eegs = time_series_reader.read()

    # base_eegs = base_eegs[:, 0:10, :]
    # bipolar_pairs = bipolar_pairs[0:10]

    from ptsa.data.filters import MorletWaveletFilter
    wf = MorletWaveletFilter(time_series=base_eegs,
                             freqs=np.logspace(np.log10(3), np.log10(180), 8),
                             # freqs=np.array([3.]),
                             output='power',
                             )

    pow_wavelet, phase_wavelet = wf.filter()

    print 'total time = ', time.time() - start

    res_start = time.time()

    # from ptsa.data.filters.ResampleFilter import ResampleFilter
    # rsf = ResampleFilter (resamplerate=50.0)
    # rsf.set_input(pow_wavelet)
    # pow_wavelet = rsf.filter()



    print 'resample_time=', time.time() - res_start
    return pow_wavelet
コード例 #2
0
def compute_phase_at_single_freq(eeg, freq, buffer_len):

    # compute phase
    phase_data = MorletWaveletFilter(eeg,
                                     np.array([freq]),
                                     output='phase',
                                     width=5,
                                     cpus=12,
                                     verbose=False).filter()

    # remove the buffer from each end
    phase_data = phase_data.remove_buffer(buffer_len)
    return phase_data
コード例 #3
0
def compute_phase_at_single_freq(eeg, freq, buffer_len):

    # compute phase
    phase_data = MorletWaveletFilter(eeg,
                                     np.array([freq]),
                                     output='phase',
                                     width=5,
                                     cpus=12,
                                     verbose=False).filter()

    # remove the buffer from each end
    phase_data = phase_data.remove_buffer(buffer_len)
    return phase_data
def compute_wavelet_at_single_freq(eeg, freq, buffer_len):

    # compute phase
    data = MorletWaveletFilter(eeg,
                               np.array([freq]),
                               output=['phase'],
                               width=5,
                               cpus=1,
                               verbose=False).filter()

    # remove the buffer from each end
    data = data.remove_buffer(buffer_len)
    return data.squeeze()
def compute_wavelet_at_single_freq(eeg, freq, buffer_len):

    # compute phase
    data = MorletWaveletFilter(eeg,
                               np.array([freq]),
                               output=['phase'],
                               width=5,
                               cpus=1,
                               verbose=False).filter()

    # remove the buffer from each end
    #     data = data.remove_buffer(buffer_len)
    return data.squeeze()
コード例 #6
0
    def test_morlet(self, output_type):
        mwf = MorletWaveletFilter(timeseries=self.timeseries,
                                  freqs=np.array([10., 20., 40.]),
                                  width=4,
                                  output=output_type)
        result = mwf.filter()

        if len(mwf.output) == 1:
            if 'power' in mwf.output:
                assert result.data.shape == (3, 1000)
            if 'phase' in mwf.output:
                assert result.data.shape == (3, 1000)
        else:
            assert result.data.shape == (2, 3, 1000)
コード例 #7
0
def _parallel_compute_power(arg_list):
    """
    Returns a timeseries object of power values. Accepts the inputs of compute_power() as a single list. Probably
    don't really need to call this directly.
    """

    events, freqs, wave_num, elec_scheme, rel_start_ms, rel_stop_ms, buf_ms, noise_freq, resample_freq, mean_over_time, \
    log_power, use_mirror_buf, time_bins, eeg = arg_list

    # first load eeg
    if eeg is None:
        eeg = load_eeg(events,
                       rel_start_ms,
                       rel_stop_ms,
                       buf_ms=buf_ms,
                       elec_scheme=elec_scheme,
                       noise_freq=noise_freq,
                       resample_freq=resample_freq,
                       use_mirror_buf=use_mirror_buf)

    # then compute power
    wave_pow = MorletWaveletFilter(eeg,
                                   freqs,
                                   output='power',
                                   width=wave_num,
                                   cpus=12,
                                   verbose=False).filter()

    # remove the buffer
    wave_pow = wave_pow.remove_buffer(buf_ms / 1000.)

    # are we taking the log?
    if log_power:
        data = wave_pow.data
        wave_pow.data = numexpr.evaluate('log10(data)')

    # mean over time if desired
    if mean_over_time:
        wave_pow = wave_pow.mean(dim='time')

    # or take the mean of each time bin, if given
    # create a new timeseries for each bin and the concat and add in new time dimension
    elif time_bins is not None:
        ts_list = []
        time_list = []
        for t in time_bins:
            t_inds = (wave_pow.time >= t[0]) & (wave_pow.time <= t[1])
            ts_list.append(wave_pow.isel(time=t_inds).mean(dim='time'))
            time_list.append(wave_pow.time.data[t_inds].mean())
        wave_pow = xr.concat(ts_list, dim='time')
        wave_pow.coords['time'] = time_list

    return wave_pow
コード例 #8
0
    def test_morlet(self):
        results0 = MorletWaveletFilter(self.timeseries,
                                       freqs=self.freqs,
                                       width=4,
                                       output=['power', 'phase']).filter()

        results1 = MorletWaveletFilter(self.timeseries.transpose(),
                                       freqs=self.freqs,
                                       width=4,
                                       output=['power', 'phase']).filter()
        print(results0)

        xr.testing.assert_allclose(results0.sel(output='power'),
                                   results1.sel(output='power'))
        xr.testing.assert_allclose(results0.sel(output='phase'),
                                   results1.sel(output='phase'))
コード例 #9
0
    def test_non_double(self):
        """Test that we can use a TimeSeries that starts out as a dtype other
        than double.

        """
        lim = 10000
        data = np.random.uniform(-lim, lim, (100, 1000)).astype(np.int16)

        ts = timeseries.TimeSeries(data=data,
                                   dims=("x", "time"),
                                   coords={
                                       "x":
                                       np.linspace(0, data.shape[0],
                                                   data.shape[0]),
                                       "time":
                                       np.arange(data.shape[1]),
                                       "samplerate":
                                       1,
                                   })

        mwf = MorletWaveletFilter(ts,
                                  np.array(range(70, 171, 10)),
                                  output="power")
        mwf.filter()
コード例 #10
0
    def test_wavelets_with_event_data_chopper(self):
        wf_session = MorletWaveletFilter(
            timeseries=self.
            session_eegs[:, :, :int(self.session_eegs.shape[2] / 4)],
            freqs=np.logspace(np.log10(3), np.log10(180), 8),
            output='power',
            verbose=True)

        pow_wavelet_session = wf_session.filter()

        sedc = DataChopper(events=self.base_events,
                           timeseries=pow_wavelet_session,
                           start_time=self.start_time,
                           end_time=self.end_time,
                           buffer_time=self.buffer_time)
        chopped_session_pow_wavelet = sedc.filter()

        # removing buffer
        chopped_session_pow_wavelet = chopped_session_pow_wavelet[:, :, :,
                                                                  500:-500]

        wf = MorletWaveletFilter(timeseries=self.base_eegs,
                                 freqs=np.logspace(np.log10(3), np.log10(180),
                                                   8),
                                 output='power',
                                 verbose=True)

        pow_wavelet = wf.filter()

        pow_wavelet = pow_wavelet[:, :, :, 500:-500]

        assert_array_almost_equal(
            (chopped_session_pow_wavelet.data - pow_wavelet.data) /
            pow_wavelet.data,
            np.zeros_like(pow_wavelet),
            decimal=5)
コード例 #11
0
def _parallel_compute_power(arg_list):
    """
    Returns a timeseries object of power values. Accepts the inputs of compute_power() as a single list. Probably
    don't really need to call this directly.
    """

    events, freqs, wave_num, elec_scheme, rel_start_ms, rel_stop_ms, buf_ms, noise_freq, resample_freq, mean_over_time, \
    log_power, use_mirror_buf, time_bins, eeg = arg_list

    # first load eeg
    if eeg is None:
        eeg = load_eeg(events,
                       rel_start_ms,
                       rel_stop_ms,
                       buf_ms=buf_ms,
                       elec_scheme=elec_scheme,
                       noise_freq=noise_freq,
                       resample_freq=resample_freq,
                       use_mirror_buf=use_mirror_buf)

    # then compute power
    wave_pow = MorletWaveletFilter(eeg,
                                   freqs,
                                   output='power',
                                   width=wave_num,
                                   cpus=12,
                                   verbose=False).filter()

    # remove the buffer
    wave_pow = wave_pow.remove_buffer(buf_ms / 1000.)

    # are we taking the log?
    if log_power:
        data = wave_pow.data
        wave_pow.data = numexpr.evaluate('log10(data)')

    # mean over time if desired
    if mean_over_time:
        wave_pow = wave_pow.mean(dim='time')

    # or take the mean of each time bin, if given
    # create a new timeseries for each bin and the concat and add in new time dimension
    elif time_bins is not None:

        # figure out window size based on sample rate
        window_size_s = time_bins[0, 1] - time_bins[0, 0]
        window_size = int(window_size_s * wave_pow.samplerate.data / 1000.)

        # compute moving average with window size that we want to average over (in samples)
        pow_move_mean = bn.move_mean(wave_pow.data, window=window_size, axis=3)

        # reduce to just windows that are centered on the times we want
        wave_pow.data = pow_move_mean
        wave_pow = wave_pow[:, :, :,
                            np.searchsorted(wave_pow.time.data, time_bins[:,
                                                                          1]) -
                            1]

        # set the times bins to be the new times bins (ie, the center of the bins)
        wave_pow['time'] = time_bins.mean(axis=1)

        # ts_list = []
        # time_list = []
        # for t in time_bins:
        #     t_inds = (wave_pow.time >= t[0]) & (wave_pow.time <= t[1])
        #     ts_list.append(wave_pow.isel(time=t_inds).mean(dim='time'))
        #     time_list.append(wave_pow.time.data[t_inds].mean())
        # wave_pow = xr.concat(ts_list, dim='time')
        # wave_pow.coords['time'] = time_list

    return wave_pow
コード例 #12
0
ファイル: P_episode.py プロジェクト: jrudoler/cmlcode
    def __BOSC_tf(self):
        '''
        Gets the time frequency matrix for events

        This function computes a continuous wavelet (Morlet) transform on
        the events of the BOSC object; this can be used to estimate the
        background spectrum (BOSC_bgfit) or to apply the BOSC method to
        detect oscillatory episodes in signal of interest (BOSC_detect).
        '''

        from ptsa.data.filters import MorletWaveletFilter

        wf = MorletWaveletFilter(timeseries=self.eeg,
                                 freqs=self.freqs,
                                 width=self.width,
                                 output='power')
        pows = wf.filter().data  # output is freqs, events, and time

        # inconsistent event labeling
        start_type, end_type = self.__get_event_keywords()

        list_events = self.events[np.logical_or(self.events.type == start_type,
                                                self.events.type == end_type)]
        while list_events.type.iloc[0] != start_type:
            list_events = list_events.iloc[1:]
        lists = list_events.list.unique()
        lists = lists[lists > 0]
        self.tfm = []
        self.list_times = []
        self.lists = []
        for lst in lists:
            start = list_events[(list_events.type == start_type)
                                & (list_events.list == lst)].eegoffset.values
            end = list_events[(list_events.type == end_type)
                              & (list_events.list == lst)].eegoffset.values

            if (start.size != 1) | (end.size != 1):
                print('Bad start/end events for list {}'.format(lst))
                continue
                # ex: list has practice or distractor but is interrupted
                # and never starts/ends

            # account for differences in samplerate - eegoffset is in samples,
            # so convert to time in ms (same as eeg.time)
            start = int(start * (1000 / self.sr))
            end = int(end * (1000 / self.sr))

            if start > self.eeg.time.values[-1]:
                print('No corresponding EEG data for list {}'.format(lst))
                continue
            tfm = pows[:, (self.eeg.time >= start) & (self.eeg.time <= end)]
            if tfm == []:
                raise Exception('Empty powers')
            self.tfm.append(tfm)
            self.list_times.append(
                self.eeg[(self.eeg.time >= start)
                         & (self.eeg.time <= end)].time.data)
            # only record successful lists
            self.lists.append(lst)
        self.interest_events = self.interest_events[np.isin(
            self.interest_events.list, self.lists)]
コード例 #13
0
ファイル: P_episode.py プロジェクト: jrudoler/cmlcode
    def raw_trace(self, freq_idx, list_idx=0, filtered=False, ax=None):
        """
        Visualize the raw_eeg at a specific frequency, with
            significant oscillations highlighted.

        Parameters:
        freq_idx - index of 1D array in P_episode().freqs attribute
        list_idx - index of list to visualize

        """
        plot_legend = False
        if ax is None:
            plot_legend = True
            ax = plt.subplot(1, 1, 1)

        bools = np.logical_and(self.eeg.time >= self.list_times[list_idx][0],
                               self.eeg.time < self.list_times[list_idx][-1])
        if filtered:
            complex_mat = MorletWaveletFilter(timeseries=self.eeg,
                                              freqs=self.freqs,
                                              width=self.width,
                                              output='complex').filter()
            target_signal = np.real(complex_mat)[freq_idx][bools]
        else:
            target_signal = self.eeg[bools]
        osc = np.copy(target_signal)
        # where the oscilations are
        # TODO: incompatible length with self.detected necessitates [:-1]
        osc[np.nonzero(self.detected[list_idx, freq_idx][:-1] == 0)[0]] = None
        time = range(len(target_signal)) / self.sr
        # plot the normal graph
        ax.plot(time,
                target_signal,
                'k',
                linewidth=.25,
                label='EEG Time Series')
        # highlight the oscilations
        ax.plot(time, osc, 'r', linewidth=2, label='Detected Oscillations')
        # shade word presentation events
        if (self.event_type == 'WORD') & \
                np.any(np.isin(self.events.type, ['WORD_OFF'])):
            local_events = self.events[np.logical_or(
                self.events.type == 'WORD', self.events.type == 'WORD_OFF')]
            local_events = local_events[np.logical_and(
                local_events.eegoffset *
                (1000 / self.sr) >= self.list_times[list_idx][0],
                local_events.eegoffset *
                (1000 / self.sr) <= self.list_times[list_idx][-1])]
            list_start = self.list_times[list_idx][0]
            for start, end in \
             zip(local_events[local_events.type == 'WORD'].eegoffset * (1000 / self.sr),
              local_events[local_events.type == 'WORD_OFF'].eegoffset * (1000 / self.sr)):
                ax.axvspan((start - list_start) / 1000,
                           (end - list_start) / 1000,
                           alpha=0.2)
        else:
            local_events = self.interest_events[np.logical_and(
                self.interest_events.eegoffset *
                (1000 / self.sr) >= self.list_times[list_idx][0],
                self.interest_events.eegoffset *
                (1000 / self.sr) <= self.list_times[list_idx][-1])]
            list_start = self.list_times[list_idx][0]
            for start in local_events.eegoffset.values * (1000 / self.sr):
                ax.axvspan((start - list_start) / 1000,
                           (start + self.relstop - list_start) / 1000,
                           alpha=0.2)
        ax.set_ylabel(r'Voltage [$\mu V$]')
        ax.set_xlabel('Time [s]')
        ax.set_title('Frequency: {} Hz'.format(round(self.freqs[freq_idx], 2)))
        if plot_legend:
            ax.legend(['EEG Time Series', 'Detected Oscillations', 'Event'])
コード例 #14
0
def power_spectra_from_spike_times(s_times, clust_nums, channel_file, rel_start_ms, rel_stop_ms, freqs,
                                           noise_freq=[58., 62.], downsample_freq=250, mean_over_spikes=True):
    """
    Function to compute power relative to spike times. This computes power at given frequencies for the ENTIRE session
    and then bins it relative to spike times. You WILL run out of memory if you don't let it downsample first. Default
    downsample is to 250 Hz.

    Parameters
    ----------
    s_times: np.ndarray
        Array (or list) of timestamps of when spikes occured. EEG will be loaded relative to these times.
    clust_nums:
        s_times: np.ndarray
        Array (or list) of cluster IDs, same size as s_times
    channel_file: str
        Path to Ncs file from which to load eeg.
    rel_start_ms: int
        Initial time (in ms), relative to the onset of each spike
    rel_stop_ms: int
        End time (in ms), relative to the onset of each spike
    freqs: np.ndarray
        array of frequencies at which to compute power
    noise_freq: list
        Stop filter will be applied to the given range. Default=[58. 62]
    downsample_freq: int or float
        Frequency to downsample the data. Use decimate, so we will likely not reach the exact frequency.
    mean_over_spikes: bool
        After computing the spike x frequency array, do we mean over spikes and return only the mean power spectra

    Returns
    -------
    dict
        dict of either spike x frequency array of power values or just frequencies, if mean_over_spikes. Keys are
        cluster numbers
    """

    # make a df with 'stTime' column for epoching
    events = pd.DataFrame(data=np.stack([s_times, clust_nums], -1), columns=['stTime', 'cluster_num'])

    # load channel data
    signals, timestamps, sr = load_ncs(channel_file)

    # downsample the session
    if downsample_freq is not None:
        signals, timestamps, sr = _my_downsample(signals, timestamps, sr, downsample_freq)
    else:
        print('I HIGHLY recommend you downsample the data before computing power across the whole session...')
        print('You will probably run out of memory.')

    # make into timeseries
    eeg = TimeSeries.create(signals, samplerate=sr, dims=['time'], coords={'time': timestamps / 1e6})

    # filter line noise
    if noise_freq is not None:
        if isinstance(noise_freq[0], float):
            noise_freq = [noise_freq]
        for this_noise_freq in noise_freq:
            b_filter = ButterworthFilter(eeg, this_noise_freq, filt_type='stop', order=4)
            eeg = b_filter.filter()

    # compute power
    wave_pow = MorletWaveletFilter(eeg, freqs, output='power', width=5, cpus=12, verbose=False).filter()

    # log the power
    data = wave_pow.data
    wave_pow.data = numexpr.evaluate('log10(data)')

    # get start and stop relative to the spikes
    epochs = _compute_epochs(events, rel_start_ms, rel_stop_ms, timestamps, sr)
    bad_epochs = (np.any(epochs < 0, 1)) | (np.any(epochs > len(signals), 1))
    epochs = epochs[~bad_epochs]
    events = events[~bad_epochs].reset_index(drop=True)

    # mean over time within epochs
    spikes_x_freqs = np.stack([np.mean(wave_pow.data[:, x[0]:x[1]], axis=1) for x in epochs])

    # make dict with keys being cluster numbers. Mean over spikes if desired.
    pow_spect_dict = {}
    for this_cluster in events.cluster_num.unique():
        if mean_over_spikes:
            pow_spect_dict[this_cluster] = spikes_x_freqs[events.cluster_num == this_cluster].mean(axis=0)
        else:
            pow_spect_dict[this_cluster] = spikes_x_freqs[events.cluster_num == this_cluster]

    return pow_spect_dict
コード例 #15
0
    def analysis(self):
        """

        """

        if self.subject_data is None:
            print('%s: compute or load data first with .load_data()!' % self.subject)

        # Get recalled or not labels
        if self.recall_filter_func is None:
            print('%s SME: please provide a .recall_filter_func function.' % self.subject)
        recalled = self.recall_filter_func(self.subject_data)

        # filter to electrodes in ROIs. First get broad electrode region labels
        region_df = self.bin_eloctrodes_into_rois()
        region_df['merged_col'] = region_df['hemi'] + '-' + region_df['region']

        # make sure we have electrodes in each unique region
        for roi in self.roi_list:
            has_elecs = []
            for label in roi:
                if np.any(region_df.merged_col == label):
                    has_elecs.append(True)
            if ~np.any(has_elecs):
                print('{}: no {} electrodes, cannot compute synchrony.'.format(self.subject, roi))
                return

        # then filter into just to ROIs defined above
        elecs_to_use = region_df.merged_col.isin([item for sublist in self.roi_list for item in sublist])
        elec_scheme = self.elec_info.copy(deep=True)
        elec_scheme['ROI'] = region_df.merged_col[elecs_to_use]
        elec_scheme = elec_scheme[elecs_to_use].reset_index()

        if self.use_wavelets:
            phase_data = MorletWaveletFilter(self.subject_data[:, elecs_to_use], self.wavelet_freq,
                                             output='phase', width=5, cpus=12,
                                             verbose=False).filter()
        else:
            # band pass eeg
            phase_data = RAM_helpers.band_pass_eeg(self.subject_data[:, elecs_to_use], self.hilbert_band_pass_range)

            # get phase at each timepoint
            phase_data.data = np.angle(hilbert(phase_data.data, N=phase_data.shape[-1], axis=-1))

        # remove the buffer
        phase_data = phase_data.remove_buffer(self.buf_ms / 1000.)

        # loop over each pair of ROIs
        for region_pair in combinations(self.roi_list, 2):
            elecs_region_1 = np.where(elec_scheme.ROI.isin(region_pair[0]))[0]
            elecs_region_2 = np.where(elec_scheme.ROI.isin(region_pair[1]))[0]

            elec_label_pairs = []
            elec_pair_pvals = []
            elec_pair_zs = []
            elec_pair_rvls = []
            elec_pair_pvals_rec = []
            elec_pair_zs_rec = []
            elec_pair_rvls_rec = []
            elec_pair_pvals_nrec = []
            elec_pair_zs_nrec = []
            elec_pair_rvls_nrec = []
            delta_mem_rayleigh_zscores = []
            delta_mem_rvl_zscores = []

            elec_pair_phase_diffs = []

            # loop over all pairs of electrodes in the ROIs
            for elec_1 in elecs_region_1:
                for elec_2 in elecs_region_2:
                    elec_label_pairs.append([elec_scheme.iloc[elec_1].label, elec_scheme.iloc[elec_2].label])

                    # and take the difference in phase values for this electrode pair
                    elec_pair_phase_diff = pycircstat.cdiff(phase_data[:, elec_1], phase_data[:, elec_2])
                    if self.include_phase_diffs_in_res:
                        elec_pair_phase_diffs.append(elec_pair_phase_diff)

                    # compute the circular stats
                    elec_pair_stats = calc_circ_stats(elec_pair_phase_diff, recalled, do_perm=False)
                    elec_pair_pvals.append(elec_pair_stats['elec_pair_pval'])
                    elec_pair_zs.append(elec_pair_stats['elec_pair_z'])
                    elec_pair_rvls.append(elec_pair_stats['elec_pair_rvl'])
                    elec_pair_pvals_rec.append(elec_pair_stats['elec_pair_pval_rec'])
                    elec_pair_zs_rec.append(elec_pair_stats['elec_pair_z_rec'])
                    elec_pair_pvals_nrec.append(elec_pair_stats['elec_pair_pval_nrec'])
                    elec_pair_zs_nrec.append(elec_pair_stats['elec_pair_z_nrec'])
                    elec_pair_rvls_rec.append(elec_pair_stats['elec_pair_rvl_rec'])
                    elec_pair_rvls_nrec.append(elec_pair_stats['elec_pair_rvl_nrec'])

                    # compute null distributions for the memory stats
                    if self.do_perm_test:
                        delta_mem_rayleigh_zscore, delta_mem_rvl_zscore = self.compute_null_stats(elec_pair_phase_diff,
                                                                                                  recalled,
                                                                                                  elec_pair_stats)
                        delta_mem_rayleigh_zscores.append(delta_mem_rayleigh_zscore)
                        delta_mem_rvl_zscores.append(delta_mem_rvl_zscore)

            region_pair_key = '+'.join(['-'.join(r) for r in region_pair])
            self.res[region_pair_key] = {}
            self.res[region_pair_key]['elec_label_pairs'] = elec_label_pairs
            self.res[region_pair_key]['elec_pair_pvals'] = np.stack(elec_pair_pvals, 0)
            self.res[region_pair_key]['elec_pair_zs'] = np.stack(elec_pair_zs, 0)
            self.res[region_pair_key]['elec_pair_rvls'] = np.stack(elec_pair_rvls, 0)
            self.res[region_pair_key]['elec_pair_pvals_rec'] = np.stack(elec_pair_pvals_rec, 0)
            self.res[region_pair_key]['elec_pair_zs_rec'] = np.stack(elec_pair_zs_rec, 0)
            self.res[region_pair_key]['elec_pair_pvals_nrec'] = np.stack(elec_pair_pvals_nrec, 0)
            self.res[region_pair_key]['elec_pair_zs_nrec'] = np.stack(elec_pair_zs_nrec, 0)
            self.res[region_pair_key]['elec_pair_rvls_rec'] = np.stack(elec_pair_rvls_rec, 0)
            self.res[region_pair_key]['elec_pair_rvls_nrec'] = np.stack(elec_pair_rvls_nrec, 0)
            if self.do_perm_test:
                self.res[region_pair_key]['delta_mem_rayleigh_zscores'] = np.stack(delta_mem_rayleigh_zscores, 0)
                self.res[region_pair_key]['delta_mem_rvl_zscores'] = np.stack(delta_mem_rvl_zscores, 0)
            if self.include_phase_diffs_in_res:
                self.res[region_pair_key]['elec_pair_phase_diffs'] = np.stack(elec_pair_phase_diffs, -1)
            self.res[region_pair_key]['time'] = phase_data.time.data
            self.res[region_pair_key]['recalled'] = recalled
コード例 #16
0
    def analysis(self):
        """
        Runs the phase synchrony analysis.
        """

        if self.subject_data is None:
            print('%s: compute or load data first with .load_data()!' % self.subject)

        # Get recalled or not labels
        if self.recall_filter_func is None:
            print('%s SME: please provide a .recall_filter_func function.' % self.subject)
        recalled = self.recall_filter_func(self.subject_data)

        # filter to electrodes in ROIs. First get broad electrode region labels
        region_df = self.bin_eloctrodes_into_rois()
        region_df['merged_col'] = region_df['hemi'] + '-' + region_df['region']

        # make sure we have electrodes in each unique region
        for roi in self.roi_list:
            has_elecs = []
            for label in roi:
                if np.any(region_df.merged_col == label):
                    has_elecs.append(True)
            if ~np.any(has_elecs):
                print('{}: no {} electrodes, cannot compute synchrony.'.format(self.subject, roi))
                return

        # then filter into just to ROIs defined above
        elecs_to_use = region_df.merged_col.isin([item for sublist in self.roi_list for item in sublist])
        elec_scheme = self.elec_info.copy(deep=True)
        elec_scheme['ROI'] = region_df.merged_col[elecs_to_use]
        elec_scheme = elec_scheme[elecs_to_use].reset_index()

        if self.use_wavelets:
            phase_data = MorletWaveletFilter(self.subject_data[:, elecs_to_use], self.wavelet_freq,
                                             output='phase', width=5, cpus=12,
                                             verbose=False).filter()
        else:
            # band pass eeg
            phase_data = ecog_helpers.band_pass_eeg(self.subject_data[:, elecs_to_use], self.hilbert_band_pass_range)

            # get phase at each timepoint
            phase_data.data = np.angle(hilbert(phase_data.data, N=phase_data.shape[-1], axis=-1))

        # remove the buffer
        phase_data = phase_data.remove_buffer(self.buf_ms / 1000.)

        # loop over each pair of ROIs
        for region_pair in combinations(self.roi_list, 2):
            elecs_region_1 = np.where(elec_scheme.ROI.isin(region_pair[0]))[0]
            elecs_region_2 = np.where(elec_scheme.ROI.isin(region_pair[1]))[0]

            elec_label_pairs = []
            elec_pair_pvals = []
            elec_pair_zs = []
            elec_pair_rvls = []
            elec_pair_pvals_rec = []
            elec_pair_zs_rec = []
            elec_pair_rvls_rec = []
            elec_pair_pvals_nrec = []
            elec_pair_zs_nrec = []
            elec_pair_rvls_nrec = []
            delta_mem_rayleigh_zscores = []
            delta_mem_rvl_zscores = []

            elec_pair_phase_diffs = []

            # loop over all pairs of electrodes in the ROIs
            for elec_1 in elecs_region_1:
                for elec_2 in elecs_region_2:
                    elec_label_pairs.append([elec_scheme.iloc[elec_1].label, elec_scheme.iloc[elec_2].label])

                    # and take the difference in phase values for this electrode pair
                    elec_pair_phase_diff = pycircstat.cdiff(phase_data[:, elec_1], phase_data[:, elec_2])
                    if self.include_phase_diffs_in_res:
                        elec_pair_phase_diffs.append(elec_pair_phase_diff)

                    # compute the circular stats
                    elec_pair_stats = calc_circ_stats(elec_pair_phase_diff, recalled, do_perm=False)
                    elec_pair_pvals.append(elec_pair_stats['elec_pair_pval'])
                    elec_pair_zs.append(elec_pair_stats['elec_pair_z'])
                    elec_pair_rvls.append(elec_pair_stats['elec_pair_rvl'])
                    elec_pair_pvals_rec.append(elec_pair_stats['elec_pair_pval_rec'])
                    elec_pair_zs_rec.append(elec_pair_stats['elec_pair_z_rec'])
                    elec_pair_pvals_nrec.append(elec_pair_stats['elec_pair_pval_nrec'])
                    elec_pair_zs_nrec.append(elec_pair_stats['elec_pair_z_nrec'])
                    elec_pair_rvls_rec.append(elec_pair_stats['elec_pair_rvl_rec'])
                    elec_pair_rvls_nrec.append(elec_pair_stats['elec_pair_rvl_nrec'])

                    # compute null distributions for the memory stats
                    if self.do_perm_test:
                        delta_mem_rayleigh_zscore, delta_mem_rvl_zscore = self.compute_null_stats(elec_pair_phase_diff,
                                                                                                  recalled,
                                                                                                  elec_pair_stats)
                        delta_mem_rayleigh_zscores.append(delta_mem_rayleigh_zscore)
                        delta_mem_rvl_zscores.append(delta_mem_rvl_zscore)

            region_pair_key = '+'.join(['-'.join(r) for r in region_pair])
            self.res[region_pair_key] = {}
            self.res[region_pair_key]['elec_label_pairs'] = elec_label_pairs
            self.res[region_pair_key]['elec_pair_pvals'] = np.stack(elec_pair_pvals, 0)
            self.res[region_pair_key]['elec_pair_zs'] = np.stack(elec_pair_zs, 0)
            self.res[region_pair_key]['elec_pair_rvls'] = np.stack(elec_pair_rvls, 0)
            self.res[region_pair_key]['elec_pair_pvals_rec'] = np.stack(elec_pair_pvals_rec, 0)
            self.res[region_pair_key]['elec_pair_zs_rec'] = np.stack(elec_pair_zs_rec, 0)
            self.res[region_pair_key]['elec_pair_pvals_nrec'] = np.stack(elec_pair_pvals_nrec, 0)
            self.res[region_pair_key]['elec_pair_zs_nrec'] = np.stack(elec_pair_zs_nrec, 0)
            self.res[region_pair_key]['elec_pair_rvls_rec'] = np.stack(elec_pair_rvls_rec, 0)
            self.res[region_pair_key]['elec_pair_rvls_nrec'] = np.stack(elec_pair_rvls_nrec, 0)
            if self.do_perm_test:
                self.res[region_pair_key]['delta_mem_rayleigh_zscores'] = np.stack(delta_mem_rayleigh_zscores, 0)
                self.res[region_pair_key]['delta_mem_rvl_zscores'] = np.stack(delta_mem_rvl_zscores, 0)
            if self.include_phase_diffs_in_res:
                self.res[region_pair_key]['elec_pair_phase_diffs'] = np.stack(elec_pair_phase_diffs, -1)
            self.res[region_pair_key]['time'] = phase_data.time.data
            self.res[region_pair_key]['recalled'] = recalled
コード例 #17
0
def power_spectra_from_spike_times(s_times,
                                   clust_nums,
                                   channel_file,
                                   rel_start_ms,
                                   rel_stop_ms,
                                   freqs,
                                   noise_freq=[58., 62.],
                                   downsample_freq=250,
                                   mean_over_spikes=True):
    """
    Function to compute power relative to spike times. This computes power at given frequencies for the ENTIRE session
    and then bins it relative to spike times. You WILL run out of memory if you don't let it downsample first. Default
    downsample is to 250 Hz.

    Parameters
    ----------
    s_times: np.ndarray
        Array (or list) of timestamps of when spikes occured. EEG will be loaded relative to these times.
    clust_nums:
        s_times: np.ndarray
        Array (or list) of cluster IDs, same size as s_times
    channel_file: str
        Path to Ncs file from which to load eeg.
    rel_start_ms: int
        Initial time (in ms), relative to the onset of each spike
    rel_stop_ms: int
        End time (in ms), relative to the onset of each spike
    freqs: np.ndarray
        array of frequencies at which to compute power
    noise_freq: list
        Stop filter will be applied to the given range. Default=[58. 62]
    downsample_freq: int or float
        Frequency to downsample the data. Use decimate, so we will likely not reach the exact frequency.
    mean_over_spikes: bool
        After computing the spike x frequency array, do we mean over spikes and return only the mean power spectra

    Returns
    -------
    dict
        dict of either spike x frequency array of power values or just frequencies, if mean_over_spikes. Keys are
        cluster numbers
    """

    # make a df with 'stTime' column for epoching
    events = pd.DataFrame(data=np.stack([s_times, clust_nums], -1),
                          columns=['stTime', 'cluster_num'])

    # load channel data
    signals, timestamps, sr = load_ncs(channel_file)

    # downsample the session
    if downsample_freq is not None:
        signals, timestamps, sr = _my_downsample(signals, timestamps, sr,
                                                 downsample_freq)
    else:
        print(
            'I HIGHLY recommend you downsample the data before computing power across the whole session...'
        )
        print('You will probably run out of memory.')

    # make into timeseries
    eeg = TimeSeries.create(signals,
                            samplerate=sr,
                            dims=['time'],
                            coords={'time': timestamps / 1e6})

    # filter line noise
    if noise_freq is not None:
        if isinstance(noise_freq[0], float):
            noise_freq = [noise_freq]
        for this_noise_freq in noise_freq:
            b_filter = ButterworthFilter(eeg,
                                         this_noise_freq,
                                         filt_type='stop',
                                         order=4)
            eeg = b_filter.filter()

    # compute power
    wave_pow = MorletWaveletFilter(eeg,
                                   freqs,
                                   output='power',
                                   width=5,
                                   cpus=12,
                                   verbose=False).filter()

    # log the power
    data = wave_pow.data
    wave_pow.data = numexpr.evaluate('log10(data)')

    # get start and stop relative to the spikes
    epochs = _compute_epochs(events, rel_start_ms, rel_stop_ms, timestamps, sr)
    bad_epochs = (np.any(epochs < 0, 1)) | (np.any(epochs > len(signals), 1))
    epochs = epochs[~bad_epochs]
    events = events[~bad_epochs].reset_index(drop=True)

    # mean over time within epochs
    spikes_x_freqs = np.stack(
        [np.mean(wave_pow.data[:, x[0]:x[1]], axis=1) for x in epochs])

    # make dict with keys being cluster numbers. Mean over spikes if desired.
    pow_spect_dict = {}
    for this_cluster in events.cluster_num.unique():
        if mean_over_spikes:
            pow_spect_dict[this_cluster] = spikes_x_freqs[
                events.cluster_num == this_cluster].mean(axis=0)
        else:
            pow_spect_dict[this_cluster] = spikes_x_freqs[events.cluster_num ==
                                                          this_cluster]

    return pow_spect_dict