Ejemplo n.º 1
0
class PairView(QtGui.QWidget):
    """For analyzing pre/post-synaptic pairs.
    """
    def __init__(self, parent=None):
        self.sweeps = []
        self.channels = []

        self.current_event_set = None
        self.event_sets = []

        QtGui.QWidget.__init__(self, parent)

        self.layout = QtGui.QGridLayout()
        self.setLayout(self.layout)
        self.layout.setContentsMargins(0, 0, 0, 0)

        self.vsplit = QtGui.QSplitter(QtCore.Qt.Vertical)
        self.layout.addWidget(self.vsplit, 0, 0)

        self.pre_plot = pg.PlotWidget()
        self.post_plot = pg.PlotWidget()
        for plt in (self.pre_plot, self.post_plot):
            plt.setClipToView(True)
            plt.setDownsampling(True, True, 'peak')

        self.post_plot.setXLink(self.pre_plot)
        self.vsplit.addWidget(self.pre_plot)
        self.vsplit.addWidget(self.post_plot)

        self.response_plots = PlotGrid()
        self.vsplit.addWidget(self.response_plots)

        self.event_splitter = QtGui.QSplitter(QtCore.Qt.Horizontal)
        self.vsplit.addWidget(self.event_splitter)

        self.event_table = pg.TableWidget()
        self.event_splitter.addWidget(self.event_table)

        self.event_widget = QtGui.QWidget()
        self.event_widget_layout = QtGui.QGridLayout()
        self.event_widget.setLayout(self.event_widget_layout)
        self.event_splitter.addWidget(self.event_widget)

        self.event_splitter.setSizes([600, 100])

        self.event_set_list = QtGui.QListWidget()
        self.event_widget_layout.addWidget(self.event_set_list, 0, 0, 1, 2)
        self.event_set_list.addItem("current")
        self.event_set_list.itemSelectionChanged.connect(
            self.event_set_selected)

        self.add_set_btn = QtGui.QPushButton('add')
        self.event_widget_layout.addWidget(self.add_set_btn, 1, 0, 1, 1)
        self.add_set_btn.clicked.connect(self.add_set_clicked)
        self.remove_set_btn = QtGui.QPushButton('del')
        self.event_widget_layout.addWidget(self.remove_set_btn, 1, 1, 1, 1)
        self.remove_set_btn.clicked.connect(self.remove_set_clicked)
        self.fit_btn = QtGui.QPushButton('fit all')
        self.event_widget_layout.addWidget(self.fit_btn, 2, 0, 1, 2)
        self.fit_btn.clicked.connect(self.fit_clicked)

        self.fit_plot = PlotGrid()

        self.fit_explorer = None

        self.artifact_remover = ArtifactRemover(user_width=True)
        self.baseline_remover = BaselineRemover()
        self.filter = SignalFilter()

        self.params = pg.parametertree.Parameter(
            name='params',
            type='group',
            children=[{
                'name': 'pre',
                'type': 'list',
                'values': []
            }, {
                'name': 'post',
                'type': 'list',
                'values': []
            }, self.artifact_remover.params, self.baseline_remover.params,
                      self.filter.params, {
                          'name': 'time constant',
                          'type': 'float',
                          'suffix': 's',
                          'siPrefix': True,
                          'value': 10e-3,
                          'dec': True,
                          'minStep': 100e-6
                      }])
        self.params.sigTreeStateChanged.connect(self._update_plots)

    def data_selected(self, sweeps, channels):
        self.sweeps = sweeps
        self.channels = channels

        self.params.child('pre').setLimits(channels)
        self.params.child('post').setLimits(channels)

        self._update_plots()

    def _update_plots(self):
        sweeps = self.sweeps
        self.current_event_set = None
        self.event_table.clear()

        # clear all plots
        self.pre_plot.clear()
        self.post_plot.clear()

        pre = self.params['pre']
        post = self.params['post']

        # If there are no selected sweeps or channels have not been set, return
        if len(
                sweeps
        ) == 0 or pre == post or pre not in self.channels or post not in self.channels:
            return

        pre_mode = sweeps[0][pre].clamp_mode
        post_mode = sweeps[0][post].clamp_mode
        for ch, mode, plot in [(pre, pre_mode, self.pre_plot),
                               (post, post_mode, self.post_plot)]:
            units = 'A' if mode == 'vc' else 'V'
            plot.setLabels(left=("Channel %d" % ch, units),
                           bottom=("Time", 's'))

        # Iterate over selected channels of all sweeps, plotting traces one at a time
        # Collect information about pulses and spikes
        pulses = []
        spikes = []
        post_traces = []
        for i, sweep in enumerate(sweeps):
            pre_trace = sweep[pre]['primary']
            post_trace = sweep[post]['primary']

            # Detect pulse times
            stim = sweep[pre]['command'].data
            sdiff = np.diff(stim)
            on_times = np.argwhere(sdiff > 0)[1:, 0]  # 1: skips test pulse
            off_times = np.argwhere(sdiff < 0)[1:, 0]
            pulses.append(on_times)

            # filter data
            post_filt = self.artifact_remover.process(
                post_trace,
                list(on_times) + list(off_times))
            post_filt = self.baseline_remover.process(post_filt)
            post_filt = self.filter.process(post_filt)
            post_traces.append(post_filt)

            # plot raw data
            color = pg.intColor(i, hues=len(sweeps) * 1.3, sat=128)
            color.setAlpha(128)
            for trace, plot in [(pre_trace, self.pre_plot),
                                (post_filt, self.post_plot)]:
                plot.plot(trace.time_values,
                          trace.data,
                          pen=color,
                          antialias=False)

            # detect spike times
            spike_inds = []
            spike_info = []
            for on, off in zip(on_times, off_times):
                spike = detect_evoked_spike(sweep[pre], [on, off])
                spike_info.append(spike)
                if spike is None:
                    spike_inds.append(None)
                else:
                    spike_inds.append(spike['rise_index'])
            spikes.append(spike_info)

            dt = pre_trace.dt
            vticks = pg.VTickGroup(
                [x * dt for x in spike_inds if x is not None],
                yrange=[0.0, 0.2],
                pen=color)
            self.pre_plot.addItem(vticks)

        # Iterate over spikes, plotting average response
        all_responses = []
        avg_responses = []
        fits = []
        fit = None

        npulses = max(map(len, pulses))
        self.response_plots.clear()
        self.response_plots.set_shape(1, npulses +
                                      1)  # 1 extra for global average
        self.response_plots.setYLink(self.response_plots[0, 0])
        for i in range(1, npulses + 1):
            self.response_plots[0, i].hideAxis('left')
        units = 'A' if post_mode == 'vc' else 'V'
        self.response_plots[0,
                            0].setLabels(left=("Averaged events (Channel %d)" %
                                               post, units))

        fit_pen = {'color': (30, 30, 255), 'width': 2, 'dash': [1, 1]}
        for i in range(npulses):
            # get the chunk of each sweep between spikes
            responses = []
            all_responses.append(responses)
            for j, sweep in enumerate(sweeps):
                # get the current spike
                if i >= len(spikes[j]):
                    continue
                spike = spikes[j][i]
                if spike is None:
                    continue

                # find next spike
                next_spike = None
                for sp in spikes[j][i + 1:]:
                    if sp is not None:
                        next_spike = sp
                        break

                # determine time range for response
                max_len = int(40e-3 /
                              dt)  # don't take more than 50ms for any response
                start = spike['rise_index']
                if next_spike is not None:
                    stop = min(start + max_len, next_spike['rise_index'])
                else:
                    stop = start + max_len

                # collect data from this trace
                trace = post_traces[j]
                d = trace.data[start:stop].copy()
                responses.append(d)

            if len(responses) == 0:
                continue

            # extend all responses to the same length and take nanmean
            avg = ragged_mean(responses, method='clip')
            avg -= float_mode(avg[:int(1e-3 / dt)])
            avg_responses.append(avg)

            # plot average response for this pulse
            start = np.median(
                [sp[i]['rise_index']
                 for sp in spikes if sp[i] is not None]) * dt
            t = np.arange(len(avg)) * dt
            self.response_plots[0, i].plot(t, avg, pen='w', antialias=True)

            # fit!
            fit = self.fit_psp(avg, t, dt, post_mode)
            fits.append(fit)

            # let the user mess with this fit
            curve = self.response_plots[0, i].plot(t,
                                                   fit.eval(),
                                                   pen=fit_pen,
                                                   antialias=True).curve
            curve.setClickable(True)
            curve.fit = fit
            curve.sigClicked.connect(self.fit_curve_clicked)

        # display global average
        global_avg = ragged_mean(avg_responses, method='clip')
        t = np.arange(len(global_avg)) * dt
        self.response_plots[0, -1].plot(t, global_avg, pen='w', antialias=True)
        global_fit = self.fit_psp(global_avg, t, dt, post_mode)
        self.response_plots[0, -1].plot(t,
                                        global_fit.eval(),
                                        pen=fit_pen,
                                        antialias=True)

        # display fit parameters in table
        events = []
        for i, f in enumerate(fits + [global_fit]):
            if f is None:
                continue
            if i >= len(fits):
                vals = OrderedDict([('id', 'avg'), ('spike_time', np.nan),
                                    ('spike_stdev', np.nan)])
            else:
                spt = [
                    s[i]['peak_index'] * dt for s in spikes if s[i] is not None
                ]
                vals = OrderedDict([('id', i), ('spike_time', np.mean(spt)),
                                    ('spike_stdev', np.std(spt))])
            vals.update(
                OrderedDict([(k, f.best_values[k]) for k in f.params.keys()]))
            events.append(vals)

        self.current_event_set = (pre, post, events, sweeps)
        self.event_set_list.setCurrentRow(0)
        self.event_set_selected()

    def fit_psp(self, data, t, dt, clamp_mode):
        mode = float_mode(data[:int(1e-3 / dt)])
        sign = -1 if data.mean() - mode < 0 else 1
        params = OrderedDict([
            ('xoffset', (2e-3, 3e-4, 5e-3)),
            ('yoffset', data[0]),
            ('amp', sign * 10e-12),
            #('k', (2e-3, 50e-6, 10e-3)),
            ('rise_time', (2e-3, 50e-6, 10e-3)),
            ('decay_tau', (4e-3, 500e-6, 50e-3)),
            ('rise_power', (2.0, 'fixed')),
        ])
        if clamp_mode == 'ic':
            params['amp'] = sign * 10e-3
            #params['k'] = (5e-3, 50e-6, 20e-3)
            params['rise_time'] = (5e-3, 50e-6, 20e-3)
            params['decay_tau'] = (15e-3, 500e-6, 150e-3)

        fit_kws = {'xtol': 1e-3, 'maxfev': 100}

        psp = fitting.Psp()
        return psp.fit(data, x=t, fit_kws=fit_kws, params=params)

    def add_set_clicked(self):
        if self.current_event_set is None:
            return
        ces = self.current_event_set
        self.event_sets.append(ces)
        item = QtGui.QListWidgetItem("%d -> %d" % (ces[0], ces[1]))
        self.event_set_list.addItem(item)
        item.event_set = ces

    def remove_set_clicked(self):
        if self.event_set_list.currentRow() == 0:
            return
        sel = self.event_set_list.takeItem(self.event_set_list.currentRow())
        self.event_sets.remove(sel.event_set)

    def event_set_selected(self):
        sel = self.event_set_list.selectedItems()[0]
        if sel.text() == "current":
            self.event_table.setData(self.current_event_set[2])
        else:
            self.event_table.setData(sel.event_set[2])

    def fit_clicked(self):
        from synaptic_release import ReleaseModel

        self.fit_plot.clear()
        self.fit_plot.show()

        n_sets = len(self.event_sets)
        self.fit_plot.set_shape(n_sets, 1)
        l = self.fit_plot[0, 0].legend
        if l is not None:
            l.scene.removeItem(l)
            self.fit_plot[0, 0].legend = None
        self.fit_plot[0, 0].addLegend()

        for i in range(n_sets):
            self.fit_plot[i, 0].setXLink(self.fit_plot[0, 0])

        spike_sets = []
        for i, evset in enumerate(self.event_sets):
            evset = evset[2][:-1]  # select events, skip last row (average)
            x = np.array([ev['spike_time'] for ev in evset])
            y = np.array([ev['amp'] for ev in evset])
            x -= x[0]
            x *= 1000
            y /= y[0]
            spike_sets.append((x, y))

            self.fit_plot[i, 0].plot(x / 1000., y, pen=None, symbol='o')

        model = ReleaseModel()
        dynamics_types = ['Dep', 'Fac', 'UR', 'SMR', 'DSR']
        for k in dynamics_types:
            model.Dynamics[k] = 0

        fit_params = []
        with pg.ProgressDialog("Fitting release model..", 0,
                               len(dynamics_types)) as dlg:
            for k in dynamics_types:
                model.Dynamics[k] = 1
                fit_params.append(model.run_fit(spike_sets))
                dlg += 1
                if dlg.wasCanceled():
                    return

        max_color = len(fit_params) * 1.5

        for i, params in enumerate(fit_params):
            for j, spikes in enumerate(spike_sets):
                x, y = spikes
                t = np.linspace(0, x.max(), 1000)
                output = model.eval(x, params.values())
                y = output[:, 1]
                x = output[:, 0] / 1000.
                self.fit_plot[j, 0].plot(x,
                                         y,
                                         pen=(i, max_color),
                                         name=dynamics_types[i])

    def fit_curve_clicked(self, curve):
        if self.fit_explorer is None:
            self.fit_explorer = FitExplorer(curve.fit)
        else:
            self.fit_explorer.set_fit(curve.fit)
        self.fit_explorer.show()
