Пример #1
0
class SynapseExplorer(QtGui.QWidget):
    def __init__(self, expts, parent=None):
        QtGui.QWidget.__init__(self, parent)
        self.expts = expts

        self.layout = QtGui.QGridLayout()
        self.setLayout(self.layout)

        self.hsplit = QtGui.QSplitter(QtCore.Qt.Horizontal)
        self.layout.addWidget(self.hsplit, 0, 0)

        self.vsplit = QtGui.QSplitter(QtCore.Qt.Vertical)
        self.hsplit.addWidget(self.vsplit)

        self.syn_tree = SynapseTreeWidget(self.expts)
        self.vsplit.addWidget(self.syn_tree)
        self.syn_tree.itemSelectionChanged.connect(self.selection_changed)

        self.expt_info = ExperimentInfoWidget()
        self.vsplit.addWidget(self.expt_info)

        self.train_plots = PlotGrid()
        self.hsplit.addWidget(self.train_plots)

        self.analyzers = {}

    def selection_changed(self):
        with pg.BusyCursor():
            sel = self.syn_tree.selectedItems()[0]
            expt = sel.expt

            self.expt_info.set_experiment(expt)

            pre_cell = sel.cells[0].cell_id
            post_cell = sel.cells[1].cell_id

            key = (expt, pre_cell, post_cell)
            if key not in self.analyzers:
                self.analyzers[key] = DynamicsAnalyzer(*key)
            analyzer = self.analyzers[key]

            if len(analyzer.pulse_responses) == 0:
                raise Exception(
                    "No suitable data found for cell %d -> cell %d in expt %s"
                    % (pre_cell, post_cell, expt))

            # Plot all individual and averaged train responses for all sets of stimulus parameters
            self.train_plots.clear()
            analyzer.plot_train_responses(plot_grid=self.train_plots)
class SynapseExplorer(QtGui.QWidget):
    def __init__(self, expts, parent=None):
        QtGui.QWidget.__init__(self, parent)
        self.expts = expts

        self.layout = QtGui.QGridLayout()
        self.setLayout(self.layout)
        
        self.hsplit = QtGui.QSplitter(QtCore.Qt.Horizontal)
        self.layout.addWidget(self.hsplit, 0, 0)
        
        self.vsplit = QtGui.QSplitter(QtCore.Qt.Vertical)
        self.hsplit.addWidget(self.vsplit)
        
        self.syn_tree = SynapseTreeWidget(self.expts)
        self.vsplit.addWidget(self.syn_tree)
        self.syn_tree.itemSelectionChanged.connect(self.selection_changed)
        
        self.expt_info = ExperimentInfoWidget()
        self.vsplit.addWidget(self.expt_info)
        
        self.train_plots = PlotGrid()
        self.hsplit.addWidget(self.train_plots)
        
        self.analyzers = {}
        
    def selection_changed(self):
        with pg.BusyCursor():
            sel = self.syn_tree.selectedItems()[0]
            expt = sel.expt
            
            self.expt_info.set_experiment(expt)
            
            pre_cell = sel.cells[0].cell_id
            post_cell = sel.cells[1].cell_id
            
            key = (expt, pre_cell, post_cell)
            if key not in self.analyzers:
                self.analyzers[key] = DynamicsAnalyzer(*key)
            analyzer = self.analyzers[key]
            
            if len(analyzer.pulse_responses) == 0:
                raise Exception("No suitable data found for cell %d -> cell %d in expt %s" % (pre_cell, post_cell, expt))
            
            # Plot all individual and averaged train responses for all sets of stimulus parameters
            self.train_plots.clear()
            analyzer.plot_train_responses(plot_grid=self.train_plots)
