class SynapseExplorer(QtGui.QWidget): def __init__(self, expts, parent=None): QtGui.QWidget.__init__(self, parent) self.expts = expts self.layout = QtGui.QGridLayout() self.setLayout(self.layout) self.hsplit = QtGui.QSplitter(QtCore.Qt.Horizontal) self.layout.addWidget(self.hsplit, 0, 0) self.vsplit = QtGui.QSplitter(QtCore.Qt.Vertical) self.hsplit.addWidget(self.vsplit) self.syn_tree = SynapseTreeWidget(self.expts) self.vsplit.addWidget(self.syn_tree) self.syn_tree.itemSelectionChanged.connect(self.selection_changed) self.expt_info = ExperimentInfoWidget() self.vsplit.addWidget(self.expt_info) self.train_plots = PlotGrid() self.hsplit.addWidget(self.train_plots) self.analyzers = {} def selection_changed(self): with pg.BusyCursor(): sel = self.syn_tree.selectedItems()[0] expt = sel.expt self.expt_info.set_experiment(expt) pre_cell = sel.cells[0].cell_id post_cell = sel.cells[1].cell_id key = (expt, pre_cell, post_cell) if key not in self.analyzers: self.analyzers[key] = DynamicsAnalyzer(*key) analyzer = self.analyzers[key] if len(analyzer.pulse_responses) == 0: raise Exception( "No suitable data found for cell %d -> cell %d in expt %s" % (pre_cell, post_cell, expt)) # Plot all individual and averaged train responses for all sets of stimulus parameters self.train_plots.clear() analyzer.plot_train_responses(plot_grid=self.train_plots)
class SynapseExplorer(QtGui.QWidget): def __init__(self, expts, parent=None): QtGui.QWidget.__init__(self, parent) self.expts = expts self.layout = QtGui.QGridLayout() self.setLayout(self.layout) self.hsplit = QtGui.QSplitter(QtCore.Qt.Horizontal) self.layout.addWidget(self.hsplit, 0, 0) self.vsplit = QtGui.QSplitter(QtCore.Qt.Vertical) self.hsplit.addWidget(self.vsplit) self.syn_tree = SynapseTreeWidget(self.expts) self.vsplit.addWidget(self.syn_tree) self.syn_tree.itemSelectionChanged.connect(self.selection_changed) self.expt_info = ExperimentInfoWidget() self.vsplit.addWidget(self.expt_info) self.train_plots = PlotGrid() self.hsplit.addWidget(self.train_plots) self.analyzers = {} def selection_changed(self): with pg.BusyCursor(): sel = self.syn_tree.selectedItems()[0] expt = sel.expt self.expt_info.set_experiment(expt) pre_cell = sel.cells[0].cell_id post_cell = sel.cells[1].cell_id key = (expt, pre_cell, post_cell) if key not in self.analyzers: self.analyzers[key] = DynamicsAnalyzer(*key) analyzer = self.analyzers[key] if len(analyzer.pulse_responses) == 0: raise Exception("No suitable data found for cell %d -> cell %d in expt %s" % (pre_cell, post_cell, expt)) # Plot all individual and averaged train responses for all sets of stimulus parameters self.train_plots.clear() analyzer.plot_train_responses(plot_grid=self.train_plots)
class PairView(QtGui.QWidget): """For analyzing pre/post-synaptic pairs. """ def __init__(self, parent=None): self.sweeps = [] self.channels = [] self.current_event_set = None self.event_sets = [] QtGui.QWidget.__init__(self, parent) self.layout = QtGui.QGridLayout() self.setLayout(self.layout) self.layout.setContentsMargins(0, 0, 0, 0) self.vsplit = QtGui.QSplitter(QtCore.Qt.Vertical) self.layout.addWidget(self.vsplit, 0, 0) self.pre_plot = pg.PlotWidget() self.post_plot = pg.PlotWidget() for plt in (self.pre_plot, self.post_plot): plt.setClipToView(True) plt.setDownsampling(True, True, 'peak') self.post_plot.setXLink(self.pre_plot) self.vsplit.addWidget(self.pre_plot) self.vsplit.addWidget(self.post_plot) self.response_plots = PlotGrid() self.vsplit.addWidget(self.response_plots) self.event_splitter = QtGui.QSplitter(QtCore.Qt.Horizontal) self.vsplit.addWidget(self.event_splitter) self.event_table = pg.TableWidget() self.event_splitter.addWidget(self.event_table) self.event_widget = QtGui.QWidget() self.event_widget_layout = QtGui.QGridLayout() self.event_widget.setLayout(self.event_widget_layout) self.event_splitter.addWidget(self.event_widget) self.event_splitter.setSizes([600, 100]) self.event_set_list = QtGui.QListWidget() self.event_widget_layout.addWidget(self.event_set_list, 0, 0, 1, 2) self.event_set_list.addItem("current") self.event_set_list.itemSelectionChanged.connect( self.event_set_selected) self.add_set_btn = QtGui.QPushButton('add') self.event_widget_layout.addWidget(self.add_set_btn, 1, 0, 1, 1) self.add_set_btn.clicked.connect(self.add_set_clicked) self.remove_set_btn = QtGui.QPushButton('del') self.event_widget_layout.addWidget(self.remove_set_btn, 1, 1, 1, 1) self.remove_set_btn.clicked.connect(self.remove_set_clicked) self.fit_btn = QtGui.QPushButton('fit all') self.event_widget_layout.addWidget(self.fit_btn, 2, 0, 1, 2) self.fit_btn.clicked.connect(self.fit_clicked) self.fit_plot = PlotGrid() self.fit_explorer = None self.artifact_remover = ArtifactRemover(user_width=True) self.baseline_remover = BaselineRemover() self.filter = SignalFilter() self.params = pg.parametertree.Parameter( name='params', type='group', children=[{ 'name': 'pre', 'type': 'list', 'values': [] }, { 'name': 'post', 'type': 'list', 'values': [] }, self.artifact_remover.params, self.baseline_remover.params, self.filter.params, { 'name': 'time constant', 'type': 'float', 'suffix': 's', 'siPrefix': True, 'value': 10e-3, 'dec': True, 'minStep': 100e-6 }]) self.params.sigTreeStateChanged.connect(self._update_plots) def data_selected(self, sweeps, channels): self.sweeps = sweeps self.channels = channels self.params.child('pre').setLimits(channels) self.params.child('post').setLimits(channels) self._update_plots() def _update_plots(self): sweeps = self.sweeps self.current_event_set = None self.event_table.clear() # clear all plots self.pre_plot.clear() self.post_plot.clear() pre = self.params['pre'] post = self.params['post'] # If there are no selected sweeps or channels have not been set, return if len( sweeps ) == 0 or pre == post or pre not in self.channels or post not in self.channels: return pre_mode = sweeps[0][pre].clamp_mode post_mode = sweeps[0][post].clamp_mode for ch, mode, plot in [(pre, pre_mode, self.pre_plot), (post, post_mode, self.post_plot)]: units = 'A' if mode == 'vc' else 'V' plot.setLabels(left=("Channel %d" % ch, units), bottom=("Time", 's')) # Iterate over selected channels of all sweeps, plotting traces one at a time # Collect information about pulses and spikes pulses = [] spikes = [] post_traces = [] for i, sweep in enumerate(sweeps): pre_trace = sweep[pre]['primary'] post_trace = sweep[post]['primary'] # Detect pulse times stim = sweep[pre]['command'].data sdiff = np.diff(stim) on_times = np.argwhere(sdiff > 0)[1:, 0] # 1: skips test pulse off_times = np.argwhere(sdiff < 0)[1:, 0] pulses.append(on_times) # filter data post_filt = self.artifact_remover.process( post_trace, list(on_times) + list(off_times)) post_filt = self.baseline_remover.process(post_filt) post_filt = self.filter.process(post_filt) post_traces.append(post_filt) # plot raw data color = pg.intColor(i, hues=len(sweeps) * 1.3, sat=128) color.setAlpha(128) for trace, plot in [(pre_trace, self.pre_plot), (post_filt, self.post_plot)]: plot.plot(trace.time_values, trace.data, pen=color, antialias=False) # detect spike times spike_inds = [] spike_info = [] for on, off in zip(on_times, off_times): spike = detect_evoked_spike(sweep[pre], [on, off]) spike_info.append(spike) if spike is None: spike_inds.append(None) else: spike_inds.append(spike['rise_index']) spikes.append(spike_info) dt = pre_trace.dt vticks = pg.VTickGroup( [x * dt for x in spike_inds if x is not None], yrange=[0.0, 0.2], pen=color) self.pre_plot.addItem(vticks) # Iterate over spikes, plotting average response all_responses = [] avg_responses = [] fits = [] fit = None npulses = max(map(len, pulses)) self.response_plots.clear() self.response_plots.set_shape(1, npulses + 1) # 1 extra for global average self.response_plots.setYLink(self.response_plots[0, 0]) for i in range(1, npulses + 1): self.response_plots[0, i].hideAxis('left') units = 'A' if post_mode == 'vc' else 'V' self.response_plots[0, 0].setLabels(left=("Averaged events (Channel %d)" % post, units)) fit_pen = {'color': (30, 30, 255), 'width': 2, 'dash': [1, 1]} for i in range(npulses): # get the chunk of each sweep between spikes responses = [] all_responses.append(responses) for j, sweep in enumerate(sweeps): # get the current spike if i >= len(spikes[j]): continue spike = spikes[j][i] if spike is None: continue # find next spike next_spike = None for sp in spikes[j][i + 1:]: if sp is not None: next_spike = sp break # determine time range for response max_len = int(40e-3 / dt) # don't take more than 50ms for any response start = spike['rise_index'] if next_spike is not None: stop = min(start + max_len, next_spike['rise_index']) else: stop = start + max_len # collect data from this trace trace = post_traces[j] d = trace.data[start:stop].copy() responses.append(d) if len(responses) == 0: continue # extend all responses to the same length and take nanmean avg = ragged_mean(responses, method='clip') avg -= float_mode(avg[:int(1e-3 / dt)]) avg_responses.append(avg) # plot average response for this pulse start = np.median( [sp[i]['rise_index'] for sp in spikes if sp[i] is not None]) * dt t = np.arange(len(avg)) * dt self.response_plots[0, i].plot(t, avg, pen='w', antialias=True) # fit! fit = self.fit_psp(avg, t, dt, post_mode) fits.append(fit) # let the user mess with this fit curve = self.response_plots[0, i].plot(t, fit.eval(), pen=fit_pen, antialias=True).curve curve.setClickable(True) curve.fit = fit curve.sigClicked.connect(self.fit_curve_clicked) # display global average global_avg = ragged_mean(avg_responses, method='clip') t = np.arange(len(global_avg)) * dt self.response_plots[0, -1].plot(t, global_avg, pen='w', antialias=True) global_fit = self.fit_psp(global_avg, t, dt, post_mode) self.response_plots[0, -1].plot(t, global_fit.eval(), pen=fit_pen, antialias=True) # display fit parameters in table events = [] for i, f in enumerate(fits + [global_fit]): if f is None: continue if i >= len(fits): vals = OrderedDict([('id', 'avg'), ('spike_time', np.nan), ('spike_stdev', np.nan)]) else: spt = [ s[i]['peak_index'] * dt for s in spikes if s[i] is not None ] vals = OrderedDict([('id', i), ('spike_time', np.mean(spt)), ('spike_stdev', np.std(spt))]) vals.update( OrderedDict([(k, f.best_values[k]) for k in f.params.keys()])) events.append(vals) self.current_event_set = (pre, post, events, sweeps) self.event_set_list.setCurrentRow(0) self.event_set_selected() def fit_psp(self, data, t, dt, clamp_mode): mode = float_mode(data[:int(1e-3 / dt)]) sign = -1 if data.mean() - mode < 0 else 1 params = OrderedDict([ ('xoffset', (2e-3, 3e-4, 5e-3)), ('yoffset', data[0]), ('amp', sign * 10e-12), #('k', (2e-3, 50e-6, 10e-3)), ('rise_time', (2e-3, 50e-6, 10e-3)), ('decay_tau', (4e-3, 500e-6, 50e-3)), ('rise_power', (2.0, 'fixed')), ]) if clamp_mode == 'ic': params['amp'] = sign * 10e-3 #params['k'] = (5e-3, 50e-6, 20e-3) params['rise_time'] = (5e-3, 50e-6, 20e-3) params['decay_tau'] = (15e-3, 500e-6, 150e-3) fit_kws = {'xtol': 1e-3, 'maxfev': 100} psp = fitting.Psp() return psp.fit(data, x=t, fit_kws=fit_kws, params=params) def add_set_clicked(self): if self.current_event_set is None: return ces = self.current_event_set self.event_sets.append(ces) item = QtGui.QListWidgetItem("%d -> %d" % (ces[0], ces[1])) self.event_set_list.addItem(item) item.event_set = ces def remove_set_clicked(self): if self.event_set_list.currentRow() == 0: return sel = self.event_set_list.takeItem(self.event_set_list.currentRow()) self.event_sets.remove(sel.event_set) def event_set_selected(self): sel = self.event_set_list.selectedItems()[0] if sel.text() == "current": self.event_table.setData(self.current_event_set[2]) else: self.event_table.setData(sel.event_set[2]) def fit_clicked(self): from synaptic_release import ReleaseModel self.fit_plot.clear() self.fit_plot.show() n_sets = len(self.event_sets) self.fit_plot.set_shape(n_sets, 1) l = self.fit_plot[0, 0].legend if l is not None: l.scene.removeItem(l) self.fit_plot[0, 0].legend = None self.fit_plot[0, 0].addLegend() for i in range(n_sets): self.fit_plot[i, 0].setXLink(self.fit_plot[0, 0]) spike_sets = [] for i, evset in enumerate(self.event_sets): evset = evset[2][:-1] # select events, skip last row (average) x = np.array([ev['spike_time'] for ev in evset]) y = np.array([ev['amp'] for ev in evset]) x -= x[0] x *= 1000 y /= y[0] spike_sets.append((x, y)) self.fit_plot[i, 0].plot(x / 1000., y, pen=None, symbol='o') model = ReleaseModel() dynamics_types = ['Dep', 'Fac', 'UR', 'SMR', 'DSR'] for k in dynamics_types: model.Dynamics[k] = 0 fit_params = [] with pg.ProgressDialog("Fitting release model..", 0, len(dynamics_types)) as dlg: for k in dynamics_types: model.Dynamics[k] = 1 fit_params.append(model.run_fit(spike_sets)) dlg += 1 if dlg.wasCanceled(): return max_color = len(fit_params) * 1.5 for i, params in enumerate(fit_params): for j, spikes in enumerate(spike_sets): x, y = spikes t = np.linspace(0, x.max(), 1000) output = model.eval(x, params.values()) y = output[:, 1] x = output[:, 0] / 1000. self.fit_plot[j, 0].plot(x, y, pen=(i, max_color), name=dynamics_types[i]) def fit_curve_clicked(self, curve): if self.fit_explorer is None: self.fit_explorer = FitExplorer(curve.fit) else: self.fit_explorer.set_fit(curve.fit) self.fit_explorer.show()
class DistancePlot(object): def __init__(self): self.grid = PlotGrid() self.grid.set_shape(2, 1) self.grid.grid.ci.layout.setRowStretchFactor(0, 5) self.grid.grid.ci.layout.setRowStretchFactor(1, 10) self.plots = (self.grid[1,0], self.grid[0,0]) self.plots[0].grid = self.grid self.plots[0].addLegend() self.grid.show() self.plots[0].setLabels(bottom=('distance', 'm'), left='connection probability') self.plots[0].setXRange(0, 200e-6) self.params = Parameter.create(name='Distance binning window', type='float', value=40.e-6, step=10.e-6, suffix='m', siPrefix=True) self.element_plot = None self.elements = [] self.element_colors = [] self.results = None self.color = None self.name = None self.params.sigTreeStateChanged.connect(self.update_plot) def plot_distance(self, results, color, name, size=10, suppress_scatter=False): """Results needs to be a DataFrame or Series object with 'Synapse' and 'Distance' as columns """ connected = results[~results['Distance'].isnull()]['Connected'] distance = results[~results['Distance'].isnull()]['Distance'] dist_win = self.params.value() if self.results is None: self.name = name self.color = color self.results = results if suppress_scatter is True: #suppress scatter plot for all results (takes forever to plot) plots = list(self.plots) plots[1] = None self.dist_plot = distance_plot(connected, distance, plots=plots, color=color, name=name, size=size, window=dist_win, spacing=dist_win) else: self.dist_plot = distance_plot(connected, distance, plots=self.plots, color=color, name=name, size=size, window=dist_win, spacing=dist_win) return self.dist_plot def invalidate_output(self): self.grid.clear() def element_distance(self, element, color, add_to_list=True): if add_to_list is True: self.element_colors.append(color) self.elements.append(element) pre = element['pre_class'][0].name post = element['post_class'][0].name name = ('%s->%s' % (pre, post)) self.element_plot = self.plot_distance(element, color=color, name=name, size=15) def element_distance_reset(self, results, color, name, suppress_scatter=False): self.elements = [] self.element_colors = [] self.grid.clear() self.dist_plot = self.plot_distance(results, color=color, name=name, size=10, suppress_scatter=suppress_scatter) def update_plot(self): self.invalidate_output() self.plot_distance(self.results, self.color, self.name, suppress_scatter=True) if self.element_plot is not None: for element, color in zip(self.elements, self.element_colors): self.element_distance(element, color, add_to_list=False)
class PairView(QtGui.QWidget): """For analyzing pre/post-synaptic pairs. """ def __init__(self, parent=None): self.sweeps = [] self.channels = [] self.current_event_set = None self.event_sets = [] QtGui.QWidget.__init__(self, parent) self.layout = QtGui.QGridLayout() self.setLayout(self.layout) self.layout.setContentsMargins(0, 0, 0, 0) self.vsplit = QtGui.QSplitter(QtCore.Qt.Vertical) self.layout.addWidget(self.vsplit, 0, 0) self.pre_plot = pg.PlotWidget() self.post_plot = pg.PlotWidget() for plt in (self.pre_plot, self.post_plot): plt.setClipToView(True) plt.setDownsampling(True, True, 'peak') self.post_plot.setXLink(self.pre_plot) self.vsplit.addWidget(self.pre_plot) self.vsplit.addWidget(self.post_plot) self.response_plots = PlotGrid() self.vsplit.addWidget(self.response_plots) self.event_splitter = QtGui.QSplitter(QtCore.Qt.Horizontal) self.vsplit.addWidget(self.event_splitter) self.event_table = pg.TableWidget() self.event_splitter.addWidget(self.event_table) self.event_widget = QtGui.QWidget() self.event_widget_layout = QtGui.QGridLayout() self.event_widget.setLayout(self.event_widget_layout) self.event_splitter.addWidget(self.event_widget) self.event_splitter.setSizes([600, 100]) self.event_set_list = QtGui.QListWidget() self.event_widget_layout.addWidget(self.event_set_list, 0, 0, 1, 2) self.event_set_list.addItem("current") self.event_set_list.itemSelectionChanged.connect(self.event_set_selected) self.add_set_btn = QtGui.QPushButton('add') self.event_widget_layout.addWidget(self.add_set_btn, 1, 0, 1, 1) self.add_set_btn.clicked.connect(self.add_set_clicked) self.remove_set_btn = QtGui.QPushButton('del') self.event_widget_layout.addWidget(self.remove_set_btn, 1, 1, 1, 1) self.remove_set_btn.clicked.connect(self.remove_set_clicked) self.fit_btn = QtGui.QPushButton('fit all') self.event_widget_layout.addWidget(self.fit_btn, 2, 0, 1, 2) self.fit_btn.clicked.connect(self.fit_clicked) self.fit_plot = PlotGrid() self.fit_explorer = None self.artifact_remover = ArtifactRemover(user_width=True) self.baseline_remover = BaselineRemover() self.filter = SignalFilter() self.params = pg.parametertree.Parameter(name='params', type='group', children=[ {'name': 'pre', 'type': 'list', 'values': []}, {'name': 'post', 'type': 'list', 'values': []}, self.artifact_remover.params, self.baseline_remover.params, self.filter.params, {'name': 'time constant', 'type': 'float', 'suffix': 's', 'siPrefix': True, 'value': 10e-3, 'dec': True, 'minStep': 100e-6} ]) self.params.sigTreeStateChanged.connect(self._update_plots) def data_selected(self, sweeps, channels): self.sweeps = sweeps self.channels = channels self.params.child('pre').setLimits(channels) self.params.child('post').setLimits(channels) self._update_plots() def _update_plots(self): sweeps = self.sweeps self.current_event_set = None self.event_table.clear() # clear all plots self.pre_plot.clear() self.post_plot.clear() pre = self.params['pre'] post = self.params['post'] # If there are no selected sweeps or channels have not been set, return if len(sweeps) == 0 or pre == post or pre not in self.channels or post not in self.channels: return pre_mode = sweeps[0][pre].clamp_mode post_mode = sweeps[0][post].clamp_mode for ch, mode, plot in [(pre, pre_mode, self.pre_plot), (post, post_mode, self.post_plot)]: units = 'A' if mode == 'vc' else 'V' plot.setLabels(left=("Channel %d" % ch, units), bottom=("Time", 's')) # Iterate over selected channels of all sweeps, plotting traces one at a time # Collect information about pulses and spikes pulses = [] spikes = [] post_traces = [] for i,sweep in enumerate(sweeps): pre_trace = sweep[pre]['primary'] post_trace = sweep[post]['primary'] # Detect pulse times stim = sweep[pre]['command'].data sdiff = np.diff(stim) on_times = np.argwhere(sdiff > 0)[1:, 0] # 1: skips test pulse off_times = np.argwhere(sdiff < 0)[1:, 0] pulses.append(on_times) # filter data post_filt = self.artifact_remover.process(post_trace, list(on_times) + list(off_times)) post_filt = self.baseline_remover.process(post_filt) post_filt = self.filter.process(post_filt) post_traces.append(post_filt) # plot raw data color = pg.intColor(i, hues=len(sweeps)*1.3, sat=128) color.setAlpha(128) for trace, plot in [(pre_trace, self.pre_plot), (post_filt, self.post_plot)]: plot.plot(trace.time_values, trace.data, pen=color, antialias=False) # detect spike times spike_inds = [] spike_info = [] for on, off in zip(on_times, off_times): spike = detect_evoked_spike(sweep[pre], [on, off]) spike_info.append(spike) if spike is None: spike_inds.append(None) else: spike_inds.append(spike['rise_index']) spikes.append(spike_info) dt = pre_trace.dt vticks = pg.VTickGroup([x * dt for x in spike_inds if x is not None], yrange=[0.0, 0.2], pen=color) self.pre_plot.addItem(vticks) # Iterate over spikes, plotting average response all_responses = [] avg_responses = [] fits = [] fit = None npulses = max(map(len, pulses)) self.response_plots.clear() self.response_plots.set_shape(1, npulses+1) # 1 extra for global average self.response_plots.setYLink(self.response_plots[0,0]) for i in range(1, npulses+1): self.response_plots[0,i].hideAxis('left') units = 'A' if post_mode == 'vc' else 'V' self.response_plots[0, 0].setLabels(left=("Averaged events (Channel %d)" % post, units)) fit_pen = {'color':(30, 30, 255), 'width':2, 'dash': [1, 1]} for i in range(npulses): # get the chunk of each sweep between spikes responses = [] all_responses.append(responses) for j, sweep in enumerate(sweeps): # get the current spike if i >= len(spikes[j]): continue spike = spikes[j][i] if spike is None: continue # find next spike next_spike = None for sp in spikes[j][i+1:]: if sp is not None: next_spike = sp break # determine time range for response max_len = int(40e-3 / dt) # don't take more than 50ms for any response start = spike['rise_index'] if next_spike is not None: stop = min(start + max_len, next_spike['rise_index']) else: stop = start + max_len # collect data from this trace trace = post_traces[j] d = trace.data[start:stop].copy() responses.append(d) if len(responses) == 0: continue # extend all responses to the same length and take nanmean avg = ragged_mean(responses, method='clip') avg -= float_mode(avg[:int(1e-3/dt)]) avg_responses.append(avg) # plot average response for this pulse start = np.median([sp[i]['rise_index'] for sp in spikes if sp[i] is not None]) * dt t = np.arange(len(avg)) * dt self.response_plots[0,i].plot(t, avg, pen='w', antialias=True) # fit! fit = self.fit_psp(avg, t, dt, post_mode) fits.append(fit) # let the user mess with this fit curve = self.response_plots[0,i].plot(t, fit.eval(), pen=fit_pen, antialias=True).curve curve.setClickable(True) curve.fit = fit curve.sigClicked.connect(self.fit_curve_clicked) # display global average global_avg = ragged_mean(avg_responses, method='clip') t = np.arange(len(global_avg)) * dt self.response_plots[0,-1].plot(t, global_avg, pen='w', antialias=True) global_fit = self.fit_psp(global_avg, t, dt, post_mode) self.response_plots[0,-1].plot(t, global_fit.eval(), pen=fit_pen, antialias=True) # display fit parameters in table events = [] for i,f in enumerate(fits + [global_fit]): if f is None: continue if i >= len(fits): vals = OrderedDict([('id', 'avg'), ('spike_time', np.nan), ('spike_stdev', np.nan)]) else: spt = [s[i]['peak_index'] * dt for s in spikes if s[i] is not None] vals = OrderedDict([('id', i), ('spike_time', np.mean(spt)), ('spike_stdev', np.std(spt))]) vals.update(OrderedDict([(k,f.best_values[k]) for k in f.params.keys()])) events.append(vals) self.current_event_set = (pre, post, events, sweeps) self.event_set_list.setCurrentRow(0) self.event_set_selected() def fit_psp(self, data, t, dt, clamp_mode): mode = float_mode(data[:int(1e-3/dt)]) sign = -1 if data.mean() - mode < 0 else 1 params = OrderedDict([ ('xoffset', (2e-3, 3e-4, 5e-3)), ('yoffset', data[0]), ('amp', sign * 10e-12), #('k', (2e-3, 50e-6, 10e-3)), ('rise_time', (2e-3, 50e-6, 10e-3)), ('decay_tau', (4e-3, 500e-6, 50e-3)), ('rise_power', (2.0, 'fixed')), ]) if clamp_mode == 'ic': params['amp'] = sign * 10e-3 #params['k'] = (5e-3, 50e-6, 20e-3) params['rise_time'] = (5e-3, 50e-6, 20e-3) params['decay_tau'] = (15e-3, 500e-6, 150e-3) fit_kws = {'xtol': 1e-3, 'maxfev': 100} psp = fitting.Psp() return psp.fit(data, x=t, fit_kws=fit_kws, params=params) def add_set_clicked(self): if self.current_event_set is None: return ces = self.current_event_set self.event_sets.append(ces) item = QtGui.QListWidgetItem("%d -> %d" % (ces[0], ces[1])) self.event_set_list.addItem(item) item.event_set = ces def remove_set_clicked(self): if self.event_set_list.currentRow() == 0: return sel = self.event_set_list.takeItem(self.event_set_list.currentRow()) self.event_sets.remove(sel.event_set) def event_set_selected(self): sel = self.event_set_list.selectedItems()[0] if sel.text() == "current": self.event_table.setData(self.current_event_set[2]) else: self.event_table.setData(sel.event_set[2]) def fit_clicked(self): from synaptic_release import ReleaseModel self.fit_plot.clear() self.fit_plot.show() n_sets = len(self.event_sets) self.fit_plot.set_shape(n_sets, 1) l = self.fit_plot[0, 0].legend if l is not None: l.scene.removeItem(l) self.fit_plot[0, 0].legend = None self.fit_plot[0, 0].addLegend() for i in range(n_sets): self.fit_plot[i,0].setXLink(self.fit_plot[0, 0]) spike_sets = [] for i,evset in enumerate(self.event_sets): evset = evset[2][:-1] # select events, skip last row (average) x = np.array([ev['spike_time'] for ev in evset]) y = np.array([ev['amp'] for ev in evset]) x -= x[0] x *= 1000 y /= y[0] spike_sets.append((x, y)) self.fit_plot[i,0].plot(x/1000., y, pen=None, symbol='o') model = ReleaseModel() dynamics_types = ['Dep', 'Fac', 'UR', 'SMR', 'DSR'] for k in dynamics_types: model.Dynamics[k] = 0 fit_params = [] with pg.ProgressDialog("Fitting release model..", 0, len(dynamics_types)) as dlg: for k in dynamics_types: model.Dynamics[k] = 1 fit_params.append(model.run_fit(spike_sets)) dlg += 1 if dlg.wasCanceled(): return max_color = len(fit_params)*1.5 for i,params in enumerate(fit_params): for j,spikes in enumerate(spike_sets): x, y = spikes t = np.linspace(0, x.max(), 1000) output = model.eval(x, params.values()) y = output[:,1] x = output[:,0]/1000. self.fit_plot[j,0].plot(x, y, pen=(i,max_color), name=dynamics_types[i]) def fit_curve_clicked(self, curve): if self.fit_explorer is None: self.fit_explorer = FitExplorer(curve.fit) else: self.fit_explorer.set_fit(curve.fit) self.fit_explorer.show()
class MultipatchMatrixView(QtGui.QWidget): def __init__(self, parent=None): QtGui.QWidget.__init__(self, parent) self.layout = QtGui.QGridLayout() self.setLayout(self.layout) self.layout.setContentsMargins(0, 0, 0, 0) self.plots = PlotGrid(parent=self) self.layout.addWidget(self.plots, 0, 0) self.plots.scene().sigMouseClicked.connect(self._plot_clicked) #self.pair_view = PairAnalyzer() self.params = pg.parametertree.Parameter(name='params', type='group', children=[ {'name': 'show', 'type': 'list', 'values': ['sweep avg', 'sweep avg + sweeps', 'sweeps', 'pulse avg']}, {'name': 'lowpass', 'type': 'bool', 'value': True, 'children': [ {'name': 'sigma', 'type': 'float', 'value': 200e-6, 'step': 1e-5, 'limits': [0, None], 'suffix': 's', 'siPrefix': True}, ]}, {'name': 'first pulse', 'type': 'int', 'value': 0, 'limits': [0, None]}, {'name': 'last pulse', 'type': 'int', 'value': 7, 'limits': [0, None]}, {'name': 'window', 'type': 'float', 'value': 30e-3, 'step': 1e-3, 'limits': [0, None], 'suffix': 's', 'siPrefix': True}, {'name': 'remove artifacts', 'type': 'bool', 'value': True, 'children': [ {'name': 'window', 'type': 'float', 'suffix': 's', 'siPrefix': True, 'value': 1e-3, 'step': 1e-4, 'bounds': [0, None]}, ]}, {'name': 'remove baseline', 'type': 'bool', 'value': True}, {'name': 'show ticks', 'type': 'bool', 'value': True}, ]) self.params.sigTreeStateChanged.connect(self._params_changed) def show_group(self, grp): self.show_sweeps(grp.sweeps) def data_selected(self, sweeps, channels): self.sweeps = sweeps self.channels = channels self._update_plots(auto_range=True) def _params_changed(self, *args): self._update_plots() def _plot_clicked(self, ev): item = self.plots.scene().itemAt(ev.scenePos()) r,c = self.plots.item_index(item) for i in range(self.plots.rows): for j in range(self.plots.cols): color = None if (i, j) != (r, c) else pg.mkColor(30, 30, 50) self.plots[i,j].vb.setBackgroundColor(color) def _update_plots(self, auto_range=False): sweeps = self.sweeps chans = self.channels self.plots.clear() if len(sweeps) == 0 or len(chans) == 0: return # collect data data = MiesNwb.pack_sweep_data(sweeps) data, stim = data[...,0], data[...,1] # unpack stim and recordings dt = sweeps[0].recordings[0]['primary'].dt # mask for selected channels mask = np.array([ch in chans for ch in sweeps[0].devices]) data = data[:, mask] stim = stim[:, mask] chans = np.array(sweeps[0].devices)[mask] modes = [sweeps[0][ch].clamp_mode for ch in chans] # get pulse times for each channel stim = stim[0] diff = stim[:,1:] - stim[:,:-1] # note: the [1:] here skips the test pulse on_times = [np.argwhere(diff[i] > 0)[1:,0] for i in range(diff.shape[0])] off_times = [np.argwhere(diff[i] < 0)[1:,0] for i in range(diff.shape[0])] # remove capacitive artifacts from adjacent electrodes if self.params['remove artifacts']: npts = int(self.params['remove artifacts', 'window'] / dt) for i in range(stim.shape[0]): for j in range(stim.shape[0]): if i == j: continue # are these headstages adjacent? hs1, hs2 = chans[i], chans[j] if abs(hs2-hs1) > 3: continue # remove artifacts for k in range(len(on_times[i])): on = on_times[i][k] off = off_times[i][k] data[:, j, on:on+npts] = data[:, j, max(0,on-npts):on].mean(axis=1)[:,None] data[:, j, off:off+npts] = data[:, j, max(0,off-npts):off].mean(axis=1)[:,None] # lowpass filter if self.params['lowpass']: data = gaussian_filter(data, (0, 0, self.params['lowpass', 'sigma'] / dt)) # prepare to plot window = int(self.params['window'] / dt) n_sweeps = data.shape[0] n_channels = data.shape[1] self.plots.set_shape(n_channels, n_channels) self.plots.setClipToView(True) self.plots.setDownsampling(True, True, 'peak') self.plots.enableAutoRange(False, False) show_sweeps = 'sweeps' in self.params['show'] show_sweep_avg = 'sweep avg' in self.params['show'] show_pulse_avg = self.params['show'] == 'pulse avg' for i in range(n_channels): for j in range(n_channels): plt = self.plots[i, j] start = on_times[j][self.params['first pulse']] - window if start < 0: frontpad = -start start = 0 else: frontpad = 0 stop = on_times[j][self.params['last pulse']] + window # select the data segment to be displayed in this matrix cell # add padding if necessary if frontpad == 0: seg = data[:, i, start:stop].copy() else: seg = np.empty((data.shape[0], stop + frontpad), data.dtype) seg[:, frontpad:] = data[:, i, start:stop] seg[:, :frontpad] = seg[:, frontpad:frontpad+1] # subtract off baseline for each sweep if self.params['remove baseline']: seg -= seg[:, :window].mean(axis=1)[:,None] if show_sweeps: alpha = 100 if show_sweep_avg else 200 color = (255, 255, 255, alpha) t = np.arange(seg.shape[1]) * dt for k in range(n_sweeps): plt.plot(t, seg[k], pen={'color': color, 'width': 1}, antialias=True) if show_sweep_avg or show_pulse_avg: # average selected segments over all sweeps segm = seg.mean(axis=0) if show_pulse_avg: # average over all selected pulses pulses = [] for k in range(self.params['first pulse'], self.params['last pulse'] + 1): pstart = on_times[j][k] - on_times[j][self.params['first pulse']] pstop = pstart + (window * 2) pulses.append(segm[pstart:pstop]) # for p in pulses: # t = np.arange(p.shape[0]) * dt # plt.plot(t, p) segm = np.vstack(pulses).mean(axis=0) t = np.arange(segm.shape[0]) * dt if i == j: color = (80, 80, 80) else: dif = segm - segm[:window].mean() qe = 30 * np.clip(dif, 0, 1e20).mean() / segm[:window].std() qi = 30 * np.clip(-dif, 0, 1e20).mean() / segm[:window].std() if modes[i] == 'ic': qi, qe = qe, qi # invert color metric for current clamp g = 100 r = np.clip(g + max(qi, 0), 0, 255) b = np.clip(g + max(qe, 0), 0, 255) color = (r, g, b) plt.plot(t, segm, pen={'color': color, 'width': 1}, antialias=True) if self.params['show ticks']: vt = pg.VTickGroup((on_times[j]-start) * dt, [0, 0.15], pen=0.4) plt.addItem(vt) # Link all plots along x axis plt.setXLink(self.plots[0, 0]) if i == j: # link y axes of all diagonal plots plt.setYLink(self.plots[0, 0]) else: # link y axes of all plots within a row plt.setYLink(self.plots[i, (i+1) % 2]) # (i+1)%2 just avoids linking to 0,0 if i < n_channels - 1: plt.getAxis('bottom').setVisible(False) if j > 0: plt.getAxis('left').setVisible(False) if i == n_channels - 1: plt.setLabels(bottom=('CH%d'%chans[j], 's')) if j == 0: plt.setLabels(left=('CH%d'%chans[i], 'A' if modes[i] == 'vc' else 'V')) if auto_range: r = 14e-12 if modes[i] == 'vc' else 5e-3 self.plots[0, 1].setYRange(-r, r) r = 2e-9 if modes[i] == 'vc' else 100e-3 self.plots[0, 0].setYRange(-r, r) self.plots[0, 0].setXRange(t[0], t[-1])
class DynamicsWindow(pg.QtGui.QSplitter): def __init__(self): pg.QtGui.QSplitter.__init__(self, pg.QtCore.Qt.Horizontal) self.browser = ExperimentBrowser() self.addWidget(self.browser) self.plots = PlotGrid() self.addWidget(self.plots) self.browser.itemSelectionChanged.connect(self.browser_item_selected) def browser_item_selected(self): with pg.BusyCursor(): selected = self.browser.selectedItems() if len(selected) != 1: return item = selected[0] if not hasattr(item, 'pair'): return pair = item.pair self.load_pair(pair) def load_pair(self, pair): print("Loading:", pair) q = pulse_response_query(pair, data=True) self.sorted_recs = sorted_pulse_responses(q.all()) self.plot_all() def plot_all(self): self.plots.clear() self.plots.set_shape(len(self.sorted_recs), 1) psp = StackedPsp() stim_keys = sorted(list(self.sorted_recs.keys())) for i, stim_key in enumerate(stim_keys): prs = self.sorted_recs[stim_key] plt = self.plots[i, 0] plt.setTitle("%s %0.0f Hz %0.2f s" % stim_key) for recording in prs: pulses = sorted(list(prs[recording].keys())) for pulse_n in pulses: rec = prs[recording][pulse_n] # spike-align pulse + offset for pulse number spike_t = rec.stim_pulse.first_spike_time if spike_t is None: spike_t = rec.stim_pulse.onset_time + 1e-3 qc_pass = rec.pulse_response.in_qc_pass if rec.synapse.synapse_type == 'in' else rec.pulse_response.ex_qc_pass pen = (255, 255, 255, 100) if qc_pass else (100, 0, 0, 100) t0 = rec.pulse_response.data_start_time - spike_t ts = TSeries(data=rec.data, t0=t0, sample_rate=db.default_sample_rate) c = plt.plot(ts.time_values, ts.data, pen=pen) # arrange plots nicely shift = (pulse_n * 35e-3 + (30e-3 if pulse_n > 8 else 0), 0) c.setPos(*shift) if not qc_pass: c.setZValue(-10) continue # evaluate recorded fit for this response fit_par = rec.pulse_response_fit if fit_par.fit_amp is None: continue fit = psp.eval( x=ts.time_values, exp_amp=fit_par.fit_exp_amp, exp_tau=fit_par.fit_decay_tau, amp=fit_par.fit_amp, rise_time=fit_par.fit_rise_time, decay_tau=fit_par.fit_decay_tau, xoffset=fit_par.fit_latency, yoffset=fit_par.fit_yoffset, rise_power=2, ) c = plt.plot(ts.time_values, fit, pen=(0, 255, 0, 100)) c.setZValue(10) c.setPos(*shift)
class DistancePlot(object): def __init__(self): self.grid = PlotGrid() self.grid.set_shape(2, 1) self.grid.grid.ci.layout.setRowStretchFactor(0, 5) self.grid.grid.ci.layout.setRowStretchFactor(1, 10) self.plots = (self.grid[1,0], self.grid[0,0]) self.plots[0].grid = self.grid self.plots[0].addLegend() self.grid.show() self.plots[0].setLabels(bottom=('distance', 'm'), left='connection probability') self.params = Parameter.create(name='Distance binning window', type='float', value=40.e-6, step=10.e-6, suffix='m', siPrefix=True) self.element_plot = None self.elements = [] self.element_colors = [] self.results = None self.color = None self.name = None self.params.sigTreeStateChanged.connect(self.update_plot) def plot_distance(self, results, color, name, size=10): """Results needs to be a DataFrame or Series object with 'connected' and 'distance' as columns """ if self.results is None: self.name = name self.color = color self.results = results connected = results['connected'] distance = results['distance'] dist_win = self.params.value() self.dist_plot = distance_plot(connected, distance, plots=self.plots, color=color, name=name, size=size, window=dist_win, spacing=dist_win) self.plots[0].setXRange(0, 200e-6) # return self.dist_plot def invalidate_output(self): self.grid.clear() def element_distance(self, element, color, add_to_list=True): if add_to_list is True: self.element_colors.append(color) self.elements.append(element) pre = element['pre_class'][0].name post = element['post_class'][0].name name = ('%s->%s' % (pre, post)) self.element_plot = self.plot_distance(element, color=color, name=name, size=15) def element_distance_reset(self, results, color, name): self.elements = [] self.element_colors = [] self.grid.clear() self.dist_plot = self.plot_distance(results, color=color, name=name, size=10) def update_plot(self): self.invalidate_output() self.plot_distance(self.results, self.color, self.name) if self.element_plot is not None: for element, color in zip(self.elements, self.element_colors): self.element_distance(element, color, add_to_list=False)