Ejemplo n.º 2
0
 def fit_curve_clicked(self, curve):
     if self.fit_explorer is None:
         self.fit_explorer = FitExplorer(curve.fit)
     else:
         self.fit_explorer.set_fit(curve.fit)
     self.fit_explorer.show()
Ejemplo n.º 3
0
 def explore_fit(self, mode, holding):
     fit = self.pair_analyzer.last_fit[mode, holding]
     self.fit_explorer = FitExplorer(fit)
     self.fit_explorer.show()
Ejemplo n.º 4
0
class PairAnalysisWindow(pg.QtGui.QWidget):
    def __init__(self, default_session, notes_session):
        pg.QtGui.QWidget.__init__(self)
        self.default_session = default_session
        self.notes_session = notes_session
        self.layout = pg.QtGui.QGridLayout()
        self.setLayout(self.layout)
        self.h_splitter = pg.QtGui.QSplitter()
        self.h_splitter.setOrientation(pg.QtCore.Qt.Horizontal)
        self.pair_analyzer = PairAnalysis()
        self.fit_explorer = None
        self.ctrl_panel = self.pair_analyzer.ctrl_panel
        self.user_params = self.ctrl_panel.user_params
        self.output_params = self.ctrl_panel.output_params
        self.ptree = ptree.ParameterTree(showHeader=False)
        self.pair_param = Parameter.create(name='Current Pair', type='str', readonly=True)
        self.ptree.addParameters(self.pair_param)
        self.ptree.addParameters(self.user_params, showTop=False)
        self.fit_ptree = ptree.ParameterTree(showHeader=False)
        self.fit_ptree.addParameters(self.output_params, showTop=False)
        self.save_btn = pg.FeedbackButton('Save Analysis')
        self.expt_btn = pg.QtGui.QPushButton('Set Experiments')
        self.fit_btn = pg.QtGui.QPushButton('Fit Responses')
        self.ic_plot = self.pair_analyzer.ic_plot
        self.vc_plot = self.pair_analyzer.vc_plot
        self.select_ptree = ptree.ParameterTree(showHeader=False)
        self.hash_select = Parameter.create(name='Hashtags', type='group', children=
            [{'name': 'With multiple selected:', 'type': 'list', 'values': ['Include if any appear', 'Include if all appear'], 'value': 'Include if any appear'}]+
            [{'name': '#', 'type': 'bool'}] +
            [{'name': ht, 'type': 'bool'} for ht in comment_hashtag[1:]])
        self.rigs = db.query(db.Experiment.rig_name).distinct().all()
        self.operators = db.query(db.Experiment.operator_name).distinct().all()
        self.rig_select = Parameter.create(name='Rig', type='group', children=[{'name': rig[0], 'type': 'bool'} for rig in self.rigs])
        self.operator_select = Parameter.create(name='Operator', type='group', children=[{'name': operator[0], 'type': 'bool'} for operator in self.operators])
        self.data_type = Parameter.create(name='Reduce data to:', type='group', children=[
            {'name': 'Pairs with data', 'type': 'bool', 'value': True},
            {'name': 'Synapse is None', 'type': 'bool'}])
        [self.select_ptree.addParameters(param) for param in [self.data_type, self.rig_select, self.operator_select, self.hash_select]]
        self.experiment_browser = self.pair_analyzer.experiment_browser
        self.v_splitter = pg.QtGui.QSplitter()
        self.v_splitter.setOrientation(pg.QtCore.Qt.Vertical)
        self.expt_splitter = pg.QtGui.QSplitter()
        self.expt_splitter.setOrientation(pg.QtCore.Qt.Vertical)
        self.h_splitter.addWidget(self.expt_splitter)
        self.expt_splitter.addWidget(self.select_ptree)
        self.expt_splitter.addWidget(self.expt_btn)
        self.h_splitter.addWidget(self.v_splitter)
        self.v_splitter.addWidget(self.experiment_browser)
        self.v_splitter.addWidget(self.ptree)
        self.v_splitter.addWidget(self.fit_btn)
        self.v_splitter.addWidget(self.fit_ptree)
        self.v_splitter.addWidget(self.save_btn)
        self.v_splitter.setSizes([200, 20, 20,400, 20])
        # self.next_pair_button = pg.QtGui.QPushButton("Load Next Pair")
        # self.v_splitter.addWidget(self.next_pair_button)
        self.h_splitter.addWidget(self.vc_plot.grid)
        self.h_splitter.addWidget(self.ic_plot.grid)
        self.fit_compare = self.pair_analyzer.fit_compare
        self.meta_compare = self.pair_analyzer.meta_compare
        self.v2_splitter = pg.QtGui.QSplitter()
        self.v2_splitter.setOrientation(pg.QtCore.Qt.Vertical)
        self.v2_splitter.addWidget(self.fit_compare)
        self.v2_splitter.addWidget(self.meta_compare)
        self.h_splitter.addWidget(self.v2_splitter)
        self.h_splitter.setSizes([100, 200, 200, 200, 500])
        self.layout.addWidget(self.h_splitter)
        self.fit_compare.hide()
        self.meta_compare.hide()
        self.setGeometry(280, 130, 1500, 900)
        self.show()

        # self.next_pair_button.clicked.connect(self.load_next_pair)
        self.experiment_browser.itemSelectionChanged.connect(self.selected_pair)
        self.save_btn.clicked.connect(self.save_to_db)
        self.expt_btn.clicked.connect(self.get_expts)
        self.fit_btn.clicked.connect(self.pair_analyzer.fit_response_update)

    def save_to_db(self):
        try:
            self.pair_analyzer.save_to_db()
            self.save_btn.success()
        except:
            self.save_btn.failure()
            raise

    def get_expts(self):
        expt_query = db.query(db.Experiment)
        synapse_none = self.data_type['Synapse is None']
        if synapse_none:
            subquery = db.query(db.Pair.experiment_id).filter(db.Pair.has_synapse==None).subquery()
            expt_query = expt_query.filter(db.Experiment.id.in_(subquery))
        selected_rigs = [rig.name() for rig in self.rig_select.children() if rig.value() is True]
        if len(selected_rigs) != 0:
            expt_query = expt_query.filter(db.Experiment.rig_name.in_(selected_rigs))
        selected_operators = [operator.name() for operator in self.operator_select.children() if operator.value()is True]
        if len(selected_operators) != 0:
            expt_query = expt_query.filter(db.Experiment.operator_name.in_(selected_operators))
        selected_hashtags = [ht.name() for ht in self.hash_select.children()[1:] if ht.value() is True]
        if len(selected_hashtags) != 0:
            timestamps = self.get_expts_hashtag(selected_hashtags)
            expt_query = expt_query.filter(db.Experiment.ext_id.in_(timestamps))
        expts = expt_query.all()
        self.set_expts(expts)

    def get_expts_hashtag(self, selected_hashtags):
        q = self.notes_session.query(notes_db.PairNotes)
        pairs_to_include = []
        note_pairs = q.all()
        note_pairs.sort(key=lambda p: p.expt_id)
        for p in note_pairs:
            comments = p.notes.get('comments')
            if comments is None:
                continue    
            if len(selected_hashtags) == 1:
                hashtag = selected_hashtags[0]
                if hashtag == '#':
                    if hashtag in comments and all([ht not in comments for ht in comment_hashtag[1:]]):
                        print(p.expt_id, p.pre_cell_id, p.post_cell_id, comments)
                        pairs_to_include.append(p)
                else:
                    if hashtag in comments:
                        print(p.expt_id, p.pre_cell_id, p.post_cell_id, comments)
                        pairs_to_include.append(p)
                
            if len(selected_hashtags) > 1:
                hashtag_present = [ht in comments for ht in selected_hashtags]
                or_expts = self.hash_select['With multiple selected:'] == 'Include if any appear'
                and_expts = self.hash_select['With multiple selected:'] == 'Include if all appear'
                if or_expts and any(hashtag_present):
                    print(p.expt_id, p.pre_cell_id, p.post_cell_id, comments)
                    pairs_to_include.append(p)
                if and_expts and all(hashtag_present):
                    print(p.expt_id, p.pre_cell_id, p.post_cell_id, comments)
                    pairs_to_include.append(p)

        return set([pair.expt_id for pair in pairs_to_include])

    def set_expts(self, expts):
        with pg.BusyCursor():
            self.experiment_browser.clear()
            has_data = self.data_type['Pairs with data']
            if not has_data:
                self.experiment_browser.populate(experiments=expts, all_pairs=True)
            else:
                self.experiment_browser.populate(experiments=expts)

    def selected_pair(self):
        with pg.BusyCursor():
            self.fit_compare.hide()
            self.meta_compare.hide()
            selected = self.experiment_browser.selectedItems()
            if len(selected) != 1:
                return
            item = selected[0]
            if hasattr(item, 'pair') is False:
                return
            pair = item.pair
            ## check to see if the pair has already been analyzed
            expt_id = pair.experiment.ext_id
            pre_cell_id = pair.pre_cell.ext_id
            post_cell_id = pair.post_cell.ext_id
            record = notes_db.get_pair_notes_record(expt_id, pre_cell_id, post_cell_id, session=self.notes_session)
            
            self.pair_param.setValue(pair)
            if record is None:
                self.pair_analyzer.load_pair(pair, self.default_session)
                self.pair_analyzer.analyze_responses()
                # self.pair_analyzer.fit_responses()
            else:
                self.pair_analyzer.load_pair(pair, self.default_session, record=record)
                self.pair_analyzer.analyze_responses()
                self.pair_analyzer.load_saved_fit(record)

    def explore_fit(self, mode, holding):
        fit = self.pair_analyzer.last_fit[mode, holding]
        self.fit_explorer = FitExplorer(fit)
        self.fit_explorer.show()