Пример #3
0
class PairView(QtGui.QWidget):
    """For analyzing pre/post-synaptic pairs.
    """
    def __init__(self, parent=None):
        self.sweeps = []
        self.channels = []

        self.current_event_set = None
        self.event_sets = []

        QtGui.QWidget.__init__(self, parent)

        self.layout = QtGui.QGridLayout()
        self.setLayout(self.layout)
        self.layout.setContentsMargins(0, 0, 0, 0)

        self.vsplit = QtGui.QSplitter(QtCore.Qt.Vertical)
        self.layout.addWidget(self.vsplit, 0, 0)

        self.pre_plot = pg.PlotWidget()
        self.post_plot = pg.PlotWidget()
        for plt in (self.pre_plot, self.post_plot):
            plt.setClipToView(True)
            plt.setDownsampling(True, True, 'peak')

        self.post_plot.setXLink(self.pre_plot)
        self.vsplit.addWidget(self.pre_plot)
        self.vsplit.addWidget(self.post_plot)

        self.response_plots = PlotGrid()
        self.vsplit.addWidget(self.response_plots)

        self.event_splitter = QtGui.QSplitter(QtCore.Qt.Horizontal)
        self.vsplit.addWidget(self.event_splitter)

        self.event_table = pg.TableWidget()
        self.event_splitter.addWidget(self.event_table)

        self.event_widget = QtGui.QWidget()
        self.event_widget_layout = QtGui.QGridLayout()
        self.event_widget.setLayout(self.event_widget_layout)
        self.event_splitter.addWidget(self.event_widget)

        self.event_splitter.setSizes([600, 100])

        self.event_set_list = QtGui.QListWidget()
        self.event_widget_layout.addWidget(self.event_set_list, 0, 0, 1, 2)
        self.event_set_list.addItem("current")
        self.event_set_list.itemSelectionChanged.connect(
            self.event_set_selected)

        self.add_set_btn = QtGui.QPushButton('add')
        self.event_widget_layout.addWidget(self.add_set_btn, 1, 0, 1, 1)
        self.add_set_btn.clicked.connect(self.add_set_clicked)
        self.remove_set_btn = QtGui.QPushButton('del')
        self.event_widget_layout.addWidget(self.remove_set_btn, 1, 1, 1, 1)
        self.remove_set_btn.clicked.connect(self.remove_set_clicked)
        self.fit_btn = QtGui.QPushButton('fit all')
        self.event_widget_layout.addWidget(self.fit_btn, 2, 0, 1, 2)
        self.fit_btn.clicked.connect(self.fit_clicked)

        self.fit_plot = PlotGrid()

        self.fit_explorer = None

        self.artifact_remover = ArtifactRemover(user_width=True)
        self.baseline_remover = BaselineRemover()
        self.filter = SignalFilter()

        self.params = pg.parametertree.Parameter(
            name='params',
            type='group',
            children=[{
                'name': 'pre',
                'type': 'list',
                'values': []
            }, {
                'name': 'post',
                'type': 'list',
                'values': []
            }, self.artifact_remover.params, self.baseline_remover.params,
                      self.filter.params, {
                          'name': 'time constant',
                          'type': 'float',
                          'suffix': 's',
                          'siPrefix': True,
                          'value': 10e-3,
                          'dec': True,
                          'minStep': 100e-6
                      }])
        self.params.sigTreeStateChanged.connect(self._update_plots)

    def data_selected(self, sweeps, channels):
        self.sweeps = sweeps
        self.channels = channels

        self.params.child('pre').setLimits(channels)
        self.params.child('post').setLimits(channels)

        self._update_plots()

    def _update_plots(self):
        sweeps = self.sweeps
        self.current_event_set = None
        self.event_table.clear()

        # clear all plots
        self.pre_plot.clear()
        self.post_plot.clear()

        pre = self.params['pre']
        post = self.params['post']

        # If there are no selected sweeps or channels have not been set, return
        if len(
                sweeps
        ) == 0 or pre == post or pre not in self.channels or post not in self.channels:
            return

        pre_mode = sweeps[0][pre].clamp_mode
        post_mode = sweeps[0][post].clamp_mode
        for ch, mode, plot in [(pre, pre_mode, self.pre_plot),
                               (post, post_mode, self.post_plot)]:
            units = 'A' if mode == 'vc' else 'V'
            plot.setLabels(left=("Channel %d" % ch, units),
                           bottom=("Time", 's'))

        # Iterate over selected channels of all sweeps, plotting traces one at a time
        # Collect information about pulses and spikes
        pulses = []
        spikes = []
        post_traces = []
        for i, sweep in enumerate(sweeps):
            pre_trace = sweep[pre]['primary']
            post_trace = sweep[post]['primary']

            # Detect pulse times
            stim = sweep[pre]['command'].data
            sdiff = np.diff(stim)
            on_times = np.argwhere(sdiff > 0)[1:, 0]  # 1: skips test pulse
            off_times = np.argwhere(sdiff < 0)[1:, 0]
            pulses.append(on_times)

            # filter data
            post_filt = self.artifact_remover.process(
                post_trace,
                list(on_times) + list(off_times))
            post_filt = self.baseline_remover.process(post_filt)
            post_filt = self.filter.process(post_filt)
            post_traces.append(post_filt)

            # plot raw data
            color = pg.intColor(i, hues=len(sweeps) * 1.3, sat=128)
            color.setAlpha(128)
            for trace, plot in [(pre_trace, self.pre_plot),
                                (post_filt, self.post_plot)]:
                plot.plot(trace.time_values,
                          trace.data,
                          pen=color,
                          antialias=False)

            # detect spike times
            spike_inds = []
            spike_info = []
            for on, off in zip(on_times, off_times):
                spike = detect_evoked_spike(sweep[pre], [on, off])
                spike_info.append(spike)
                if spike is None:
                    spike_inds.append(None)
                else:
                    spike_inds.append(spike['rise_index'])
            spikes.append(spike_info)

            dt = pre_trace.dt
            vticks = pg.VTickGroup(
                [x * dt for x in spike_inds if x is not None],
                yrange=[0.0, 0.2],
                pen=color)
            self.pre_plot.addItem(vticks)

        # Iterate over spikes, plotting average response
        all_responses = []
        avg_responses = []
        fits = []
        fit = None

        npulses = max(map(len, pulses))
        self.response_plots.clear()
        self.response_plots.set_shape(1, npulses +
                                      1)  # 1 extra for global average
        self.response_plots.setYLink(self.response_plots[0, 0])
        for i in range(1, npulses + 1):
            self.response_plots[0, i].hideAxis('left')
        units = 'A' if post_mode == 'vc' else 'V'
        self.response_plots[0,
                            0].setLabels(left=("Averaged events (Channel %d)" %
                                               post, units))

        fit_pen = {'color': (30, 30, 255), 'width': 2, 'dash': [1, 1]}
        for i in range(npulses):
            # get the chunk of each sweep between spikes
            responses = []
            all_responses.append(responses)
            for j, sweep in enumerate(sweeps):
                # get the current spike
                if i >= len(spikes[j]):
                    continue
                spike = spikes[j][i]
                if spike is None:
                    continue

                # find next spike
                next_spike = None
                for sp in spikes[j][i + 1:]:
                    if sp is not None:
                        next_spike = sp
                        break

                # determine time range for response
                max_len = int(40e-3 /
                              dt)  # don't take more than 50ms for any response
                start = spike['rise_index']
                if next_spike is not None:
                    stop = min(start + max_len, next_spike['rise_index'])
                else:
                    stop = start + max_len

                # collect data from this trace
                trace = post_traces[j]
                d = trace.data[start:stop].copy()
                responses.append(d)

            if len(responses) == 0:
                continue

            # extend all responses to the same length and take nanmean
            avg = ragged_mean(responses, method='clip')
            avg -= float_mode(avg[:int(1e-3 / dt)])
            avg_responses.append(avg)

            # plot average response for this pulse
            start = np.median(
                [sp[i]['rise_index']
                 for sp in spikes if sp[i] is not None]) * dt
            t = np.arange(len(avg)) * dt
            self.response_plots[0, i].plot(t, avg, pen='w', antialias=True)

            # fit!
            fit = self.fit_psp(avg, t, dt, post_mode)
            fits.append(fit)

            # let the user mess with this fit
            curve = self.response_plots[0, i].plot(t,
                                                   fit.eval(),
                                                   pen=fit_pen,
                                                   antialias=True).curve
            curve.setClickable(True)
            curve.fit = fit
            curve.sigClicked.connect(self.fit_curve_clicked)

        # display global average
        global_avg = ragged_mean(avg_responses, method='clip')
        t = np.arange(len(global_avg)) * dt
        self.response_plots[0, -1].plot(t, global_avg, pen='w', antialias=True)
        global_fit = self.fit_psp(global_avg, t, dt, post_mode)
        self.response_plots[0, -1].plot(t,
                                        global_fit.eval(),
                                        pen=fit_pen,
                                        antialias=True)

        # display fit parameters in table
        events = []
        for i, f in enumerate(fits + [global_fit]):
            if f is None:
                continue
            if i >= len(fits):
                vals = OrderedDict([('id', 'avg'), ('spike_time', np.nan),
                                    ('spike_stdev', np.nan)])
            else:
                spt = [
                    s[i]['peak_index'] * dt for s in spikes if s[i] is not None
                ]
                vals = OrderedDict([('id', i), ('spike_time', np.mean(spt)),
                                    ('spike_stdev', np.std(spt))])
            vals.update(
                OrderedDict([(k, f.best_values[k]) for k in f.params.keys()]))
            events.append(vals)

        self.current_event_set = (pre, post, events, sweeps)
        self.event_set_list.setCurrentRow(0)
        self.event_set_selected()

    def fit_psp(self, data, t, dt, clamp_mode):
        mode = float_mode(data[:int(1e-3 / dt)])
        sign = -1 if data.mean() - mode < 0 else 1
        params = OrderedDict([
            ('xoffset', (2e-3, 3e-4, 5e-3)),
            ('yoffset', data[0]),
            ('amp', sign * 10e-12),
            #('k', (2e-3, 50e-6, 10e-3)),
            ('rise_time', (2e-3, 50e-6, 10e-3)),
            ('decay_tau', (4e-3, 500e-6, 50e-3)),
            ('rise_power', (2.0, 'fixed')),
        ])
        if clamp_mode == 'ic':
            params['amp'] = sign * 10e-3
            #params['k'] = (5e-3, 50e-6, 20e-3)
            params['rise_time'] = (5e-3, 50e-6, 20e-3)
            params['decay_tau'] = (15e-3, 500e-6, 150e-3)

        fit_kws = {'xtol': 1e-3, 'maxfev': 100}

        psp = fitting.Psp()
        return psp.fit(data, x=t, fit_kws=fit_kws, params=params)

    def add_set_clicked(self):
        if self.current_event_set is None:
            return
        ces = self.current_event_set
        self.event_sets.append(ces)
        item = QtGui.QListWidgetItem("%d -> %d" % (ces[0], ces[1]))
        self.event_set_list.addItem(item)
        item.event_set = ces

    def remove_set_clicked(self):
        if self.event_set_list.currentRow() == 0:
            return
        sel = self.event_set_list.takeItem(self.event_set_list.currentRow())
        self.event_sets.remove(sel.event_set)

    def event_set_selected(self):
        sel = self.event_set_list.selectedItems()[0]
        if sel.text() == "current":
            self.event_table.setData(self.current_event_set[2])
        else:
            self.event_table.setData(sel.event_set[2])

    def fit_clicked(self):
        from synaptic_release import ReleaseModel

        self.fit_plot.clear()
        self.fit_plot.show()

        n_sets = len(self.event_sets)
        self.fit_plot.set_shape(n_sets, 1)
        l = self.fit_plot[0, 0].legend
        if l is not None:
            l.scene.removeItem(l)
            self.fit_plot[0, 0].legend = None
        self.fit_plot[0, 0].addLegend()

        for i in range(n_sets):
            self.fit_plot[i, 0].setXLink(self.fit_plot[0, 0])

        spike_sets = []
        for i, evset in enumerate(self.event_sets):
            evset = evset[2][:-1]  # select events, skip last row (average)
            x = np.array([ev['spike_time'] for ev in evset])
            y = np.array([ev['amp'] for ev in evset])
            x -= x[0]
            x *= 1000
            y /= y[0]
            spike_sets.append((x, y))

            self.fit_plot[i, 0].plot(x / 1000., y, pen=None, symbol='o')

        model = ReleaseModel()
        dynamics_types = ['Dep', 'Fac', 'UR', 'SMR', 'DSR']
        for k in dynamics_types:
            model.Dynamics[k] = 0

        fit_params = []
        with pg.ProgressDialog("Fitting release model..", 0,
                               len(dynamics_types)) as dlg:
            for k in dynamics_types:
                model.Dynamics[k] = 1
                fit_params.append(model.run_fit(spike_sets))
                dlg += 1
                if dlg.wasCanceled():
                    return

        max_color = len(fit_params) * 1.5

        for i, params in enumerate(fit_params):
            for j, spikes in enumerate(spike_sets):
                x, y = spikes
                t = np.linspace(0, x.max(), 1000)
                output = model.eval(x, params.values())
                y = output[:, 1]
                x = output[:, 0] / 1000.
                self.fit_plot[j, 0].plot(x,
                                         y,
                                         pen=(i, max_color),
                                         name=dynamics_types[i])

    def fit_curve_clicked(self, curve):
        if self.fit_explorer is None:
            self.fit_explorer = FitExplorer(curve.fit)
        else:
            self.fit_explorer.set_fit(curve.fit)
        self.fit_explorer.show()
