Beispiel #1
0
class ComputePSPowers(RamTask):
    def __init__(self, params, mark_as_completed=True):
        RamTask.__init__(self, mark_as_completed)
        self.params = params
        self.wavelet_transform = MorletWaveletTransform()

    def restore(self):
        subject = self.pipeline.subject
        experiment = self.pipeline.experiment

        ps_pow_mat_pre = joblib.load(
            self.get_path_to_resource_in_workspace(subject + '-' + experiment +
                                                   '-ps_pow_mat_pre.pkl'))
        ps_pow_mat_post = joblib.load(
            self.get_path_to_resource_in_workspace(subject + '-' + experiment +
                                                   '-ps_pow_mat_post.pkl'))

        self.pass_object('ps_pow_mat_pre', ps_pow_mat_pre)
        self.pass_object('ps_pow_mat_post', ps_pow_mat_post)

    def run(self):
        subject = self.pipeline.subject
        experiment = self.pipeline.experiment

        #fetching objects from other tasks
        events = self.get_passed_object(self.pipeline.experiment + '_events')
        # channels = self.get_passed_object('channels')
        # tal_info = self.get_passed_object('tal_info')
        monopolar_channels = self.get_passed_object('monopolar_channels')
        bipolar_pairs = self.get_passed_object('bipolar_pairs')

        sessions = np.unique(events.session)
        print experiment, 'sessions:', sessions

        ps_pow_mat_pre, ps_pow_mat_post = self.compute_ps_powers(
            events, sessions, monopolar_channels, bipolar_pairs, experiment)

        joblib.dump(
            ps_pow_mat_pre,
            self.get_path_to_resource_in_workspace(subject + '-' + experiment +
                                                   '-ps_pow_mat_pre.pkl'))
        joblib.dump(
            ps_pow_mat_post,
            self.get_path_to_resource_in_workspace(subject + '-' + experiment +
                                                   '-ps_pow_mat_post.pkl'))

        self.pass_object('ps_pow_mat_pre', ps_pow_mat_pre)
        self.pass_object('ps_pow_mat_post', ps_pow_mat_post)

    def compute_ps_powers(self, events, sessions, monopolar_channels,
                          bipolar_pairs, experiment):
        n_freqs = len(self.params.freqs)
        n_bps = len(bipolar_pairs)

        pow_mat_pre = pow_mat_post = None

        pow_ev = None
        samplerate = winsize = bufsize = None

        monopolar_channels_list = list(monopolar_channels)
        for sess in sessions:
            sess_events = events[events.session == sess]
            # print type(sess_events)

            n_events = len(sess_events)

            print 'Loading EEG for', n_events, 'events of session', sess

            pre_start_time = self.params.ps_start_time - self.params.ps_offset
            pre_end_time = self.params.ps_end_time - self.params.ps_offset

            # eegs_pre = Events(sess_events).get_data(channels=channels, start_time=pre_start_time, end_time=pre_end_time,
            #             buffer_time=self.params.ps_buf, eoffset='eegoffset', keep_buffer=True, eoffset_in_time=False)

            eeg_pre_reader = EEGReader(
                events=sess_events,
                channels=np.array(monopolar_channels_list),
                start_time=pre_start_time,
                end_time=pre_end_time,
                buffer_time=self.params.ps_buf)

            eegs_pre = eeg_pre_reader.read()

            if samplerate is None:
                # samplerate = round(eegs_pre.samplerate)
                # samplerate = eegs_pre.attrs['samplerate']

                samplerate = float(eegs_pre['samplerate'])

                winsize = int(
                    round(samplerate * (pre_end_time - pre_start_time +
                                        2 * self.params.ps_buf)))
                bufsize = int(round(samplerate * self.params.ps_buf))
                print 'samplerate =', samplerate, 'winsize =', winsize, 'bufsize =', bufsize
                pow_ev = np.empty(shape=n_freqs * winsize, dtype=float)
                self.wavelet_transform.init(self.params.width,
                                            self.params.freqs[0],
                                            self.params.freqs[-1], n_freqs,
                                            samplerate, winsize)

            # mirroring
            nb_ = int(round(samplerate * (self.params.ps_buf)))
            eegs_pre[..., -nb_:] = eegs_pre[..., -nb_ - 2:-2 * nb_ - 2:-1]

            dim3_pre = eegs_pre.shape[
                2]  # because post-stim time inreval does not align for all stim events (stims have different duration)
            # we have to take care of aligning eegs_post ourselves time dim to dim3

            # eegs_post = np.zeros_like(eegs_pre)

            from ptsa.data.TimeSeriesX import TimeSeriesX
            eegs_post = TimeSeriesX(np.zeros_like(eegs_pre),
                                    dims=eegs_pre.dims,
                                    coords=eegs_pre.coords)

            post_start_time = self.params.ps_offset
            post_end_time = self.params.ps_offset + (self.params.ps_end_time -
                                                     self.params.ps_start_time)
            for i_ev in xrange(n_events):
                ev_offset = sess_events[i_ev].pulse_duration
                if ev_offset > 0:
                    if experiment == 'PS3' and sess_events[i_ev].nBursts > 0:
                        ev_offset *= sess_events[i_ev].nBursts + 1
                    ev_offset *= 0.001
                else:
                    ev_offset = 0.0

                # eeg_post = Events(sess_events[i_ev:i_ev+1]).get_data(channels=channels, start_time=post_start_time+ev_offset,
                #             end_time=post_end_time+ev_offset, buffer_time=self.params.ps_buf,
                #             eoffset='eegoffset', keep_buffer=True, eoffset_in_time=False)

                eeg_post_reader = EEGReader(
                    events=sess_events[i_ev:i_ev + 1],
                    channels=np.array(monopolar_channels_list),
                    start_time=post_start_time + ev_offset,
                    end_time=post_end_time + ev_offset,
                    buffer_time=self.params.ps_buf)

                eeg_post = eeg_post_reader.read()

                dim3_post = eeg_post.shape[2]
                # here we take care of possible mismatch of time dim length
                if dim3_pre == dim3_post:
                    eegs_post[:, i_ev:i_ev + 1, :] = eeg_post
                elif dim3_pre < dim3_post:
                    eegs_post[:, i_ev:i_ev + 1, :] = eeg_post[:, :, :-1]
                else:
                    eegs_post[:, i_ev:i_ev + 1, :-1] = eeg_post

            # mirroring
            eegs_post[..., :nb_] = eegs_post[..., 2 * nb_:nb_:-1]

            print 'Computing', experiment, 'powers'

            sess_pow_mat_pre = np.empty(shape=(n_events, n_bps, n_freqs),
                                        dtype=np.float)
            sess_pow_mat_post = np.empty_like(sess_pow_mat_pre)

            for i, ti in enumerate(bipolar_pairs):
                bp = ti['channel_str']
                print 'Computing powers for bipolar pair', bp
                elec1 = np.where(monopolar_channels == bp[0])[0][0]
                elec2 = np.where(monopolar_channels == bp[1])[0][0]

                #
                # for i,ti in enumerate(tal_info):
                #     bp = ti['channel_str']
                #     print 'Computing powers for bipolar pair', bp
                #     elec1 = np.where(channels == bp[0])[0][0]
                #     elec2 = np.where(channels == bp[1])[0][0]

                bp_data_pre = eegs_pre[elec1] - eegs_pre[elec2]
                # bp_data_pre.attrs['samplerate'] = samplerate

                bp_data_pre = bp_data_pre.filtered(
                    [58, 62], filt_type='stop', order=self.params.filt_order)
                for ev in xrange(n_events):
                    #pow_pre_ev = phase_pow_multi(self.params.freqs, bp_data_pre[ev], to_return='power')
                    self.wavelet_transform.multiphasevec(
                        bp_data_pre[ev][0:winsize], pow_ev)
                    #sess_pow_mat_pre[ev,i,:] = np.mean(pow_pre_ev[:,nb_:-nb_], axis=1)
                    pow_ev_stripped = np.reshape(
                        pow_ev, (n_freqs, winsize))[:,
                                                    bufsize:winsize - bufsize]
                    if self.params.log_powers:
                        np.log10(pow_ev_stripped, out=pow_ev_stripped)
                    sess_pow_mat_pre[ev, i, :] = np.nanmean(pow_ev_stripped,
                                                            axis=1)

                bp_data_post = eegs_post[elec1] - eegs_post[elec2]
                # bp_data_post.attrs['samplerate'] = samplerate

                bp_data_post = bp_data_post.filtered(
                    [58, 62], filt_type='stop', order=self.params.filt_order)
                for ev in xrange(n_events):
                    #pow_post_ev = phase_pow_multi(self.params.freqs, bp_data_post[ev], to_return='power')
                    self.wavelet_transform.multiphasevec(
                        bp_data_post[ev][0:winsize], pow_ev)
                    #sess_pow_mat_post[ev,i,:] = np.mean(pow_post_ev[:,nb_:-nb_], axis=1)
                    pow_ev_stripped = np.reshape(
                        pow_ev, (n_freqs, winsize))[:,
                                                    bufsize:winsize - bufsize]
                    if self.params.log_powers:
                        np.log10(pow_ev_stripped, out=pow_ev_stripped)
                    sess_pow_mat_post[ev, i, :] = np.nanmean(pow_ev_stripped,
                                                             axis=1)

            sess_pow_mat_pre = sess_pow_mat_pre.reshape(
                (n_events, n_bps * n_freqs))
            #sess_pow_mat_pre = zscore(sess_pow_mat_pre, axis=0, ddof=1)

            sess_pow_mat_post = sess_pow_mat_post.reshape(
                (n_events, n_bps * n_freqs))
            #sess_pow_mat_post = zscore(sess_pow_mat_post, axis=0, ddof=1)

            sess_pow_mat_joint = zscore(np.vstack(
                (sess_pow_mat_pre, sess_pow_mat_post)),
                                        axis=0,
                                        ddof=1)
            sess_pow_mat_pre = sess_pow_mat_joint[:n_events, ...]
            sess_pow_mat_post = sess_pow_mat_joint[n_events:, ...]

            pow_mat_pre = np.vstack(
                (pow_mat_pre, sess_pow_mat_pre
                 )) if pow_mat_pre is not None else sess_pow_mat_pre
            pow_mat_post = np.vstack(
                (pow_mat_post, sess_pow_mat_post
                 )) if pow_mat_post is not None else sess_pow_mat_post

        return pow_mat_pre, pow_mat_post