Ejemplo n.º 5
0
 def fit_curve_clicked(self, curve):
     if self.fit_explorer is None:
         self.fit_explorer = FitExplorer(curve.fit)
     else:
         self.fit_explorer.set_fit(curve.fit)
     self.fit_explorer.show()
Ejemplo n.º 6
0
class PairView(QtGui.QWidget):
    """For analyzing pre/post-synaptic pairs.
    """
    def __init__(self, parent=None):
        self.sweeps = []
        self.channels = []

        self.current_event_set = None
        self.event_sets = []

        QtGui.QWidget.__init__(self, parent)

        self.layout = QtGui.QGridLayout()
        self.setLayout(self.layout)
        self.layout.setContentsMargins(0, 0, 0, 0)

        self.vsplit = QtGui.QSplitter(QtCore.Qt.Vertical)
        self.layout.addWidget(self.vsplit, 0, 0)
        
        self.pre_plot = pg.PlotWidget()
        self.post_plot = pg.PlotWidget()
        for plt in (self.pre_plot, self.post_plot):
            plt.setClipToView(True)
            plt.setDownsampling(True, True, 'peak')
        
        self.post_plot.setXLink(self.pre_plot)
        self.vsplit.addWidget(self.pre_plot)
        self.vsplit.addWidget(self.post_plot)
        
        self.response_plots = PlotGrid()
        self.vsplit.addWidget(self.response_plots)

        self.event_splitter = QtGui.QSplitter(QtCore.Qt.Horizontal)
        self.vsplit.addWidget(self.event_splitter)
        
        self.event_table = pg.TableWidget()
        self.event_splitter.addWidget(self.event_table)
        
        self.event_widget = QtGui.QWidget()
        self.event_widget_layout = QtGui.QGridLayout()
        self.event_widget.setLayout(self.event_widget_layout)
        self.event_splitter.addWidget(self.event_widget)
        
        self.event_splitter.setSizes([600, 100])
        
        self.event_set_list = QtGui.QListWidget()
        self.event_widget_layout.addWidget(self.event_set_list, 0, 0, 1, 2)
        self.event_set_list.addItem("current")
        self.event_set_list.itemSelectionChanged.connect(self.event_set_selected)
        
        self.add_set_btn = QtGui.QPushButton('add')
        self.event_widget_layout.addWidget(self.add_set_btn, 1, 0, 1, 1)
        self.add_set_btn.clicked.connect(self.add_set_clicked)
        self.remove_set_btn = QtGui.QPushButton('del')
        self.event_widget_layout.addWidget(self.remove_set_btn, 1, 1, 1, 1)
        self.remove_set_btn.clicked.connect(self.remove_set_clicked)
        self.fit_btn = QtGui.QPushButton('fit all')
        self.event_widget_layout.addWidget(self.fit_btn, 2, 0, 1, 2)
        self.fit_btn.clicked.connect(self.fit_clicked)

        self.fit_plot = PlotGrid()

        self.fit_explorer = None

        self.artifact_remover = ArtifactRemover(user_width=True)
        self.baseline_remover = BaselineRemover()
        self.filter = SignalFilter()
        
        self.params = pg.parametertree.Parameter(name='params', type='group', children=[
            {'name': 'pre', 'type': 'list', 'values': []},
            {'name': 'post', 'type': 'list', 'values': []},
            self.artifact_remover.params,
            self.baseline_remover.params,
            self.filter.params,
            {'name': 'time constant', 'type': 'float', 'suffix': 's', 'siPrefix': True, 'value': 10e-3, 'dec': True, 'minStep': 100e-6}
            
        ])
        self.params.sigTreeStateChanged.connect(self._update_plots)

    def data_selected(self, sweeps, channels):
        self.sweeps = sweeps
        self.channels = channels
        
        self.params.child('pre').setLimits(channels)
        self.params.child('post').setLimits(channels)
        
        self._update_plots()

    def _update_plots(self):
        sweeps = self.sweeps
        self.current_event_set = None
        self.event_table.clear()
        
        # clear all plots
        self.pre_plot.clear()
        self.post_plot.clear()

        pre = self.params['pre']
        post = self.params['post']
        
        # If there are no selected sweeps or channels have not been set, return
        if len(sweeps) == 0 or pre == post or pre not in self.channels or post not in self.channels:
            return

        pre_mode = sweeps[0][pre].clamp_mode
        post_mode = sweeps[0][post].clamp_mode
        for ch, mode, plot in [(pre, pre_mode, self.pre_plot), (post, post_mode, self.post_plot)]:
            units = 'A' if mode == 'vc' else 'V'
            plot.setLabels(left=("Channel %d" % ch, units), bottom=("Time", 's'))
        
        # Iterate over selected channels of all sweeps, plotting traces one at a time
        # Collect information about pulses and spikes
        pulses = []
        spikes = []
        post_traces = []
        for i,sweep in enumerate(sweeps):
            pre_trace = sweep[pre]['primary']
            post_trace = sweep[post]['primary']
            
            # Detect pulse times
            stim = sweep[pre]['command'].data
            sdiff = np.diff(stim)
            on_times = np.argwhere(sdiff > 0)[1:, 0]  # 1: skips test pulse
            off_times = np.argwhere(sdiff < 0)[1:, 0]
            pulses.append(on_times)

            # filter data
            post_filt = self.artifact_remover.process(post_trace, list(on_times) + list(off_times))
            post_filt = self.baseline_remover.process(post_filt)
            post_filt = self.filter.process(post_filt)
            post_traces.append(post_filt)
            
            # plot raw data
            color = pg.intColor(i, hues=len(sweeps)*1.3, sat=128)
            color.setAlpha(128)
            for trace, plot in [(pre_trace, self.pre_plot), (post_filt, self.post_plot)]:
                plot.plot(trace.time_values, trace.data, pen=color, antialias=False)

            # detect spike times
            spike_inds = []
            spike_info = []
            for on, off in zip(on_times, off_times):
                spike = detect_evoked_spike(sweep[pre], [on, off])
                spike_info.append(spike)
                if spike is None:
                    spike_inds.append(None)
                else:
                    spike_inds.append(spike['rise_index'])
            spikes.append(spike_info)
                    
            dt = pre_trace.dt
            vticks = pg.VTickGroup([x * dt for x in spike_inds if x is not None], yrange=[0.0, 0.2], pen=color)
            self.pre_plot.addItem(vticks)

        # Iterate over spikes, plotting average response
        all_responses = []
        avg_responses = []
        fits = []
        fit = None
        
        npulses = max(map(len, pulses))
        self.response_plots.clear()
        self.response_plots.set_shape(1, npulses+1) # 1 extra for global average
        self.response_plots.setYLink(self.response_plots[0,0])
        for i in range(1, npulses+1):
            self.response_plots[0,i].hideAxis('left')
        units = 'A' if post_mode == 'vc' else 'V'
        self.response_plots[0, 0].setLabels(left=("Averaged events (Channel %d)" % post, units))
        
        fit_pen = {'color':(30, 30, 255), 'width':2, 'dash': [1, 1]}
        for i in range(npulses):
            # get the chunk of each sweep between spikes
            responses = []
            all_responses.append(responses)
            for j, sweep in enumerate(sweeps):
                # get the current spike
                if i >= len(spikes[j]):
                    continue
                spike = spikes[j][i]
                if spike is None:
                    continue
                
                # find next spike
                next_spike = None
                for sp in spikes[j][i+1:]:
                    if sp is not None:
                        next_spike = sp
                        break
                    
                # determine time range for response
                max_len = int(40e-3 / dt)  # don't take more than 50ms for any response
                start = spike['rise_index']
                if next_spike is not None:
                    stop = min(start + max_len, next_spike['rise_index'])
                else:
                    stop = start + max_len
                    
                # collect data from this trace
                trace = post_traces[j]
                d = trace.data[start:stop].copy()
                responses.append(d)

            if len(responses) == 0:
                continue
                
            # extend all responses to the same length and take nanmean
            avg = ragged_mean(responses, method='clip')
            avg -= float_mode(avg[:int(1e-3/dt)])
            avg_responses.append(avg)
            
            # plot average response for this pulse
            start = np.median([sp[i]['rise_index'] for sp in spikes if sp[i] is not None]) * dt
            t = np.arange(len(avg)) * dt
            self.response_plots[0,i].plot(t, avg, pen='w', antialias=True)

            # fit!
            fit = self.fit_psp(avg, t, dt, post_mode)
            fits.append(fit)
            
            # let the user mess with this fit
            curve = self.response_plots[0,i].plot(t, fit.eval(), pen=fit_pen, antialias=True).curve
            curve.setClickable(True)
            curve.fit = fit
            curve.sigClicked.connect(self.fit_curve_clicked)
            
        # display global average
        global_avg = ragged_mean(avg_responses, method='clip')
        t = np.arange(len(global_avg)) * dt
        self.response_plots[0,-1].plot(t, global_avg, pen='w', antialias=True)
        global_fit = self.fit_psp(global_avg, t, dt, post_mode)
        self.response_plots[0,-1].plot(t, global_fit.eval(), pen=fit_pen, antialias=True)
            
        # display fit parameters in table
        events = []
        for i,f in enumerate(fits + [global_fit]):
            if f is None:
                continue
            if i >= len(fits):
                vals = OrderedDict([('id', 'avg'), ('spike_time', np.nan), ('spike_stdev', np.nan)])
            else:
                spt = [s[i]['peak_index'] * dt for s in spikes if s[i] is not None]
                vals = OrderedDict([('id', i), ('spike_time', np.mean(spt)), ('spike_stdev', np.std(spt))])
            vals.update(OrderedDict([(k,f.best_values[k]) for k in f.params.keys()]))
            events.append(vals)
            
        self.current_event_set = (pre, post, events, sweeps)
        self.event_set_list.setCurrentRow(0)
        self.event_set_selected()

    def fit_psp(self, data, t, dt, clamp_mode):
        mode = float_mode(data[:int(1e-3/dt)])
        sign = -1 if data.mean() - mode < 0 else 1
        params = OrderedDict([
            ('xoffset', (2e-3, 3e-4, 5e-3)),
            ('yoffset', data[0]),
            ('amp', sign * 10e-12),
            #('k', (2e-3, 50e-6, 10e-3)),
            ('rise_time', (2e-3, 50e-6, 10e-3)),
            ('decay_tau', (4e-3, 500e-6, 50e-3)),
            ('rise_power', (2.0, 'fixed')),
        ])
        if clamp_mode == 'ic':
            params['amp'] = sign * 10e-3
            #params['k'] = (5e-3, 50e-6, 20e-3)
            params['rise_time'] = (5e-3, 50e-6, 20e-3)
            params['decay_tau'] = (15e-3, 500e-6, 150e-3)
        
        fit_kws = {'xtol': 1e-3, 'maxfev': 100}
        
        psp = fitting.Psp()
        return psp.fit(data, x=t, fit_kws=fit_kws, params=params)

    def add_set_clicked(self):
        if self.current_event_set is None:
            return
        ces = self.current_event_set
        self.event_sets.append(ces)
        item = QtGui.QListWidgetItem("%d -> %d" % (ces[0], ces[1]))
        self.event_set_list.addItem(item)
        item.event_set = ces
        
    def remove_set_clicked(self):
        if self.event_set_list.currentRow() == 0:
            return
        sel = self.event_set_list.takeItem(self.event_set_list.currentRow())
        self.event_sets.remove(sel.event_set)
        
    def event_set_selected(self):
        sel = self.event_set_list.selectedItems()[0]
        if sel.text() == "current":
            self.event_table.setData(self.current_event_set[2])
        else:
            self.event_table.setData(sel.event_set[2])
    
    def fit_clicked(self):
        from synaptic_release import ReleaseModel
        
        self.fit_plot.clear()
        self.fit_plot.show()
        
        n_sets = len(self.event_sets)
        self.fit_plot.set_shape(n_sets, 1)
        l = self.fit_plot[0, 0].legend
        if l is not None:
            l.scene.removeItem(l)
            self.fit_plot[0, 0].legend = None
        self.fit_plot[0, 0].addLegend()
        
        for i in range(n_sets):
            self.fit_plot[i,0].setXLink(self.fit_plot[0, 0])
        
        spike_sets = []
        for i,evset in enumerate(self.event_sets):
            evset = evset[2][:-1]  # select events, skip last row (average)
            x = np.array([ev['spike_time'] for ev in evset])
            y = np.array([ev['amp'] for ev in evset])
            x -= x[0]
            x *= 1000
            y /= y[0]
            spike_sets.append((x, y))
            
            self.fit_plot[i,0].plot(x/1000., y, pen=None, symbol='o')
            
        model = ReleaseModel()
        dynamics_types = ['Dep', 'Fac', 'UR', 'SMR', 'DSR']
        for k in dynamics_types:
            model.Dynamics[k] = 0
        
        fit_params = []
        with pg.ProgressDialog("Fitting release model..", 0, len(dynamics_types)) as dlg:
            for k in dynamics_types:
                model.Dynamics[k] = 1
                fit_params.append(model.run_fit(spike_sets))
                dlg += 1
                if dlg.wasCanceled():
                    return
        
        max_color = len(fit_params)*1.5

        for i,params in enumerate(fit_params):
            for j,spikes in enumerate(spike_sets):
                x, y = spikes
                t = np.linspace(0, x.max(), 1000)
                output = model.eval(x, params.values())
                y = output[:,1]
                x = output[:,0]/1000.
                self.fit_plot[j,0].plot(x, y, pen=(i,max_color), name=dynamics_types[i])

    def fit_curve_clicked(self, curve):
        if self.fit_explorer is None:
            self.fit_explorer = FitExplorer(curve.fit)
        else:
            self.fit_explorer.set_fit(curve.fit)
        self.fit_explorer.show()