class DistancePlot(object):
    def __init__(self):
        self.grid = PlotGrid()
        self.grid.set_shape(2, 1)
        self.grid.grid.ci.layout.setRowStretchFactor(0, 5)
        self.grid.grid.ci.layout.setRowStretchFactor(1, 10)
        self.plots = (self.grid[1,0], self.grid[0,0])
        self.plots[0].grid = self.grid
        self.plots[0].addLegend()
        self.grid.show()
        self.plots[0].setLabels(bottom=('distance', 'm'), left='connection probability')
        self.plots[0].setXRange(0, 200e-6)
        self.params = Parameter.create(name='Distance binning window', type='float', value=40.e-6, step=10.e-6, suffix='m', siPrefix=True)
        self.element_plot = None
        self.elements = []
        self.element_colors = []
        self.results = None
        self.color = None
        self.name = None

        self.params.sigTreeStateChanged.connect(self.update_plot)

    def plot_distance(self, results, color, name, size=10, suppress_scatter=False):
        """Results needs to be a DataFrame or Series object with 'Synapse' and 'Distance' as columns

        """
        connected = results[~results['Distance'].isnull()]['Connected']
        distance = results[~results['Distance'].isnull()]['Distance'] 
        dist_win = self.params.value()
        if self.results is None:
            self.name = name
            self.color = color
            self.results = results
        if suppress_scatter is True:
        #suppress scatter plot for all results (takes forever to plot)
            plots = list(self.plots)
            plots[1] = None
            self.dist_plot = distance_plot(connected, distance, plots=plots, color=color, name=name, size=size, window=dist_win, spacing=dist_win)
        else:
            self.dist_plot = distance_plot(connected, distance, plots=self.plots, color=color, name=name, size=size, window=dist_win, spacing=dist_win)
        
        return self.dist_plot

    def invalidate_output(self):
        self.grid.clear()

    def element_distance(self, element, color, add_to_list=True):
        if add_to_list is True:
            self.element_colors.append(color)
            self.elements.append(element)
        pre = element['pre_class'][0].name
        post = element['post_class'][0].name
        name = ('%s->%s' % (pre, post))
        self.element_plot = self.plot_distance(element, color=color, name=name, size=15)

    def element_distance_reset(self, results, color, name, suppress_scatter=False):
        self.elements = []
        self.element_colors = []
        self.grid.clear()
        self.dist_plot = self.plot_distance(results, color=color, name=name, size=10, suppress_scatter=suppress_scatter)

    def update_plot(self):
        self.invalidate_output()
        self.plot_distance(self.results, self.color, self.name, suppress_scatter=True)
        if self.element_plot is not None:
            for element, color in zip(self.elements, self.element_colors):
                self.element_distance(element, color, add_to_list=False)