Beispiel #2
0
class ComputeFR1Powers(RamTask):
    def __init__(self, params, mark_as_completed=True):
        RamTask.__init__(self, mark_as_completed)
        self.params = params
        self.pow_mat = None
        self.samplerate = None
        self.wavelet_transform = MorletWaveletTransform()

    def restore(self):
        subject = self.pipeline.subject
        task = self.pipeline.task

        self.pow_mat = joblib.load(
            self.get_path_to_resource_in_workspace(subject + '-' + task +
                                                   '-pow_mat.pkl'))
        self.samplerate = joblib.load(
            self.get_path_to_resource_in_workspace(subject +
                                                   '-samplerate.pkl'))

        self.pass_object('pow_mat', self.pow_mat)
        self.pass_object('samplerate', self.samplerate)

    def run(self):
        subject = self.pipeline.subject
        task = self.pipeline.task

        events = self.get_passed_object(task + '_events')

        sessions = np.unique(events.session)
        print 'sessions:', sessions

        # channels = self.get_passed_object('channels')
        # tal_info = self.get_passed_object('tal_info')
        monopolar_channels = self.get_passed_object('monopolar_channels')
        bipolar_pairs = self.get_passed_object('bipolar_pairs')

        self.compute_powers(events, sessions, monopolar_channels,
                            bipolar_pairs)

        self.pass_object('pow_mat', self.pow_mat)
        self.pass_object('samplerate', self.samplerate)

        joblib.dump(
            self.pow_mat,
            self.get_path_to_resource_in_workspace(subject + '-' + task +
                                                   '-pow_mat.pkl'))
        joblib.dump(
            self.samplerate,
            self.get_path_to_resource_in_workspace(subject +
                                                   '-samplerate.pkl'))

    def compute_powers(self, events, sessions, monopolar_channels,
                       bipolar_pairs):
        n_freqs = len(self.params.freqs)
        n_bps = len(bipolar_pairs)

        self.pow_mat = None

        pow_ev = None
        winsize = bufsize = None
        for sess in sessions:
            sess_events = events[events.session == sess]
            n_events = len(sess_events)

            print 'Loading EEG for', n_events, 'events of session', sess

            # eegs = Events(sess_events).get_data(channels=channels, start_time=self.params.fr1_start_time, end_time=self.params.fr1_end_time,
            #                             buffer_time=self.params.fr1_buf, eoffset='eegoffset', keep_buffer=True, eoffset_in_time=False)

            # from ptsa.data.readers import TimeSeriesEEGReader
            # time_series_reader = TimeSeriesEEGReader(events=sess_events, start_time=self.params.fr1_start_time,
            #                                  end_time=self.params.fr1_end_time, buffer_time=self.params.fr1_buf, keep_buffer=True)
            #
            # eegs = time_series_reader.read(monopolar_channels)

            eeg_reader = EEGReader(events=sess_events,
                                   channels=monopolar_channels,
                                   start_time=self.params.fr1_start_time,
                                   end_time=self.params.fr1_end_time,
                                   buffer_time=self.params.fr1_buf)

            eegs = eeg_reader.read()

            # print 'eegs=',eegs.values[0,0,:2],eegs.values[0,0,-2:]
            # sys.exit()
            #
            # a = eegs[0]-eegs[1]

            # mirroring
            #eegs[...,:1365] = eegs[...,2730:1365:-1]
            #eegs[...,2731:4096] = eegs[...,2729:1364:-1]

            if self.samplerate is None:
                self.samplerate = float(eegs.samplerate)
                winsize = int(
                    round(
                        self.samplerate *
                        (self.params.fr1_end_time - self.params.fr1_start_time
                         + 2 * self.params.fr1_buf)))
                bufsize = int(round(self.samplerate * self.params.fr1_buf))
                print 'samplerate =', self.samplerate, 'winsize =', winsize, 'bufsize =', bufsize
                pow_ev = np.empty(shape=n_freqs * winsize, dtype=float)
                self.wavelet_transform.init(self.params.width,
                                            self.params.freqs[0],
                                            self.params.freqs[-1], n_freqs,
                                            self.samplerate, winsize)

            print 'Computing FR1 powers'

            sess_pow_mat = np.empty(shape=(n_events, n_bps, n_freqs),
                                    dtype=np.float)

            #monopolar_channels_np = np.array(monopolar_channels)
            for i, ti in enumerate(bipolar_pairs):
                # print bp
                # print monopolar_channels

                # print np.where(monopolar_channels == bp[0])
                # print np.where(monopolar_channels == bp[1])
                bp = ti['channel_str']
                print 'Computing powers for bipolar pair', bp
                elec1 = np.where(monopolar_channels == bp[0])[0][0]
                elec2 = np.where(monopolar_channels == bp[1])[0][0]
                # print 'elec1=',elec1
                # print 'elec2=',elec2
                # eegs_elec1 = eegs[elec1]
                # eegs_elec2 = eegs[elec2]
                # print 'eegs_elec1=',eegs_elec1
                # print 'eegs_elec2=',eegs_elec2
                # eegs_elec1.reset_coords('channels')
                # eegs_elec2.reset_coords('channels')

                bp_data = eegs[elec1] - eegs[elec2]
                bp_data.attrs['samplerate'] = self.samplerate

                # bp_data = eegs[elec1] - eegs[elec2]
                # bp_data = eegs[elec1] - eegs[elec2]
                # bp_data = eegs.values[elec1] - eegs.values[elec2]

                bp_data = bp_data.filtered([58, 62],
                                           filt_type='stop',
                                           order=self.params.filt_order)
                for ev in xrange(n_events):
                    self.wavelet_transform.multiphasevec(
                        bp_data[ev][0:winsize], pow_ev)
                    #if np.min(pow_ev) < 0.0:
                    #    print ev, events[ev]
                    #    joblib.dump(bp_data[ev], 'bad_bp_ev%d'%ev)
                    #    joblib.dump(eegs[elec1][ev], 'bad_elec1_ev%d'%ev)
                    #    joblib.dump(eegs[elec2][ev], 'bad_elec2_ev%d'%ev)
                    #    print 'Negative powers detected'
                    #    import sys
                    #    sys.exit(1)
                    pow_ev_stripped = np.reshape(
                        pow_ev, (n_freqs, winsize))[:,
                                                    bufsize:winsize - bufsize]
                    if self.params.log_powers:
                        np.log10(pow_ev_stripped, out=pow_ev_stripped)
                    sess_pow_mat[ev, i, :] = np.nanmean(pow_ev_stripped,
                                                        axis=1)

            self.pow_mat = np.concatenate(
                (self.pow_mat, sess_pow_mat),
                axis=0) if self.pow_mat is not None else sess_pow_mat

        self.pow_mat = np.reshape(self.pow_mat, (len(events), n_bps * n_freqs))