Ejemplo n.º 7
0
import sys
import pyqtgraph as pg
from neuroanalysis.ui.fitting import FitExplorer
from multipatch_analysis.database import default_db as db
from multipatch_analysis.avg_response_fit import get_pair_avg_fits
from multipatch_analysis.ui.avg_response_fit import AvgResponseFitUi

app = pg.mkQApp()
pg.dbg()

session = db.session()

expt_id, pre_cell_id, post_cell_id = sys.argv[1:4]
pair = db.experiment_from_ext_id(expt_id, session=session).pairs[pre_cell_id,
                                                                 post_cell_id]

ui = AvgResponseFitUi()

fits = get_pair_avg_fits(pair, session=session, ui=ui)

ui.widget.show()

fe = FitExplorer(fits['vc', -55]['fit_result'])
fe.show()

if sys.flags.interactive == 0:
    app.exec_()
Ejemplo n.º 8
0
pg.dbg()

pg.mkQApp()

t = np.arange(2000) * 1e-4

rise_time = 5e-3
decay_tau = 20e-3
n_psp = 4
args = {
    'yoffset': (0, 'fixed'),
    'xoffset': (2e-3, -1e-3, 5e-3),
    'rise_time': (rise_time, rise_time*0.5, rise_time*2),
    'decay_tau': (decay_tau, decay_tau*0.5, decay_tau*2),
    'rise_power': (2, 'fixed'),
}
for i in range(n_psp):
    args['xoffset%d'%i] = (25e-3*i, 'fixed')
    args['amp%d'%i] = (250e-6, 0, 10e-3)