class PairView(QtGui.QWidget):
    """For analyzing pre/post-synaptic pairs.
    """
    def __init__(self, parent=None):
        self.sweeps = []
        self.channels = []

        self.current_event_set = None
        self.event_sets = []

        QtGui.QWidget.__init__(self, parent)

        self.layout = QtGui.QGridLayout()
        self.setLayout(self.layout)
        self.layout.setContentsMargins(0, 0, 0, 0)

        self.vsplit = QtGui.QSplitter(QtCore.Qt.Vertical)
        self.layout.addWidget(self.vsplit, 0, 0)
        
        self.pre_plot = pg.PlotWidget()
        self.post_plot = pg.PlotWidget()
        for plt in (self.pre_plot, self.post_plot):
            plt.setClipToView(True)
            plt.setDownsampling(True, True, 'peak')
        
        self.post_plot.setXLink(self.pre_plot)
        self.vsplit.addWidget(self.pre_plot)
        self.vsplit.addWidget(self.post_plot)
        
        self.response_plots = PlotGrid()
        self.vsplit.addWidget(self.response_plots)

        self.event_splitter = QtGui.QSplitter(QtCore.Qt.Horizontal)
        self.vsplit.addWidget(self.event_splitter)
        
        self.event_table = pg.TableWidget()
        self.event_splitter.addWidget(self.event_table)
        
        self.event_widget = QtGui.QWidget()
        self.event_widget_layout = QtGui.QGridLayout()
        self.event_widget.setLayout(self.event_widget_layout)
        self.event_splitter.addWidget(self.event_widget)
        
        self.event_splitter.setSizes([600, 100])
        
        self.event_set_list = QtGui.QListWidget()
        self.event_widget_layout.addWidget(self.event_set_list, 0, 0, 1, 2)
        self.event_set_list.addItem("current")
        self.event_set_list.itemSelectionChanged.connect(self.event_set_selected)
        
        self.add_set_btn = QtGui.QPushButton('add')
        self.event_widget_layout.addWidget(self.add_set_btn, 1, 0, 1, 1)
        self.add_set_btn.clicked.connect(self.add_set_clicked)
        self.remove_set_btn = QtGui.QPushButton('del')
        self.event_widget_layout.addWidget(self.remove_set_btn, 1, 1, 1, 1)
        self.remove_set_btn.clicked.connect(self.remove_set_clicked)
        self.fit_btn = QtGui.QPushButton('fit all')
        self.event_widget_layout.addWidget(self.fit_btn, 2, 0, 1, 2)
        self.fit_btn.clicked.connect(self.fit_clicked)

        self.fit_plot = PlotGrid()

        self.fit_explorer = None

        self.artifact_remover = ArtifactRemover(user_width=True)
        self.baseline_remover = BaselineRemover()
        self.filter = SignalFilter()
        
        self.params = pg.parametertree.Parameter(name='params', type='group', children=[
            {'name': 'pre', 'type': 'list', 'values': []},
            {'name': 'post', 'type': 'list', 'values': []},
            self.artifact_remover.params,
            self.baseline_remover.params,
            self.filter.params,
            {'name': 'time constant', 'type': 'float', 'suffix': 's', 'siPrefix': True, 'value': 10e-3, 'dec': True, 'minStep': 100e-6}
            
        ])
        self.params.sigTreeStateChanged.connect(self._update_plots)

    def data_selected(self, sweeps, channels):
        self.sweeps = sweeps
        self.channels = channels
        
        self.params.child('pre').setLimits(channels)
        self.params.child('post').setLimits(channels)
        
        self._update_plots()

    def _update_plots(self):
        sweeps = self.sweeps
        self.current_event_set = None
        self.event_table.clear()
        
        # clear all plots
        self.pre_plot.clear()
        self.post_plot.clear()

        pre = self.params['pre']
        post = self.params['post']
        
        # If there are no selected sweeps or channels have not been set, return
        if len(sweeps) == 0 or pre == post or pre not in self.channels or post not in self.channels:
            return

        pre_mode = sweeps[0][pre].clamp_mode
        post_mode = sweeps[0][post].clamp_mode
        for ch, mode, plot in [(pre, pre_mode, self.pre_plot), (post, post_mode, self.post_plot)]:
            units = 'A' if mode == 'vc' else 'V'
            plot.setLabels(left=("Channel %d" % ch, units), bottom=("Time", 's'))
        
        # Iterate over selected channels of all sweeps, plotting traces one at a time
        # Collect information about pulses and spikes
        pulses = []
        spikes = []
        post_traces = []
        for i,sweep in enumerate(sweeps):
            pre_trace = sweep[pre]['primary']
            post_trace = sweep[post]['primary']
            
            # Detect pulse times
            stim = sweep[pre]['command'].data
            sdiff = np.diff(stim)
            on_times = np.argwhere(sdiff > 0)[1:, 0]  # 1: skips test pulse
            off_times = np.argwhere(sdiff < 0)[1:, 0]
            pulses.append(on_times)

            # filter data
            post_filt = self.artifact_remover.process(post_trace, list(on_times) + list(off_times))
            post_filt = self.baseline_remover.process(post_filt)
            post_filt = self.filter.process(post_filt)
            post_traces.append(post_filt)
            
            # plot raw data
            color = pg.intColor(i, hues=len(sweeps)*1.3, sat=128)
            color.setAlpha(128)
            for trace, plot in [(pre_trace, self.pre_plot), (post_filt, self.post_plot)]:
                plot.plot(trace.time_values, trace.data, pen=color, antialias=False)

            # detect spike times
            spike_inds = []
            spike_info = []
            for on, off in zip(on_times, off_times):
                spike = detect_evoked_spike(sweep[pre], [on, off])
                spike_info.append(spike)
                if spike is None:
                    spike_inds.append(None)
                else:
                    spike_inds.append(spike['rise_index'])
            spikes.append(spike_info)
                    
            dt = pre_trace.dt
            vticks = pg.VTickGroup([x * dt for x in spike_inds if x is not None], yrange=[0.0, 0.2], pen=color)
            self.pre_plot.addItem(vticks)

        # Iterate over spikes, plotting average response
        all_responses = []
        avg_responses = []
        fits = []
        fit = None
        
        npulses = max(map(len, pulses))
        self.response_plots.clear()
        self.response_plots.set_shape(1, npulses+1) # 1 extra for global average
        self.response_plots.setYLink(self.response_plots[0,0])
        for i in range(1, npulses+1):
            self.response_plots[0,i].hideAxis('left')
        units = 'A' if post_mode == 'vc' else 'V'
        self.response_plots[0, 0].setLabels(left=("Averaged events (Channel %d)" % post, units))
        
        fit_pen = {'color':(30, 30, 255), 'width':2, 'dash': [1, 1]}
        for i in range(npulses):
            # get the chunk of each sweep between spikes
            responses = []
            all_responses.append(responses)
            for j, sweep in enumerate(sweeps):
                # get the current spike
                if i >= len(spikes[j]):
                    continue
                spike = spikes[j][i]
                if spike is None:
                    continue
                
                # find next spike
                next_spike = None
                for sp in spikes[j][i+1:]:
                    if sp is not None:
                        next_spike = sp
                        break
                    
                # determine time range for response
                max_len = int(40e-3 / dt)  # don't take more than 50ms for any response
                start = spike['rise_index']
                if next_spike is not None:
                    stop = min(start + max_len, next_spike['rise_index'])
                else:
                    stop = start + max_len
                    
                # collect data from this trace
                trace = post_traces[j]
                d = trace.data[start:stop].copy()
                responses.append(d)

            if len(responses) == 0:
                continue
                
            # extend all responses to the same length and take nanmean
            avg = ragged_mean(responses, method='clip')
            avg -= float_mode(avg[:int(1e-3/dt)])
            avg_responses.append(avg)
            
            # plot average response for this pulse
            start = np.median([sp[i]['rise_index'] for sp in spikes if sp[i] is not None]) * dt
            t = np.arange(len(avg)) * dt
            self.response_plots[0,i].plot(t, avg, pen='w', antialias=True)

            # fit!
            fit = self.fit_psp(avg, t, dt, post_mode)
            fits.append(fit)
            
            # let the user mess with this fit
            curve = self.response_plots[0,i].plot(t, fit.eval(), pen=fit_pen, antialias=True).curve
            curve.setClickable(True)
            curve.fit = fit
            curve.sigClicked.connect(self.fit_curve_clicked)
            
        # display global average
        global_avg = ragged_mean(avg_responses, method='clip')
        t = np.arange(len(global_avg)) * dt
        self.response_plots[0,-1].plot(t, global_avg, pen='w', antialias=True)
        global_fit = self.fit_psp(global_avg, t, dt, post_mode)
        self.response_plots[0,-1].plot(t, global_fit.eval(), pen=fit_pen, antialias=True)
            
        # display fit parameters in table
        events = []
        for i,f in enumerate(fits + [global_fit]):
            if f is None:
                continue
            if i >= len(fits):
                vals = OrderedDict([('id', 'avg'), ('spike_time', np.nan), ('spike_stdev', np.nan)])
            else:
                spt = [s[i]['peak_index'] * dt for s in spikes if s[i] is not None]
                vals = OrderedDict([('id', i), ('spike_time', np.mean(spt)), ('spike_stdev', np.std(spt))])
            vals.update(OrderedDict([(k,f.best_values[k]) for k in f.params.keys()]))
            events.append(vals)
            
        self.current_event_set = (pre, post, events, sweeps)
        self.event_set_list.setCurrentRow(0)
        self.event_set_selected()

    def fit_psp(self, data, t, dt, clamp_mode):
        mode = float_mode(data[:int(1e-3/dt)])
        sign = -1 if data.mean() - mode < 0 else 1
        params = OrderedDict([
            ('xoffset', (2e-3, 3e-4, 5e-3)),
            ('yoffset', data[0]),
            ('amp', sign * 10e-12),
            #('k', (2e-3, 50e-6, 10e-3)),
            ('rise_time', (2e-3, 50e-6, 10e-3)),
            ('decay_tau', (4e-3, 500e-6, 50e-3)),
            ('rise_power', (2.0, 'fixed')),
        ])
        if clamp_mode == 'ic':
            params['amp'] = sign * 10e-3
            #params['k'] = (5e-3, 50e-6, 20e-3)
            params['rise_time'] = (5e-3, 50e-6, 20e-3)
            params['decay_tau'] = (15e-3, 500e-6, 150e-3)
        
        fit_kws = {'xtol': 1e-3, 'maxfev': 100}
        
        psp = fitting.Psp()
        return psp.fit(data, x=t, fit_kws=fit_kws, params=params)

    def add_set_clicked(self):
        if self.current_event_set is None:
            return
        ces = self.current_event_set
        self.event_sets.append(ces)
        item = QtGui.QListWidgetItem("%d -> %d" % (ces[0], ces[1]))
        self.event_set_list.addItem(item)
        item.event_set = ces
        
    def remove_set_clicked(self):
        if self.event_set_list.currentRow() == 0:
            return
        sel = self.event_set_list.takeItem(self.event_set_list.currentRow())
        self.event_sets.remove(sel.event_set)
        
    def event_set_selected(self):
        sel = self.event_set_list.selectedItems()[0]
        if sel.text() == "current":
            self.event_table.setData(self.current_event_set[2])
        else:
            self.event_table.setData(sel.event_set[2])
    
    def fit_clicked(self):
        from synaptic_release import ReleaseModel
        
        self.fit_plot.clear()
        self.fit_plot.show()
        
        n_sets = len(self.event_sets)
        self.fit_plot.set_shape(n_sets, 1)
        l = self.fit_plot[0, 0].legend
        if l is not None:
            l.scene.removeItem(l)
            self.fit_plot[0, 0].legend = None
        self.fit_plot[0, 0].addLegend()
        
        for i in range(n_sets):
            self.fit_plot[i,0].setXLink(self.fit_plot[0, 0])
        
        spike_sets = []
        for i,evset in enumerate(self.event_sets):
            evset = evset[2][:-1]  # select events, skip last row (average)
            x = np.array([ev['spike_time'] for ev in evset])
            y = np.array([ev['amp'] for ev in evset])
            x -= x[0]
            x *= 1000
            y /= y[0]
            spike_sets.append((x, y))
            
            self.fit_plot[i,0].plot(x/1000., y, pen=None, symbol='o')
            
        model = ReleaseModel()
        dynamics_types = ['Dep', 'Fac', 'UR', 'SMR', 'DSR']
        for k in dynamics_types:
            model.Dynamics[k] = 0
        
        fit_params = []
        with pg.ProgressDialog("Fitting release model..", 0, len(dynamics_types)) as dlg:
            for k in dynamics_types:
                model.Dynamics[k] = 1
                fit_params.append(model.run_fit(spike_sets))
                dlg += 1
                if dlg.wasCanceled():
                    return
        
        max_color = len(fit_params)*1.5

        for i,params in enumerate(fit_params):
            for j,spikes in enumerate(spike_sets):
                x, y = spikes
                t = np.linspace(0, x.max(), 1000)
                output = model.eval(x, params.values())
                y = output[:,1]
                x = output[:,0]/1000.
                self.fit_plot[j,0].plot(x, y, pen=(i,max_color), name=dynamics_types[i])

    def fit_curve_clicked(self, curve):
        if self.fit_explorer is None:
            self.fit_explorer = FitExplorer(curve.fit)
        else:
            self.fit_explorer.set_fit(curve.fit)
        self.fit_explorer.show()