class ComputeFR1Powers(RamTask):
    def __init__(self, params, mark_as_completed=True):
        RamTask.__init__(self, mark_as_completed)
        self.params = params
        self.pow_mat = None
        self.samplerate = None
        self.wavelet_transform = MorletWaveletTransform()

    def restore(self):
        subject = self.pipeline.subject
        task = self.pipeline.task

        self.pow_mat = joblib.load(self.get_path_to_resource_in_workspace(subject + '-' + task + '-pow_mat.pkl'))
        self.samplerate = joblib.load(self.get_path_to_resource_in_workspace(subject + '-samplerate.pkl'))

        self.pass_object('pow_mat', self.pow_mat)
        self.pass_object('samplerate', self.samplerate)

    def run(self):
        subject = self.pipeline.subject
        task = self.pipeline.task

        events = self.get_passed_object(task+'_events')

        sessions = np.unique(events.session)
        print 'sessions:', sessions

        # channels = self.get_passed_object('channels')
        # tal_info = self.get_passed_object('tal_info')
        monopolar_channels = self.get_passed_object('monopolar_channels')
        bipolar_pairs = self.get_passed_object('bipolar_pairs')

        self.compute_powers(events, sessions, monopolar_channels, bipolar_pairs)

        self.pass_object('pow_mat', self.pow_mat)
        self.pass_object('samplerate', self.samplerate)

        joblib.dump(self.pow_mat, self.get_path_to_resource_in_workspace(subject + '-' + task + '-pow_mat.pkl'))
        joblib.dump(self.samplerate, self.get_path_to_resource_in_workspace(subject + '-samplerate.pkl'))

    def compute_powers(self, events, sessions,monopolar_channels , bipolar_pairs ):
        n_freqs = len(self.params.freqs)
        n_bps = len(bipolar_pairs)

        self.pow_mat = None

        pow_ev = None
        winsize = bufsize = None
        for sess in sessions:
            sess_events = events[events.session == sess]
            n_events = len(sess_events)

            print 'Loading EEG for', n_events, 'events of session', sess

            # eegs = Events(sess_events).get_data(channels=channels, start_time=self.params.fr1_start_time, end_time=self.params.fr1_end_time,
            #                             buffer_time=self.params.fr1_buf, eoffset='eegoffset', keep_buffer=True, eoffset_in_time=False)

            # from ptsa.data.readers import TimeSeriesEEGReader
            # time_series_reader = TimeSeriesEEGReader(events=sess_events, start_time=self.params.fr1_start_time,
            #                                  end_time=self.params.fr1_end_time, buffer_time=self.params.fr1_buf, keep_buffer=True)
            #
            # eegs = time_series_reader.read(monopolar_channels)

            # VERSION 2/22/2016
            # eeg_reader = EEGReader(events=sess_events, channels=monopolar_channels,
            #                        start_time=self.params.fr1_start_time,
            #                        end_time=self.params.fr1_end_time, buffer_time=self.params.fr1_buf)

            # VERSION WITH MIRRORING
            eeg_reader = EEGReader(events=sess_events, channels=monopolar_channels,
                                   start_time=self.params.fr1_start_time,
                                   end_time=self.params.fr1_end_time, buffer_time=0.0)


            eegs = eeg_reader.read()
            if eeg_reader.removed_bad_data():
                print 'REMOVED SOME BAD EVENTS !!!'
                sess_events = eegs['events'].values.view(np.recarray)
                n_events = len(sess_events)
                events = np.hstack((events[events.session!=sess],sess_events)).view(np.recarray)
                ev_order = np.argsort(events, order=('session','list','mstime'))
                events = events[ev_order]
                self.pass_object(self.pipeline.task+'_events', events)


            # mirroring
            #eegs[...,:1365] = eegs[...,2730:1365:-1]
            #eegs[...,2731:4096] = eegs[...,2729:1364:-1]

            eegs = eegs.add_mirror_buffer(duration=self.params.fr1_buf)


            if self.samplerate is None:
                self.samplerate = float(eegs.samplerate)
                winsize = int(round(self.samplerate*(self.params.fr1_end_time-self.params.fr1_start_time+2*self.params.fr1_buf)))
                bufsize = int(round(self.samplerate*self.params.fr1_buf))
                print 'samplerate =', self.samplerate, 'winsize =', winsize, 'bufsize =', bufsize
                pow_ev = np.empty(shape=n_freqs*winsize, dtype=float)
                self.wavelet_transform.init(self.params.width, self.params.freqs[0], self.params.freqs[-1], n_freqs, self.samplerate, winsize)

            print 'Computing FR1 powers'

            sess_pow_mat = np.empty(shape=(n_events, n_bps, n_freqs), dtype=np.float)

            #monopolar_channels_np = np.array(monopolar_channels)
            for i,ti in enumerate(bipolar_pairs):
                # print bp
                # print monopolar_channels

                # print np.where(monopolar_channels == bp[0])
                # print np.where(monopolar_channels == bp[1])
                bp = ti['channel_str']
                print 'Computing powers for bipolar pair', bp
                elec1 = np.where(monopolar_channels == bp[0])[0][0]
                elec2 = np.where(monopolar_channels == bp[1])[0][0]
                # print 'elec1=',elec1
                # print 'elec2=',elec2
                # eegs_elec1 = eegs[elec1]
                # eegs_elec2 = eegs[elec2]
                # print 'eegs_elec1=',eegs_elec1
                # print 'eegs_elec2=',eegs_elec2
                # eegs_elec1.reset_coords('channels')
                # eegs_elec2.reset_coords('channels')

                bp_data = eegs[elec1] - eegs[elec2]
                bp_data.attrs['samplerate'] = self.samplerate

                # bp_data = eegs[elec1] - eegs[elec2]
                # bp_data = eegs[elec1] - eegs[elec2]
                # bp_data = eegs.values[elec1] - eegs.values[elec2]

                bp_data = bp_data.filtered([58,62], filt_type='stop', order=self.params.filt_order)
                for ev in xrange(n_events):
                    self.wavelet_transform.multiphasevec(bp_data[ev][0:winsize], pow_ev)
                    #if np.min(pow_ev) < 0.0:
                    #    print ev, events[ev]
                    #    joblib.dump(bp_data[ev], 'bad_bp_ev%d'%ev)
                    #    joblib.dump(eegs[elec1][ev], 'bad_elec1_ev%d'%ev)
                    #    joblib.dump(eegs[elec2][ev], 'bad_elec2_ev%d'%ev)
                    #    print 'Negative powers detected'
                    #    import sys
                    #    sys.exit(1)
                    pow_ev_stripped = np.reshape(pow_ev, (n_freqs,winsize))[:,bufsize:winsize-bufsize]
                    if self.params.log_powers:
                        np.log10(pow_ev_stripped, out=pow_ev_stripped)
                    sess_pow_mat[ev,i,:] = np.nanmean(pow_ev_stripped, axis=1)

            self.pow_mat = np.concatenate((self.pow_mat,sess_pow_mat), axis=0) if self.pow_mat is not None else sess_pow_mat

        self.pow_mat = np.reshape(self.pow_mat, (len(events), n_bps*n_freqs))
