def plot_model_results(self, model=None, fit=None): if model is None: if self._last_model_fit is None: raise Exception("Must run fit_release_model before plotting results.") model, fit = self._last_model_fit spike_sets = self.spike_sets rel_plots = PlotGrid() rel_plots.set_shape(2, 1) ind_plot = rel_plots[0, 0] ind_plot.setTitle('Release model fit - induction frequency') ind_plot.setLabels(bottom=('time', 's'), left='relative amplitude') rec_plot = rel_plots[1, 0] rec_plot.setTitle('Release model fit - recovery delay') rec_plot.setLabels(bottom=('time', 's'), left='relative amplitude') ind_plot.setLogMode(x=True, y=False) rec_plot.setLogMode(x=True, y=False) ind_plot.setXLink(rec_plot) for i,stim_params in enumerate(self.stim_param_order): x,y = spike_sets[i] output = model.eval(x, fit.values(), dt=0.5) y1 = output[:,1] x1 = output[:,0] if stim_params[1] - 0.250 < 5e-3: ind_plot.plot((x+10)/1000., y, pen=None, symbol='o', symbolBrush=(i,10)) ind_plot.plot((x1+10)/1000., y1, pen=(i,10)) if stim_params[0] == 50: rec_plot.plot((x+10)/1000., y, pen=None, symbol='o', symbolBrush=(i,10)) rec_plot.plot((x1+10)/1000., y1, pen=(i,10)) rel_plots.show() return rel_plots
def summary_plot_pulse(grand_trace, feature_list, feature_mean, labels, titles, i, plot=None, color=None, name=None): if type(feature_list) is tuple: n_features = len(feature_list) else: n_features = 1 if plot is None: plot = PlotGrid() plot.set_shape(n_features, 2) plot.show() for g in range(n_features): plot[g, 1].addLegend() plot[g, 1].setLabels(left=('Vm', 'V')) plot[g, 1].setLabels(bottom=('t', 's')) for feature in range(n_features): if n_features > 1: features = feature_list[feature] mean = feature_mean[feature] label = labels[feature] title = titles[feature] else: features = feature_list mean = feature_mean label = labels title = titles plot[feature, 0].setLabels(left=(label[0], label[1])) plot[feature, 0].hideAxis('bottom') plot[feature, 0].setTitle(title) plot[feature, 1].plot(grand_trace.time_values, grand_trace.data, pen=color, name=name) dx = pg.pseudoScatter(np.array(features).astype(float), 0.3, bidir=True) plot[feature, 0].plot((0.3 * dx / dx.max()) + i, features, pen=None, symbol='x', symbolSize=5, symbolBrush=color, symbolPen=None) plot[feature, 0].plot([i], [mean], pen=None, symbol='o', symbolBrush=color, symbolPen='w', symbolSize=10) return plot
def plot_train_responses(self, plot_grid=None): """ Plot individual and averaged train responses for each set of stimulus parameters. Return a new PlotGrid. """ train_responses = self.train_responses if plot_grid is None: train_plots = PlotGrid() else: train_plots = plot_grid train_plots.set_shape(len(self.stim_param_order), 2) for i,stim_params in enumerate(train_responses.keys()): # Collect and plot average traces covering the induction and recovery # periods for this set of stim params ind_group = train_responses[stim_params][0] rec_group = train_responses[stim_params][1] for j in range(len(ind_group)): ind = ind_group.responses[j] rec = rec_group.responses[j] base = np.median(ind_group.baselines[j].data) train_plots[i,0].plot(ind.time_values, ind.data - base, pen=(128, 128, 128, 100)) train_plots[i,1].plot(rec.time_values, rec.data - base, pen=(128, 128, 128, 100)) ind_avg = ind_group.bsub_mean() rec_avg = rec_group.bsub_mean() ind_freq, rec_delay, holding = stim_params rec_delay = np.round(rec_delay, 2) train_plots[i,0].plot(ind_avg.time_values, ind_avg.data, pen='g', antialias=True) train_plots[i,1].plot(rec_avg.time_values, rec_avg.data, pen='g', antialias=True) train_plots[i,0].setLabels(left=('Vm', 'V')) label = pg.LabelItem("ind: %0.0f rec: %0.0f hold: %0.0f" % (ind_freq, rec_delay*1000, holding*1000)) label.setParentItem(train_plots[i,0].vb) train_plots[i,0].label = label train_plots.show() train_plots.setYLink(train_plots[0,0]) for i in range(train_plots.shape[0]): train_plots[i,0].setXLink(train_plots[0,0]) train_plots[i,1].setXLink(train_plots[0,1]) train_plots.grid.ci.layout.setColumnStretchFactor(0, 3) train_plots.grid.ci.layout.setColumnStretchFactor(1, 2) train_plots.setClipToView(False) # has a bug :( train_plots.setDownsampling(True, True, 'peak') return train_plots
def trace_average_matrix(expts, **kwds): types = ['tlx3', 'sim1', 'pvalb', 'sst', 'vip'] results = {} plots = PlotGrid() plots.set_shape(len(types), len(types)) indplots = PlotGrid() indplots.set_shape(len(types), len(types)) plots.show() indplots.show() for i, pre_type in enumerate(types): for j, post_type in enumerate(types): avg_plot = plots[i, j] ind_plot = indplots[i, j] avg = plot_trace_average(all_expts, pre_type, post_type, avg_plot, ind_plot, **kwds) results[(pre_type, post_type)] = avg return results
def summary_plot_pulse(feature_list, labels, titles, i, median=False, grand_trace=None, plot=None, color=None, name=None): if type(feature_list) is tuple: n_features = len(feature_list) else: n_features = 1 if plot is None: plot = PlotGrid() plot.set_shape(n_features, 2) plot.show() for g in range(n_features): plot[g, 1].addLegend() for feature in range(n_features): if n_features > 1: current_feature = feature_list[feature] if median is True: mean = np.nanmedian(current_feature) else: mean = np.nanmean(current_feature) label = labels[feature] title = titles[feature] else: current_feature = feature_list mean = np.nanmean(current_feature) label = labels title = titles plot[feature, 0].setLabels(left=(label[0], label[1])) plot[feature, 0].hideAxis('bottom') plot[feature, 0].setTitle(title) if grand_trace is not None: plot[feature, 1].plot(grand_trace.time_values, grand_trace.data, pen=color, name=name) if len(current_feature) > 1: dx = pg.pseudoScatter(np.array(current_feature).astype(float), 0.7, bidir=True) #bar = pg.BarGraphItem(x=[i], height=mean, width=0.7, brush='w', pen={'color': color, 'width': 2}) #plot[feature, 0].addItem(bar) plot[feature, 0].plot([i], [mean], symbol='o', symbolSize=20, symbolPen='k', symbolBrush=color) sem = stats.sem(current_feature, nan_policy='omit') #err = pg.ErrorBarItem(x=np.asarray([i]), y=np.asarray([mean]), height=sem, beam=0.1) #plot[feature, 0].addItem(err) plot[feature, 0].plot((0.3 * dx / dx.max()) + i, current_feature, pen=None, symbol='o', symbolSize=10, symbolPen='w', symbolBrush=(color[0], color[1], color[2], 100)) else: plot[feature, 0].plot([i], current_feature, pen=None, symbol='o', symbolSize=10, symbolPen='w', symbolBrush=color) return plot
def pulse_average_matrix(expts, **kwds): results = {} types = ['sim1', 'tlx3', 'pvalb', 'sst', 'vip'] plots = PlotGrid() plots.set_shape(len(types), len(types)) indplots = PlotGrid() indplots.set_shape(len(types), len(types)) plots.show() indplots.show() for i, pre_type in enumerate(types): for j, post_type in enumerate(types): avg_plot = plots[i, j] ind_plot = indplots[i, j] avg = plot_pulse_average(all_expts, pre_type, post_type, avg_plot, ind_plot, **kwds) pg.QtGui.QApplication.processEvents() results[(pre_type, post_type)] = avg return results
def plot_model_results(self, model=None, fit=None): if model is None: if self._last_model_fit is None: raise Exception( "Must run fit_release_model before plotting results.") model, fit = self._last_model_fit spike_sets = self.spike_sets rel_plots = PlotGrid() rel_plots.set_shape(2, 1) ind_plot = rel_plots[0, 0] ind_plot.setTitle('Release model fit - induction frequency') ind_plot.setLabels(bottom=('time', 's'), left='relative amplitude') rec_plot = rel_plots[1, 0] rec_plot.setTitle('Release model fit - recovery delay') rec_plot.setLabels(bottom=('time', 's'), left='relative amplitude') ind_plot.setLogMode(x=True, y=False) rec_plot.setLogMode(x=True, y=False) ind_plot.setXLink(rec_plot) for i, stim_params in enumerate(self.stim_param_order): x, y = spike_sets[i] output = model.eval(x, fit.values(), dt=0.5) y1 = output[:, 1] x1 = output[:, 0] if stim_params[1] - 0.250 < 5e-3: ind_plot.plot((x + 10) / 1000., y, pen=None, symbol='o', symbolBrush=(i, 10)) ind_plot.plot((x1 + 10) / 1000., y1, pen=(i, 10)) if stim_params[0] == 50: rec_plot.plot((x + 10) / 1000., y, pen=None, symbol='o', symbolBrush=(i, 10)) rec_plot.plot((x1 + 10) / 1000., y1, pen=(i, 10)) rel_plots.show() return rel_plots
class DistancePlot(object): def __init__(self): self.grid = PlotGrid() self.grid.set_shape(2, 1) self.grid.grid.ci.layout.setRowStretchFactor(0, 5) self.grid.grid.ci.layout.setRowStretchFactor(1, 10) self.plots = (self.grid[1,0], self.grid[0,0]) self.plots[0].grid = self.grid self.plots[0].addLegend() self.grid.show() self.plots[0].setLabels(bottom=('distance', 'm'), left='connection probability') self.plots[0].setXRange(0, 200e-6) self.params = Parameter.create(name='Distance binning window', type='float', value=40.e-6, step=10.e-6, suffix='m', siPrefix=True) self.element_plot = None self.elements = [] self.element_colors = [] self.results = None self.color = None self.name = None self.params.sigTreeStateChanged.connect(self.update_plot) def plot_distance(self, results, color, name, size=10, suppress_scatter=False): """Results needs to be a DataFrame or Series object with 'Synapse' and 'Distance' as columns """ connected = results[~results['Distance'].isnull()]['Connected'] distance = results[~results['Distance'].isnull()]['Distance'] dist_win = self.params.value() if self.results is None: self.name = name self.color = color self.results = results if suppress_scatter is True: #suppress scatter plot for all results (takes forever to plot) plots = list(self.plots) plots[1] = None self.dist_plot = distance_plot(connected, distance, plots=plots, color=color, name=name, size=size, window=dist_win, spacing=dist_win) else: self.dist_plot = distance_plot(connected, distance, plots=self.plots, color=color, name=name, size=size, window=dist_win, spacing=dist_win) return self.dist_plot def invalidate_output(self): self.grid.clear() def element_distance(self, element, color, add_to_list=True): if add_to_list is True: self.element_colors.append(color) self.elements.append(element) pre = element['pre_class'][0].name post = element['post_class'][0].name name = ('%s->%s' % (pre, post)) self.element_plot = self.plot_distance(element, color=color, name=name, size=15) def element_distance_reset(self, results, color, name, suppress_scatter=False): self.elements = [] self.element_colors = [] self.grid.clear() self.dist_plot = self.plot_distance(results, color=color, name=name, size=10, suppress_scatter=suppress_scatter) def update_plot(self): self.invalidate_output() self.plot_distance(self.results, self.color, self.name, suppress_scatter=True) if self.element_plot is not None: for element, color in zip(self.elements, self.element_colors): self.element_distance(element, color, add_to_list=False)
class TSeriesPlot(pg.GraphicsLayoutWidget): def __init__(self, title, units): pg.GraphicsLayoutWidget.__init__(self) self.grid = PlotGrid() self.grid.set_shape(4, 1) self.grid.grid.ci.layout.setRowStretchFactor(0, 3) self.grid.grid.ci.layout.setRowStretchFactor(1, 8) self.grid.grid.ci.layout.setRowStretchFactor(2, 5) self.grid.grid.ci.layout.setRowStretchFactor(3, 10) self.grid.show() self.trace_plots = (self.grid[1, 0], self.grid[3, 0]) self.spike_plots = (self.grid[0, 0], self.grid[2, 0]) self.plots = self.spike_plots + self.trace_plots for plot in self.plots[:-1]: plot.hideAxis('bottom') self.plots[-1].setLabel('bottom', text='Time from spike', units='s') self.fit_item_55 = None self.fit_item_70 = None self.fit_color = {True: 'g', False: 'r'} self.qc_color = {'qc_pass': (255, 255, 255, 100), 'qc_fail': (255, 0, 0, 100)} self.plots[0].setTitle(title) for (plot, holding) in zip(self.trace_plots, holdings): plot.setXLink(self.plots[-1]) plot.setLabel('left', text="%d holding" % int(holding), units=units) for plot in self.spike_plots: plot.setXLink(self.plots[-1]) plot.setLabel('left', text="presynaptic spike") plot.addLine(x=0) self.plots[-1].setXRange(-5e-3, 10e-3) self.items = [] def plot_responses(self, pulse_responses): self.plot_traces(pulse_responses) self.plot_spikes(pulse_responses) def plot_traces(self, pulse_responses): for i, holding in enumerate(pulse_responses.keys()): for qc, prs in pulse_responses[holding].items(): if len(prs) == 0: continue prl = PulseResponseList(prs) post_ts = prl.post_tseries(align='spike', bsub=True) for trace in post_ts: item = self.trace_plots[i].plot(trace.time_values, trace.data, pen=self.qc_color[qc]) if qc == 'qc_fail': item.setZValue(-10) self.items.append(item) if qc == 'qc_pass': grand_trace = post_ts.mean() item = self.trace_plots[i].plot(grand_trace.time_values, grand_trace.data, pen={'color': 'b', 'width': 2}) self.items.append(item) self.trace_plots[i].autoRange() self.trace_plots[i].setXRange(-5e-3, 10e-3) # y_range = [grand_trace.data.min(), grand_trace.data.max()] # self.plots[i].setYRange(y_range[0], y_range[1], padding=1) def plot_spikes(self, pulse_responses): for i, holding in enumerate(pulse_responses.keys()): for prs in pulse_responses[holding].values(): if len(prs) == 0: continue prl = PulseResponseList(prs) pre_ts = prl.pre_tseries(align='spike', bsub=True) for pr, spike in zip(prl, pre_ts): qc = 'qc_pass' if pr.stim_pulse.n_spikes == 1 else 'qc_fail' item = self.spike_plots[i].plot(spike.time_values, spike.data, pen=self.qc_color[qc]) if qc == 'qc_fail': item.setZValue(-10) self.items.append(item) def plot_fit(self, trace, holding, fit_pass=False): if holding == -55: if self.fit_item_55 is not None: self.trace_plots[0].removeItem(self.fit_item_55) self.fit_item_55 = pg.PlotDataItem(trace.time_values, trace.data, name='-55 holding', pen={'color': self.fit_color[fit_pass], 'width': 3}) self.trace_plots[0].addItem(self.fit_item_55) elif holding == -70: if self.fit_item_70 is not None: self.trace_plots[1].removeItem(self.fit_item_70) self.fit_item_70 = pg.PlotDataItem(trace.time_values, trace.data, name='-70 holding', pen={'color': self.fit_color[fit_pass], 'width': 3}) self.trace_plots[1].addItem(self.fit_item_70) def color_fit(self, name, value): if '-55' in name: if self.fit_item_55 is not None: self.fit_item_55.setPen({'color': self.fit_color[value], 'width': 3}) if '-70' in name: if self.fit_item_70 is not None: self.fit_item_70.setPen({'color': self.fit_color[value], 'width': 3}) def clear_plots(self): for item in self.items + [self.fit_item_55, self.fit_item_70]: if item is None: continue item.scene().removeItem(item) self.items = [] self.plots[-1].autoRange() self.plots[-1].setXRange(-5e-3, 10e-3) self.fit_item_70 = None self.fit_item_55 = None
help='plot 50Hz train and deconvolution') parser.add_argument('--link-y-axis', action='store_true', default=False, dest='link-y-axis', help='link all y-axis down a column') args = vars(parser.parse_args(sys.argv[1:])) plot_sweeps = args['sweeps'] plot_trains = args['trains'] link_y_axis = args['link-y-axis'] all_expts = ExperimentList(cache='expts_cache.pkl') grid = PlotGrid() if plot_trains is True: grid.set_shape(8, 3) else: grid.set_shape(8, 1) grid.show() for g in range(0, grid.shape[0], 2): grid.grid.ci.layout.setRowStretchFactor(g, 5) grid.grid.ci.layout.setRowStretchFactor(g + 1, 20) grid[g, 0].hideAxis('bottom') if plot_trains is True: grid[g, 1].hideAxis('bottom') grid[g, 2].hideAxis('bottom') grid[g, 2].hideAxis('left') grid[g + 1, 2].hideAxis('left') if plot_trains is True: grid[0, 1].hideAxis('bottom') grid[0, 2].hideAxis('bottom')
class PairView(QtGui.QWidget): """For analyzing pre/post-synaptic pairs. """ def __init__(self, parent=None): self.sweeps = [] self.channels = [] self.current_event_set = None self.event_sets = [] QtGui.QWidget.__init__(self, parent) self.layout = QtGui.QGridLayout() self.setLayout(self.layout) self.layout.setContentsMargins(0, 0, 0, 0) self.vsplit = QtGui.QSplitter(QtCore.Qt.Vertical) self.layout.addWidget(self.vsplit, 0, 0) self.pre_plot = pg.PlotWidget() self.post_plot = pg.PlotWidget() for plt in (self.pre_plot, self.post_plot): plt.setClipToView(True) plt.setDownsampling(True, True, 'peak') self.post_plot.setXLink(self.pre_plot) self.vsplit.addWidget(self.pre_plot) self.vsplit.addWidget(self.post_plot) self.response_plots = PlotGrid() self.vsplit.addWidget(self.response_plots) self.event_splitter = QtGui.QSplitter(QtCore.Qt.Horizontal) self.vsplit.addWidget(self.event_splitter) self.event_table = pg.TableWidget() self.event_splitter.addWidget(self.event_table) self.event_widget = QtGui.QWidget() self.event_widget_layout = QtGui.QGridLayout() self.event_widget.setLayout(self.event_widget_layout) self.event_splitter.addWidget(self.event_widget) self.event_splitter.setSizes([600, 100]) self.event_set_list = QtGui.QListWidget() self.event_widget_layout.addWidget(self.event_set_list, 0, 0, 1, 2) self.event_set_list.addItem("current") self.event_set_list.itemSelectionChanged.connect(self.event_set_selected) self.add_set_btn = QtGui.QPushButton('add') self.event_widget_layout.addWidget(self.add_set_btn, 1, 0, 1, 1) self.add_set_btn.clicked.connect(self.add_set_clicked) self.remove_set_btn = QtGui.QPushButton('del') self.event_widget_layout.addWidget(self.remove_set_btn, 1, 1, 1, 1) self.remove_set_btn.clicked.connect(self.remove_set_clicked) self.fit_btn = QtGui.QPushButton('fit all') self.event_widget_layout.addWidget(self.fit_btn, 2, 0, 1, 2) self.fit_btn.clicked.connect(self.fit_clicked) self.fit_plot = PlotGrid() self.fit_explorer = None self.artifact_remover = ArtifactRemover(user_width=True) self.baseline_remover = BaselineRemover() self.filter = SignalFilter() self.params = pg.parametertree.Parameter(name='params', type='group', children=[ {'name': 'pre', 'type': 'list', 'values': []}, {'name': 'post', 'type': 'list', 'values': []}, self.artifact_remover.params, self.baseline_remover.params, self.filter.params, {'name': 'time constant', 'type': 'float', 'suffix': 's', 'siPrefix': True, 'value': 10e-3, 'dec': True, 'minStep': 100e-6} ]) self.params.sigTreeStateChanged.connect(self._update_plots) def data_selected(self, sweeps, channels): self.sweeps = sweeps self.channels = channels self.params.child('pre').setLimits(channels) self.params.child('post').setLimits(channels) self._update_plots() def _update_plots(self): sweeps = self.sweeps self.current_event_set = None self.event_table.clear() # clear all plots self.pre_plot.clear() self.post_plot.clear() pre = self.params['pre'] post = self.params['post'] # If there are no selected sweeps or channels have not been set, return if len(sweeps) == 0 or pre == post or pre not in self.channels or post not in self.channels: return pre_mode = sweeps[0][pre].clamp_mode post_mode = sweeps[0][post].clamp_mode for ch, mode, plot in [(pre, pre_mode, self.pre_plot), (post, post_mode, self.post_plot)]: units = 'A' if mode == 'vc' else 'V' plot.setLabels(left=("Channel %d" % ch, units), bottom=("Time", 's')) # Iterate over selected channels of all sweeps, plotting traces one at a time # Collect information about pulses and spikes pulses = [] spikes = [] post_traces = [] for i,sweep in enumerate(sweeps): pre_trace = sweep[pre]['primary'] post_trace = sweep[post]['primary'] # Detect pulse times stim = sweep[pre]['command'].data sdiff = np.diff(stim) on_times = np.argwhere(sdiff > 0)[1:, 0] # 1: skips test pulse off_times = np.argwhere(sdiff < 0)[1:, 0] pulses.append(on_times) # filter data post_filt = self.artifact_remover.process(post_trace, list(on_times) + list(off_times)) post_filt = self.baseline_remover.process(post_filt) post_filt = self.filter.process(post_filt) post_traces.append(post_filt) # plot raw data color = pg.intColor(i, hues=len(sweeps)*1.3, sat=128) color.setAlpha(128) for trace, plot in [(pre_trace, self.pre_plot), (post_filt, self.post_plot)]: plot.plot(trace.time_values, trace.data, pen=color, antialias=False) # detect spike times spike_inds = [] spike_info = [] for on, off in zip(on_times, off_times): spike = detect_evoked_spike(sweep[pre], [on, off]) spike_info.append(spike) if spike is None: spike_inds.append(None) else: spike_inds.append(spike['rise_index']) spikes.append(spike_info) dt = pre_trace.dt vticks = pg.VTickGroup([x * dt for x in spike_inds if x is not None], yrange=[0.0, 0.2], pen=color) self.pre_plot.addItem(vticks) # Iterate over spikes, plotting average response all_responses = [] avg_responses = [] fits = [] fit = None npulses = max(map(len, pulses)) self.response_plots.clear() self.response_plots.set_shape(1, npulses+1) # 1 extra for global average self.response_plots.setYLink(self.response_plots[0,0]) for i in range(1, npulses+1): self.response_plots[0,i].hideAxis('left') units = 'A' if post_mode == 'vc' else 'V' self.response_plots[0, 0].setLabels(left=("Averaged events (Channel %d)" % post, units)) fit_pen = {'color':(30, 30, 255), 'width':2, 'dash': [1, 1]} for i in range(npulses): # get the chunk of each sweep between spikes responses = [] all_responses.append(responses) for j, sweep in enumerate(sweeps): # get the current spike if i >= len(spikes[j]): continue spike = spikes[j][i] if spike is None: continue # find next spike next_spike = None for sp in spikes[j][i+1:]: if sp is not None: next_spike = sp break # determine time range for response max_len = int(40e-3 / dt) # don't take more than 50ms for any response start = spike['rise_index'] if next_spike is not None: stop = min(start + max_len, next_spike['rise_index']) else: stop = start + max_len # collect data from this trace trace = post_traces[j] d = trace.data[start:stop].copy() responses.append(d) if len(responses) == 0: continue # extend all responses to the same length and take nanmean avg = ragged_mean(responses, method='clip') avg -= float_mode(avg[:int(1e-3/dt)]) avg_responses.append(avg) # plot average response for this pulse start = np.median([sp[i]['rise_index'] for sp in spikes if sp[i] is not None]) * dt t = np.arange(len(avg)) * dt self.response_plots[0,i].plot(t, avg, pen='w', antialias=True) # fit! fit = self.fit_psp(avg, t, dt, post_mode) fits.append(fit) # let the user mess with this fit curve = self.response_plots[0,i].plot(t, fit.eval(), pen=fit_pen, antialias=True).curve curve.setClickable(True) curve.fit = fit curve.sigClicked.connect(self.fit_curve_clicked) # display global average global_avg = ragged_mean(avg_responses, method='clip') t = np.arange(len(global_avg)) * dt self.response_plots[0,-1].plot(t, global_avg, pen='w', antialias=True) global_fit = self.fit_psp(global_avg, t, dt, post_mode) self.response_plots[0,-1].plot(t, global_fit.eval(), pen=fit_pen, antialias=True) # display fit parameters in table events = [] for i,f in enumerate(fits + [global_fit]): if f is None: continue if i >= len(fits): vals = OrderedDict([('id', 'avg'), ('spike_time', np.nan), ('spike_stdev', np.nan)]) else: spt = [s[i]['peak_index'] * dt for s in spikes if s[i] is not None] vals = OrderedDict([('id', i), ('spike_time', np.mean(spt)), ('spike_stdev', np.std(spt))]) vals.update(OrderedDict([(k,f.best_values[k]) for k in f.params.keys()])) events.append(vals) self.current_event_set = (pre, post, events, sweeps) self.event_set_list.setCurrentRow(0) self.event_set_selected() def fit_psp(self, data, t, dt, clamp_mode): mode = float_mode(data[:int(1e-3/dt)]) sign = -1 if data.mean() - mode < 0 else 1 params = OrderedDict([ ('xoffset', (2e-3, 3e-4, 5e-3)), ('yoffset', data[0]), ('amp', sign * 10e-12), #('k', (2e-3, 50e-6, 10e-3)), ('rise_time', (2e-3, 50e-6, 10e-3)), ('decay_tau', (4e-3, 500e-6, 50e-3)), ('rise_power', (2.0, 'fixed')), ]) if clamp_mode == 'ic': params['amp'] = sign * 10e-3 #params['k'] = (5e-3, 50e-6, 20e-3) params['rise_time'] = (5e-3, 50e-6, 20e-3) params['decay_tau'] = (15e-3, 500e-6, 150e-3) fit_kws = {'xtol': 1e-3, 'maxfev': 100} psp = fitting.Psp() return psp.fit(data, x=t, fit_kws=fit_kws, params=params) def add_set_clicked(self): if self.current_event_set is None: return ces = self.current_event_set self.event_sets.append(ces) item = QtGui.QListWidgetItem("%d -> %d" % (ces[0], ces[1])) self.event_set_list.addItem(item) item.event_set = ces def remove_set_clicked(self): if self.event_set_list.currentRow() == 0: return sel = self.event_set_list.takeItem(self.event_set_list.currentRow()) self.event_sets.remove(sel.event_set) def event_set_selected(self): sel = self.event_set_list.selectedItems()[0] if sel.text() == "current": self.event_table.setData(self.current_event_set[2]) else: self.event_table.setData(sel.event_set[2]) def fit_clicked(self): from synaptic_release import ReleaseModel self.fit_plot.clear() self.fit_plot.show() n_sets = len(self.event_sets) self.fit_plot.set_shape(n_sets, 1) l = self.fit_plot[0, 0].legend if l is not None: l.scene.removeItem(l) self.fit_plot[0, 0].legend = None self.fit_plot[0, 0].addLegend() for i in range(n_sets): self.fit_plot[i,0].setXLink(self.fit_plot[0, 0]) spike_sets = [] for i,evset in enumerate(self.event_sets): evset = evset[2][:-1] # select events, skip last row (average) x = np.array([ev['spike_time'] for ev in evset]) y = np.array([ev['amp'] for ev in evset]) x -= x[0] x *= 1000 y /= y[0] spike_sets.append((x, y)) self.fit_plot[i,0].plot(x/1000., y, pen=None, symbol='o') model = ReleaseModel() dynamics_types = ['Dep', 'Fac', 'UR', 'SMR', 'DSR'] for k in dynamics_types: model.Dynamics[k] = 0 fit_params = [] with pg.ProgressDialog("Fitting release model..", 0, len(dynamics_types)) as dlg: for k in dynamics_types: model.Dynamics[k] = 1 fit_params.append(model.run_fit(spike_sets)) dlg += 1 if dlg.wasCanceled(): return max_color = len(fit_params)*1.5 for i,params in enumerate(fit_params): for j,spikes in enumerate(spike_sets): x, y = spikes t = np.linspace(0, x.max(), 1000) output = model.eval(x, params.values()) y = output[:,1] x = output[:,0]/1000. self.fit_plot[j,0].plot(x, y, pen=(i,max_color), name=dynamics_types[i]) def fit_curve_clicked(self, curve): if self.fit_explorer is None: self.fit_explorer = FitExplorer(curve.fit) else: self.fit_explorer.set_fit(curve.fit) self.fit_explorer.show()
connection_types = [((c_type[0], 'unknown'), (c_type[1], 'unknown'))] holding = [-55, -75] freqs = [10, 20, 50, 100] rec_t = [250, 500, 1000, 2000, 4000] sweep_threshold = 3 deconv = True # cache_file = 'train_response_cache.pkl' # response_cache = load_cache(cache_file) # cache_change = [] # log_rec_plt = pg.plot() # log_rec_plt.setLogMode(x=True) qc_plot = pg.plot() ind_plot = PlotGrid() ind_plot.set_shape(4, len(connection_types)) ind_plot.show() rec_plot = PlotGrid() rec_plot.set_shape(5, len(connection_types)) rec_plot.show() if deconv is True: deconv_ind_plot = PlotGrid() deconv_ind_plot.set_shape(4, len(connection_types)) deconv_ind_plot.show() deconv_rec_plot = PlotGrid() deconv_rec_plot.set_shape(5, len(connection_types)) deconv_rec_plot.show() summary_plot = PlotGrid() summary_plot.set_shape(len(connection_types), 2) summary_plot.show() symbols = ['o', 's', 'd', '+', 't']
pg.plot(col + np.random.uniform(size=len(col)) * 0.7, rms, pen=None, symbol='o', symbolPen=None, symbolBrush=(255, 255, 255, 50)) plt = pg.plot( labels={ 'left': 'number of sweeps (normalized per rig)', 'bottom': ('baseline rms error', 'V') }) plt.addLegend() grid = PlotGrid() grid.set_shape(3, 1) grid.show() for r, c in ((1, 'r'), (2, 'g'), (3, 'b')): mask = rig == r rig_data = rms[mask] y, x = np.histogram(rig_data, bins=np.linspace(0, 0.002, 1000)) plt.plot(x, y / len(rig_data), stepMode=True, connect='finite', pen=c, name="Rig %d" % r) p = grid[r - 1, 0] p.plot(ts[mask],
def summary_plot_pulse(feature_list, feature_mean, labels, titles, i, grand_trace=None, plot=None, color=None, name=None): if type(feature_list) is tuple: n_features = len(feature_list) else: n_features = 1 if plot is None: plot = PlotGrid() plot.set_shape(n_features, 2) plot.show() for g in range(n_features): plot[g, 1].addLegend() plot[g, 1].setLabels(left=('Vm', 'V')) plot[g, 1].setLabels(bottom=('t', 's')) for feature in range(n_features): if n_features > 1: features = feature_list[feature] mean = feature_mean[feature] label = labels[feature] title = titles[feature] else: features = feature_list mean = feature_mean label = labels title = titles plot[feature, 0].setLabels(left=(label[0], label[1])) plot[feature, 0].hideAxis('bottom') plot[feature, 0].setTitle(title) if grand_trace is not None: plot[feature, 1].plot(grand_trace.time_values, grand_trace.data, pen=color, name=name) if len(features) > 1: dx = pg.pseudoScatter(np.array(features).astype(float), 0.3, bidir=True) bar = pg.BarGraphItem(x=[i], height=mean, width=0.7, brush='w', pen={ 'color': color, 'width': 2 }) plot[feature, 0].addItem(bar) sem = stats.sem(features) err = pg.ErrorBarItem(x=np.asarray([i]), y=np.asarray([mean]), height=sem, beam=0.3) plot[feature, 0].addItem(err) plot[feature, 0].plot((0.3 * dx / dx.max()) + i, features, pen=None, symbol='o', symbolSize=10, symbolPen='w', symbolBrush=color) else: plot[feature, 0].plot([i], features, pen=None, symbol='o', symbolSize=10, symbolPen='w', symbolBrush=color) return plot
class MultipatchMatrixView(QtGui.QWidget): def __init__(self, parent=None): QtGui.QWidget.__init__(self, parent) self.layout = QtGui.QGridLayout() self.setLayout(self.layout) self.layout.setContentsMargins(0, 0, 0, 0) self.plots = PlotGrid(parent=self) self.layout.addWidget(self.plots, 0, 0) self.plots.scene().sigMouseClicked.connect(self._plot_clicked) #self.pair_view = PairAnalyzer() self.params = pg.parametertree.Parameter(name='params', type='group', children=[ {'name': 'show', 'type': 'list', 'values': ['sweep avg', 'sweep avg + sweeps', 'sweeps', 'pulse avg']}, {'name': 'lowpass', 'type': 'bool', 'value': True, 'children': [ {'name': 'sigma', 'type': 'float', 'value': 200e-6, 'step': 1e-5, 'limits': [0, None], 'suffix': 's', 'siPrefix': True}, ]}, {'name': 'first pulse', 'type': 'int', 'value': 0, 'limits': [0, None]}, {'name': 'last pulse', 'type': 'int', 'value': 7, 'limits': [0, None]}, {'name': 'window', 'type': 'float', 'value': 30e-3, 'step': 1e-3, 'limits': [0, None], 'suffix': 's', 'siPrefix': True}, {'name': 'remove artifacts', 'type': 'bool', 'value': True, 'children': [ {'name': 'window', 'type': 'float', 'suffix': 's', 'siPrefix': True, 'value': 1e-3, 'step': 1e-4, 'bounds': [0, None]}, ]}, {'name': 'remove baseline', 'type': 'bool', 'value': True}, {'name': 'show ticks', 'type': 'bool', 'value': True}, ]) self.params.sigTreeStateChanged.connect(self._params_changed) def show_group(self, grp): self.show_sweeps(grp.sweeps) def data_selected(self, sweeps, channels): self.sweeps = sweeps self.channels = channels self._update_plots(auto_range=True) def _params_changed(self, *args): self._update_plots() def _plot_clicked(self, ev): item = self.plots.scene().itemAt(ev.scenePos()) r,c = self.plots.item_index(item) for i in range(self.plots.rows): for j in range(self.plots.cols): color = None if (i, j) != (r, c) else pg.mkColor(30, 30, 50) self.plots[i,j].vb.setBackgroundColor(color) def _update_plots(self, auto_range=False): sweeps = self.sweeps chans = self.channels self.plots.clear() if len(sweeps) == 0 or len(chans) == 0: return # collect data data = MiesNwb.pack_sweep_data(sweeps) data, stim = data[...,0], data[...,1] # unpack stim and recordings dt = sweeps[0].recordings[0]['primary'].dt # mask for selected channels mask = np.array([ch in chans for ch in sweeps[0].devices]) data = data[:, mask] stim = stim[:, mask] chans = np.array(sweeps[0].devices)[mask] modes = [sweeps[0][ch].clamp_mode for ch in chans] # get pulse times for each channel stim = stim[0] diff = stim[:,1:] - stim[:,:-1] # note: the [1:] here skips the test pulse on_times = [np.argwhere(diff[i] > 0)[1:,0] for i in range(diff.shape[0])] off_times = [np.argwhere(diff[i] < 0)[1:,0] for i in range(diff.shape[0])] # remove capacitive artifacts from adjacent electrodes if self.params['remove artifacts']: npts = int(self.params['remove artifacts', 'window'] / dt) for i in range(stim.shape[0]): for j in range(stim.shape[0]): if i == j: continue # are these headstages adjacent? hs1, hs2 = chans[i], chans[j] if abs(hs2-hs1) > 3: continue # remove artifacts for k in range(len(on_times[i])): on = on_times[i][k] off = off_times[i][k] data[:, j, on:on+npts] = data[:, j, max(0,on-npts):on].mean(axis=1)[:,None] data[:, j, off:off+npts] = data[:, j, max(0,off-npts):off].mean(axis=1)[:,None] # lowpass filter if self.params['lowpass']: data = gaussian_filter(data, (0, 0, self.params['lowpass', 'sigma'] / dt)) # prepare to plot window = int(self.params['window'] / dt) n_sweeps = data.shape[0] n_channels = data.shape[1] self.plots.set_shape(n_channels, n_channels) self.plots.setClipToView(True) self.plots.setDownsampling(True, True, 'peak') self.plots.enableAutoRange(False, False) show_sweeps = 'sweeps' in self.params['show'] show_sweep_avg = 'sweep avg' in self.params['show'] show_pulse_avg = self.params['show'] == 'pulse avg' for i in range(n_channels): for j in range(n_channels): plt = self.plots[i, j] start = on_times[j][self.params['first pulse']] - window if start < 0: frontpad = -start start = 0 else: frontpad = 0 stop = on_times[j][self.params['last pulse']] + window # select the data segment to be displayed in this matrix cell # add padding if necessary if frontpad == 0: seg = data[:, i, start:stop].copy() else: seg = np.empty((data.shape[0], stop + frontpad), data.dtype) seg[:, frontpad:] = data[:, i, start:stop] seg[:, :frontpad] = seg[:, frontpad:frontpad+1] # subtract off baseline for each sweep if self.params['remove baseline']: seg -= seg[:, :window].mean(axis=1)[:,None] if show_sweeps: alpha = 100 if show_sweep_avg else 200 color = (255, 255, 255, alpha) t = np.arange(seg.shape[1]) * dt for k in range(n_sweeps): plt.plot(t, seg[k], pen={'color': color, 'width': 1}, antialias=True) if show_sweep_avg or show_pulse_avg: # average selected segments over all sweeps segm = seg.mean(axis=0) if show_pulse_avg: # average over all selected pulses pulses = [] for k in range(self.params['first pulse'], self.params['last pulse'] + 1): pstart = on_times[j][k] - on_times[j][self.params['first pulse']] pstop = pstart + (window * 2) pulses.append(segm[pstart:pstop]) # for p in pulses: # t = np.arange(p.shape[0]) * dt # plt.plot(t, p) segm = np.vstack(pulses).mean(axis=0) t = np.arange(segm.shape[0]) * dt if i == j: color = (80, 80, 80) else: dif = segm - segm[:window].mean() qe = 30 * np.clip(dif, 0, 1e20).mean() / segm[:window].std() qi = 30 * np.clip(-dif, 0, 1e20).mean() / segm[:window].std() if modes[i] == 'ic': qi, qe = qe, qi # invert color metric for current clamp g = 100 r = np.clip(g + max(qi, 0), 0, 255) b = np.clip(g + max(qe, 0), 0, 255) color = (r, g, b) plt.plot(t, segm, pen={'color': color, 'width': 1}, antialias=True) if self.params['show ticks']: vt = pg.VTickGroup((on_times[j]-start) * dt, [0, 0.15], pen=0.4) plt.addItem(vt) # Link all plots along x axis plt.setXLink(self.plots[0, 0]) if i == j: # link y axes of all diagonal plots plt.setYLink(self.plots[0, 0]) else: # link y axes of all plots within a row plt.setYLink(self.plots[i, (i+1) % 2]) # (i+1)%2 just avoids linking to 0,0 if i < n_channels - 1: plt.getAxis('bottom').setVisible(False) if j > 0: plt.getAxis('left').setVisible(False) if i == n_channels - 1: plt.setLabels(bottom=('CH%d'%chans[j], 's')) if j == 0: plt.setLabels(left=('CH%d'%chans[i], 'A' if modes[i] == 'vc' else 'V')) if auto_range: r = 14e-12 if modes[i] == 'vc' else 5e-3 self.plots[0, 1].setYRange(-r, r) r = 2e-9 if modes[i] == 'vc' else 100e-3 self.plots[0, 0].setYRange(-r, r) self.plots[0, 0].setXRange(t[0], t[-1])
rms = np.array([row[0] for row in rows]) rig = np.array([row[1] for row in rows]).astype(int) hs = np.array([row[2] for row in rows]).astype(int) col = rig*8 + hs ts = np.array([time.mktime(row[3].timetuple()) for row in rows]) #ts -= ts[0] pg.plot(col + np.random.uniform(size=len(col))*0.7, rms, pen=None, symbol='o', symbolPen=None, symbolBrush=(255, 255, 255, 50)) plt = pg.plot(labels={'left': 'number of sweeps (normalized per rig)', 'bottom': ('baseline rms error', 'V')}) plt.addLegend() grid = PlotGrid() grid.set_shape(3, 1) grid.show() for r, c in ((1, 'r'), (2, 'g'), (3, 'b')): mask = rig==r rig_data = rms[mask] y, x = np.histogram(rig_data, bins=np.linspace(0, 0.002, 1000)) plt.plot(x, y/len(rig_data), stepMode=True, connect='finite', pen=c, name="Rig %d" % r) p = grid[r-1, 0] p.plot(ts[mask], rig_data, pen=None, symbol='o', symbolPen=None, symbolBrush=(255, 255, 255, 100)) p.setLabels(left=('rig %d baseline rms noise'%r, 'V')) grid.setXLink(grid[0, 0]) grid.setYLink(grid[0, 0])
class PulseResponsePlot(object): def __init__(self, title, units): self.grid = PlotGrid() self.grid.set_shape(2, 1) self.grid.grid.ci.layout.setRowStretchFactor(0, 3) self.grid.grid.ci.layout.setRowStretchFactor(1, 8) self.spike_plot = self.grid[0, 0] self.response_plot = self.grid[1, 0] self.spike_plot.hideAxis('bottom') self.spike_plot.setTitle(title) self.spike_plot.setLabel('left', text="presynaptic spike") self.spike_plot.addLine(x=0) self.response_plot.setLabel('bottom', text='Time from spike', units='s') self.response_plot.setXLink(self.spike_plot) # self.response_plot.setLabel('left', text="%d holding" % int(holding), units=units) self.response_plot.enableAutoRange(False, False) self.items = [] def show_pulse_responses(self, passed_prs, failed_prs): for prs, pen in [(failed_prs, (255, 0, 0, 40)), (passed_prs, (255, 255, 255, 40))]: if len(prs) == 0: continue prl = PulseResponseList(prs) post_ts = prl.post_tseries(align='spike', bsub=True) for ts in post_ts: item = self.response_plot.plot(ts.time_values, ts.data, pen=pen) self.items.append(item) pre_ts = prl.pre_tseries(align='spike', bsub=True) for ts in pre_ts: item = self.spike_plot.plot(ts.time_values, ts.data, pen=pen) self.items.append(item) self.response_plot.autoRange() def show_expected_fit(self, fit_params, qc_pass): psp = StackedPsp() t = np.linspace(-10e-3, 20e-3, 1000) v = psp.eval(x=t, **fit_params) pen = {'color': (0, 150, 0) if qc_pass else (255, 100, 0), 'dash': [5, 5]} self.response_plot.plot(t, v, pen=pen, zValue=10) def show_fit_results(self, results, average, qc_pass): self.response_plot.plot(average.time_values, average.data, pen='b') psp = StackedPsp() t = average.time_values v = psp.eval(x=t, **results.best_values) pen = {'color': (0, 150, 0) if qc_pass else (255, 100, 0), 'width': 2} self.response_plot.plot(t, v, pen=pen) def plot_fit(self, trace, fit, holding, fit_pass=False): if holding == '-55': if self.fit_item_55 is not None: self.trace_plots[0].removeItem(self.fit_item_55) self.fit_item_55 = pg.PlotDataItem(trace.time_values, fit, name='-55 holding', pen={'color': self.fit_color[fit_pass], 'width': 3}) self.trace_plots[0].addItem(self.fit_item_55) elif holding == '-70': if self.fit_item_70 is not None: self.trace_plots[1].removeItem(self.fit_item_70) self.fit_item_70 = pg.PlotDataItem(trace.time_values, fit, name='-70 holding', pen={'color': self.fit_color[fit_pass], 'width': 3}) self.trace_plots[1].addItem(self.fit_item_70) def color_fit(self, name, value): if '-55' in name: if self.fit_item_55 is not None: self.fit_item_55.setPen({'color': self.fit_color[value], 'width': 3}) if '-70' in name: if self.fit_item_70 is not None: self.fit_item_70.setPen({'color': self.fit_color[value], 'width': 3}) def clear_plots(self): for item in self.items + [self.fit_item_55, self.fit_item_70]: if item is None: continue item.scene().removeItem(item) self.items = [] self.plots[-1].autoRange() self.plots[-1].setXRange(-5e-3, 10e-3) self.fit_item_70 = None self.fit_item_55 = None
class DistancePlot(object): def __init__(self): self.grid = PlotGrid() self.grid.set_shape(2, 1) self.grid.grid.ci.layout.setRowStretchFactor(0, 5) self.grid.grid.ci.layout.setRowStretchFactor(1, 10) self.plots = (self.grid[1,0], self.grid[0,0]) self.plots[0].grid = self.grid self.plots[0].addLegend() self.grid.show() self.plots[0].setLabels(bottom=('distance', 'm'), left='connection probability') self.params = Parameter.create(name='Distance binning window', type='float', value=40.e-6, step=10.e-6, suffix='m', siPrefix=True) self.element_plot = None self.elements = [] self.element_colors = [] self.results = None self.color = None self.name = None self.params.sigTreeStateChanged.connect(self.update_plot) def plot_distance(self, results, color, name, size=10): """Results needs to be a DataFrame or Series object with 'connected' and 'distance' as columns """ if self.results is None: self.name = name self.color = color self.results = results connected = results['connected'] distance = results['distance'] dist_win = self.params.value() self.dist_plot = distance_plot(connected, distance, plots=self.plots, color=color, name=name, size=size, window=dist_win, spacing=dist_win) self.plots[0].setXRange(0, 200e-6) # return self.dist_plot def invalidate_output(self): self.grid.clear() def element_distance(self, element, color, add_to_list=True): if add_to_list is True: self.element_colors.append(color) self.elements.append(element) pre = element['pre_class'][0].name post = element['post_class'][0].name name = ('%s->%s' % (pre, post)) self.element_plot = self.plot_distance(element, color=color, name=name, size=15) def element_distance_reset(self, results, color, name): self.elements = [] self.element_colors = [] self.grid.clear() self.dist_plot = self.plot_distance(results, color=color, name=name, size=10) def update_plot(self): self.invalidate_output() self.plot_distance(self.results, self.color, self.name) if self.element_plot is not None: for element, color in zip(self.elements, self.element_colors): self.element_distance(element, color, add_to_list=False)
def distance_plot(connected, distance, plots=None, color=(100, 100, 255), size=10, window=40e-6, name=None, fill_alpha=30): """Draw connectivity vs distance profiles with confidence intervals. Parameters ---------- connected : boolean array Whether a synaptic connection was found for each probe distance : array Distance between cells for each probe plots : list of PlotWidget | PlotItem (optional) Two plots used to display distance profile and scatter plot. color : tuple (R, G, B) color values for line and confidence interval. The confidence interval will be drawn with reduced opacity (see *fill_alpha*) size: int size of scatter plot symbol window : float Width of distance window over which proportions are calculated for each point on the profile line. fill_alpha : int Opacity of confidence interval fill (0-255) Note: using a spacing value that is smaller than the window size may cause an otherwise smooth decrease over distance to instead look more like a series of downward steps. """ color = pg.colorTuple(pg.mkColor(color))[:3] connected = np.array(connected).astype(float) distance = np.array(distance) # scatter plot connections probed if plots is None: grid = PlotGrid() grid.set_shape(2, 1) grid.grid.ci.layout.setRowStretchFactor(0, 5) grid.grid.ci.layout.setRowStretchFactor(1, 10) plots = (grid[1, 0], grid[0, 0]) plots[0].grid = grid plots[0].addLegend() grid.show() plots[0].setLabels(bottom=('distance', 'm'), left='connection probability') if plots[1] is not None: # scatter points a bit pts = np.vstack([distance, connected]).T conn = pts[:, 1] == 1 unconn = pts[:, 1] == 0 if np.any(conn): cscat = pg.pseudoScatter(pts[:, 0][conn], spacing=10e-6, bidir=False) mx = abs(cscat).max() if mx != 0: cscat = cscat * 0.2 # / mx pts[:, 1][conn] = -5e-5 - cscat if np.any(unconn): uscat = pg.pseudoScatter(pts[:, 0][unconn], spacing=10e-6, bidir=False) mx = abs(uscat).max() if mx != 0: uscat = uscat * 0.2 # / mx pts[:, 1][unconn] = uscat plots[1].setXLink(plots[0]) plots[1].hideAxis('bottom') plots[1].hideAxis('left') color2 = color + (100, ) scatter = plots[1].plot(pts[:, 0], pts[:, 1], pen=None, symbol='o', labels={'bottom': ('distance', 'm')}, size=size, symbolBrush=color2, symbolPen=None, name=name) scatter.scatter.opts[ 'compositionMode'] = pg.QtGui.QPainter.CompositionMode_Plus # use a sliding window to plot the proportion of connections found along with a 95% confidence interval # for connection probability bin_edges = np.arange(0, 500e-6, window) xvals, prop, lower, upper = connectivity_profile(connected, distance, bin_edges) # plot connection probability and confidence intervals color2 = [c / 3.0 for c in color] xvals = (xvals[:-1] + xvals[1:]) * 0.5 mid_curve = plots[0].plot(xvals, prop, pen={ 'color': color, 'width': 3 }, antialias=True, name=name) upper_curve = plots[0].plot(xvals, upper, pen=(0, 0, 0, 0), antialias=True) lower_curve = plots[0].plot(xvals, lower, pen=(0, 0, 0, 0), antialias=True) upper_curve.setVisible(False) lower_curve.setVisible(False) color2 = color + (fill_alpha, ) fill = pg.FillBetweenItem(upper_curve, lower_curve, brush=color2) fill.setZValue(-10) plots[0].addItem(fill, ignoreBounds=True) return plots, xvals, prop, upper, lower
qc_params = (sign, [1e-3]) holding = [-55, -75] freqs = [10, 20, 50, 100] rec_t = [250, 500, 1000, 2000, 4000] sweep_threshold = 3 deconv = True # cache_file = 'train_response_cache.pkl' # response_cache = load_cache(cache_file) # cache_change = [] # log_rec_plt = pg.plot() # log_rec_plt.setLogMode(x=True) qc_plot = pg.plot() ind_plot = PlotGrid() ind_plot.set_shape(4, len(connection_types)) ind_plot.show() rec_plot = PlotGrid() rec_plot.set_shape(5, len(connection_types)) rec_plot.show() if deconv is True: deconv_ind_plot = PlotGrid() deconv_ind_plot.set_shape(4, len(connection_types)) deconv_ind_plot.show() deconv_rec_plot = PlotGrid() deconv_rec_plot.set_shape(5,len(connection_types)) deconv_rec_plot.show() summary_plot = PlotGrid() summary_plot.set_shape(len(connection_types), 2) summary_plot.show() symbols = ['o', 's', 'd', '+', 't']
def summary_plot_pulse(feature_list, labels, titles, i, median=False, grand_trace=None, plot=None, color=None, name=None): """ Plots features of single-pulse responses such as amplitude, latency, etc. for group analysis. Can be used for one group by ideal for comparing across many groups in the feature_list Parameters ---------- feature_list : list of lists of floats single-pulse features such as amplitude. Can be multiple features each a list themselves labels : list of pyqtgraph.LabelItem axis labels, must be a list of same length as feature_list titles : list of strings plot title, must be a list of same length as feature_list i : integer iterator to place groups along x-axis median : boolean to calculate median (True) vs mean (False), default is False grand_trace : neuroanalysis.data.TraceView object option to plot response trace alongside scatter plot, default is None plot : pyqtgraph.PlotItem If not None, plot the data on the referenced pyqtgraph object. color : tuple plot color name : pyqtgraph.LegendItem Returns ------- plot : pyqtgraph.PlotItem 2 x n plot with scatter plot and optional trace response plot for each feature (n) """ if type(feature_list) is tuple: n_features = len(feature_list) else: n_features = 1 if plot is None: plot = PlotGrid() plot.set_shape(n_features, 2) plot.show() for g in range(n_features): plot[g, 1].addLegend() for feature in range(n_features): if n_features > 1: current_feature = feature_list[feature] if median is True: mean = np.nanmedian(current_feature) else: mean = np.nanmean(current_feature) label = labels[feature] title = titles[feature] else: current_feature = feature_list mean = np.nanmean(current_feature) label = labels title = titles plot[feature, 0].setLabels(left=(label[0], label[1])) plot[feature, 0].hideAxis('bottom') plot[feature, 0].setTitle(title) if grand_trace is not None: plot[feature, 1].plot(grand_trace.time_values, grand_trace.data, pen=color, name=name) if len(current_feature) > 1: dx = pg.pseudoScatter(np.array(current_feature).astype(float), 0.7, bidir=True) plot[feature, 0].plot([i], [mean], symbol='o', symbolSize=20, symbolPen='k', symbolBrush=color) sem = stats.sem(current_feature, nan_policy='omit') if len(color) != 3: new_color = pg.glColor(color) color = (new_color[0]*255, new_color[1]*255, new_color[2]*255) plot[feature, 0].plot((0.3 * dx / dx.max()) + i, current_feature, pen=None, symbol='o', symbolSize=10, symbolPen='w', symbolBrush=(color[0], color[1], color[2], 100)) else: plot[feature, 0].plot([i], current_feature, pen=None, symbol='o', symbolSize=10, symbolPen='w', symbolBrush=color) return plot
def save_fit_psp_test_set(): """NOTE THIS CODE DOES NOT WORK BUT IS HERE FOR DOCUMENTATION PURPOSES SO THAT WE CAN TRACE BACK HOW THE TEST DATA WAS CREATED IF NEEDED. Create a test set of data for testing the fit_psp function. Uses Steph's original first_puls_feature.py code to filter out error causing data. Example run statement python save save_fit_psp_test_set.py --organism mouse --connection ee Comment in the code that does the saving at the bottom """ import pyqtgraph as pg import numpy as np import csv import sys import argparse from multipatch_analysis.experiment_list import cached_experiments from manuscript_figures import get_response, get_amplitude, response_filter, feature_anova, write_cache, trace_plot, \ colors_human, colors_mouse, fail_rate, pulse_qc, feature_kw from synapse_comparison import load_cache, summary_plot_pulse from neuroanalysis.data import TraceList, Trace from neuroanalysis.ui.plot_grid import PlotGrid from multipatch_analysis.connection_detection import fit_psp from rep_connections import ee_connections, human_connections, no_include, all_connections, ie_connections, ii_connections, ei_connections from multipatch_analysis.synaptic_dynamics import DynamicsAnalyzer from scipy import stats import time import pandas as pd import json import os app = pg.mkQApp() pg.dbg() pg.setConfigOption('background', 'w') pg.setConfigOption('foreground', 'k') parser = argparse.ArgumentParser( description= 'Enter organism and type of connection you"d like to analyze ex: mouse ee (all mouse excitatory-' 'excitatory). Alternatively enter a cre-type connection ex: sim1-sim1') parser.add_argument('--organism', dest='organism', help='Select mouse or human') parser.add_argument('--connection', dest='connection', help='Specify connections to analyze') args = vars(parser.parse_args(sys.argv[1:])) all_expts = cached_experiments() manifest = { 'Type': [], 'Connection': [], 'amp': [], 'latency': [], 'rise': [], 'rise2080': [], 'rise1090': [], 'rise1080': [], 'decay': [], 'nrmse': [], 'CV': [] } fit_qc = {'nrmse': 8, 'decay': 499e-3} if args['organism'] == 'mouse': color_palette = colors_mouse calcium = 'high' age = '40-60' sweep_threshold = 3 threshold = 0.03e-3 connection = args['connection'] if connection == 'ee': connection_types = ee_connections.keys() elif connection == 'ii': connection_types = ii_connections.keys() elif connection == 'ei': connection_types = ei_connections.keys() elif connection == 'ie': connection_types == ie_connections.keys() elif connection == 'all': connection_types = all_connections.keys() elif len(connection.split('-')) == 2: c_type = connection.split('-') if c_type[0] == '2/3': pre_type = ('2/3', 'unknown') else: pre_type = (None, c_type[0]) if c_type[1] == '2/3': post_type = ('2/3', 'unknown') else: post_type = (None, c_type[0]) connection_types = [(pre_type, post_type)] elif args['organism'] == 'human': color_palette = colors_human calcium = None age = None sweep_threshold = 5 threshold = None connection = args['connection'] if connection == 'ee': connection_types = human_connections.keys() else: c_type = connection.split('-') connection_types = [((c_type[0], 'unknown'), (c_type[1], 'unknown'))] plt = pg.plot() scale_offset = (-20, -20) scale_anchor = (0.4, 1) holding = [-65, -75] qc_plot = pg.plot() grand_response = {} expt_ids = {} feature_plot = None feature2_plot = PlotGrid() feature2_plot.set_shape(5, 1) feature2_plot.show() feature3_plot = PlotGrid() feature3_plot.set_shape(1, 3) feature3_plot.show() amp_plot = pg.plot() synapse_plot = PlotGrid() synapse_plot.set_shape(len(connection_types), 1) synapse_plot.show() for c in range(len(connection_types)): cre_type = (connection_types[c][0][1], connection_types[c][1][1]) target_layer = (connection_types[c][0][0], connection_types[c][1][0]) conn_type = connection_types[c] expt_list = all_expts.select(cre_type=cre_type, target_layer=target_layer, calcium=calcium, age=age) color = color_palette[c] grand_response[conn_type[0]] = { 'trace': [], 'amp': [], 'latency': [], 'rise': [], 'dist': [], 'decay': [], 'CV': [], 'amp_measured': [] } expt_ids[conn_type[0]] = [] synapse_plot[c, 0].addLegend() for expt in expt_list: for pre, post in expt.connections: if [expt.uid, pre, post] in no_include: continue cre_check = expt.cells[pre].cre_type == cre_type[ 0] and expt.cells[post].cre_type == cre_type[1] layer_check = expt.cells[pre].target_layer == target_layer[ 0] and expt.cells[post].target_layer == target_layer[1] if cre_check is True and layer_check is True: pulse_response, artifact = get_response( expt, pre, post, analysis_type='pulse') if threshold is not None and artifact > threshold: continue response_subset, hold = response_filter( pulse_response, freq_range=[0, 50], holding_range=holding, pulse=True) if len(response_subset) >= sweep_threshold: qc_plot.clear() qc_list = pulse_qc(response_subset, baseline=1.5, pulse=None, plot=qc_plot) if len(qc_list) >= sweep_threshold: avg_trace, avg_amp, amp_sign, peak_t = get_amplitude( qc_list) # if amp_sign is '-': # continue # #print ('%s, %0.0f' %((expt.uid, pre, post), hold, )) # all_amps = fail_rate(response_subset, '+', peak_t) # cv = np.std(all_amps)/np.mean(all_amps) # # # weight parts of the trace during fitting dt = avg_trace.dt weight = np.ones( len(avg_trace.data )) * 10. #set everything to ten initially weight[int(10e-3 / dt):int( 12e-3 / dt)] = 0. #area around stim artifact weight[int(12e-3 / dt):int( 19e-3 / dt)] = 30. #area around steep PSP rise # check if the test data dir is there and if not create it test_data_dir = 'test_psp_fit' if not os.path.isdir(test_data_dir): os.mkdir(test_data_dir) save_dict = {} save_dict['input'] = { 'data': avg_trace.data.tolist(), 'dtype': str(avg_trace.data.dtype), 'dt': float(avg_trace.dt), 'amp_sign': amp_sign, 'yoffset': 0, 'xoffset': 14e-3, 'avg_amp': float(avg_amp), 'method': 'leastsq', 'stacked': False, 'rise_time_mult_factor': 10., 'weight': weight.tolist() } # need to remake trace because different output is created avg_trace_simple = Trace( data=np.array(save_dict['input']['data']), dt=save_dict['input'] ['dt']) # create Trace object psp_fits_original = fit_psp( avg_trace, sign=save_dict['input']['amp_sign'], yoffset=save_dict['input']['yoffset'], xoffset=save_dict['input']['xoffset'], amp=save_dict['input']['avg_amp'], method=save_dict['input']['method'], stacked=save_dict['input']['stacked'], rise_time_mult_factor=save_dict['input'] ['rise_time_mult_factor'], fit_kws={ 'weights': save_dict['input']['weight'] }) psp_fits_simple = fit_psp( avg_trace_simple, sign=save_dict['input']['amp_sign'], yoffset=save_dict['input']['yoffset'], xoffset=save_dict['input']['xoffset'], amp=save_dict['input']['avg_amp'], method=save_dict['input']['method'], stacked=save_dict['input']['stacked'], rise_time_mult_factor=save_dict['input'] ['rise_time_mult_factor'], fit_kws={ 'weights': save_dict['input']['weight'] }) print expt.uid, pre, post if psp_fits_original.nrmse( ) != psp_fits_simple.nrmse(): print ' the nrmse values dont match' print '\toriginal', psp_fits_original.nrmse() print '\tsimple', psp_fits_simple.nrmse()
parser = argparse.ArgumentParser() # plot options parser.add_argument('--sweeps', action='store_true', default=False, dest='sweeps', help='plot individual sweeps behing average') parser.add_argument('--trains', action='store_true', default=False, dest='trains', help='plot 50Hz train and deconvolution') parser.add_argument('--link-y-axis', action='store_true', default=False, dest='link-y-axis', help='link all y-axis down a column') args = vars(parser.parse_args(sys.argv[1:])) plot_sweeps = args['sweeps'] plot_trains = args['trains'] link_y_axis = args['link-y-axis'] expt_cache = 'C:/Users/Stephanies/multipatch_analysis/tools/expts_cache.pkl' all_expts = cached_experiments() test = PlotGrid() test.set_shape(len(connection_types.keys()), 1) test.show() grid = PlotGrid() if plot_trains is True: grid.set_shape(len(connection_types.keys()), 3) grid[0, 1].setTitle(title='50 Hz Train') grid[0, 2].setTitle(title='Exponential Deconvolution') tau = 15e-3 lp = 1000 else: grid.set_shape(len(connection_types.keys()), 2) grid.show() row = 0 grid[0, 0].setTitle(title='First Pulse') maxYpulse = []
def summary_plot_pulse(feature_list, labels, titles, i, median=False, grand_trace=None, plot=None, color=None, name=None): """ Plots features of single-pulse responses such as amplitude, latency, etc. for group analysis. Can be used for one group by ideal for comparing across many groups in the feature_list Parameters ---------- feature_list : list of lists of floats single-pulse features such as amplitude. Can be multiple features each a list themselves labels : list of pyqtgraph.LabelItem axis labels, must be a list of same length as feature_list titles : list of strings plot title, must be a list of same length as feature_list i : integer iterator to place groups along x-axis median : boolean to calculate median (True) vs mean (False), default is False grand_trace : neuroanalysis.data.TSeriesView object option to plot response trace alongside scatter plot, default is None plot : pyqtgraph.PlotItem If not None, plot the data on the referenced pyqtgraph object. color : tuple plot color name : pyqtgraph.LegendItem Returns ------- plot : pyqtgraph.PlotItem 2 x n plot with scatter plot and optional trace response plot for each feature (n) """ if type(feature_list) is tuple: n_features = len(feature_list) else: n_features = 1 if plot is None: plot = PlotGrid() plot.set_shape(n_features, 2) plot.show() for g in range(n_features): plot[g, 1].addLegend() for feature in range(n_features): if n_features > 1: current_feature = feature_list[feature] if median is True: mean = np.nanmedian(current_feature) else: mean = np.nanmean(current_feature) label = labels[feature] title = titles[feature] else: current_feature = feature_list mean = np.nanmean(current_feature) label = labels title = titles plot[feature, 0].setLabels(left=(label[0], label[1])) plot[feature, 0].hideAxis('bottom') plot[feature, 0].setTitle(title) if grand_trace is not None: plot[feature, 1].plot(grand_trace.time_values, grand_trace.data, pen=color, name=name) if len(current_feature) > 1: dx = pg.pseudoScatter(np.array(current_feature).astype(float), 0.7, bidir=True) plot[feature, 0].plot([i], [mean], symbol='o', symbolSize=20, symbolPen='k', symbolBrush=color) sem = stats.sem(current_feature, nan_policy='omit') if len(color) != 3: new_color = pg.glColor(color) color = (new_color[0] * 255, new_color[1] * 255, new_color[2] * 255) plot[feature, 0].plot( (0.3 * dx / dx.max()) + i, current_feature, pen=None, symbol='o', symbolSize=10, symbolPen='w', symbolBrush=(color[0], color[1], color[2], 100)) else: plot[feature, 0].plot([i], current_feature, pen=None, symbol='o', symbolSize=10, symbolPen='w', symbolBrush=color) return plot
#pulses = analyzer.get_evoked_responses(pre_id, post_id, clamp_mode='ic', pulse_ids=[0]) analyzer = DynamicsAnalyzer(expt, pre_id, post_id, align_to='spike') # collect all first pulse responses responses = analyzer.amp_group # collect all events #responses = analyzer.all_events n_responses = len(responses) # do exponential deconvolution on all responses deconv = TraceList() grid1 = PlotGrid() grid1.set_shape(2, 1) for i in range(n_responses): r = responses.responses[i] grid1[0, 0].plot(r.time_values, r.data) filt = bessel_filter(r - np.median(r.time_slice(0, 10e-3).data), 300.) responses.responses[i] = filt dec = exp_deconvolve(r, 15e-3) baseline = np.median(dec.data[:100]) r2 = bessel_filter(dec-baseline, 300.) grid1[1, 0].plot(r2.time_values, r2.data) deconv.append(r2) grid1.show()
def distance_plot(connected, distance, plots=None, color=(100, 100, 255), size=10, window=40e-6, spacing=None, name=None, fill_alpha=30): """Draw connectivity vs distance profiles with confidence intervals. Parameters ---------- connected : boolean array Whether a synaptic connection was found for each probe distance : array Distance between cells for each probe plots : list of PlotWidget | PlotItem (optional) Two plots used to display distance profile and scatter plot. color : tuple (R, G, B) color values for line and confidence interval. The confidence interval will be drawn with alpha=100 size: int size of scatter plot symbol window : float Width of distance window over which proportions are calculated for each point on the profile line. spacing : float Distance spacing between points on the profile line Note: using a spacing value that is smaller than the window size may cause an otherwise smooth decrease over distance to instead look more like a series of downward steps. """ color = pg.colorTuple(pg.mkColor(color))[:3] connected = np.array(connected).astype(float) distance = np.array(distance) pts = np.vstack([distance, connected]).T # scatter points a bit conn = pts[:,1] == 1 unconn = pts[:,1] == 0 if np.any(conn): cscat = pg.pseudoScatter(pts[:,0][conn], spacing=10e-6, bidir=False) mx = abs(cscat).max() if mx != 0: cscat = cscat * 0.2# / mx pts[:,1][conn] = -5e-5 - cscat if np.any(unconn): uscat = pg.pseudoScatter(pts[:,0][unconn], spacing=10e-6, bidir=False) mx = abs(uscat).max() if mx != 0: uscat = uscat * 0.2# / mx pts[:,1][unconn] = uscat # scatter plot connections probed if plots is None: grid = PlotGrid() grid.set_shape(2, 1) grid.grid.ci.layout.setRowStretchFactor(0, 5) grid.grid.ci.layout.setRowStretchFactor(1, 10) plots = (grid[1,0], grid[0,0]) plots[0].grid = grid plots[0].addLegend() grid.show() plots[0].setLabels(bottom=('distance', 'm'), left='connection probability') if plots[1] is not None: plots[1].setXLink(plots[0]) plots[1].hideAxis('bottom') plots[1].hideAxis('left') color2 = color + (100,) scatter = plots[1].plot(pts[:,0], pts[:,1], pen=None, symbol='o', labels={'bottom': ('distance', 'm')}, size=size, symbolBrush=color2, symbolPen=None, name=name) scatter.scatter.opts['compositionMode'] = pg.QtGui.QPainter.CompositionMode_Plus # use a sliding window to plot the proportion of connections found along with a 95% confidence interval # for connection probability if spacing is None: spacing = window / 4.0 xvals = np.arange(window / 2.0, 500e-6, spacing) upper = [] lower = [] prop = [] ci_xvals = [] for x in xvals: minx = x - window / 2.0 maxx = x + window / 2.0 # select points inside this window mask = (distance >= minx) & (distance <= maxx) pts_in_window = connected[mask] # compute stats for window n_probed = pts_in_window.shape[0] n_conn = pts_in_window.sum() if n_probed == 0: prop.append(np.nan) else: prop.append(n_conn / n_probed) ci = proportion_confint(n_conn, n_probed, method='beta') lower.append(ci[0]) upper.append(ci[1]) ci_xvals.append(x) # plot connection probability and confidence intervals color2 = [c / 3.0 for c in color] mid_curve = plots[0].plot(xvals, prop, pen={'color': color, 'width': 3}, antialias=True, name=name) upper_curve = plots[0].plot(ci_xvals, upper, pen=(0, 0, 0, 0), antialias=True) lower_curve = plots[0].plot(ci_xvals, lower, pen=(0, 0, 0, 0), antialias=True) upper_curve.setVisible(False) lower_curve.setVisible(False) color2 = color + (fill_alpha,) fill = pg.FillBetweenItem(upper_curve, lower_curve, brush=color2) fill.setZValue(-10) plots[0].addItem(fill, ignoreBounds=True) return plots, ci_xvals, prop, upper, lower
def plot_response_averages(expt, show_baseline=False, **kwds): analyzer = MultiPatchExperimentAnalyzer.get(expt) devs = analyzer.list_devs() # First get average evoked responses for all pre/post pairs responses, rows, cols = analyzer.get_evoked_response_matrix(**kwds) # resize plot grid accordingly plots = PlotGrid() plots.set_shape(len(rows), len(cols)) plots.show() ranges = [([], []), ([], [])] points = [] # Plot each matrix element with PSP fit for i, dev1 in enumerate(rows): for j, dev2 in enumerate(cols): # select plot and hide axes plt = plots[i, j] if i < len(devs) - 1: plt.getAxis('bottom').setVisible(False) if j > 0: plt.getAxis('left').setVisible(False) if dev1 == dev2: plt.getAxis('bottom').setVisible(False) plt.getAxis('left').setVisible(False) continue # adjust axes / labels plt.setXLink(plots[0, 0]) plt.setYLink(plots[0, 0]) plt.addLine(x=10e-3, pen=0.3) plt.addLine(y=0, pen=0.3) plt.setLabels(bottom=(str(dev2), 's')) if kwds.get('clamp_mode', 'ic') == 'ic': plt.setLabels(left=('%s' % dev1, 'V')) else: plt.setLabels(left=('%s' % dev1, 'A')) # print "==========", dev1, dev2 avg_response = responses[(dev1, dev2)].bsub_mean() if avg_response is not None: avg_response.t0 = 0 t = avg_response.time_values y = bessel_filter(Trace(avg_response.data, dt=avg_response.dt), 2e3).data plt.plot(t, y, antialias=True) # fit! #fit = responses[(dev1, dev2)].fit_psp(yoffset=0, mask_stim_artifact=(abs(dev1-dev2) < 3)) #lsnr = np.log(fit.snr) #lerr = np.log(fit.nrmse()) #color = ( #np.clip(255 * (-lerr/3.), 0, 255), #np.clip(50 * lsnr, 0, 255), #np.clip(255 * (1+lerr/3.), 0, 255) #) #plt.plot(t, fit.best_fit, pen=color) ## plt.plot(t, fit.init_fit, pen='y') #points.append({'x': lerr, 'y': lsnr, 'brush': color}) #if show_baseline: ## plot baseline for reference #bl = avg_response.meta['baseline'] - avg_response.meta['baseline_med'] #plt.plot(np.arange(len(bl)) * avg_response.dt, bl, pen=(0, 100, 0), antialias=True) # keep track of data range across all plots ranges[0][0].append(y.min()) ranges[0][1].append(y.max()) ranges[1][0].append(t[0]) ranges[1][1].append(t[-1]) plots[0,0].setYRange(min(ranges[0][0]), max(ranges[0][1])) plots[0,0].setXRange(min(ranges[1][0]), max(ranges[1][1])) # scatter plot of SNR vs NRMSE plt = pg.plot() plt.setLabels(left='ln(SNR)', bottom='ln(NRMSE)') plt.plot([p['x'] for p in points], [p['y'] for p in points], pen=None, symbol='o', symbolBrush=[pg.mkBrush(p['brush']) for p in points]) # show threshold line line = pg.InfiniteLine(pos=[0, 6], angle=180/np.pi * np.arctan(1)) plt.addItem(line, ignoreBounds=True) return plots
def plot_train_responses(self, plot_grid=None): """ Plot individual and averaged train responses for each set of stimulus parameters. Return a new PlotGrid. """ train_responses = self.train_responses if plot_grid is None: train_plots = PlotGrid() else: train_plots = plot_grid train_plots.set_shape(len(self.stim_param_order), 2) for i, stim_params in enumerate(train_responses.keys()): # Collect and plot average traces covering the induction and recovery # periods for this set of stim params ind_group = train_responses[stim_params][0] rec_group = train_responses[stim_params][1] for j in range(len(ind_group)): ind = ind_group.responses[j] rec = rec_group.responses[j] base = np.median(ind_group.baselines[j].data) train_plots[i, 0].plot(ind.time_values, ind.data - base, pen=(128, 128, 128, 100)) train_plots[i, 1].plot(rec.time_values, rec.data - base, pen=(128, 128, 128, 100)) ind_avg = ind_group.bsub_mean() rec_avg = rec_group.bsub_mean() ind_freq, rec_delay, holding = stim_params rec_delay = np.round(rec_delay, 2) train_plots[i, 0].plot(ind_avg.time_values, ind_avg.data, pen='g', antialias=True) train_plots[i, 1].plot(rec_avg.time_values, rec_avg.data, pen='g', antialias=True) train_plots[i, 0].setLabels(left=('Vm', 'V')) label = pg.LabelItem("ind: %0.0f rec: %0.0f hold: %0.0f" % (ind_freq, rec_delay * 1000, holding * 1000)) label.setParentItem(train_plots[i, 0].vb) train_plots[i, 0].label = label train_plots.show() train_plots.setYLink(train_plots[0, 0]) for i in range(train_plots.shape[0]): train_plots[i, 0].setXLink(train_plots[0, 0]) train_plots[i, 1].setXLink(train_plots[0, 1]) train_plots.grid.ci.layout.setColumnStretchFactor(0, 3) train_plots.grid.ci.layout.setColumnStretchFactor(1, 2) train_plots.setClipToView(False) # has a bug :( train_plots.setDownsampling(True, True, 'peak') return train_plots
app = pg.mkQApp() pg.dbg() all_expts = cached_experiments() # results = pulse_average_matrix(all_expts, clamp_mode='ic', min_duration=25e-3, pulse_ids=[0, 8]) # 1. collect a list of all experiments containing connections conn_expts = {} for c in all_expts.connection_summary(): conn_expts.setdefault(c['expt'], []).append(c['cells']) # 2. for each experiment, get and cache the full set of first-pulse averages types = ['sim1', 'tlx3', 'pvalb', 'sst', 'vip'] indplots = PlotGrid() indplots.set_shape(len(types), len(types)) indplots.show() cachefile = "first_pulse_average.pkl" cache = pickle.load(open(cachefile, 'rb')) if os.path.isfile(cachefile) else {} for expt, conns in sorted(conn_expts.items()): if expt.source_id not in cache: print "Load:", expt.source_id try: with expt.data: analyzer = MultiPatchExperimentAnalyzer.get(expt.data) responses = [] for pre_cell, post_cell in conns: pre_id, post_id = pre_cell.cell_id-1, post_cell.cell_id-1 resp = analyzer.get_evoked_responses(pre_id, post_id,
class PairView(QtGui.QWidget): """For analyzing pre/post-synaptic pairs. """ def __init__(self, parent=None): self.sweeps = [] self.channels = [] self.current_event_set = None self.event_sets = [] QtGui.QWidget.__init__(self, parent) self.layout = QtGui.QGridLayout() self.setLayout(self.layout) self.layout.setContentsMargins(0, 0, 0, 0) self.vsplit = QtGui.QSplitter(QtCore.Qt.Vertical) self.layout.addWidget(self.vsplit, 0, 0) self.pre_plot = pg.PlotWidget() self.post_plot = pg.PlotWidget() for plt in (self.pre_plot, self.post_plot): plt.setClipToView(True) plt.setDownsampling(True, True, 'peak') self.post_plot.setXLink(self.pre_plot) self.vsplit.addWidget(self.pre_plot) self.vsplit.addWidget(self.post_plot) self.response_plots = PlotGrid() self.vsplit.addWidget(self.response_plots) self.event_splitter = QtGui.QSplitter(QtCore.Qt.Horizontal) self.vsplit.addWidget(self.event_splitter) self.event_table = pg.TableWidget() self.event_splitter.addWidget(self.event_table) self.event_widget = QtGui.QWidget() self.event_widget_layout = QtGui.QGridLayout() self.event_widget.setLayout(self.event_widget_layout) self.event_splitter.addWidget(self.event_widget) self.event_splitter.setSizes([600, 100]) self.event_set_list = QtGui.QListWidget() self.event_widget_layout.addWidget(self.event_set_list, 0, 0, 1, 2) self.event_set_list.addItem("current") self.event_set_list.itemSelectionChanged.connect( self.event_set_selected) self.add_set_btn = QtGui.QPushButton('add') self.event_widget_layout.addWidget(self.add_set_btn, 1, 0, 1, 1) self.add_set_btn.clicked.connect(self.add_set_clicked) self.remove_set_btn = QtGui.QPushButton('del') self.event_widget_layout.addWidget(self.remove_set_btn, 1, 1, 1, 1) self.remove_set_btn.clicked.connect(self.remove_set_clicked) self.fit_btn = QtGui.QPushButton('fit all') self.event_widget_layout.addWidget(self.fit_btn, 2, 0, 1, 2) self.fit_btn.clicked.connect(self.fit_clicked) self.fit_plot = PlotGrid() self.fit_explorer = None self.artifact_remover = ArtifactRemover(user_width=True) self.baseline_remover = BaselineRemover() self.filter = SignalFilter() self.params = pg.parametertree.Parameter( name='params', type='group', children=[{ 'name': 'pre', 'type': 'list', 'values': [] }, { 'name': 'post', 'type': 'list', 'values': [] }, self.artifact_remover.params, self.baseline_remover.params, self.filter.params, { 'name': 'time constant', 'type': 'float', 'suffix': 's', 'siPrefix': True, 'value': 10e-3, 'dec': True, 'minStep': 100e-6 }]) self.params.sigTreeStateChanged.connect(self._update_plots) def data_selected(self, sweeps, channels): self.sweeps = sweeps self.channels = channels self.params.child('pre').setLimits(channels) self.params.child('post').setLimits(channels) self._update_plots() def _update_plots(self): sweeps = self.sweeps self.current_event_set = None self.event_table.clear() # clear all plots self.pre_plot.clear() self.post_plot.clear() pre = self.params['pre'] post = self.params['post'] # If there are no selected sweeps or channels have not been set, return if len( sweeps ) == 0 or pre == post or pre not in self.channels or post not in self.channels: return pre_mode = sweeps[0][pre].clamp_mode post_mode = sweeps[0][post].clamp_mode for ch, mode, plot in [(pre, pre_mode, self.pre_plot), (post, post_mode, self.post_plot)]: units = 'A' if mode == 'vc' else 'V' plot.setLabels(left=("Channel %d" % ch, units), bottom=("Time", 's')) # Iterate over selected channels of all sweeps, plotting traces one at a time # Collect information about pulses and spikes pulses = [] spikes = [] post_traces = [] for i, sweep in enumerate(sweeps): pre_trace = sweep[pre]['primary'] post_trace = sweep[post]['primary'] # Detect pulse times stim = sweep[pre]['command'].data sdiff = np.diff(stim) on_times = np.argwhere(sdiff > 0)[1:, 0] # 1: skips test pulse off_times = np.argwhere(sdiff < 0)[1:, 0] pulses.append(on_times) # filter data post_filt = self.artifact_remover.process( post_trace, list(on_times) + list(off_times)) post_filt = self.baseline_remover.process(post_filt) post_filt = self.filter.process(post_filt) post_traces.append(post_filt) # plot raw data color = pg.intColor(i, hues=len(sweeps) * 1.3, sat=128) color.setAlpha(128) for trace, plot in [(pre_trace, self.pre_plot), (post_filt, self.post_plot)]: plot.plot(trace.time_values, trace.data, pen=color, antialias=False) # detect spike times spike_inds = [] spike_info = [] for on, off in zip(on_times, off_times): spike = detect_evoked_spike(sweep[pre], [on, off]) spike_info.append(spike) if spike is None: spike_inds.append(None) else: spike_inds.append(spike['rise_index']) spikes.append(spike_info) dt = pre_trace.dt vticks = pg.VTickGroup( [x * dt for x in spike_inds if x is not None], yrange=[0.0, 0.2], pen=color) self.pre_plot.addItem(vticks) # Iterate over spikes, plotting average response all_responses = [] avg_responses = [] fits = [] fit = None npulses = max(map(len, pulses)) self.response_plots.clear() self.response_plots.set_shape(1, npulses + 1) # 1 extra for global average self.response_plots.setYLink(self.response_plots[0, 0]) for i in range(1, npulses + 1): self.response_plots[0, i].hideAxis('left') units = 'A' if post_mode == 'vc' else 'V' self.response_plots[0, 0].setLabels(left=("Averaged events (Channel %d)" % post, units)) fit_pen = {'color': (30, 30, 255), 'width': 2, 'dash': [1, 1]} for i in range(npulses): # get the chunk of each sweep between spikes responses = [] all_responses.append(responses) for j, sweep in enumerate(sweeps): # get the current spike if i >= len(spikes[j]): continue spike = spikes[j][i] if spike is None: continue # find next spike next_spike = None for sp in spikes[j][i + 1:]: if sp is not None: next_spike = sp break # determine time range for response max_len = int(40e-3 / dt) # don't take more than 50ms for any response start = spike['rise_index'] if next_spike is not None: stop = min(start + max_len, next_spike['rise_index']) else: stop = start + max_len # collect data from this trace trace = post_traces[j] d = trace.data[start:stop].copy() responses.append(d) if len(responses) == 0: continue # extend all responses to the same length and take nanmean avg = ragged_mean(responses, method='clip') avg -= float_mode(avg[:int(1e-3 / dt)]) avg_responses.append(avg) # plot average response for this pulse start = np.median( [sp[i]['rise_index'] for sp in spikes if sp[i] is not None]) * dt t = np.arange(len(avg)) * dt self.response_plots[0, i].plot(t, avg, pen='w', antialias=True) # fit! fit = self.fit_psp(avg, t, dt, post_mode) fits.append(fit) # let the user mess with this fit curve = self.response_plots[0, i].plot(t, fit.eval(), pen=fit_pen, antialias=True).curve curve.setClickable(True) curve.fit = fit curve.sigClicked.connect(self.fit_curve_clicked) # display global average global_avg = ragged_mean(avg_responses, method='clip') t = np.arange(len(global_avg)) * dt self.response_plots[0, -1].plot(t, global_avg, pen='w', antialias=True) global_fit = self.fit_psp(global_avg, t, dt, post_mode) self.response_plots[0, -1].plot(t, global_fit.eval(), pen=fit_pen, antialias=True) # display fit parameters in table events = [] for i, f in enumerate(fits + [global_fit]): if f is None: continue if i >= len(fits): vals = OrderedDict([('id', 'avg'), ('spike_time', np.nan), ('spike_stdev', np.nan)]) else: spt = [ s[i]['peak_index'] * dt for s in spikes if s[i] is not None ] vals = OrderedDict([('id', i), ('spike_time', np.mean(spt)), ('spike_stdev', np.std(spt))]) vals.update( OrderedDict([(k, f.best_values[k]) for k in f.params.keys()])) events.append(vals) self.current_event_set = (pre, post, events, sweeps) self.event_set_list.setCurrentRow(0) self.event_set_selected() def fit_psp(self, data, t, dt, clamp_mode): mode = float_mode(data[:int(1e-3 / dt)]) sign = -1 if data.mean() - mode < 0 else 1 params = OrderedDict([ ('xoffset', (2e-3, 3e-4, 5e-3)), ('yoffset', data[0]), ('amp', sign * 10e-12), #('k', (2e-3, 50e-6, 10e-3)), ('rise_time', (2e-3, 50e-6, 10e-3)), ('decay_tau', (4e-3, 500e-6, 50e-3)), ('rise_power', (2.0, 'fixed')), ]) if clamp_mode == 'ic': params['amp'] = sign * 10e-3 #params['k'] = (5e-3, 50e-6, 20e-3) params['rise_time'] = (5e-3, 50e-6, 20e-3) params['decay_tau'] = (15e-3, 500e-6, 150e-3) fit_kws = {'xtol': 1e-3, 'maxfev': 100} psp = fitting.Psp() return psp.fit(data, x=t, fit_kws=fit_kws, params=params) def add_set_clicked(self): if self.current_event_set is None: return ces = self.current_event_set self.event_sets.append(ces) item = QtGui.QListWidgetItem("%d -> %d" % (ces[0], ces[1])) self.event_set_list.addItem(item) item.event_set = ces def remove_set_clicked(self): if self.event_set_list.currentRow() == 0: return sel = self.event_set_list.takeItem(self.event_set_list.currentRow()) self.event_sets.remove(sel.event_set) def event_set_selected(self): sel = self.event_set_list.selectedItems()[0] if sel.text() == "current": self.event_table.setData(self.current_event_set[2]) else: self.event_table.setData(sel.event_set[2]) def fit_clicked(self): from synaptic_release import ReleaseModel self.fit_plot.clear() self.fit_plot.show() n_sets = len(self.event_sets) self.fit_plot.set_shape(n_sets, 1) l = self.fit_plot[0, 0].legend if l is not None: l.scene.removeItem(l) self.fit_plot[0, 0].legend = None self.fit_plot[0, 0].addLegend() for i in range(n_sets): self.fit_plot[i, 0].setXLink(self.fit_plot[0, 0]) spike_sets = [] for i, evset in enumerate(self.event_sets): evset = evset[2][:-1] # select events, skip last row (average) x = np.array([ev['spike_time'] for ev in evset]) y = np.array([ev['amp'] for ev in evset]) x -= x[0] x *= 1000 y /= y[0] spike_sets.append((x, y)) self.fit_plot[i, 0].plot(x / 1000., y, pen=None, symbol='o') model = ReleaseModel() dynamics_types = ['Dep', 'Fac', 'UR', 'SMR', 'DSR'] for k in dynamics_types: model.Dynamics[k] = 0 fit_params = [] with pg.ProgressDialog("Fitting release model..", 0, len(dynamics_types)) as dlg: for k in dynamics_types: model.Dynamics[k] = 1 fit_params.append(model.run_fit(spike_sets)) dlg += 1 if dlg.wasCanceled(): return max_color = len(fit_params) * 1.5 for i, params in enumerate(fit_params): for j, spikes in enumerate(spike_sets): x, y = spikes t = np.linspace(0, x.max(), 1000) output = model.eval(x, params.values()) y = output[:, 1] x = output[:, 0] / 1000. self.fit_plot[j, 0].plot(x, y, pen=(i, max_color), name=dynamics_types[i]) def fit_curve_clicked(self, curve): if self.fit_explorer is None: self.fit_explorer = FitExplorer(curve.fit) else: self.fit_explorer.set_fit(curve.fit) self.fit_explorer.show()
# model = ReleaseModel() # model.Dynamics = {'Dep':1, 'Fac':0, 'UR':0, 'SMR':0, 'DSR':0} # params = {(('2/3', 'unknown'), ('2/3', 'unknown')): [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # ((None,'rorb'), (None,'rorb')): [0, 506.7, 0, 0, 0, 0, 0.22, 0, 0, 0, 0, 0], # ((None,'sim1'), (None,'sim1')): [0, 1213.8, 0, 0, 0, 0, 0.17, 0, 0, 0, 0, 0], # ((None,'tlx3'), (None,'tlx3')): [0, 319.4, 0, 0, 0, 0, 0.16, 0, 0, 0, 0, 0]} ind = data[0] rec = data[1] freq = 10 delta = 250 order = human_connections.keys() delay_order = [250, 500, 1000, 2000, 4000] ind_plt = PlotGrid() ind_plt.set_shape(4, 1) ind_plt.show() rec_plt = PlotGrid() rec_plt.set_shape(1, 2) rec_plt.show() ind_plt_scatter = pg.plot() ind_plt_all = PlotGrid() ind_plt_all.set_shape(1, 2) ind_plt_all.show() ind_50 = {} ninth_pulse_250 = {} gain_plot = pg.plot() for t, type in enumerate(order): if type == (('2', 'unknown'), ('2', 'unknown')): continue
class ExperimentTimeline(QtGui.QWidget): def __init__(self): QtGui.QWidget.__init__(self) self.channels = None self.start_time = None # starting time according to NWB file self.layout = QtGui.QGridLayout() self.setLayout(self.layout) self.layout.setContentsMargins(0, 0, 0, 0) self.plots = PlotGrid() self.plots.set_shape(config.n_headstages, 1) self.plots.setXLink(self.plots[0, 0]) self.layout.addWidget(self.plots, 0, 0) self.ptree = pg.parametertree.ParameterTree(showHeader=False) self.layout.addWidget(self.ptree, 0, 1) self.ptree.setMaximumWidth(250) self.params = pg.parametertree.Parameter.create(name='params', type='group', addText='Add pipette') self.params.addNew = self.add_pipette_clicked # monkey! self.ptree.setParameters(self.params, showTop=False) def list_channels(self): return self.channels def get_channel_plot(self, chan): return self.plots[chan, 0] def add_pipette_clicked(self): self.add_pipette(channel=self.channels[0], start=0, stop=500) def remove_pipettes(self): for ch in self.params.children(): self.params.removeChild(ch) ch.region.scene().removeItem(ch.region) for i in range(self.plots.shape[0]): self.plots[i, 0].clear() def load_site(self, site_dh): """Generate pipette list for this site """ self.remove_pipettes() # automatically fill pipette fluorophore field expt_dh = site_dh.parent().parent() expt_info = expt_dh.info() dye = expt_info.get('internal_dye', None) internal = expt_info.get('internal', None) # automatically select electrode regions self.channels = list(range(config.n_headstages)) site_info = site_dh.info() for i in self.channels: hs_state = site_info.get('Headstage %d' % (i + 1), None) status = { 'NS': 'No seal', 'LS': 'Low seal', 'GS': 'GOhm seal', 'TF': 'Technical failure', 'NA': 'No attempt', None: 'Not recorded', }.get(hs_state, hs_state) self.add_pipette(i, status=status, internal_dye=dye, internal=internal) def load_nwb(self, nwb_handle): with pg.BusyCursor(): self._load_nwb(nwb_handle) def _load_nwb(self, nwb_handle): self.nwb_handle = nwb_handle self.nwb = MiesNwb(nwb_handle.name()) # load all recordings recs = {} for srec in self.nwb.contents: for chan in srec.devices: recs.setdefault(chan, []).append(srec[chan]) chans = sorted(recs.keys()) # find time of first recording start_time = min([rec[0].start_time for rec in recs.values()]) self.start_time = start_time end_time = max([rec[-1].start_time for rec in recs.values()]) self.plots.setXRange(0, (end_time - start_time).seconds) # plot all recordings for i, chan in enumerate(chans): n_recs = len(recs[chan]) times = np.empty(n_recs) i_hold = np.empty(n_recs) v_hold = np.empty(n_recs) v_noise = np.empty(n_recs) i_noise = np.empty(n_recs) # load QC metrics for all recordings for j, rec in enumerate(recs[chan]): dt = (rec.start_time - start_time).seconds times[j] = dt v_hold[j] = rec.baseline_potential i_hold[j] = rec.baseline_current if rec.clamp_mode == 'vc': v_noise[j] = np.nan i_noise[j] = rec.baseline_rms_noise else: v_noise[j] = rec.baseline_rms_noise i_noise[j] = np.nan # scale all qc metrics to the range 0-1 pass_brush = pg.mkBrush(100, 100, 255, 200) fail_brush = pg.mkBrush(255, 0, 0, 200) v_hold = (v_hold + 60e-3) / 20e-3 i_hold = i_hold / 400e-12 v_noise = v_noise / 5e-3 i_noise = i_noise / 100e-12 plt = self.get_channel_plot(chan) plt.setLabels(left=("Ch %d" % chan)) for data, symbol in [(np.zeros_like(times), 'o'), (v_hold, 't'), (i_hold, 'x'), (v_noise, 't1'), (i_noise, 'x')]: brushes = np.where(np.abs(data) > 1.0, fail_brush, pass_brush) plt.plot(times, data, pen=None, symbol=symbol, symbolPen=None, symbolBrush=brushes) for i in recs.keys(): start = (recs[i][0].start_time - start_time).seconds - 1 stop = (recs[i][-1].start_time - start_time).seconds + 1 pip_param = self.params.child('Pipette %d' % (i + 1)) pip_param.set_time_range(start, stop) got_data = len(recs[i]) > 2 pip_param['got data'] = got_data def add_pipette(self, channel, status=None, **kwds): elec = PipetteParameter(self, channel, status=status, **kwds) self.params.addChild(elec, autoIncrementName=True) elec.child('channel').sigValueChanged.connect( self._pipette_channel_changed) elec.region.sigRegionChangeFinished.connect( self._pipette_region_changed) self._pipette_channel_changed(elec.child('channel')) def _pipette_channel_changed(self, param): plt = self.get_channel_plot(param.value()) plt.addItem(param.parent().region) self._rename_pipettes() def _pipette_region_changed(self): self._rename_pipettes() def _rename_pipettes(self): # sort electrodes by channel elecs = {} for elec in self.params.children(): elecs.setdefault(elec['channel'], []).append(elec) for chan in elecs: # sort all electrodes on this channel by start time chan_elecs = sorted(elecs[chan], key=lambda e: e.region.getRegion()[0]) # assign names for i, elec in enumerate(chan_elecs): # rename all first to avoid name colisions elec.setName('rename%d' % i) for i, elec in enumerate(chan_elecs): # If there are multiple electrodes on this channel, then # each extra electrode increments its name by the number of # headstages (for example, on AD channel 3, the first electrode # is called "Electrode 4", and on an 8-headstage system, the # second electrode will be "Electrode 12"). e_id = (chan + 1) + (i * config.n_headstages) elec.id = e_id elec.setName('Pipette %d' % e_id) def save(self): state = {} for elec in self.params.children(): rgn = elec.region.getRegion() if self.start_time is None: start = None stop = None else: start = self.start_time + datetime.timedelta(seconds=rgn[0]) stop = self.start_time + datetime.timedelta(seconds=rgn[1]) state[elec.id] = OrderedDict([ ('pipette_status', elec['status']), ('got_data', elec['got data']), ('ad_channel', elec['channel']), ('patch_start', start), ('patch_stop', stop), ('cell_labels', { 'biocytin': '', 'red': '', 'green': '', 'blue': '' }), #('cell_qc', {'holding': None, 'access': None, 'spiking': None}), ('target_layer', elec['target layer']), ('morphology', elec['morphology']), ('internal_solution', elec['internal']), ('internal_dye', elec['internal dye']), ('synapse_to', None), ('gap_to', None), ('notes', ''), ]) return state
from neuroanalysis.data import Trace # Load test data data = np.load('test_data/evoked_spikes/vc_evoked_spikes.npz')['arr_0'] dt = 20e-6 # gaussian filtering constant sigma = 20e-6 / dt # Initialize Qt pg.mkQApp() pg.dbg() # Create a window with a grid of plots (N rows, 1 column) win = PlotGrid() win.set_shape(data.shape[0], 1) win.show() # Loop over all 10 channels for i in range(data.shape[0]): # select the data for this channel trace = data[i, :, 0] stim = data[i, :, 1] # select the plot we will use for this trace plot = win[i, 0] # link all x-axes together plot.setXLink(win[0, 0]) xaxis = plot.getAxis('bottom') if i == data.shape[0] - 1:
# model = ReleaseModel() # model.Dynamics = {'Dep':1, 'Fac':0, 'UR':0, 'SMR':0, 'DSR':0} # params = {(('2/3', 'unknown'), ('2/3', 'unknown')): [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # ((None,'rorb'), (None,'rorb')): [0, 506.7, 0, 0, 0, 0, 0.22, 0, 0, 0, 0, 0], # ((None,'sim1'), (None,'sim1')): [0, 1213.8, 0, 0, 0, 0, 0.17, 0, 0, 0, 0, 0], # ((None,'tlx3'), (None,'tlx3')): [0, 319.4, 0, 0, 0, 0, 0.16, 0, 0, 0, 0, 0]} ind = data[0] rec = data[1] freq = 10 delta =250 order = human_connections.keys() delay_order = [250, 500, 1000, 2000, 4000] ind_plt = PlotGrid() ind_plt.set_shape(4, 1) ind_plt.show() rec_plt = PlotGrid() rec_plt.set_shape(1, 2) rec_plt.show() ind_plt_scatter = pg.plot() ind_plt_all = PlotGrid() ind_plt_all.set_shape(1, 2) ind_plt_all.show() ind_50 = {} ninth_pulse_250 = {} gain_plot = pg.plot() for t, type in enumerate(order): if type == (('2', 'unknown'), ('2', 'unknown')):
pt = pg.parametertree.ParameterTree(showHeader=False) params = pg.parametertree.Parameter.create(name='params', type='group', children=[ dict(name='data', type='list', values=trace_names), evd.params, ]) pt.setParameters(params, showTop=False) hs.addWidget(pt) plots = PlotGrid() plots.set_shape(2, 1) plots.setXLink(plots[0, 0]) hs.addWidget(plots) evd.set_plots(plots[0, 0], plots[1, 0]) def update(auto_range=False): evd.process(traces[params['data']]) if auto_range: plots[0, 0].autoRange() evd.parameters_changed.connect(lambda: update(auto_range=False)) params.child('data').sigValueChanged.connect(lambda: update(auto_range=True))
help='plot 50Hz train and deconvolution') parser.add_argument('--link-y-axis', action='store_true', default=False, dest='link-y-axis', help='link all y-axis down a column') args = vars(parser.parse_args(sys.argv[1:])) plot_sweeps = args['sweeps'] plot_trains = args['trains'] link_y_axis = args['link-y-axis'] expt_cache = 'C:/Users/Stephanies/multipatch_analysis/tools/expts_cache.pkl' all_expts = ExperimentList(cache=expt_cache) test = PlotGrid() test.set_shape(len(connection_types.keys()), 1) test.show() grid = PlotGrid() if plot_trains is True: grid.set_shape(len(connection_types.keys()), 3) grid[0, 1].setTitle(title='50 Hz Train') grid[0, 2].setTitle(title='Exponential Deconvolution') tau = 15e-3 lp = 1000 else: grid.set_shape(len(connection_types.keys()), 2) grid.show() row = 0 grid[0, 0].setTitle(title='First Pulse') maxYpulse = []
def distance_plot(connected, distance, plots=None, color=(100, 100, 255), window=40e-6, spacing=None, name=None, fill_alpha=30): """Draw connectivity vs distance profiles with confidence intervals. Parameters ---------- connected : boolean array Whether a synaptic connection was found for each probe distance : array Distance between cells for each probe plots : list of PlotWidget | PlotItem (optional) Two plots used to display distance profile and scatter plot. color : tuple (R, G, B) color values for line and confidence interval. The confidence interval will be drawn with alpha=100 window : float Width of distance window over which proportions are calculated for each point on the profile line. spacing : float Distance spacing between points on the profile line Note: using a spacing value that is smaller than the window size may cause an otherwise smooth decrease over distance to instead look more like a series of downward steps. """ color = pg.colorTuple(pg.mkColor(color))[:3] connected = np.array(connected).astype(float) distance = np.array(distance) pts = np.vstack([distance, connected]).T # scatter points a bit conn = pts[:, 1] == 1 unconn = pts[:, 1] == 0 if np.any(conn): cscat = pg.pseudoScatter(pts[:, 0][conn], spacing=10e-6, bidir=False) mx = abs(cscat).max() if mx != 0: cscat = cscat * 0.2 # / mx pts[:, 1][conn] = -2e-5 - cscat if np.any(unconn): uscat = pg.pseudoScatter(pts[:, 0][unconn], spacing=10e-6, bidir=False) mx = abs(uscat).max() if mx != 0: uscat = uscat * 0.2 # / mx pts[:, 1][unconn] = uscat # scatter plot connections probed if plots is None: grid = PlotGrid() grid.set_shape(2, 1) grid.grid.ci.layout.setRowStretchFactor(0, 5) grid.grid.ci.layout.setRowStretchFactor(1, 10) plots = (grid[1, 0], grid[0, 0]) plots[0].grid = grid plots[0].addLegend() grid.show() plots[0].setLabels(bottom=('distance', 'm'), left='connection probability') if plots[1] is not None: plots[1].setXLink(plots[0]) plots[1].hideAxis('bottom') plots[1].hideAxis('left') color2 = color + (100, ) scatter = plots[1].plot(pts[:, 0], pts[:, 1], pen=None, symbol='o', labels={'bottom': ('distance', 'm')}, symbolBrush=color2, symbolPen=None, name=name) scatter.scatter.opts[ 'compositionMode'] = pg.QtGui.QPainter.CompositionMode_Plus # use a sliding window to plot the proportion of connections found along with a 95% confidence interval # for connection probability if spacing is None: spacing = window / 4.0 xvals = np.arange(window / 2.0, 500e-6, spacing) upper = [] lower = [] prop = [] ci_xvals = [] for x in xvals: minx = x - window / 2.0 maxx = x + window / 2.0 # select points inside this window mask = (distance >= minx) & (distance <= maxx) pts_in_window = connected[mask] # compute stats for window n_probed = pts_in_window.shape[0] n_conn = pts_in_window.sum() if n_probed == 0: prop.append(np.nan) else: prop.append(n_conn / n_probed) ci = binomial_ci(n_conn, n_probed) lower.append(ci[0]) upper.append(ci[1]) ci_xvals.append(x) # plot connection probability and confidence intervals color2 = [c / 3.0 for c in color] mid_curve = plots[0].plot(xvals, prop, pen={ 'color': color, 'width': 3 }, antialias=True, name=name) upper_curve = plots[0].plot(ci_xvals, upper, pen=(0, 0, 0, 0), antialias=True) lower_curve = plots[0].plot(ci_xvals, lower, pen=(0, 0, 0, 0), antialias=True) upper_curve.setVisible(False) lower_curve.setVisible(False) color2 = color + (fill_alpha, ) fill = pg.FillBetweenItem(upper_curve, lower_curve, brush=color2) fill.setZValue(-10) plots[0].addItem(fill, ignoreBounds=True) return plots
class DynamicsWindow(pg.QtGui.QSplitter): def __init__(self): pg.QtGui.QSplitter.__init__(self, pg.QtCore.Qt.Horizontal) self.browser = ExperimentBrowser() self.addWidget(self.browser) self.plots = PlotGrid() self.addWidget(self.plots) self.browser.itemSelectionChanged.connect(self.browser_item_selected) def browser_item_selected(self): with pg.BusyCursor(): selected = self.browser.selectedItems() if len(selected) != 1: return item = selected[0] if not hasattr(item, 'pair'): return pair = item.pair self.load_pair(pair) def load_pair(self, pair): print("Loading:", pair) q = pulse_response_query(pair, data=True) self.sorted_recs = sorted_pulse_responses(q.all()) self.plot_all() def plot_all(self): self.plots.clear() self.plots.set_shape(len(self.sorted_recs), 1) psp = StackedPsp() stim_keys = sorted(list(self.sorted_recs.keys())) for i, stim_key in enumerate(stim_keys): prs = self.sorted_recs[stim_key] plt = self.plots[i, 0] plt.setTitle("%s %0.0f Hz %0.2f s" % stim_key) for recording in prs: pulses = sorted(list(prs[recording].keys())) for pulse_n in pulses: rec = prs[recording][pulse_n] # spike-align pulse + offset for pulse number spike_t = rec.stim_pulse.first_spike_time if spike_t is None: spike_t = rec.stim_pulse.onset_time + 1e-3 qc_pass = rec.pulse_response.in_qc_pass if rec.synapse.synapse_type == 'in' else rec.pulse_response.ex_qc_pass pen = (255, 255, 255, 100) if qc_pass else (100, 0, 0, 100) t0 = rec.pulse_response.data_start_time - spike_t ts = TSeries(data=rec.data, t0=t0, sample_rate=db.default_sample_rate) c = plt.plot(ts.time_values, ts.data, pen=pen) # arrange plots nicely shift = (pulse_n * 35e-3 + (30e-3 if pulse_n > 8 else 0), 0) c.setPos(*shift) if not qc_pass: c.setZValue(-10) continue # evaluate recorded fit for this response fit_par = rec.pulse_response_fit if fit_par.fit_amp is None: continue fit = psp.eval( x=ts.time_values, exp_amp=fit_par.fit_exp_amp, exp_tau=fit_par.fit_decay_tau, amp=fit_par.fit_amp, rise_time=fit_par.fit_rise_time, decay_tau=fit_par.fit_decay_tau, xoffset=fit_par.fit_latency, yoffset=fit_par.fit_yoffset, rise_power=2, ) c = plt.plot(ts.time_values, fit, pen=(0, 255, 0, 100)) c.setZValue(10) c.setPos(*shift)
connection_types = human_connections.keys() else: c_type = connection.split('-') connection_types = [((c_type[0], 'unknown'), (c_type[1], 'unknown'))] plt = pg.plot() scale_offset = (-20, -20) scale_anchor = (0.4, 1) holding = [-65, -75] qc_plot = pg.plot() grand_response = {} expt_ids = {} feature_plot = None feature2_plot = PlotGrid() feature2_plot.set_shape(5,1) feature2_plot.show() feature3_plot = PlotGrid() feature3_plot.set_shape(1, 3) feature3_plot.show() amp_plot = pg.plot() synapse_plot = PlotGrid() synapse_plot.set_shape(len(connection_types), 1) synapse_plot.show() for c in range(len(connection_types)): cre_type = (connection_types[c][0][1], connection_types[c][1][1]) target_layer = (connection_types[c][0][0], connection_types[c][1][0]) type = connection_types[c] expt_list = all_expts.select(cre_type=cre_type, target_layer=target_layer, calcium=calcium, age=age) color = color_palette[c] grand_response[type[0]] = {'trace': [], 'amp': [], 'latency': [], 'rise': [], 'dist': [], 'decay':[], 'CV': [], 'amp_measured': []}
connection = args['connection'] if connection == 'ee': connection_types = human_connections.keys() else: c_type = connection.split('-') connection_types = [((c_type[0], 'unknown'), (c_type[1], 'unknown'))] sweep_threshold = 5 threshold = 0.03e-3 scale_offset = (-20, -20) scale_anchor = (0.4, 1) qc_plot = pg.plot() grand_response = {} feature_plot = None synapse_plot = PlotGrid() synapse_plot.set_shape(len(connection_types), 1) synapse_plot.show() for c in range(len(connection_types)): cre_type = (connection_types[c][0][1], connection_types[c][1][1]) target_layer = (connection_types[c][0][0], connection_types[c][1][0]) type = connection_types[c] expt_list = all_expts.select(cre_type=cre_type, target_layer=target_layer, calcium=calcium, age=age) color = color_palette[c] grand_response[type[0]] = { 'trace': [], 'amp': [], 'latency': [], 'rise': [],
def save_fit_psp_test_set(): """NOTE THIS CODE DOES NOT WORK BUT IS HERE FOR DOCUMENTATION PURPOSES SO THAT WE CAN TRACE BACK HOW THE TEST DATA WAS CREATED IF NEEDED. Create a test set of data for testing the fit_psp function. Uses Steph's original first_puls_feature.py code to filter out error causing data. Example run statement python save save_fit_psp_test_set.py --organism mouse --connection ee Comment in the code that does the saving at the bottom """ import pyqtgraph as pg import numpy as np import csv import sys import argparse from multipatch_analysis.experiment_list import cached_experiments from manuscript_figures import get_response, get_amplitude, response_filter, feature_anova, write_cache, trace_plot, \ colors_human, colors_mouse, fail_rate, pulse_qc, feature_kw from synapse_comparison import load_cache, summary_plot_pulse from neuroanalysis.data import TraceList, Trace from neuroanalysis.ui.plot_grid import PlotGrid from multipatch_analysis.connection_detection import fit_psp from rep_connections import ee_connections, human_connections, no_include, all_connections, ie_connections, ii_connections, ei_connections from multipatch_analysis.synaptic_dynamics import DynamicsAnalyzer from scipy import stats import time import pandas as pd import json import os app = pg.mkQApp() pg.dbg() pg.setConfigOption('background', 'w') pg.setConfigOption('foreground', 'k') parser = argparse.ArgumentParser(description='Enter organism and type of connection you"d like to analyze ex: mouse ee (all mouse excitatory-' 'excitatory). Alternatively enter a cre-type connection ex: sim1-sim1') parser.add_argument('--organism', dest='organism', help='Select mouse or human') parser.add_argument('--connection', dest='connection', help='Specify connections to analyze') args = vars(parser.parse_args(sys.argv[1:])) all_expts = cached_experiments() manifest = {'Type': [], 'Connection': [], 'amp': [], 'latency': [],'rise':[], 'rise2080': [], 'rise1090': [], 'rise1080': [], 'decay': [], 'nrmse': [], 'CV': []} fit_qc = {'nrmse': 8, 'decay': 499e-3} if args['organism'] == 'mouse': color_palette = colors_mouse calcium = 'high' age = '40-60' sweep_threshold = 3 threshold = 0.03e-3 connection = args['connection'] if connection == 'ee': connection_types = ee_connections.keys() elif connection == 'ii': connection_types = ii_connections.keys() elif connection == 'ei': connection_types = ei_connections.keys() elif connection == 'ie': connection_types == ie_connections.keys() elif connection == 'all': connection_types = all_connections.keys() elif len(connection.split('-')) == 2: c_type = connection.split('-') if c_type[0] == '2/3': pre_type = ('2/3', 'unknown') else: pre_type = (None, c_type[0]) if c_type[1] == '2/3': post_type = ('2/3', 'unknown') else: post_type = (None, c_type[0]) connection_types = [(pre_type, post_type)] elif args['organism'] == 'human': color_palette = colors_human calcium = None age = None sweep_threshold = 5 threshold = None connection = args['connection'] if connection == 'ee': connection_types = human_connections.keys() else: c_type = connection.split('-') connection_types = [((c_type[0], 'unknown'), (c_type[1], 'unknown'))] plt = pg.plot() scale_offset = (-20, -20) scale_anchor = (0.4, 1) holding = [-65, -75] qc_plot = pg.plot() grand_response = {} expt_ids = {} feature_plot = None feature2_plot = PlotGrid() feature2_plot.set_shape(5,1) feature2_plot.show() feature3_plot = PlotGrid() feature3_plot.set_shape(1, 3) feature3_plot.show() amp_plot = pg.plot() synapse_plot = PlotGrid() synapse_plot.set_shape(len(connection_types), 1) synapse_plot.show() for c in range(len(connection_types)): cre_type = (connection_types[c][0][1], connection_types[c][1][1]) target_layer = (connection_types[c][0][0], connection_types[c][1][0]) conn_type = connection_types[c] expt_list = all_expts.select(cre_type=cre_type, target_layer=target_layer, calcium=calcium, age=age) color = color_palette[c] grand_response[conn_type[0]] = {'trace': [], 'amp': [], 'latency': [], 'rise': [], 'dist': [], 'decay':[], 'CV': [], 'amp_measured': []} expt_ids[conn_type[0]] = [] synapse_plot[c, 0].addLegend() for expt in expt_list: for pre, post in expt.connections: if [expt.uid, pre, post] in no_include: continue cre_check = expt.cells[pre].cre_type == cre_type[0] and expt.cells[post].cre_type == cre_type[1] layer_check = expt.cells[pre].target_layer == target_layer[0] and expt.cells[post].target_layer == target_layer[1] if cre_check is True and layer_check is True: pulse_response, artifact = get_response(expt, pre, post, analysis_type='pulse') if threshold is not None and artifact > threshold: continue response_subset, hold = response_filter(pulse_response, freq_range=[0, 50], holding_range=holding, pulse=True) if len(response_subset) >= sweep_threshold: qc_plot.clear() qc_list = pulse_qc(response_subset, baseline=1.5, pulse=None, plot=qc_plot) if len(qc_list) >= sweep_threshold: avg_trace, avg_amp, amp_sign, peak_t = get_amplitude(qc_list) # if amp_sign is '-': # continue # #print ('%s, %0.0f' %((expt.uid, pre, post), hold, )) # all_amps = fail_rate(response_subset, '+', peak_t) # cv = np.std(all_amps)/np.mean(all_amps) # # # weight parts of the trace during fitting dt = avg_trace.dt weight = np.ones(len(avg_trace.data))*10. #set everything to ten initially weight[int(10e-3/dt):int(12e-3/dt)] = 0. #area around stim artifact weight[int(12e-3/dt):int(19e-3/dt)] = 30. #area around steep PSP rise # check if the test data dir is there and if not create it test_data_dir='test_psp_fit' if not os.path.isdir(test_data_dir): os.mkdir(test_data_dir) save_dict={} save_dict['input']={'data': avg_trace.data.tolist(), 'dtype': str(avg_trace.data.dtype), 'dt': float(avg_trace.dt), 'amp_sign': amp_sign, 'yoffset': 0, 'xoffset': 14e-3, 'avg_amp': float(avg_amp), 'method': 'leastsq', 'stacked': False, 'rise_time_mult_factor': 10., 'weight': weight.tolist()} # need to remake trace because different output is created avg_trace_simple=Trace(data=np.array(save_dict['input']['data']), dt=save_dict['input']['dt']) # create Trace object psp_fits_original = fit_psp(avg_trace, sign=save_dict['input']['amp_sign'], yoffset=save_dict['input']['yoffset'], xoffset=save_dict['input']['xoffset'], amp=save_dict['input']['avg_amp'], method=save_dict['input']['method'], stacked=save_dict['input']['stacked'], rise_time_mult_factor=save_dict['input']['rise_time_mult_factor'], fit_kws={'weights': save_dict['input']['weight']}) psp_fits_simple = fit_psp(avg_trace_simple, sign=save_dict['input']['amp_sign'], yoffset=save_dict['input']['yoffset'], xoffset=save_dict['input']['xoffset'], amp=save_dict['input']['avg_amp'], method=save_dict['input']['method'], stacked=save_dict['input']['stacked'], rise_time_mult_factor=save_dict['input']['rise_time_mult_factor'], fit_kws={'weights': save_dict['input']['weight']}) print expt.uid, pre, post if psp_fits_original.nrmse()!=psp_fits_simple.nrmse(): print ' the nrmse values dont match' print '\toriginal', psp_fits_original.nrmse() print '\tsimple', psp_fits_simple.nrmse()
def plot_features(organism=None, conn_type=None, calcium=None, age=None, sweep_thresh=None, fit_thresh=None): s = db.Session() filters = { 'organism': organism, 'conn_type': conn_type, 'calcium': calcium, 'age': age } selection = [{}] for key, value in filters.iteritems(): if value is not None: temp_list = [] value_list = value.split(',') for v in value_list: temp = [s1.copy() for s1 in selection] for t in temp: t[key] = v temp_list = temp_list + temp selection = list(temp_list) if len(selection) > 0: response_grid = PlotGrid() response_grid.set_shape(len(selection), 1) response_grid.show() feature_grid = PlotGrid() feature_grid.set_shape(6, 1) feature_grid.show() for i, select in enumerate(selection): pre_cell = db.aliased(db.Cell) post_cell = db.aliased(db.Cell) q_filter = [] if sweep_thresh is not None: q_filter.append(FirstPulseFeatures.n_sweeps>=sweep_thresh) species = select.get('organism') if species is not None: q_filter.append(db.Slice.species==species) c_type = select.get('conn_type') if c_type is not None: pre_type, post_type = c_type.split('-') pre_layer, pre_cre = pre_type.split(';') if pre_layer == 'None': pre_layer = None post_layer, post_cre = post_type.split(';') if post_layer == 'None': post_layer = None q_filter.extend([pre_cell.cre_type==pre_cre, pre_cell.target_layer==pre_layer, post_cell.cre_type==post_cre, post_cell.target_layer==post_layer]) calc_conc = select.get('calcium') if calc_conc is not None: q_filter.append(db.Experiment.acsf.like(calc_conc + '%')) age_range = select.get('age') if age_range is not None: age_lower, age_upper = age_range.split('-') q_filter.append(db.Slice.age.between(int(age_lower), int(age_upper))) q = s.query(FirstPulseFeatures).join(db.Pair, FirstPulseFeatures.pair_id==db.Pair.id)\ .join(pre_cell, db.Pair.pre_cell_id==pre_cell.id)\ .join(post_cell, db.Pair.post_cell_id==post_cell.id)\ .join(db.Experiment, db.Experiment.id==db.Pair.expt_id)\ .join(db.Slice, db.Slice.id==db.Experiment.slice_id) for filter_arg in q_filter: q = q.filter(filter_arg) results = q.all() trace_list = [] for pair in results: #TODO set t0 to latency to align to foot of PSP trace = Trace(data=pair.avg_psp, sample_rate=db.default_sample_rate) trace_list.append(trace) response_grid[i, 0].plot(trace.time_values, trace.data) if len(trace_list) > 0: grand_trace = TraceList(trace_list).mean() response_grid[i, 0].plot(grand_trace.time_values, grand_trace.data, pen='b') response_grid[i, 0].setTitle('layer %s, %s-> layer %s, %s; n_synapses = %d' % (pre_layer, pre_cre, post_layer, post_cre, len(trace_list))) else: print('No synapses for layer %s, %s-> layer %s, %s' % (pre_layer, pre_cre, post_layer, post_cre)) return response_grid, feature_grid
def plot_features(organism=None, conn_type=None, calcium=None, age=None, sweep_thresh=None, fit_thresh=None): s = db.session() filters = { 'organism': organism, 'conn_type': conn_type, 'calcium': calcium, 'age': age } selection = [{}] for key, value in filters.iteritems(): if value is not None: temp_list = [] value_list = value.split(',') for v in value_list: temp = [s1.copy() for s1 in selection] for t in temp: t[key] = v temp_list = temp_list + temp selection = list(temp_list) if len(selection) > 0: response_grid = PlotGrid() response_grid.set_shape(len(selection), 1) response_grid.show() feature_grid = PlotGrid() feature_grid.set_shape(6, 1) feature_grid.show() for i, select in enumerate(selection): pre_cell = aliased(db.Cell) post_cell = aliased(db.Cell) q_filter = [] if sweep_thresh is not None: q_filter.append(FirstPulseFeatures.n_sweeps >= sweep_thresh) species = select.get('organism') if species is not None: q_filter.append(db.Slice.species == species) c_type = select.get('conn_type') if c_type is not None: pre_type, post_type = c_type.split('-') pre_layer, pre_cre = pre_type.split(';') if pre_layer == 'None': pre_layer = None post_layer, post_cre = post_type.split(';') if post_layer == 'None': post_layer = None q_filter.extend([ pre_cell.cre_type == pre_cre, pre_cell.target_layer == pre_layer, post_cell.cre_type == post_cre, post_cell.target_layer == post_layer ]) calc_conc = select.get('calcium') if calc_conc is not None: q_filter.append(db.Experiment.acsf.like(calc_conc + '%')) age_range = select.get('age') if age_range is not None: age_lower, age_upper = age_range.split('-') q_filter.append( db.Slice.age.between(int(age_lower), int(age_upper))) q = s.query(FirstPulseFeatures).join(db.Pair, FirstPulseFeatures.pair_id==db.Pair.id)\ .join(pre_cell, db.Pair.pre_cell_id==pre_cell.id)\ .join(post_cell, db.Pair.post_cell_id==post_cell.id)\ .join(db.Experiment, db.Experiment.id==db.Pair.expt_id)\ .join(db.Slice, db.Slice.id==db.Experiment.slice_id) for filter_arg in q_filter: q = q.filter(filter_arg) results = q.all() trace_list = [] for pair in results: #TODO set t0 to latency to align to foot of PSP trace = TSeries(data=pair.avg_psp, sample_rate=db.default_sample_rate) trace_list.append(trace) response_grid[i, 0].plot(trace.time_values, trace.data) if len(trace_list) > 0: grand_trace = TSeriesList(trace_list).mean() response_grid[i, 0].plot(grand_trace.time_values, grand_trace.data, pen='b') response_grid[i, 0].setTitle( 'layer %s, %s-> layer %s, %s; n_synapses = %d' % (pre_layer, pre_cre, post_layer, post_cre, len(trace_list))) else: print('No synapses for layer %s, %s-> layer %s, %s' % (pre_layer, pre_cre, post_layer, post_cre)) return response_grid, feature_grid
class ExperimentTimeline(QtGui.QWidget): def __init__(self): QtGui.QWidget.__init__(self) self.channels = None self.start_time = None # starting time according to NWB file self.layout = QtGui.QGridLayout() self.setLayout(self.layout) self.layout.setContentsMargins(0, 0, 0, 0) self.plots = PlotGrid() self.plots.set_shape(config.n_headstages, 1) self.plots.setXLink(self.plots[0, 0]) self.layout.addWidget(self.plots, 0, 0) self.ptree = pg.parametertree.ParameterTree(showHeader=False) self.layout.addWidget(self.ptree, 0, 1) self.ptree.setMaximumWidth(250) self.params = pg.parametertree.Parameter.create(name='params', type='group', addText='Add pipette') self.params.addNew = self.add_pipette_clicked # monkey! self.ptree.setParameters(self.params, showTop=False) def list_channels(self): return self.channels def get_channel_plot(self, chan): return self.plots[chan, 0] def add_pipette_clicked(self): self.add_pipette(channel=self.channels[0], start=0, stop=500) def remove_pipettes(self): for ch in self.params.children(): self.params.removeChild(ch) ch.region.scene().removeItem(ch.region) for i in range(self.plots.shape[0]): self.plots[i,0].clear() def load_site(self, site_dh): """Generate pipette list for this site """ self.remove_pipettes() # automatically fill pipette fluorophore field expt_dh = site_dh.parent().parent() expt_info = expt_dh.info() dye = expt_info.get('internal_dye', None) internal = expt_info.get('internal', None) # automatically select electrode regions self.channels = list(range(config.n_headstages)) site_info = site_dh.info() for i in self.channels: hs_state = site_info.get('Headstage %d'%(i+1), None) status = { 'NS': 'No seal', 'LS': 'Low seal', 'GS': 'GOhm seal', 'TF': 'Technical failure', 'NA': 'No attempt', None: 'Not recorded', }.get(hs_state, hs_state) self.add_pipette(i, status=status, internal_dye=dye, internal=internal) def load_nwb(self, nwb_handle): with pg.BusyCursor(): self._load_nwb(nwb_handle) def _load_nwb(self, nwb_handle): self.nwb_handle = nwb_handle self.nwb = MiesNwb(nwb_handle.name()) # load all recordings recs = {} for srec in self.nwb.contents: for chan in srec.devices: recs.setdefault(chan, []).append(srec[chan]) chans = sorted(recs.keys()) # find time of first recording start_time = min([rec[0].start_time for rec in recs.values()]) self.start_time = start_time end_time = max([rec[-1].start_time for rec in recs.values()]) self.plots.setXRange(0, (end_time-start_time).seconds) # plot all recordings for i,chan in enumerate(chans): n_recs = len(recs[chan]) times = np.empty(n_recs) i_hold = np.empty(n_recs) v_hold = np.empty(n_recs) v_noise = np.empty(n_recs) i_noise = np.empty(n_recs) # load QC metrics for all recordings for j,rec in enumerate(recs[chan]): dt = (rec.start_time - start_time).seconds times[j] = dt v_hold[j] = rec.baseline_potential i_hold[j] = rec.baseline_current if rec.clamp_mode == 'vc': v_noise[j] = np.nan i_noise[j] = rec.baseline_rms_noise else: v_noise[j] = rec.baseline_rms_noise i_noise[j] = np.nan # scale all qc metrics to the range 0-1 pass_brush = pg.mkBrush(100, 100, 255, 200) fail_brush = pg.mkBrush(255, 0, 0, 200) v_hold = (v_hold + 60e-3) / 20e-3 i_hold = i_hold / 400e-12 v_noise = v_noise / 5e-3 i_noise = i_noise / 100e-12 plt = self.get_channel_plot(chan) plt.setLabels(left=("Ch %d" % chan)) for data,symbol in [(np.zeros_like(times), 'o'), (v_hold, 't'), (i_hold, 'x'), (v_noise, 't1'), (i_noise, 'x')]: brushes = np.where(np.abs(data) > 1.0, fail_brush, pass_brush) plt.plot(times, data, pen=None, symbol=symbol, symbolPen=None, symbolBrush=brushes) for i in recs.keys(): start = (recs[i][0].start_time - start_time).seconds - 1 stop = (recs[i][-1].start_time - start_time).seconds + 1 pip_param = self.params.child('Pipette %d' % (i+1)) pip_param.set_time_range(start, stop) got_data = len(recs[i]) > 2 pip_param['got data'] = got_data def add_pipette(self, channel, status=None, **kwds): elec = PipetteParameter(self, channel, status=status, **kwds) self.params.addChild(elec, autoIncrementName=True) elec.child('channel').sigValueChanged.connect(self._pipette_channel_changed) elec.region.sigRegionChangeFinished.connect(self._pipette_region_changed) self._pipette_channel_changed(elec.child('channel')) def _pipette_channel_changed(self, param): plt = self.get_channel_plot(param.value()) plt.addItem(param.parent().region) self._rename_pipettes() def _pipette_region_changed(self): self._rename_pipettes() def _rename_pipettes(self): # sort electrodes by channel elecs = {} for elec in self.params.children(): elecs.setdefault(elec['channel'], []).append(elec) for chan in elecs: # sort all electrodes on this channel by start time chan_elecs = sorted(elecs[chan], key=lambda e: e.region.getRegion()[0]) # assign names for i,elec in enumerate(chan_elecs): # rename all first to avoid name colisions elec.setName('rename%d' % i) for i,elec in enumerate(chan_elecs): # If there are multiple electrodes on this channel, then # each extra electrode increments its name by the number of # headstages (for example, on AD channel 3, the first electrode # is called "Electrode 4", and on an 8-headstage system, the # second electrode will be "Electrode 12"). e_id = (chan+1) + (i*config.n_headstages) elec.id = e_id elec.setName('Pipette %d' % e_id) def save(self): state = {} for elec in self.params.children(): rgn = elec.region.getRegion() if self.start_time is None: start = None stop = None else: start = self.start_time + datetime.timedelta(seconds=rgn[0]) stop = self.start_time + datetime.timedelta(seconds=rgn[1]) state[elec.id] = OrderedDict([ ('pipette_status', elec['status']), ('got_data', elec['got data']), ('ad_channel', elec['channel']), ('patch_start', start), ('patch_stop', stop), ('cell_labels', {'biocytin': '', 'red': '', 'green': '', 'blue': ''}), #('cell_qc', {'holding': None, 'access': None, 'spiking': None}), ('target_layer', elec['target layer']), ('morphology', elec['morphology']), ('internal_solution', elec['internal']), ('internal_dye', elec['internal dye']), ('synapse_to', None), ('gap_to', None), ('notes', ''), ]) return state