class MultipatchMatrixView(QtGui.QWidget):
    def __init__(self, parent=None):
        QtGui.QWidget.__init__(self, parent)
        self.layout = QtGui.QGridLayout()
        self.setLayout(self.layout)
        self.layout.setContentsMargins(0, 0, 0, 0)

        self.plots = PlotGrid(parent=self)
        self.layout.addWidget(self.plots, 0, 0)
        self.plots.scene().sigMouseClicked.connect(self._plot_clicked)

        #self.pair_view = PairAnalyzer()

        self.params = pg.parametertree.Parameter(name='params', type='group', children=[
            {'name': 'show', 'type': 'list', 'values': ['sweep avg', 'sweep avg + sweeps', 'sweeps', 'pulse avg']},
            {'name': 'lowpass', 'type': 'bool', 'value': True, 'children': [
                {'name': 'sigma', 'type': 'float', 'value': 200e-6, 'step': 1e-5, 'limits': [0, None], 'suffix': 's', 'siPrefix': True},
            ]},
            {'name': 'first pulse', 'type': 'int', 'value': 0, 'limits': [0, None]},
            {'name': 'last pulse', 'type': 'int', 'value': 7, 'limits': [0, None]},
            {'name': 'window', 'type': 'float', 'value': 30e-3, 'step': 1e-3, 'limits': [0, None], 'suffix': 's', 'siPrefix': True},
            {'name': 'remove artifacts', 'type': 'bool', 'value': True, 'children': [
                {'name': 'window', 'type': 'float', 'suffix': 's', 'siPrefix': True, 'value': 1e-3, 'step': 1e-4, 'bounds': [0, None]},
            ]},
            {'name': 'remove baseline', 'type': 'bool', 'value': True},
            {'name': 'show ticks', 'type': 'bool', 'value': True},
        ])
        self.params.sigTreeStateChanged.connect(self._params_changed)

    def show_group(self, grp):
        self.show_sweeps(grp.sweeps)

    def data_selected(self, sweeps, channels):
        self.sweeps = sweeps
        self.channels = channels
        self._update_plots(auto_range=True)

    def _params_changed(self, *args):
        self._update_plots()

    def _plot_clicked(self, ev):
        item = self.plots.scene().itemAt(ev.scenePos())
        r,c = self.plots.item_index(item)
        for i in range(self.plots.rows):
            for j in range(self.plots.cols):
                color = None if (i, j) != (r, c) else pg.mkColor(30, 30, 50)
                self.plots[i,j].vb.setBackgroundColor(color)

    def _update_plots(self, auto_range=False):
        sweeps = self.sweeps
        chans = self.channels
        self.plots.clear()
        if len(sweeps) == 0 or len(chans) == 0:
            return
        
        # collect data
        data = MiesNwb.pack_sweep_data(sweeps)
        data, stim = data[...,0], data[...,1]  # unpack stim and recordings
        dt = sweeps[0].recordings[0]['primary'].dt

        # mask for selected channels
        mask = np.array([ch in chans for ch in sweeps[0].devices])
        data = data[:, mask]
        stim = stim[:, mask]
        chans = np.array(sweeps[0].devices)[mask]

        modes = [sweeps[0][ch].clamp_mode for ch in chans]
        
        # get pulse times for each channel
        stim = stim[0]
        diff = stim[:,1:] - stim[:,:-1]
        # note: the [1:] here skips the test pulse
        on_times = [np.argwhere(diff[i] > 0)[1:,0] for i in range(diff.shape[0])]
        off_times = [np.argwhere(diff[i] < 0)[1:,0] for i in range(diff.shape[0])]

        # remove capacitive artifacts from adjacent electrodes
        if self.params['remove artifacts']:
            npts = int(self.params['remove artifacts', 'window'] / dt)
            for i in range(stim.shape[0]):
                for j in range(stim.shape[0]):
                    if i == j:
                        continue
                    
                    # are these headstages adjacent?
                    hs1, hs2 = chans[i], chans[j]
                    if abs(hs2-hs1) > 3:
                        continue
                    
                    # remove artifacts
                    for k in range(len(on_times[i])):
                        on = on_times[i][k]
                        off = off_times[i][k]
                        data[:, j, on:on+npts] = data[:, j, max(0,on-npts):on].mean(axis=1)[:,None]
                        data[:, j, off:off+npts] = data[:, j, max(0,off-npts):off].mean(axis=1)[:,None]

        # lowpass filter
        if self.params['lowpass']:
            data = gaussian_filter(data, (0, 0, self.params['lowpass', 'sigma'] / dt))

        # prepare to plot
        window = int(self.params['window'] / dt)
        n_sweeps = data.shape[0]
        n_channels = data.shape[1]
        self.plots.set_shape(n_channels, n_channels)
        self.plots.setClipToView(True)
        self.plots.setDownsampling(True, True, 'peak')
        self.plots.enableAutoRange(False, False)

        show_sweeps = 'sweeps' in self.params['show']
        show_sweep_avg = 'sweep avg' in self.params['show']
        show_pulse_avg = self.params['show'] == 'pulse avg'

        for i in range(n_channels):
            for j in range(n_channels):
                plt = self.plots[i, j]
                start = on_times[j][self.params['first pulse']] - window
                if start < 0:
                    frontpad = -start
                    start = 0
                else:
                    frontpad = 0
                stop = on_times[j][self.params['last pulse']] + window

                # select the data segment to be displayed in this matrix cell
                # add padding if necessary
                if frontpad == 0:
                    seg = data[:, i, start:stop].copy()
                else:
                    seg = np.empty((data.shape[0], stop + frontpad), data.dtype)
                    seg[:, frontpad:] = data[:, i, start:stop]
                    seg[:, :frontpad] = seg[:, frontpad:frontpad+1]

                # subtract off baseline for each sweep
                if self.params['remove baseline']:
                    seg -= seg[:, :window].mean(axis=1)[:,None]

                if show_sweeps:
                    alpha = 100 if show_sweep_avg else 200
                    color = (255, 255, 255, alpha)
                    t = np.arange(seg.shape[1]) * dt
                    for k in range(n_sweeps):
                        plt.plot(t, seg[k], pen={'color': color, 'width': 1}, antialias=True)

                if show_sweep_avg or show_pulse_avg:
                    # average selected segments over all sweeps
                    segm = seg.mean(axis=0)

                    if show_pulse_avg:
                        # average over all selected pulses
                        pulses = []
                        for k in range(self.params['first pulse'], self.params['last pulse'] + 1):
                            pstart = on_times[j][k] - on_times[j][self.params['first pulse']]
                            pstop = pstart + (window * 2)
                            pulses.append(segm[pstart:pstop])
                        # for p in pulses:
                        #     t = np.arange(p.shape[0]) * dt
                        #     plt.plot(t, p)
                        segm = np.vstack(pulses).mean(axis=0)

                    t = np.arange(segm.shape[0]) * dt

                    if i == j:
                        color = (80, 80, 80)
                    else:
                        dif = segm - segm[:window].mean()
                        qe = 30 * np.clip(dif, 0, 1e20).mean() / segm[:window].std()
                        qi = 30 * np.clip(-dif, 0, 1e20).mean() / segm[:window].std()
                        if modes[i] == 'ic':
                            qi, qe = qe, qi  # invert color metric for current clamp 
                        g = 100
                        r = np.clip(g + max(qi, 0), 0, 255)
                        b = np.clip(g + max(qe, 0), 0, 255)
                        color = (r, g, b)

                    plt.plot(t, segm, pen={'color': color, 'width': 1}, antialias=True)

                if self.params['show ticks']:
                    vt = pg.VTickGroup((on_times[j]-start) * dt, [0, 0.15], pen=0.4)
                    plt.addItem(vt)

                # Link all plots along x axis
                plt.setXLink(self.plots[0, 0])

                if i == j:
                    # link y axes of all diagonal plots
                    plt.setYLink(self.plots[0, 0])
                else:
                    # link y axes of all plots within a row
                    plt.setYLink(self.plots[i, (i+1) % 2])  # (i+1)%2 just avoids linking to 0,0

                if i < n_channels - 1:
                    plt.getAxis('bottom').setVisible(False)
                if j > 0:
                    plt.getAxis('left').setVisible(False)

                if i == n_channels - 1:
                    plt.setLabels(bottom=('CH%d'%chans[j], 's'))
                if j == 0:
                    plt.setLabels(left=('CH%d'%chans[i], 'A' if modes[i] == 'vc' else 'V'))

        if auto_range:
            r = 14e-12 if modes[i] == 'vc' else 5e-3
            self.plots[0, 1].setYRange(-r, r)
            r = 2e-9 if modes[i] == 'vc' else 100e-3
            self.plots[0, 0].setYRange(-r, r)

            self.plots[0, 0].setXRange(t[0], t[-1])