Beispiel #4
0
class ComputeFR1Powers(RamTask):
    def __init__(self, params, mark_as_completed=True):
        RamTask.__init__(self, mark_as_completed)
        self.params = params
        self.pow_mat = None
        self.samplerate = None
        self.wavelet_transform = MorletWaveletTransform()

    def restore(self):
        subject = self.pipeline.subject
        task = self.pipeline.task

        self.pow_mat = joblib.load(self.get_path_to_resource_in_workspace(subject + '-' + task + '-pow_mat.pkl'))
        self.samplerate = joblib.load(self.get_path_to_resource_in_workspace(subject + '-samplerate.pkl'))

        self.pass_object('pow_mat', self.pow_mat)
        self.pass_object('samplerate', self.samplerate)

    def run(self):
        subject = self.pipeline.subject
        task = self.pipeline.task

        events = self.get_passed_object(task+'_events')

        sessions = np.unique(events.session)
        print 'sessions:', sessions

        channels = self.get_passed_object('channels')
        tal_info = self.get_passed_object('tal_info')
        self.compute_powers(events, sessions, channels, tal_info)

        self.pass_object('pow_mat', self.pow_mat)
        self.pass_object('samplerate', self.samplerate)

        joblib.dump(self.pow_mat, self.get_path_to_resource_in_workspace(subject + '-' + task + '-pow_mat.pkl'))
        joblib.dump(self.samplerate, self.get_path_to_resource_in_workspace(subject + '-samplerate.pkl'))

    def compute_powers(self, events, sessions, channels, tal_info):
        n_freqs = len(self.params.freqs)
        n_bps = len(tal_info)

        self.wavelet_transform.init(5, self.params.freqs[0], self.params.freqs[-1], n_freqs, 1000.0, 4096)

        self.pow_mat = None

        pow_ev = np.empty(shape=n_freqs*4096, dtype=float)
        for sess in sessions:
            sess_events = events[events.session == sess]
            n_events = len(sess_events)

            print 'Loading EEG for', n_events, 'events of session', sess

            eegs = Events(sess_events).get_data(channels=channels, start_time=self.params.fr1_start_time, end_time=self.params.fr1_end_time,
                                        buffer_time=self.params.fr1_buf, eoffset='eegoffset', keep_buffer=True, eoffset_in_time=False)

            #print describe(eegs)

            # mirroring
            #eegs[...,:1365] = eegs[...,2730:1365:-1]
            #eegs[...,2731:4096] = eegs[...,2729:1364:-1]

            self.samplerate = eegs.samplerate

            print 'Computing FR1 powers'

            sess_pow_mat = np.empty(shape=(n_events, n_bps, n_freqs), dtype=np.float)

            for i,ti in enumerate(tal_info):
                bp = ti['channel_str']
                print 'Computing powers for bipolar pair', bp
                elec1 = np.where(channels == bp[0])[0][0]
                elec2 = np.where(channels == bp[1])[0][0]
                bp_data = eegs[elec1] - eegs[elec2]
                bp_data = bp_data.filtered([58,62], filt_type='stop', order=self.params.filt_order)
                for ev in xrange(n_events):
                    #pow_ev = phase_pow_multi(self.params.freqs, bp_data[ev][0:4096], to_return='power')
                    self.wavelet_transform.multiphasevec(bp_data[ev][0:4096], pow_ev)
                    if np.min(pow_ev) < 0.0:
                        print ev, events[ev]
                        joblib.dump(bp_data[ev], 'bad_bp_ev%d'%ev)
                        joblib.dump(eegs[elec1][ev], 'bad_elec1_ev%d'%ev)
                        joblib.dump(eegs[elec2][ev], 'bad_elec2_ev%d'%ev)
                        print 'Negative powers detected'
                        import sys
                        sys.exit(1)

                    if self.params.log_powers:
                        np.log10(pow_ev, out=pow_ev)

                    pow_ev_stripped = np.reshape(pow_ev, (n_freqs,4096))[:,1365:1365+1366]
                    sess_pow_mat[ev,i,:] = np.nanmean(pow_ev_stripped, axis=1)

            self.pow_mat = np.concatenate((self.pow_mat,sess_pow_mat), axis=0) if self.pow_mat is not None else sess_pow_mat

        self.pow_mat = np.reshape(self.pow_mat, (len(events), n_bps*n_freqs))