fit_kws = {'xtol': 1e-4, 'maxfev': 1000, 'nan_policy': 'omit'}                
model = PspTrain(n_psp)

args2 = {k:(v[0] if isinstance(v, tuple) else v) for k,v in args.items()}
y = np.random.normal(size=len(t), scale=30e-6) + model.eval(x=t, **args2)

fit = model.fit(y, x=t, params=args, fit_kws=fit_kws, method='leastsq')

ex = FitExplorer(fit)
ex.show()
Ejemplo n.º 9
0
from neuroanalysis.fitting import Psp
from neuroanalysis.ui.fitting import FitExplorer

pg.mkQApp()
pg.dbg()

# Load PSP data from the test_data repository
if len(sys.argv) == 1:
    data_file = 'test_data/test_psp_fit/1485904693.10_8_2NOTstacked.json'
else:
    data_file = sys.argv[1]

data = json.load(open(data_file))

y = np.array(data['input']['data'])
x = np.arange(len(y)) * data['input']['dt']

psp = Psp()
params = OrderedDict([
    ('xoffset', (10e-3, 10e-3, 15e-3)),
    ('yoffset', 0),
    ('amp', 0.1e-3),
    ('rise_time', (2e-3, 500e-6, 10e-3)),
    ('decay_tau', (4e-3, 1e-3, 50e-3)),
    ('rise_power', (2.0, 'fixed')),
])
fit = psp.fit(y, x=x, xtol=1e-3, maxfev=100, params=params)

x = FitExplorer(fit=fit)
x.show()