class DynamicsWindow(pg.QtGui.QSplitter):
    def __init__(self):
        pg.QtGui.QSplitter.__init__(self, pg.QtCore.Qt.Horizontal)

        self.browser = ExperimentBrowser()
        self.addWidget(self.browser)

        self.plots = PlotGrid()
        self.addWidget(self.plots)

        self.browser.itemSelectionChanged.connect(self.browser_item_selected)

    def browser_item_selected(self):
        with pg.BusyCursor():
            selected = self.browser.selectedItems()
            if len(selected) != 1:
                return
            item = selected[0]
            if not hasattr(item, 'pair'):
                return
            pair = item.pair

            self.load_pair(pair)

    def load_pair(self, pair):
        print("Loading:", pair)

        q = pulse_response_query(pair, data=True)
        self.sorted_recs = sorted_pulse_responses(q.all())

        self.plot_all()

    def plot_all(self):
        self.plots.clear()
        self.plots.set_shape(len(self.sorted_recs), 1)
        psp = StackedPsp()

        stim_keys = sorted(list(self.sorted_recs.keys()))
        for i, stim_key in enumerate(stim_keys):
            prs = self.sorted_recs[stim_key]
            plt = self.plots[i, 0]
            plt.setTitle("%s  %0.0f Hz  %0.2f s" % stim_key)

            for recording in prs:
                pulses = sorted(list(prs[recording].keys()))
                for pulse_n in pulses:
                    rec = prs[recording][pulse_n]
                    # spike-align pulse + offset for pulse number
                    spike_t = rec.stim_pulse.first_spike_time
                    if spike_t is None:
                        spike_t = rec.stim_pulse.onset_time + 1e-3

                    qc_pass = rec.pulse_response.in_qc_pass if rec.synapse.synapse_type == 'in' else rec.pulse_response.ex_qc_pass
                    pen = (255, 255, 255, 100) if qc_pass else (100, 0, 0, 100)

                    t0 = rec.pulse_response.data_start_time - spike_t
                    ts = TSeries(data=rec.data,
                                 t0=t0,
                                 sample_rate=db.default_sample_rate)
                    c = plt.plot(ts.time_values, ts.data, pen=pen)

                    # arrange plots nicely
                    shift = (pulse_n * 35e-3 + (30e-3 if pulse_n > 8 else 0),
                             0)
                    c.setPos(*shift)

                    if not qc_pass:
                        c.setZValue(-10)
                        continue

                    # evaluate recorded fit for this response
                    fit_par = rec.pulse_response_fit
                    if fit_par.fit_amp is None:
                        continue
                    fit = psp.eval(
                        x=ts.time_values,
                        exp_amp=fit_par.fit_exp_amp,
                        exp_tau=fit_par.fit_decay_tau,
                        amp=fit_par.fit_amp,
                        rise_time=fit_par.fit_rise_time,
                        decay_tau=fit_par.fit_decay_tau,
                        xoffset=fit_par.fit_latency,
                        yoffset=fit_par.fit_yoffset,
                        rise_power=2,
                    )
                    c = plt.plot(ts.time_values, fit, pen=(0, 255, 0, 100))
                    c.setZValue(10)
                    c.setPos(*shift)