Beispiel #5
0
class ComputePSPowers(RamTask):
    def __init__(self, params, mark_as_completed=True):
        RamTask.__init__(self, mark_as_completed)
        self.params = params
        self.wavelet_transform = MorletWaveletTransform()

    def restore(self):
        subject = self.pipeline.subject
        experiment = self.pipeline.experiment

        ps_pow_mat_pre = joblib.load(self.get_path_to_resource_in_workspace(subject+'-'+experiment+'-ps_pow_mat_pre.pkl'))
        ps_pow_mat_post = joblib.load(self.get_path_to_resource_in_workspace(subject+'-'+experiment+'-ps_pow_mat_post.pkl'))

        self.pass_object('ps_pow_mat_pre',ps_pow_mat_pre)
        self.pass_object('ps_pow_mat_post',ps_pow_mat_post)


    def run(self):
        subject = self.pipeline.subject
        experiment = self.pipeline.experiment

        #fetching objects from other tasks
        events = self.get_passed_object(self.pipeline.experiment+'_events')
        channels = self.get_passed_object('channels')
        tal_info = self.get_passed_object('tal_info')

        sessions = np.unique(events.session)
        print experiment, 'sessions:', sessions

        ps_pow_mat_pre, ps_pow_mat_post = self.compute_ps_powers(events, sessions, channels, tal_info, experiment)

        joblib.dump(ps_pow_mat_pre, self.get_path_to_resource_in_workspace(subject+'-'+experiment+'-ps_pow_mat_pre.pkl'))
        joblib.dump(ps_pow_mat_post, self.get_path_to_resource_in_workspace(subject+'-'+experiment+'-ps_pow_mat_post.pkl'))

        self.pass_object('ps_pow_mat_pre',ps_pow_mat_pre)
        self.pass_object('ps_pow_mat_post',ps_pow_mat_post)

    def compute_ps_powers(self, events, sessions, channels, tal_info, experiment):
        n_freqs = len(self.params.freqs)
        n_bps = len(tal_info)

        pow_mat_pre = pow_mat_post = None

        pow_ev = None
        samplerate = winsize = bufsize = None
        for sess in sessions:
            sess_events = events[events.session == sess]
            n_events = len(sess_events)

            print 'Loading EEG for', n_events, 'events of session', sess

            pre_start_time = self.params.ps_start_time - self.params.ps_offset
            pre_end_time = self.params.ps_end_time - self.params.ps_offset
            eegs_pre = Events(sess_events).get_data(channels=channels, start_time=pre_start_time, end_time=pre_end_time,
                        buffer_time=self.params.ps_buf, eoffset='eegoffset', keep_buffer=True, eoffset_in_time=False)

            if samplerate is None:
                samplerate = round(eegs_pre.samplerate)
                winsize = int(round(samplerate*(pre_end_time-pre_start_time+2*self.params.ps_buf)))
                bufsize = int(round(samplerate*self.params.ps_buf))
                print 'samplerate =', samplerate, 'winsize =', winsize, 'bufsize =', bufsize
                pow_ev = np.empty(shape=n_freqs*winsize, dtype=float)
                self.wavelet_transform.init(self.params.width, self.params.freqs[0], self.params.freqs[-1], n_freqs, samplerate, winsize)

            # mirroring
            nb_ = int(round(samplerate*(self.params.ps_buf)))
            eegs_pre[...,-nb_:] = eegs_pre[...,-nb_-1:-2*nb_-1:-1]

            dim3_pre = eegs_pre.shape[2]  # because post-stim time inreval does not align for all stim events (stims have different duration)
                                          # we have to take care of aligning eegs_post ourselves time dim to dim3

            eegs_post = np.zeros_like(eegs_pre)
            post_start_time = self.params.ps_offset
            post_end_time = self.params.ps_offset + (self.params.ps_end_time - self.params.ps_start_time)
            for i_ev in xrange(n_events):
                ev_offset = sess_events[i_ev].pulse_duration
                if ev_offset > 0:
                    if experiment == 'PS3' and sess_events[i_ev].nBursts > 0:
                        ev_offset *= sess_events[i_ev].nBursts + 1
                    ev_offset *= 0.001
                else:
                    ev_offset = 0.0

                eeg_post = Events(sess_events[i_ev:i_ev+1]).get_data(channels=channels, start_time=post_start_time+ev_offset,
                            end_time=post_end_time+ev_offset, buffer_time=self.params.ps_buf,
                            eoffset='eegoffset', keep_buffer=True, eoffset_in_time=False)
                
                dim3_post = eeg_post.shape[2]
                # here we take care of possible mismatch of time dim length
                if dim3_pre == dim3_post:
                    eegs_post[:,i_ev:i_ev+1,:] = eeg_post
                elif dim3_pre < dim3_post:
                    eegs_post[:,i_ev:i_ev+1,:] = eeg_post[:,:,:-1]
                else:
                    eegs_post[:,i_ev:i_ev+1,:-1] = eeg_post

            # mirroring
            eegs_post[...,:nb_] = eegs_post[...,2*nb_-1:nb_-1:-1]

            print 'Computing', experiment, 'powers'

            sess_pow_mat_pre = np.empty(shape=(n_events, n_bps, n_freqs), dtype=np.float)
            sess_pow_mat_post = np.empty_like(sess_pow_mat_pre)

            for i,ti in enumerate(tal_info):
                bp = ti['channel_str']
                print 'Computing powers for bipolar pair', bp
                elec1 = np.where(channels == bp[0])[0][0]
                elec2 = np.where(channels == bp[1])[0][0]

                bp_data_pre = eegs_pre[elec1] - eegs_pre[elec2]
                bp_data_pre = bp_data_pre.filtered([58,62], filt_type='stop', order=self.params.filt_order)
                for ev in xrange(n_events):
                    #pow_pre_ev = phase_pow_multi(self.params.freqs, bp_data_pre[ev], to_return='power')
                    self.wavelet_transform.multiphasevec(bp_data_pre[ev][0:winsize], pow_ev)
                    #sess_pow_mat_pre[ev,i,:] = np.mean(pow_pre_ev[:,nb_:-nb_], axis=1)
                    pow_ev_stripped = np.reshape(pow_ev, (n_freqs,winsize))[:,bufsize:winsize-bufsize]
                    if self.params.log_powers:
                        np.log10(pow_ev_stripped, out=pow_ev_stripped)
                    sess_pow_mat_pre[ev,i,:] = np.nanmean(pow_ev_stripped, axis=1)

                bp_data_post = eegs_post[elec1] - eegs_post[elec2]
                bp_data_post = bp_data_post.filtered([58,62], filt_type='stop', order=self.params.filt_order)
                for ev in xrange(n_events):
                    #pow_post_ev = phase_pow_multi(self.params.freqs, bp_data_post[ev], to_return='power')
                    self.wavelet_transform.multiphasevec(bp_data_post[ev][0:winsize], pow_ev)
                    #sess_pow_mat_post[ev,i,:] = np.mean(pow_post_ev[:,nb_:-nb_], axis=1)
                    pow_ev_stripped = np.reshape(pow_ev, (n_freqs,winsize))[:,bufsize:winsize-bufsize]
                    if self.params.log_powers:
                        np.log10(pow_ev_stripped, out=pow_ev_stripped)
                    sess_pow_mat_post[ev,i,:] = np.nanmean(pow_ev_stripped, axis=1)

            sess_pow_mat_pre = sess_pow_mat_pre.reshape((n_events, n_bps*n_freqs))
            sess_pow_mat_pre = zscore(sess_pow_mat_pre, axis=0, ddof=1)

            sess_pow_mat_post = sess_pow_mat_post.reshape((n_events, n_bps*n_freqs))
            sess_pow_mat_post = zscore(sess_pow_mat_post, axis=0, ddof=1)

            pow_mat_pre = np.vstack((pow_mat_pre,sess_pow_mat_pre)) if pow_mat_pre is not None else sess_pow_mat_pre
            pow_mat_post = np.vstack((pow_mat_post,sess_pow_mat_post)) if pow_mat_post is not None else sess_pow_mat_post

        return pow_mat_pre, pow_mat_post