class DistancePlot(object):
    def __init__(self):
        self.grid = PlotGrid()
        self.grid.set_shape(2, 1)
        self.grid.grid.ci.layout.setRowStretchFactor(0, 5)
        self.grid.grid.ci.layout.setRowStretchFactor(1, 10)
        self.plots = (self.grid[1,0], self.grid[0,0])
        self.plots[0].grid = self.grid
        self.plots[0].addLegend()
        self.grid.show()
        self.plots[0].setLabels(bottom=('distance', 'm'), left='connection probability')
        self.params = Parameter.create(name='Distance binning window', type='float', value=40.e-6, step=10.e-6, suffix='m', siPrefix=True)
        self.element_plot = None
        self.elements = []
        self.element_colors = []
        self.results = None
        self.color = None
        self.name = None

        self.params.sigTreeStateChanged.connect(self.update_plot)

    def plot_distance(self, results, color, name, size=10):
        """Results needs to be a DataFrame or Series object with 'connected' and 'distance' as columns

        """
        if self.results is None:
            self.name = name
            self.color = color
            self.results = results
        connected = results['connected']
        distance = results['distance'] 
        dist_win = self.params.value()
        self.dist_plot = distance_plot(connected, distance, plots=self.plots, color=color, name=name, size=size, window=dist_win, spacing=dist_win)
        self.plots[0].setXRange(0, 200e-6)
        # 
        return self.dist_plot

    def invalidate_output(self):
        self.grid.clear()

    def element_distance(self, element, color, add_to_list=True):
        if add_to_list is True:
            self.element_colors.append(color)
            self.elements.append(element)
        pre = element['pre_class'][0].name
        post = element['post_class'][0].name
        name = ('%s->%s' % (pre, post))
        self.element_plot = self.plot_distance(element, color=color, name=name, size=15)

    def element_distance_reset(self, results, color, name):
        self.elements = []
        self.element_colors = []
        self.grid.clear()
        self.dist_plot = self.plot_distance(results, color=color, name=name, size=10)

    def update_plot(self):
        self.invalidate_output()
        self.plot_distance(self.results, self.color, self.name)
        if self.element_plot is not None:
            for element, color in zip(self.elements, self.element_colors):
                self.element_distance(element, color, add_to_list=False)