class ComputePSPowers(RamTask):
    def __init__(self, params, mark_as_completed=True):
        RamTask.__init__(self, mark_as_completed)
        self.params = params
        self.wavelet_transform = MorletWaveletTransform()

    def restore(self):
        subject = self.pipeline.subject
        experiment = self.pipeline.experiment

        ps_pow_mat_pre = joblib.load(
            self.get_path_to_resource_in_workspace(subject + "-" + experiment + "-ps_pow_mat_pre.pkl")
        )
        ps_pow_mat_post = joblib.load(
            self.get_path_to_resource_in_workspace(subject + "-" + experiment + "-ps_pow_mat_post.pkl")
        )

        self.pass_object("ps_pow_mat_pre", ps_pow_mat_pre)
        self.pass_object("ps_pow_mat_post", ps_pow_mat_post)

    def run(self):
        subject = self.pipeline.subject
        experiment = self.pipeline.experiment

        # fetching objects from other tasks
        events = self.get_passed_object(self.pipeline.experiment + "_events")
        # channels = self.get_passed_object('channels')
        # tal_info = self.get_passed_object('tal_info')
        monopolar_channels = self.get_passed_object("monopolar_channels")
        bipolar_pairs = self.get_passed_object("bipolar_pairs")

        sessions = np.unique(events.session)
        print experiment, "sessions:", sessions

        ps_pow_mat_pre, ps_pow_mat_post = self.compute_ps_powers(
            events, sessions, monopolar_channels, bipolar_pairs, experiment
        )

        joblib.dump(
            ps_pow_mat_pre, self.get_path_to_resource_in_workspace(subject + "-" + experiment + "-ps_pow_mat_pre.pkl")
        )
        joblib.dump(
            ps_pow_mat_post, self.get_path_to_resource_in_workspace(subject + "-" + experiment + "-ps_pow_mat_post.pkl")
        )

        self.pass_object("ps_pow_mat_pre", ps_pow_mat_pre)
        self.pass_object("ps_pow_mat_post", ps_pow_mat_post)

    def compute_ps_powers(self, events, sessions, monopolar_channels, bipolar_pairs, experiment):
        n_freqs = len(self.params.freqs)
        n_bps = len(bipolar_pairs)

        pow_mat_pre = pow_mat_post = None

        pow_ev = None
        samplerate = winsize = bufsize = None

        monopolar_channels_list = list(monopolar_channels)
        for sess in sessions:
            sess_events = events[events.session == sess]
            # print type(sess_events)

            n_events = len(sess_events)

            print "Loading EEG for", n_events, "events of session", sess

            pre_start_time = self.params.ps_start_time - self.params.ps_offset
            pre_end_time = self.params.ps_end_time - self.params.ps_offset

            # eegs_pre = Events(sess_events).get_data(channels=channels, start_time=pre_start_time, end_time=pre_end_time,
            #             buffer_time=self.params.ps_buf, eoffset='eegoffset', keep_buffer=True, eoffset_in_time=False)

            eeg_pre_reader = EEGReader(
                events=sess_events,
                channels=np.array(monopolar_channels_list),
                start_time=pre_start_time,
                end_time=pre_end_time,
                buffer_time=self.params.ps_buf,
            )

            eegs_pre = eeg_pre_reader.read()
            if eeg_pre_reader.removed_bad_data():
                print "REMOVED SOME BAD EVENTS !!!"
                sess_events = eegs_pre["events"].values.view(np.recarray)
                n_events = len(sess_events)
                events = np.hstack((events[events.session != sess], sess_events)).view(np.recarray)
                ev_order = np.argsort(events, order=("session", "mstime"))
                events = events[ev_order]
                self.pass_object(self.pipeline.experiment + "_events", events)

            if samplerate is None:
                # samplerate = round(eegs_pre.samplerate)
                # samplerate = eegs_pre.attrs['samplerate']

                samplerate = float(eegs_pre["samplerate"])

                winsize = int(round(samplerate * (pre_end_time - pre_start_time + 2 * self.params.ps_buf)))
                bufsize = int(round(samplerate * self.params.ps_buf))
                print "samplerate =", samplerate, "winsize =", winsize, "bufsize =", bufsize
                pow_ev = np.empty(shape=n_freqs * winsize, dtype=float)
                self.wavelet_transform.init(
                    self.params.width, self.params.freqs[0], self.params.freqs[-1], n_freqs, samplerate, winsize
                )

            # mirroring
            nb_ = int(round(samplerate * (self.params.ps_buf)))
            eegs_pre[..., -nb_:] = eegs_pre[..., -nb_ - 2 : -2 * nb_ - 2 : -1]

            dim3_pre = eegs_pre.shape[
                2
            ]  # because post-stim time inreval does not align for all stim events (stims have different duration)
            # we have to take care of aligning eegs_post ourselves time dim to dim3

            # eegs_post = np.zeros_like(eegs_pre)

            from ptsa.data.TimeSeriesX import TimeSeriesX

            eegs_post = TimeSeriesX(np.zeros_like(eegs_pre), dims=eegs_pre.dims, coords=eegs_pre.coords)

            post_start_time = self.params.ps_offset
            post_end_time = self.params.ps_offset + (self.params.ps_end_time - self.params.ps_start_time)
            for i_ev in xrange(n_events):
                ev_offset = (
                    sess_events[i_ev].pulse_duration if experiment != "PS3" else sess_events[i_ev].train_duration
                )
                if ev_offset > 0:
                    ev_offset *= 0.001
                else:
                    ev_offset = 0.0

                # eeg_post = Events(sess_events[i_ev:i_ev+1]).get_data(channels=channels, start_time=post_start_time+ev_offset,
                #             end_time=post_end_time+ev_offset, buffer_time=self.params.ps_buf,
                #             eoffset='eegoffset', keep_buffer=True, eoffset_in_time=False)

                eeg_post_reader = EEGReader(
                    events=sess_events[i_ev : i_ev + 1],
                    channels=np.array(monopolar_channels_list),
                    start_time=post_start_time + ev_offset,
                    end_time=post_end_time + ev_offset,
                    buffer_time=self.params.ps_buf,
                )

                eeg_post = eeg_post_reader.read()

                dim3_post = eeg_post.shape[2]
                # here we take care of possible mismatch of time dim length
                if dim3_pre == dim3_post:
                    eegs_post[:, i_ev : i_ev + 1, :] = eeg_post
                elif dim3_pre < dim3_post:
                    eegs_post[:, i_ev : i_ev + 1, :] = eeg_post[:, :, :-1]
                else:
                    eegs_post[:, i_ev : i_ev + 1, :-1] = eeg_post

            # mirroring
            eegs_post[..., :nb_] = eegs_post[..., 2 * nb_ : nb_ : -1]

            print "Computing", experiment, "powers"

            sess_pow_mat_pre = np.empty(shape=(n_events, n_bps, n_freqs), dtype=np.float)
            sess_pow_mat_post = np.empty_like(sess_pow_mat_pre)

            for i, ti in enumerate(bipolar_pairs):
                bp = ti["channel_str"]
                print "Computing powers for bipolar pair", bp
                elec1 = np.where(monopolar_channels == bp[0])[0][0]
                elec2 = np.where(monopolar_channels == bp[1])[0][0]

                #
                # for i,ti in enumerate(tal_info):
                #     bp = ti['channel_str']
                #     print 'Computing powers for bipolar pair', bp
                #     elec1 = np.where(channels == bp[0])[0][0]
                #     elec2 = np.where(channels == bp[1])[0][0]

                bp_data_pre = eegs_pre[elec1] - eegs_pre[elec2]
                # bp_data_pre.attrs['samplerate'] = samplerate

                bp_data_pre = bp_data_pre.filtered([58, 62], filt_type="stop", order=self.params.filt_order)
                for ev in xrange(n_events):
                    # pow_pre_ev = phase_pow_multi(self.params.freqs, bp_data_pre[ev], to_return='power')
                    self.wavelet_transform.multiphasevec(bp_data_pre[ev][0:winsize], pow_ev)
                    # sess_pow_mat_pre[ev,i,:] = np.mean(pow_pre_ev[:,nb_:-nb_], axis=1)
                    pow_ev_stripped = np.reshape(pow_ev, (n_freqs, winsize))[:, bufsize : winsize - bufsize]
                    if self.params.log_powers:
                        np.log10(pow_ev_stripped, out=pow_ev_stripped)
                    sess_pow_mat_pre[ev, i, :] = np.nanmean(pow_ev_stripped, axis=1)

                bp_data_post = eegs_post[elec1] - eegs_post[elec2]
                # bp_data_post.attrs['samplerate'] = samplerate

                bp_data_post = bp_data_post.filtered([58, 62], filt_type="stop", order=self.params.filt_order)
                for ev in xrange(n_events):
                    # pow_post_ev = phase_pow_multi(self.params.freqs, bp_data_post[ev], to_return='power')
                    self.wavelet_transform.multiphasevec(bp_data_post[ev][0:winsize], pow_ev)
                    # sess_pow_mat_post[ev,i,:] = np.mean(pow_post_ev[:,nb_:-nb_], axis=1)
                    pow_ev_stripped = np.reshape(pow_ev, (n_freqs, winsize))[:, bufsize : winsize - bufsize]
                    if self.params.log_powers:
                        np.log10(pow_ev_stripped, out=pow_ev_stripped)
                    sess_pow_mat_post[ev, i, :] = np.nanmean(pow_ev_stripped, axis=1)

            sess_pow_mat_pre = sess_pow_mat_pre.reshape((n_events, n_bps * n_freqs))
            # sess_pow_mat_pre = zscore(sess_pow_mat_pre, axis=0, ddof=1)

            sess_pow_mat_post = sess_pow_mat_post.reshape((n_events, n_bps * n_freqs))
            # sess_pow_mat_post = zscore(sess_pow_mat_post, axis=0, ddof=1)

            sess_pow_mat_joint = zscore(np.vstack((sess_pow_mat_pre, sess_pow_mat_post)), axis=0, ddof=1)
            sess_pow_mat_pre = sess_pow_mat_joint[:n_events, ...]
            sess_pow_mat_post = sess_pow_mat_joint[n_events:, ...]

            pow_mat_pre = np.vstack((pow_mat_pre, sess_pow_mat_pre)) if pow_mat_pre is not None else sess_pow_mat_pre
            pow_mat_post = (
                np.vstack((pow_mat_post, sess_pow_mat_post)) if pow_mat_post is not None else sess_pow_mat_post
            )

        return pow_mat_pre, pow_mat_post
class ComputeControlPowers(RamTask):
    def __init__(self, params, mark_as_completed=True):
        RamTask.__init__(self, mark_as_completed)
        self.params = params
        self.pow_mat = None
        self.samplerate = None
        self.wavelet_transform = MorletWaveletTransform()

    def restore(self):
        subject = self.pipeline.subject
        task = self.pipeline.task

        self.pow_mat = joblib.load(self.get_path_to_resource_in_workspace(subject + '-' + task + '-control_pow_mat_pre.pkl'))
        self.pass_object('control_pow_mat_pre', self.pow_mat)

        self.pow_mat = joblib.load(self.get_path_to_resource_in_workspace(subject + '-' + task + '-control_pow_mat_0.45.pkl'))
        self.pass_object('control_pow_mat_045', self.pow_mat)

        self.pow_mat = joblib.load(self.get_path_to_resource_in_workspace(subject + '-' + task + '-control_pow_mat_0.7.pkl'))
        self.pass_object('control_pow_mat_07', self.pow_mat)

        self.pow_mat = joblib.load(self.get_path_to_resource_in_workspace(subject + '-' + task + '-control_pow_mat_1.2.pkl'))
        self.pass_object('control_pow_mat_12', self.pow_mat)

        #self.samplerate = joblib.load(self.get_path_to_resource_in_workspace(subject + '-samplerate.pkl'))
        #self.pass_object('samplerate', self.samplerate)

    def run(self):
        subject = self.pipeline.subject
        task = self.pipeline.task

        events = self.get_passed_object(task+'_control_events')

        sessions = np.unique(events.session)
        print 'sessions:', sessions

        # channels = self.get_passed_object('channels')
        # tal_info = self.get_passed_object('tal_info')
        monopolar_channels = self.get_passed_object('monopolar_channels')
        bipolar_pairs = self.get_passed_object('bipolar_pairs')

        self.compute_powers(events, sessions, monopolar_channels, bipolar_pairs, self.params.control_start_time, self.params.control_end_time, False, True)
        self.pass_object('control_pow_mat_pre', self.pow_mat)
        joblib.dump(self.pow_mat, self.get_path_to_resource_in_workspace(subject + '-' + task + '-control_pow_mat_pre.pkl'))

        self.samplerate = None
        self.compute_powers(events, sessions, monopolar_channels, bipolar_pairs, self.params.control_start_time+0.45, self.params.control_end_time+0.45, True, False)
        self.pass_object('control_pow_mat_045', self.pow_mat)
        joblib.dump(self.pow_mat, self.get_path_to_resource_in_workspace(subject + '-' + task + '-control_pow_mat_0.45.pkl'))

        self.samplerate = None
        self.compute_powers(events, sessions, monopolar_channels, bipolar_pairs, self.params.control_start_time+0.7, self.params.control_end_time+0.7, True, False)
        self.pass_object('control_pow_mat_07', self.pow_mat)
        joblib.dump(self.pow_mat, self.get_path_to_resource_in_workspace(subject + '-' + task + '-control_pow_mat_0.7.pkl'))

        self.samplerate = None
        self.compute_powers(events, sessions, monopolar_channels, bipolar_pairs, self.params.control_start_time+1.2, self.params.control_end_time+1.2, True, False)
        self.pass_object('control_pow_mat_12', self.pow_mat)
        joblib.dump(self.pow_mat, self.get_path_to_resource_in_workspace(subject + '-' + task + '-control_pow_mat_1.2.pkl'))

        #self.pass_object('samplerate', self.samplerate)


    def compute_powers(self, events, sessions, monopolar_channels, bipolar_pairs, start_time, end_time, mirror_front, mirror_back):
        n_freqs = len(self.params.freqs)
        n_bps = len(bipolar_pairs)

        self.pow_mat = None

        pow_ev = None
        winsize = bufsize = None
        for sess in sessions:
            sess_events = events[events.session == sess]
            n_events = len(sess_events)

            print 'Loading EEG for', n_events, 'events of session', sess

            # eegs = Events(sess_events).get_data(channels=channels, start_time=self.params.control_start_time, end_time=self.params.control_end_time,
            #                             buffer_time=self.params.control_buf, eoffset='eegoffset', keep_buffer=True, eoffset_in_time=False)

            # from ptsa.data.readers import TimeSeriesEEGReader
            # time_series_reader = TimeSeriesEEGReader(events=sess_events, start_time=self.params.control_start_time,
            #                                  end_time=self.params.control_end_time, buffer_time=self.params.control_buf, keep_buffer=True)
            #
            # eegs = time_series_reader.read(monopolar_channels)

            eeg_reader = EEGReader(events=sess_events, channels = monopolar_channels,
                                   start_time=start_time, end_time=end_time, buffer_time=self.params.control_buf)

            eegs = eeg_reader.read()

            # print 'eegs=',eegs.values[0,0,:2],eegs.values[0,0,-2:]
            # sys.exit()
            #
            # a = eegs[0]-eegs[1]

            #eegs[...,:1365] = eegs[...,2730:1365:-1]
            #eegs[...,2731:4096] = eegs[...,2729:1364:-1]

            if self.samplerate is None:
                self.samplerate = float(eegs.samplerate)
                winsize = int(round(self.samplerate*(self.params.control_end_time-self.params.control_start_time+2*self.params.control_buf)))
                bufsize = int(round(self.samplerate*self.params.control_buf))
                print 'samplerate =', self.samplerate, 'winsize =', winsize, 'bufsize =', bufsize
                pow_ev = np.empty(shape=n_freqs*winsize, dtype=float)
                self.wavelet_transform.init(self.params.width, self.params.freqs[0], self.params.freqs[-1], n_freqs, self.samplerate, winsize)

            # mirroring
            nb_ = int(round(self.samplerate*(self.params.control_buf)))
            if mirror_front:
                eegs[...,:nb_] = eegs[...,2*nb_-1:nb_-1:-1]
            if mirror_back:
                eegs[...,-nb_:] = eegs[...,-nb_-1:-2*nb_-1:-1]

            print 'Computing control powers'

            sess_pow_mat = np.empty(shape=(n_events, n_bps, n_freqs), dtype=np.float)

            #monopolar_channels_np = np.array(monopolar_channels)
            for i,ti in enumerate(bipolar_pairs):
                # print bp
                # print monopolar_channels

                # print np.where(monopolar_channels == bp[0])
                # print np.where(monopolar_channels == bp[1])
                bp = ti['channel_str']
                print 'Computing powers for bipolar pair', bp
                elec1 = np.where(monopolar_channels == bp[0])[0][0]
                elec2 = np.where(monopolar_channels == bp[1])[0][0]
                # print 'elec1=',elec1
                # print 'elec2=',elec2
                # eegs_elec1 = eegs[elec1]
                # eegs_elec2 = eegs[elec2]
                # print 'eegs_elec1=',eegs_elec1
                # print 'eegs_elec2=',eegs_elec2
                # eegs_elec1.reset_coords('channels')
                # eegs_elec2.reset_coords('channels')

                bp_data = eegs[elec1] - eegs[elec2]
                bp_data.attrs['samplerate'] = self.samplerate

                # bp_data = eegs[elec1] - eegs[elec2]
                # bp_data = eegs[elec1] - eegs[elec2]
                # bp_data = eegs.values[elec1] - eegs.values[elec2]

                bp_data = bp_data.filtered([58,62], filt_type='stop', order=self.params.filt_order)
                for ev in xrange(n_events):
                    self.wavelet_transform.multiphasevec(bp_data[ev][0:winsize], pow_ev)
                    #if np.min(pow_ev) < 0.0:
                    #    print ev, events[ev]
                    #    joblib.dump(bp_data[ev], 'bad_bp_ev%d'%ev)
                    #    joblib.dump(eegs[elec1][ev], 'bad_elec1_ev%d'%ev)
                    #    joblib.dump(eegs[elec2][ev], 'bad_elec2_ev%d'%ev)
                    #    print 'Negative powers detected'
                    #    import sys
                    #    sys.exit(1)
                    pow_ev_stripped = np.reshape(pow_ev, (n_freqs,winsize))[:,bufsize:winsize-bufsize]
                    if self.params.log_powers:
                        np.log10(pow_ev_stripped, out=pow_ev_stripped)
                    sess_pow_mat[ev,i,:] = np.nanmean(pow_ev_stripped, axis=1)

            sess_pow_mat = zscore(sess_pow_mat, axis=0, ddof=1)
            self.pow_mat = np.concatenate((self.pow_mat,sess_pow_mat), axis=0) if self.pow_mat is not None else sess_pow_mat

        self.pow_mat = np.reshape(self.pow_mat, (len(events), n_bps*n_freqs))