def plot_model_results(self, model=None, fit=None): if model is None: if self._last_model_fit is None: raise Exception("Must run fit_release_model before plotting results.") model, fit = self._last_model_fit spike_sets = self.spike_sets rel_plots = PlotGrid() rel_plots.set_shape(2, 1) ind_plot = rel_plots[0, 0] ind_plot.setTitle('Release model fit - induction frequency') ind_plot.setLabels(bottom=('time', 's'), left='relative amplitude') rec_plot = rel_plots[1, 0] rec_plot.setTitle('Release model fit - recovery delay') rec_plot.setLabels(bottom=('time', 's'), left='relative amplitude') ind_plot.setLogMode(x=True, y=False) rec_plot.setLogMode(x=True, y=False) ind_plot.setXLink(rec_plot) for i,stim_params in enumerate(self.stim_param_order): x,y = spike_sets[i] output = model.eval(x, fit.values(), dt=0.5) y1 = output[:,1] x1 = output[:,0] if stim_params[1] - 0.250 < 5e-3: ind_plot.plot((x+10)/1000., y, pen=None, symbol='o', symbolBrush=(i,10)) ind_plot.plot((x1+10)/1000., y1, pen=(i,10)) if stim_params[0] == 50: rec_plot.plot((x+10)/1000., y, pen=None, symbol='o', symbolBrush=(i,10)) rec_plot.plot((x1+10)/1000., y1, pen=(i,10)) rel_plots.show() return rel_plots
def summary_plot_pulse(grand_trace, feature_list, feature_mean, labels, titles, i, plot=None, color=None, name=None): if type(feature_list) is tuple: n_features = len(feature_list) else: n_features = 1 if plot is None: plot = PlotGrid() plot.set_shape(n_features, 2) plot.show() for g in range(n_features): plot[g, 1].addLegend() plot[g, 1].setLabels(left=('Vm', 'V')) plot[g, 1].setLabels(bottom=('t', 's')) for feature in range(n_features): if n_features > 1: features = feature_list[feature] mean = feature_mean[feature] label = labels[feature] title = titles[feature] else: features = feature_list mean = feature_mean label = labels title = titles plot[feature, 0].setLabels(left=(label[0], label[1])) plot[feature, 0].hideAxis('bottom') plot[feature, 0].setTitle(title) plot[feature, 1].plot(grand_trace.time_values, grand_trace.data, pen=color, name=name) dx = pg.pseudoScatter(np.array(features).astype(float), 0.3, bidir=True) plot[feature, 0].plot((0.3 * dx / dx.max()) + i, features, pen=None, symbol='x', symbolSize=5, symbolBrush=color, symbolPen=None) plot[feature, 0].plot([i], [mean], pen=None, symbol='o', symbolBrush=color, symbolPen='w', symbolSize=10) return plot
def plot_train_responses(self, plot_grid=None): """ Plot individual and averaged train responses for each set of stimulus parameters. Return a new PlotGrid. """ train_responses = self.train_responses if plot_grid is None: train_plots = PlotGrid() else: train_plots = plot_grid train_plots.set_shape(len(self.stim_param_order), 2) for i,stim_params in enumerate(train_responses.keys()): # Collect and plot average traces covering the induction and recovery # periods for this set of stim params ind_group = train_responses[stim_params][0] rec_group = train_responses[stim_params][1] for j in range(len(ind_group)): ind = ind_group.responses[j] rec = rec_group.responses[j] base = np.median(ind_group.baselines[j].data) train_plots[i,0].plot(ind.time_values, ind.data - base, pen=(128, 128, 128, 100)) train_plots[i,1].plot(rec.time_values, rec.data - base, pen=(128, 128, 128, 100)) ind_avg = ind_group.bsub_mean() rec_avg = rec_group.bsub_mean() ind_freq, rec_delay, holding = stim_params rec_delay = np.round(rec_delay, 2) train_plots[i,0].plot(ind_avg.time_values, ind_avg.data, pen='g', antialias=True) train_plots[i,1].plot(rec_avg.time_values, rec_avg.data, pen='g', antialias=True) train_plots[i,0].setLabels(left=('Vm', 'V')) label = pg.LabelItem("ind: %0.0f rec: %0.0f hold: %0.0f" % (ind_freq, rec_delay*1000, holding*1000)) label.setParentItem(train_plots[i,0].vb) train_plots[i,0].label = label train_plots.show() train_plots.setYLink(train_plots[0,0]) for i in range(train_plots.shape[0]): train_plots[i,0].setXLink(train_plots[0,0]) train_plots[i,1].setXLink(train_plots[0,1]) train_plots.grid.ci.layout.setColumnStretchFactor(0, 3) train_plots.grid.ci.layout.setColumnStretchFactor(1, 2) train_plots.setClipToView(False) # has a bug :( train_plots.setDownsampling(True, True, 'peak') return train_plots
def trace_average_matrix(expts, **kwds): types = ['tlx3', 'sim1', 'pvalb', 'sst', 'vip'] results = {} plots = PlotGrid() plots.set_shape(len(types), len(types)) indplots = PlotGrid() indplots.set_shape(len(types), len(types)) plots.show() indplots.show() for i, pre_type in enumerate(types): for j, post_type in enumerate(types): avg_plot = plots[i, j] ind_plot = indplots[i, j] avg = plot_trace_average(all_expts, pre_type, post_type, avg_plot, ind_plot, **kwds) results[(pre_type, post_type)] = avg return results
def summary_plot_pulse(feature_list, labels, titles, i, median=False, grand_trace=None, plot=None, color=None, name=None): if type(feature_list) is tuple: n_features = len(feature_list) else: n_features = 1 if plot is None: plot = PlotGrid() plot.set_shape(n_features, 2) plot.show() for g in range(n_features): plot[g, 1].addLegend() for feature in range(n_features): if n_features > 1: current_feature = feature_list[feature] if median is True: mean = np.nanmedian(current_feature) else: mean = np.nanmean(current_feature) label = labels[feature] title = titles[feature] else: current_feature = feature_list mean = np.nanmean(current_feature) label = labels title = titles plot[feature, 0].setLabels(left=(label[0], label[1])) plot[feature, 0].hideAxis('bottom') plot[feature, 0].setTitle(title) if grand_trace is not None: plot[feature, 1].plot(grand_trace.time_values, grand_trace.data, pen=color, name=name) if len(current_feature) > 1: dx = pg.pseudoScatter(np.array(current_feature).astype(float), 0.7, bidir=True) #bar = pg.BarGraphItem(x=[i], height=mean, width=0.7, brush='w', pen={'color': color, 'width': 2}) #plot[feature, 0].addItem(bar) plot[feature, 0].plot([i], [mean], symbol='o', symbolSize=20, symbolPen='k', symbolBrush=color) sem = stats.sem(current_feature, nan_policy='omit') #err = pg.ErrorBarItem(x=np.asarray([i]), y=np.asarray([mean]), height=sem, beam=0.1) #plot[feature, 0].addItem(err) plot[feature, 0].plot((0.3 * dx / dx.max()) + i, current_feature, pen=None, symbol='o', symbolSize=10, symbolPen='w', symbolBrush=(color[0], color[1], color[2], 100)) else: plot[feature, 0].plot([i], current_feature, pen=None, symbol='o', symbolSize=10, symbolPen='w', symbolBrush=color) return plot
def pulse_average_matrix(expts, **kwds): results = {} types = ['sim1', 'tlx3', 'pvalb', 'sst', 'vip'] plots = PlotGrid() plots.set_shape(len(types), len(types)) indplots = PlotGrid() indplots.set_shape(len(types), len(types)) plots.show() indplots.show() for i, pre_type in enumerate(types): for j, post_type in enumerate(types): avg_plot = plots[i, j] ind_plot = indplots[i, j] avg = plot_pulse_average(all_expts, pre_type, post_type, avg_plot, ind_plot, **kwds) pg.QtGui.QApplication.processEvents() results[(pre_type, post_type)] = avg return results
def plot_model_results(self, model=None, fit=None): if model is None: if self._last_model_fit is None: raise Exception( "Must run fit_release_model before plotting results.") model, fit = self._last_model_fit spike_sets = self.spike_sets rel_plots = PlotGrid() rel_plots.set_shape(2, 1) ind_plot = rel_plots[0, 0] ind_plot.setTitle('Release model fit - induction frequency') ind_plot.setLabels(bottom=('time', 's'), left='relative amplitude') rec_plot = rel_plots[1, 0] rec_plot.setTitle('Release model fit - recovery delay') rec_plot.setLabels(bottom=('time', 's'), left='relative amplitude') ind_plot.setLogMode(x=True, y=False) rec_plot.setLogMode(x=True, y=False) ind_plot.setXLink(rec_plot) for i, stim_params in enumerate(self.stim_param_order): x, y = spike_sets[i] output = model.eval(x, fit.values(), dt=0.5) y1 = output[:, 1] x1 = output[:, 0] if stim_params[1] - 0.250 < 5e-3: ind_plot.plot((x + 10) / 1000., y, pen=None, symbol='o', symbolBrush=(i, 10)) ind_plot.plot((x1 + 10) / 1000., y1, pen=(i, 10)) if stim_params[0] == 50: rec_plot.plot((x + 10) / 1000., y, pen=None, symbol='o', symbolBrush=(i, 10)) rec_plot.plot((x1 + 10) / 1000., y1, pen=(i, 10)) rel_plots.show() return rel_plots
def plot_response_averages(expt, show_baseline=False, **kwds): analyzer = MultiPatchExperimentAnalyzer.get(expt) devs = analyzer.list_devs() # First get average evoked responses for all pre/post pairs responses, rows, cols = analyzer.get_evoked_response_matrix(**kwds) # resize plot grid accordingly plots = PlotGrid() plots.set_shape(len(rows), len(cols)) plots.show() ranges = [([], []), ([], [])] points = [] # Plot each matrix element with PSP fit for i, dev1 in enumerate(rows): for j, dev2 in enumerate(cols): # select plot and hide axes plt = plots[i, j] if i < len(devs) - 1: plt.getAxis('bottom').setVisible(False) if j > 0: plt.getAxis('left').setVisible(False) if dev1 == dev2: plt.getAxis('bottom').setVisible(False) plt.getAxis('left').setVisible(False) continue # adjust axes / labels plt.setXLink(plots[0, 0]) plt.setYLink(plots[0, 0]) plt.addLine(x=10e-3, pen=0.3) plt.addLine(y=0, pen=0.3) plt.setLabels(bottom=(str(dev2), 's')) if kwds.get('clamp_mode', 'ic') == 'ic': plt.setLabels(left=('%s' % dev1, 'V')) else: plt.setLabels(left=('%s' % dev1, 'A')) # print "==========", dev1, dev2 avg_response = responses[(dev1, dev2)].bsub_mean() if avg_response is not None: avg_response.t0 = 0 t = avg_response.time_values y = bessel_filter(Trace(avg_response.data, dt=avg_response.dt), 2e3).data plt.plot(t, y, antialias=True) # fit! #fit = responses[(dev1, dev2)].fit_psp(yoffset=0, mask_stim_artifact=(abs(dev1-dev2) < 3)) #lsnr = np.log(fit.snr) #lerr = np.log(fit.nrmse()) #color = ( #np.clip(255 * (-lerr/3.), 0, 255), #np.clip(50 * lsnr, 0, 255), #np.clip(255 * (1+lerr/3.), 0, 255) #) #plt.plot(t, fit.best_fit, pen=color) ## plt.plot(t, fit.init_fit, pen='y') #points.append({'x': lerr, 'y': lsnr, 'brush': color}) #if show_baseline: ## plot baseline for reference #bl = avg_response.meta['baseline'] - avg_response.meta['baseline_med'] #plt.plot(np.arange(len(bl)) * avg_response.dt, bl, pen=(0, 100, 0), antialias=True) # keep track of data range across all plots ranges[0][0].append(y.min()) ranges[0][1].append(y.max()) ranges[1][0].append(t[0]) ranges[1][1].append(t[-1]) plots[0,0].setYRange(min(ranges[0][0]), max(ranges[0][1])) plots[0,0].setXRange(min(ranges[1][0]), max(ranges[1][1])) # scatter plot of SNR vs NRMSE plt = pg.plot() plt.setLabels(left='ln(SNR)', bottom='ln(NRMSE)') plt.plot([p['x'] for p in points], [p['y'] for p in points], pen=None, symbol='o', symbolBrush=[pg.mkBrush(p['brush']) for p in points]) # show threshold line line = pg.InfiniteLine(pos=[0, 6], angle=180/np.pi * np.arctan(1)) plt.addItem(line, ignoreBounds=True) return plots
def distance_plot(connected, distance, plots=None, color=(100, 100, 255), size=10, window=40e-6, name=None, fill_alpha=30): """Draw connectivity vs distance profiles with confidence intervals. Parameters ---------- connected : boolean array Whether a synaptic connection was found for each probe distance : array Distance between cells for each probe plots : list of PlotWidget | PlotItem (optional) Two plots used to display distance profile and scatter plot. color : tuple (R, G, B) color values for line and confidence interval. The confidence interval will be drawn with reduced opacity (see *fill_alpha*) size: int size of scatter plot symbol window : float Width of distance window over which proportions are calculated for each point on the profile line. fill_alpha : int Opacity of confidence interval fill (0-255) Note: using a spacing value that is smaller than the window size may cause an otherwise smooth decrease over distance to instead look more like a series of downward steps. """ color = pg.colorTuple(pg.mkColor(color))[:3] connected = np.array(connected).astype(float) distance = np.array(distance) # scatter plot connections probed if plots is None: grid = PlotGrid() grid.set_shape(2, 1) grid.grid.ci.layout.setRowStretchFactor(0, 5) grid.grid.ci.layout.setRowStretchFactor(1, 10) plots = (grid[1, 0], grid[0, 0]) plots[0].grid = grid plots[0].addLegend() grid.show() plots[0].setLabels(bottom=('distance', 'm'), left='connection probability') if plots[1] is not None: # scatter points a bit pts = np.vstack([distance, connected]).T conn = pts[:, 1] == 1 unconn = pts[:, 1] == 0 if np.any(conn): cscat = pg.pseudoScatter(pts[:, 0][conn], spacing=10e-6, bidir=False) mx = abs(cscat).max() if mx != 0: cscat = cscat * 0.2 # / mx pts[:, 1][conn] = -5e-5 - cscat if np.any(unconn): uscat = pg.pseudoScatter(pts[:, 0][unconn], spacing=10e-6, bidir=False) mx = abs(uscat).max() if mx != 0: uscat = uscat * 0.2 # / mx pts[:, 1][unconn] = uscat plots[1].setXLink(plots[0]) plots[1].hideAxis('bottom') plots[1].hideAxis('left') color2 = color + (100, ) scatter = plots[1].plot(pts[:, 0], pts[:, 1], pen=None, symbol='o', labels={'bottom': ('distance', 'm')}, size=size, symbolBrush=color2, symbolPen=None, name=name) scatter.scatter.opts[ 'compositionMode'] = pg.QtGui.QPainter.CompositionMode_Plus # use a sliding window to plot the proportion of connections found along with a 95% confidence interval # for connection probability bin_edges = np.arange(0, 500e-6, window) xvals, prop, lower, upper = connectivity_profile(connected, distance, bin_edges) # plot connection probability and confidence intervals color2 = [c / 3.0 for c in color] xvals = (xvals[:-1] + xvals[1:]) * 0.5 mid_curve = plots[0].plot(xvals, prop, pen={ 'color': color, 'width': 3 }, antialias=True, name=name) upper_curve = plots[0].plot(xvals, upper, pen=(0, 0, 0, 0), antialias=True) lower_curve = plots[0].plot(xvals, lower, pen=(0, 0, 0, 0), antialias=True) upper_curve.setVisible(False) lower_curve.setVisible(False) color2 = color + (fill_alpha, ) fill = pg.FillBetweenItem(upper_curve, lower_curve, brush=color2) fill.setZValue(-10) plots[0].addItem(fill, ignoreBounds=True) return plots, xvals, prop, upper, lower
holding = [-55, -75] freqs = [10, 20, 50, 100] rec_t = [250, 500, 1000, 2000, 4000] sweep_threshold = 3 deconv = True # cache_file = 'train_response_cache.pkl' # response_cache = load_cache(cache_file) # cache_change = [] # log_rec_plt = pg.plot() # log_rec_plt.setLogMode(x=True) qc_plot = pg.plot() ind_plot = PlotGrid() ind_plot.set_shape(4, len(connection_types)) ind_plot.show() rec_plot = PlotGrid() rec_plot.set_shape(5, len(connection_types)) rec_plot.show() if deconv is True: deconv_ind_plot = PlotGrid() deconv_ind_plot.set_shape(4, len(connection_types)) deconv_ind_plot.show() deconv_rec_plot = PlotGrid() deconv_rec_plot.set_shape(5, len(connection_types)) deconv_rec_plot.show() summary_plot = PlotGrid() summary_plot.set_shape(len(connection_types), 2) summary_plot.show() symbols = ['o', 's', 'd', '+', 't'] trace_color = (0, 0, 0, 5)
# model.Dynamics = {'Dep':1, 'Fac':0, 'UR':0, 'SMR':0, 'DSR':0} # params = {(('2/3', 'unknown'), ('2/3', 'unknown')): [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # ((None,'rorb'), (None,'rorb')): [0, 506.7, 0, 0, 0, 0, 0.22, 0, 0, 0, 0, 0], # ((None,'sim1'), (None,'sim1')): [0, 1213.8, 0, 0, 0, 0, 0.17, 0, 0, 0, 0, 0], # ((None,'tlx3'), (None,'tlx3')): [0, 319.4, 0, 0, 0, 0, 0.16, 0, 0, 0, 0, 0]} ind = data[0] rec = data[1] freq = 10 delta =250 order = human_connections.keys() delay_order = [250, 500, 1000, 2000, 4000] ind_plt = PlotGrid() ind_plt.set_shape(4, 1) ind_plt.show() rec_plt = PlotGrid() rec_plt.set_shape(1, 2) rec_plt.show() ind_plt_scatter = pg.plot() ind_plt_all = PlotGrid() ind_plt_all.set_shape(1, 2) ind_plt_all.show() ind_50 = {} ninth_pulse_250 = {} gain_plot = pg.plot() for t, type in enumerate(order): if type == (('2', 'unknown'), ('2', 'unknown')): continue
parser.add_argument('--link-y-axis', action='store_true', default=False, dest='link-y-axis', help='link all y-axis down a column') args = vars(parser.parse_args(sys.argv[1:])) plot_sweeps = args['sweeps'] plot_trains = args['trains'] link_y_axis = args['link-y-axis'] expt_cache = 'C:/Users/Stephanies/multipatch_analysis/tools/expts_cache.pkl' all_expts = ExperimentList(cache=expt_cache) test = PlotGrid() test.set_shape(len(connection_types.keys()), 1) test.show() grid = PlotGrid() if plot_trains is True: grid.set_shape(len(connection_types.keys()), 3) grid[0, 1].setTitle(title='50 Hz Train') grid[0, 2].setTitle(title='Exponential Deconvolution') tau = 15e-3 lp = 1000 else: grid.set_shape(len(connection_types.keys()), 2) grid.show() row = 0 grid[0, 0].setTitle(title='First Pulse') maxYpulse = [] maxYtrain = []
def distance_plot(connected, distance, plots=None, color=(100, 100, 255), size=10, window=40e-6, spacing=None, name=None, fill_alpha=30): """Draw connectivity vs distance profiles with confidence intervals. Parameters ---------- connected : boolean array Whether a synaptic connection was found for each probe distance : array Distance between cells for each probe plots : list of PlotWidget | PlotItem (optional) Two plots used to display distance profile and scatter plot. color : tuple (R, G, B) color values for line and confidence interval. The confidence interval will be drawn with alpha=100 size: int size of scatter plot symbol window : float Width of distance window over which proportions are calculated for each point on the profile line. spacing : float Distance spacing between points on the profile line Note: using a spacing value that is smaller than the window size may cause an otherwise smooth decrease over distance to instead look more like a series of downward steps. """ color = pg.colorTuple(pg.mkColor(color))[:3] connected = np.array(connected).astype(float) distance = np.array(distance) pts = np.vstack([distance, connected]).T # scatter points a bit conn = pts[:,1] == 1 unconn = pts[:,1] == 0 if np.any(conn): cscat = pg.pseudoScatter(pts[:,0][conn], spacing=10e-6, bidir=False) mx = abs(cscat).max() if mx != 0: cscat = cscat * 0.2# / mx pts[:,1][conn] = -5e-5 - cscat if np.any(unconn): uscat = pg.pseudoScatter(pts[:,0][unconn], spacing=10e-6, bidir=False) mx = abs(uscat).max() if mx != 0: uscat = uscat * 0.2# / mx pts[:,1][unconn] = uscat # scatter plot connections probed if plots is None: grid = PlotGrid() grid.set_shape(2, 1) grid.grid.ci.layout.setRowStretchFactor(0, 5) grid.grid.ci.layout.setRowStretchFactor(1, 10) plots = (grid[1,0], grid[0,0]) plots[0].grid = grid plots[0].addLegend() grid.show() plots[0].setLabels(bottom=('distance', 'm'), left='connection probability') if plots[1] is not None: plots[1].setXLink(plots[0]) plots[1].hideAxis('bottom') plots[1].hideAxis('left') color2 = color + (100,) scatter = plots[1].plot(pts[:,0], pts[:,1], pen=None, symbol='o', labels={'bottom': ('distance', 'm')}, size=size, symbolBrush=color2, symbolPen=None, name=name) scatter.scatter.opts['compositionMode'] = pg.QtGui.QPainter.CompositionMode_Plus # use a sliding window to plot the proportion of connections found along with a 95% confidence interval # for connection probability if spacing is None: spacing = window / 4.0 xvals = np.arange(window / 2.0, 500e-6, spacing) upper = [] lower = [] prop = [] ci_xvals = [] for x in xvals: minx = x - window / 2.0 maxx = x + window / 2.0 # select points inside this window mask = (distance >= minx) & (distance <= maxx) pts_in_window = connected[mask] # compute stats for window n_probed = pts_in_window.shape[0] n_conn = pts_in_window.sum() if n_probed == 0: prop.append(np.nan) else: prop.append(n_conn / n_probed) ci = proportion_confint(n_conn, n_probed, method='beta') lower.append(ci[0]) upper.append(ci[1]) ci_xvals.append(x) # plot connection probability and confidence intervals color2 = [c / 3.0 for c in color] mid_curve = plots[0].plot(xvals, prop, pen={'color': color, 'width': 3}, antialias=True, name=name) upper_curve = plots[0].plot(ci_xvals, upper, pen=(0, 0, 0, 0), antialias=True) lower_curve = plots[0].plot(ci_xvals, lower, pen=(0, 0, 0, 0), antialias=True) upper_curve.setVisible(False) lower_curve.setVisible(False) color2 = color + (fill_alpha,) fill = pg.FillBetweenItem(upper_curve, lower_curve, brush=color2) fill.setZValue(-10) plots[0].addItem(fill, ignoreBounds=True) return plots, ci_xvals, prop, upper, lower
class PairView(QtGui.QWidget): """For analyzing pre/post-synaptic pairs. """ def __init__(self, parent=None): self.sweeps = [] self.channels = [] self.current_event_set = None self.event_sets = [] QtGui.QWidget.__init__(self, parent) self.layout = QtGui.QGridLayout() self.setLayout(self.layout) self.layout.setContentsMargins(0, 0, 0, 0) self.vsplit = QtGui.QSplitter(QtCore.Qt.Vertical) self.layout.addWidget(self.vsplit, 0, 0) self.pre_plot = pg.PlotWidget() self.post_plot = pg.PlotWidget() for plt in (self.pre_plot, self.post_plot): plt.setClipToView(True) plt.setDownsampling(True, True, 'peak') self.post_plot.setXLink(self.pre_plot) self.vsplit.addWidget(self.pre_plot) self.vsplit.addWidget(self.post_plot) self.response_plots = PlotGrid() self.vsplit.addWidget(self.response_plots) self.event_splitter = QtGui.QSplitter(QtCore.Qt.Horizontal) self.vsplit.addWidget(self.event_splitter) self.event_table = pg.TableWidget() self.event_splitter.addWidget(self.event_table) self.event_widget = QtGui.QWidget() self.event_widget_layout = QtGui.QGridLayout() self.event_widget.setLayout(self.event_widget_layout) self.event_splitter.addWidget(self.event_widget) self.event_splitter.setSizes([600, 100]) self.event_set_list = QtGui.QListWidget() self.event_widget_layout.addWidget(self.event_set_list, 0, 0, 1, 2) self.event_set_list.addItem("current") self.event_set_list.itemSelectionChanged.connect(self.event_set_selected) self.add_set_btn = QtGui.QPushButton('add') self.event_widget_layout.addWidget(self.add_set_btn, 1, 0, 1, 1) self.add_set_btn.clicked.connect(self.add_set_clicked) self.remove_set_btn = QtGui.QPushButton('del') self.event_widget_layout.addWidget(self.remove_set_btn, 1, 1, 1, 1) self.remove_set_btn.clicked.connect(self.remove_set_clicked) self.fit_btn = QtGui.QPushButton('fit all') self.event_widget_layout.addWidget(self.fit_btn, 2, 0, 1, 2) self.fit_btn.clicked.connect(self.fit_clicked) self.fit_plot = PlotGrid() self.fit_explorer = None self.artifact_remover = ArtifactRemover(user_width=True) self.baseline_remover = BaselineRemover() self.filter = SignalFilter() self.params = pg.parametertree.Parameter(name='params', type='group', children=[ {'name': 'pre', 'type': 'list', 'values': []}, {'name': 'post', 'type': 'list', 'values': []}, self.artifact_remover.params, self.baseline_remover.params, self.filter.params, {'name': 'time constant', 'type': 'float', 'suffix': 's', 'siPrefix': True, 'value': 10e-3, 'dec': True, 'minStep': 100e-6} ]) self.params.sigTreeStateChanged.connect(self._update_plots) def data_selected(self, sweeps, channels): self.sweeps = sweeps self.channels = channels self.params.child('pre').setLimits(channels) self.params.child('post').setLimits(channels) self._update_plots() def _update_plots(self): sweeps = self.sweeps self.current_event_set = None self.event_table.clear() # clear all plots self.pre_plot.clear() self.post_plot.clear() pre = self.params['pre'] post = self.params['post'] # If there are no selected sweeps or channels have not been set, return if len(sweeps) == 0 or pre == post or pre not in self.channels or post not in self.channels: return pre_mode = sweeps[0][pre].clamp_mode post_mode = sweeps[0][post].clamp_mode for ch, mode, plot in [(pre, pre_mode, self.pre_plot), (post, post_mode, self.post_plot)]: units = 'A' if mode == 'vc' else 'V' plot.setLabels(left=("Channel %d" % ch, units), bottom=("Time", 's')) # Iterate over selected channels of all sweeps, plotting traces one at a time # Collect information about pulses and spikes pulses = [] spikes = [] post_traces = [] for i,sweep in enumerate(sweeps): pre_trace = sweep[pre]['primary'] post_trace = sweep[post]['primary'] # Detect pulse times stim = sweep[pre]['command'].data sdiff = np.diff(stim) on_times = np.argwhere(sdiff > 0)[1:, 0] # 1: skips test pulse off_times = np.argwhere(sdiff < 0)[1:, 0] pulses.append(on_times) # filter data post_filt = self.artifact_remover.process(post_trace, list(on_times) + list(off_times)) post_filt = self.baseline_remover.process(post_filt) post_filt = self.filter.process(post_filt) post_traces.append(post_filt) # plot raw data color = pg.intColor(i, hues=len(sweeps)*1.3, sat=128) color.setAlpha(128) for trace, plot in [(pre_trace, self.pre_plot), (post_filt, self.post_plot)]: plot.plot(trace.time_values, trace.data, pen=color, antialias=False) # detect spike times spike_inds = [] spike_info = [] for on, off in zip(on_times, off_times): spike = detect_evoked_spike(sweep[pre], [on, off]) spike_info.append(spike) if spike is None: spike_inds.append(None) else: spike_inds.append(spike['rise_index']) spikes.append(spike_info) dt = pre_trace.dt vticks = pg.VTickGroup([x * dt for x in spike_inds if x is not None], yrange=[0.0, 0.2], pen=color) self.pre_plot.addItem(vticks) # Iterate over spikes, plotting average response all_responses = [] avg_responses = [] fits = [] fit = None npulses = max(map(len, pulses)) self.response_plots.clear() self.response_plots.set_shape(1, npulses+1) # 1 extra for global average self.response_plots.setYLink(self.response_plots[0,0]) for i in range(1, npulses+1): self.response_plots[0,i].hideAxis('left') units = 'A' if post_mode == 'vc' else 'V' self.response_plots[0, 0].setLabels(left=("Averaged events (Channel %d)" % post, units)) fit_pen = {'color':(30, 30, 255), 'width':2, 'dash': [1, 1]} for i in range(npulses): # get the chunk of each sweep between spikes responses = [] all_responses.append(responses) for j, sweep in enumerate(sweeps): # get the current spike if i >= len(spikes[j]): continue spike = spikes[j][i] if spike is None: continue # find next spike next_spike = None for sp in spikes[j][i+1:]: if sp is not None: next_spike = sp break # determine time range for response max_len = int(40e-3 / dt) # don't take more than 50ms for any response start = spike['rise_index'] if next_spike is not None: stop = min(start + max_len, next_spike['rise_index']) else: stop = start + max_len # collect data from this trace trace = post_traces[j] d = trace.data[start:stop].copy() responses.append(d) if len(responses) == 0: continue # extend all responses to the same length and take nanmean avg = ragged_mean(responses, method='clip') avg -= float_mode(avg[:int(1e-3/dt)]) avg_responses.append(avg) # plot average response for this pulse start = np.median([sp[i]['rise_index'] for sp in spikes if sp[i] is not None]) * dt t = np.arange(len(avg)) * dt self.response_plots[0,i].plot(t, avg, pen='w', antialias=True) # fit! fit = self.fit_psp(avg, t, dt, post_mode) fits.append(fit) # let the user mess with this fit curve = self.response_plots[0,i].plot(t, fit.eval(), pen=fit_pen, antialias=True).curve curve.setClickable(True) curve.fit = fit curve.sigClicked.connect(self.fit_curve_clicked) # display global average global_avg = ragged_mean(avg_responses, method='clip') t = np.arange(len(global_avg)) * dt self.response_plots[0,-1].plot(t, global_avg, pen='w', antialias=True) global_fit = self.fit_psp(global_avg, t, dt, post_mode) self.response_plots[0,-1].plot(t, global_fit.eval(), pen=fit_pen, antialias=True) # display fit parameters in table events = [] for i,f in enumerate(fits + [global_fit]): if f is None: continue if i >= len(fits): vals = OrderedDict([('id', 'avg'), ('spike_time', np.nan), ('spike_stdev', np.nan)]) else: spt = [s[i]['peak_index'] * dt for s in spikes if s[i] is not None] vals = OrderedDict([('id', i), ('spike_time', np.mean(spt)), ('spike_stdev', np.std(spt))]) vals.update(OrderedDict([(k,f.best_values[k]) for k in f.params.keys()])) events.append(vals) self.current_event_set = (pre, post, events, sweeps) self.event_set_list.setCurrentRow(0) self.event_set_selected() def fit_psp(self, data, t, dt, clamp_mode): mode = float_mode(data[:int(1e-3/dt)]) sign = -1 if data.mean() - mode < 0 else 1 params = OrderedDict([ ('xoffset', (2e-3, 3e-4, 5e-3)), ('yoffset', data[0]), ('amp', sign * 10e-12), #('k', (2e-3, 50e-6, 10e-3)), ('rise_time', (2e-3, 50e-6, 10e-3)), ('decay_tau', (4e-3, 500e-6, 50e-3)), ('rise_power', (2.0, 'fixed')), ]) if clamp_mode == 'ic': params['amp'] = sign * 10e-3 #params['k'] = (5e-3, 50e-6, 20e-3) params['rise_time'] = (5e-3, 50e-6, 20e-3) params['decay_tau'] = (15e-3, 500e-6, 150e-3) fit_kws = {'xtol': 1e-3, 'maxfev': 100} psp = fitting.Psp() return psp.fit(data, x=t, fit_kws=fit_kws, params=params) def add_set_clicked(self): if self.current_event_set is None: return ces = self.current_event_set self.event_sets.append(ces) item = QtGui.QListWidgetItem("%d -> %d" % (ces[0], ces[1])) self.event_set_list.addItem(item) item.event_set = ces def remove_set_clicked(self): if self.event_set_list.currentRow() == 0: return sel = self.event_set_list.takeItem(self.event_set_list.currentRow()) self.event_sets.remove(sel.event_set) def event_set_selected(self): sel = self.event_set_list.selectedItems()[0] if sel.text() == "current": self.event_table.setData(self.current_event_set[2]) else: self.event_table.setData(sel.event_set[2]) def fit_clicked(self): from synaptic_release import ReleaseModel self.fit_plot.clear() self.fit_plot.show() n_sets = len(self.event_sets) self.fit_plot.set_shape(n_sets, 1) l = self.fit_plot[0, 0].legend if l is not None: l.scene.removeItem(l) self.fit_plot[0, 0].legend = None self.fit_plot[0, 0].addLegend() for i in range(n_sets): self.fit_plot[i,0].setXLink(self.fit_plot[0, 0]) spike_sets = [] for i,evset in enumerate(self.event_sets): evset = evset[2][:-1] # select events, skip last row (average) x = np.array([ev['spike_time'] for ev in evset]) y = np.array([ev['amp'] for ev in evset]) x -= x[0] x *= 1000 y /= y[0] spike_sets.append((x, y)) self.fit_plot[i,0].plot(x/1000., y, pen=None, symbol='o') model = ReleaseModel() dynamics_types = ['Dep', 'Fac', 'UR', 'SMR', 'DSR'] for k in dynamics_types: model.Dynamics[k] = 0 fit_params = [] with pg.ProgressDialog("Fitting release model..", 0, len(dynamics_types)) as dlg: for k in dynamics_types: model.Dynamics[k] = 1 fit_params.append(model.run_fit(spike_sets)) dlg += 1 if dlg.wasCanceled(): return max_color = len(fit_params)*1.5 for i,params in enumerate(fit_params): for j,spikes in enumerate(spike_sets): x, y = spikes t = np.linspace(0, x.max(), 1000) output = model.eval(x, params.values()) y = output[:,1] x = output[:,0]/1000. self.fit_plot[j,0].plot(x, y, pen=(i,max_color), name=dynamics_types[i]) def fit_curve_clicked(self, curve): if self.fit_explorer is None: self.fit_explorer = FitExplorer(curve.fit) else: self.fit_explorer.set_fit(curve.fit) self.fit_explorer.show()
def summary_plot_pulse(feature_list, labels, titles, i, median=False, grand_trace=None, plot=None, color=None, name=None): """ Plots features of single-pulse responses such as amplitude, latency, etc. for group analysis. Can be used for one group by ideal for comparing across many groups in the feature_list Parameters ---------- feature_list : list of lists of floats single-pulse features such as amplitude. Can be multiple features each a list themselves labels : list of pyqtgraph.LabelItem axis labels, must be a list of same length as feature_list titles : list of strings plot title, must be a list of same length as feature_list i : integer iterator to place groups along x-axis median : boolean to calculate median (True) vs mean (False), default is False grand_trace : neuroanalysis.data.TraceView object option to plot response trace alongside scatter plot, default is None plot : pyqtgraph.PlotItem If not None, plot the data on the referenced pyqtgraph object. color : tuple plot color name : pyqtgraph.LegendItem Returns ------- plot : pyqtgraph.PlotItem 2 x n plot with scatter plot and optional trace response plot for each feature (n) """ if type(feature_list) is tuple: n_features = len(feature_list) else: n_features = 1 if plot is None: plot = PlotGrid() plot.set_shape(n_features, 2) plot.show() for g in range(n_features): plot[g, 1].addLegend() for feature in range(n_features): if n_features > 1: current_feature = feature_list[feature] if median is True: mean = np.nanmedian(current_feature) else: mean = np.nanmean(current_feature) label = labels[feature] title = titles[feature] else: current_feature = feature_list mean = np.nanmean(current_feature) label = labels title = titles plot[feature, 0].setLabels(left=(label[0], label[1])) plot[feature, 0].hideAxis('bottom') plot[feature, 0].setTitle(title) if grand_trace is not None: plot[feature, 1].plot(grand_trace.time_values, grand_trace.data, pen=color, name=name) if len(current_feature) > 1: dx = pg.pseudoScatter(np.array(current_feature).astype(float), 0.7, bidir=True) plot[feature, 0].plot([i], [mean], symbol='o', symbolSize=20, symbolPen='k', symbolBrush=color) sem = stats.sem(current_feature, nan_policy='omit') if len(color) != 3: new_color = pg.glColor(color) color = (new_color[0]*255, new_color[1]*255, new_color[2]*255) plot[feature, 0].plot((0.3 * dx / dx.max()) + i, current_feature, pen=None, symbol='o', symbolSize=10, symbolPen='w', symbolBrush=(color[0], color[1], color[2], 100)) else: plot[feature, 0].plot([i], current_feature, pen=None, symbol='o', symbolSize=10, symbolPen='w', symbolBrush=color) return plot
def plot_train_responses(self, plot_grid=None): """ Plot individual and averaged train responses for each set of stimulus parameters. Return a new PlotGrid. """ train_responses = self.train_responses if plot_grid is None: train_plots = PlotGrid() else: train_plots = plot_grid train_plots.set_shape(len(self.stim_param_order), 2) for i, stim_params in enumerate(train_responses.keys()): # Collect and plot average traces covering the induction and recovery # periods for this set of stim params ind_group = train_responses[stim_params][0] rec_group = train_responses[stim_params][1] for j in range(len(ind_group)): ind = ind_group.responses[j] rec = rec_group.responses[j] base = np.median(ind_group.baselines[j].data) train_plots[i, 0].plot(ind.time_values, ind.data - base, pen=(128, 128, 128, 100)) train_plots[i, 1].plot(rec.time_values, rec.data - base, pen=(128, 128, 128, 100)) ind_avg = ind_group.bsub_mean() rec_avg = rec_group.bsub_mean() ind_freq, rec_delay, holding = stim_params rec_delay = np.round(rec_delay, 2) train_plots[i, 0].plot(ind_avg.time_values, ind_avg.data, pen='g', antialias=True) train_plots[i, 1].plot(rec_avg.time_values, rec_avg.data, pen='g', antialias=True) train_plots[i, 0].setLabels(left=('Vm', 'V')) label = pg.LabelItem("ind: %0.0f rec: %0.0f hold: %0.0f" % (ind_freq, rec_delay * 1000, holding * 1000)) label.setParentItem(train_plots[i, 0].vb) train_plots[i, 0].label = label train_plots.show() train_plots.setYLink(train_plots[0, 0]) for i in range(train_plots.shape[0]): train_plots[i, 0].setXLink(train_plots[0, 0]) train_plots[i, 1].setXLink(train_plots[0, 1]) train_plots.grid.ci.layout.setColumnStretchFactor(0, 3) train_plots.grid.ci.layout.setColumnStretchFactor(1, 2) train_plots.setClipToView(False) # has a bug :( train_plots.setDownsampling(True, True, 'peak') return train_plots
def save_fit_psp_test_set(): """NOTE THIS CODE DOES NOT WORK BUT IS HERE FOR DOCUMENTATION PURPOSES SO THAT WE CAN TRACE BACK HOW THE TEST DATA WAS CREATED IF NEEDED. Create a test set of data for testing the fit_psp function. Uses Steph's original first_puls_feature.py code to filter out error causing data. Example run statement python save save_fit_psp_test_set.py --organism mouse --connection ee Comment in the code that does the saving at the bottom """ import pyqtgraph as pg import numpy as np import csv import sys import argparse from multipatch_analysis.experiment_list import cached_experiments from manuscript_figures import get_response, get_amplitude, response_filter, feature_anova, write_cache, trace_plot, \ colors_human, colors_mouse, fail_rate, pulse_qc, feature_kw from synapse_comparison import load_cache, summary_plot_pulse from neuroanalysis.data import TraceList, Trace from neuroanalysis.ui.plot_grid import PlotGrid from multipatch_analysis.connection_detection import fit_psp from rep_connections import ee_connections, human_connections, no_include, all_connections, ie_connections, ii_connections, ei_connections from multipatch_analysis.synaptic_dynamics import DynamicsAnalyzer from scipy import stats import time import pandas as pd import json import os app = pg.mkQApp() pg.dbg() pg.setConfigOption('background', 'w') pg.setConfigOption('foreground', 'k') parser = argparse.ArgumentParser( description= 'Enter organism and type of connection you"d like to analyze ex: mouse ee (all mouse excitatory-' 'excitatory). Alternatively enter a cre-type connection ex: sim1-sim1') parser.add_argument('--organism', dest='organism', help='Select mouse or human') parser.add_argument('--connection', dest='connection', help='Specify connections to analyze') args = vars(parser.parse_args(sys.argv[1:])) all_expts = cached_experiments() manifest = { 'Type': [], 'Connection': [], 'amp': [], 'latency': [], 'rise': [], 'rise2080': [], 'rise1090': [], 'rise1080': [], 'decay': [], 'nrmse': [], 'CV': [] } fit_qc = {'nrmse': 8, 'decay': 499e-3} if args['organism'] == 'mouse': color_palette = colors_mouse calcium = 'high' age = '40-60' sweep_threshold = 3 threshold = 0.03e-3 connection = args['connection'] if connection == 'ee': connection_types = ee_connections.keys() elif connection == 'ii': connection_types = ii_connections.keys() elif connection == 'ei': connection_types = ei_connections.keys() elif connection == 'ie': connection_types == ie_connections.keys() elif connection == 'all': connection_types = all_connections.keys() elif len(connection.split('-')) == 2: c_type = connection.split('-') if c_type[0] == '2/3': pre_type = ('2/3', 'unknown') else: pre_type = (None, c_type[0]) if c_type[1] == '2/3': post_type = ('2/3', 'unknown') else: post_type = (None, c_type[0]) connection_types = [(pre_type, post_type)] elif args['organism'] == 'human': color_palette = colors_human calcium = None age = None sweep_threshold = 5 threshold = None connection = args['connection'] if connection == 'ee': connection_types = human_connections.keys() else: c_type = connection.split('-') connection_types = [((c_type[0], 'unknown'), (c_type[1], 'unknown'))] plt = pg.plot() scale_offset = (-20, -20) scale_anchor = (0.4, 1) holding = [-65, -75] qc_plot = pg.plot() grand_response = {} expt_ids = {} feature_plot = None feature2_plot = PlotGrid() feature2_plot.set_shape(5, 1) feature2_plot.show() feature3_plot = PlotGrid() feature3_plot.set_shape(1, 3) feature3_plot.show() amp_plot = pg.plot() synapse_plot = PlotGrid() synapse_plot.set_shape(len(connection_types), 1) synapse_plot.show() for c in range(len(connection_types)): cre_type = (connection_types[c][0][1], connection_types[c][1][1]) target_layer = (connection_types[c][0][0], connection_types[c][1][0]) conn_type = connection_types[c] expt_list = all_expts.select(cre_type=cre_type, target_layer=target_layer, calcium=calcium, age=age) color = color_palette[c] grand_response[conn_type[0]] = { 'trace': [], 'amp': [], 'latency': [], 'rise': [], 'dist': [], 'decay': [], 'CV': [], 'amp_measured': [] } expt_ids[conn_type[0]] = [] synapse_plot[c, 0].addLegend() for expt in expt_list: for pre, post in expt.connections: if [expt.uid, pre, post] in no_include: continue cre_check = expt.cells[pre].cre_type == cre_type[ 0] and expt.cells[post].cre_type == cre_type[1] layer_check = expt.cells[pre].target_layer == target_layer[ 0] and expt.cells[post].target_layer == target_layer[1] if cre_check is True and layer_check is True: pulse_response, artifact = get_response( expt, pre, post, analysis_type='pulse') if threshold is not None and artifact > threshold: continue response_subset, hold = response_filter( pulse_response, freq_range=[0, 50], holding_range=holding, pulse=True) if len(response_subset) >= sweep_threshold: qc_plot.clear() qc_list = pulse_qc(response_subset, baseline=1.5, pulse=None, plot=qc_plot) if len(qc_list) >= sweep_threshold: avg_trace, avg_amp, amp_sign, peak_t = get_amplitude( qc_list) # if amp_sign is '-': # continue # #print ('%s, %0.0f' %((expt.uid, pre, post), hold, )) # all_amps = fail_rate(response_subset, '+', peak_t) # cv = np.std(all_amps)/np.mean(all_amps) # # # weight parts of the trace during fitting dt = avg_trace.dt weight = np.ones( len(avg_trace.data )) * 10. #set everything to ten initially weight[int(10e-3 / dt):int( 12e-3 / dt)] = 0. #area around stim artifact weight[int(12e-3 / dt):int( 19e-3 / dt)] = 30. #area around steep PSP rise # check if the test data dir is there and if not create it test_data_dir = 'test_psp_fit' if not os.path.isdir(test_data_dir): os.mkdir(test_data_dir) save_dict = {} save_dict['input'] = { 'data': avg_trace.data.tolist(), 'dtype': str(avg_trace.data.dtype), 'dt': float(avg_trace.dt), 'amp_sign': amp_sign, 'yoffset': 0, 'xoffset': 14e-3, 'avg_amp': float(avg_amp), 'method': 'leastsq', 'stacked': False, 'rise_time_mult_factor': 10., 'weight': weight.tolist() } # need to remake trace because different output is created avg_trace_simple = Trace( data=np.array(save_dict['input']['data']), dt=save_dict['input'] ['dt']) # create Trace object psp_fits_original = fit_psp( avg_trace, sign=save_dict['input']['amp_sign'], yoffset=save_dict['input']['yoffset'], xoffset=save_dict['input']['xoffset'], amp=save_dict['input']['avg_amp'], method=save_dict['input']['method'], stacked=save_dict['input']['stacked'], rise_time_mult_factor=save_dict['input'] ['rise_time_mult_factor'], fit_kws={ 'weights': save_dict['input']['weight'] }) psp_fits_simple = fit_psp( avg_trace_simple, sign=save_dict['input']['amp_sign'], yoffset=save_dict['input']['yoffset'], xoffset=save_dict['input']['xoffset'], amp=save_dict['input']['avg_amp'], method=save_dict['input']['method'], stacked=save_dict['input']['stacked'], rise_time_mult_factor=save_dict['input'] ['rise_time_mult_factor'], fit_kws={ 'weights': save_dict['input']['weight'] }) print expt.uid, pre, post if psp_fits_original.nrmse( ) != psp_fits_simple.nrmse(): print ' the nrmse values dont match' print '\toriginal', psp_fits_original.nrmse() print '\tsimple', psp_fits_simple.nrmse()
def summary_plot_pulse(feature_list, labels, titles, i, median=False, grand_trace=None, plot=None, color=None, name=None): """ Plots features of single-pulse responses such as amplitude, latency, etc. for group analysis. Can be used for one group by ideal for comparing across many groups in the feature_list Parameters ---------- feature_list : list of lists of floats single-pulse features such as amplitude. Can be multiple features each a list themselves labels : list of pyqtgraph.LabelItem axis labels, must be a list of same length as feature_list titles : list of strings plot title, must be a list of same length as feature_list i : integer iterator to place groups along x-axis median : boolean to calculate median (True) vs mean (False), default is False grand_trace : neuroanalysis.data.TSeriesView object option to plot response trace alongside scatter plot, default is None plot : pyqtgraph.PlotItem If not None, plot the data on the referenced pyqtgraph object. color : tuple plot color name : pyqtgraph.LegendItem Returns ------- plot : pyqtgraph.PlotItem 2 x n plot with scatter plot and optional trace response plot for each feature (n) """ if type(feature_list) is tuple: n_features = len(feature_list) else: n_features = 1 if plot is None: plot = PlotGrid() plot.set_shape(n_features, 2) plot.show() for g in range(n_features): plot[g, 1].addLegend() for feature in range(n_features): if n_features > 1: current_feature = feature_list[feature] if median is True: mean = np.nanmedian(current_feature) else: mean = np.nanmean(current_feature) label = labels[feature] title = titles[feature] else: current_feature = feature_list mean = np.nanmean(current_feature) label = labels title = titles plot[feature, 0].setLabels(left=(label[0], label[1])) plot[feature, 0].hideAxis('bottom') plot[feature, 0].setTitle(title) if grand_trace is not None: plot[feature, 1].plot(grand_trace.time_values, grand_trace.data, pen=color, name=name) if len(current_feature) > 1: dx = pg.pseudoScatter(np.array(current_feature).astype(float), 0.7, bidir=True) plot[feature, 0].plot([i], [mean], symbol='o', symbolSize=20, symbolPen='k', symbolBrush=color) sem = stats.sem(current_feature, nan_policy='omit') if len(color) != 3: new_color = pg.glColor(color) color = (new_color[0] * 255, new_color[1] * 255, new_color[2] * 255) plot[feature, 0].plot( (0.3 * dx / dx.max()) + i, current_feature, pen=None, symbol='o', symbolSize=10, symbolPen='w', symbolBrush=(color[0], color[1], color[2], 100)) else: plot[feature, 0].plot([i], current_feature, pen=None, symbol='o', symbolSize=10, symbolPen='w', symbolBrush=color) return plot
def summary_plot_pulse(feature_list, feature_mean, labels, titles, i, grand_trace=None, plot=None, color=None, name=None): if type(feature_list) is tuple: n_features = len(feature_list) else: n_features = 1 if plot is None: plot = PlotGrid() plot.set_shape(n_features, 2) plot.show() for g in range(n_features): plot[g, 1].addLegend() plot[g, 1].setLabels(left=('Vm', 'V')) plot[g, 1].setLabels(bottom=('t', 's')) for feature in range(n_features): if n_features > 1: features = feature_list[feature] mean = feature_mean[feature] label = labels[feature] title = titles[feature] else: features = feature_list mean = feature_mean label = labels title = titles plot[feature, 0].setLabels(left=(label[0], label[1])) plot[feature, 0].hideAxis('bottom') plot[feature, 0].setTitle(title) if grand_trace is not None: plot[feature, 1].plot(grand_trace.time_values, grand_trace.data, pen=color, name=name) if len(features) > 1: dx = pg.pseudoScatter(np.array(features).astype(float), 0.3, bidir=True) bar = pg.BarGraphItem(x=[i], height=mean, width=0.7, brush='w', pen={ 'color': color, 'width': 2 }) plot[feature, 0].addItem(bar) sem = stats.sem(features) err = pg.ErrorBarItem(x=np.asarray([i]), y=np.asarray([mean]), height=sem, beam=0.3) plot[feature, 0].addItem(err) plot[feature, 0].plot((0.3 * dx / dx.max()) + i, features, pen=None, symbol='o', symbolSize=10, symbolPen='w', symbolBrush=color) else: plot[feature, 0].plot([i], features, pen=None, symbol='o', symbolSize=10, symbolPen='w', symbolBrush=color) return plot
rms = np.array([row[0] for row in rows]) rig = np.array([row[1] for row in rows]).astype(int) hs = np.array([row[2] for row in rows]).astype(int) col = rig*8 + hs ts = np.array([time.mktime(row[3].timetuple()) for row in rows]) #ts -= ts[0] pg.plot(col + np.random.uniform(size=len(col))*0.7, rms, pen=None, symbol='o', symbolPen=None, symbolBrush=(255, 255, 255, 50)) plt = pg.plot(labels={'left': 'number of sweeps (normalized per rig)', 'bottom': ('baseline rms error', 'V')}) plt.addLegend() grid = PlotGrid() grid.set_shape(3, 1) grid.show() for r, c in ((1, 'r'), (2, 'g'), (3, 'b')): mask = rig==r rig_data = rms[mask] y, x = np.histogram(rig_data, bins=np.linspace(0, 0.002, 1000)) plt.plot(x, y/len(rig_data), stepMode=True, connect='finite', pen=c, name="Rig %d" % r) p = grid[r-1, 0] p.plot(ts[mask], rig_data, pen=None, symbol='o', symbolPen=None, symbolBrush=(255, 255, 255, 100)) p.setLabels(left=('rig %d baseline rms noise'%r, 'V')) grid.setXLink(grid[0, 0]) grid.setYLink(grid[0, 0])
class DistancePlot(object): def __init__(self): self.grid = PlotGrid() self.grid.set_shape(2, 1) self.grid.grid.ci.layout.setRowStretchFactor(0, 5) self.grid.grid.ci.layout.setRowStretchFactor(1, 10) self.plots = (self.grid[1,0], self.grid[0,0]) self.plots[0].grid = self.grid self.plots[0].addLegend() self.grid.show() self.plots[0].setLabels(bottom=('distance', 'm'), left='connection probability') self.params = Parameter.create(name='Distance binning window', type='float', value=40.e-6, step=10.e-6, suffix='m', siPrefix=True) self.element_plot = None self.elements = [] self.element_colors = [] self.results = None self.color = None self.name = None self.params.sigTreeStateChanged.connect(self.update_plot) def plot_distance(self, results, color, name, size=10): """Results needs to be a DataFrame or Series object with 'connected' and 'distance' as columns """ if self.results is None: self.name = name self.color = color self.results = results connected = results['connected'] distance = results['distance'] dist_win = self.params.value() self.dist_plot = distance_plot(connected, distance, plots=self.plots, color=color, name=name, size=size, window=dist_win, spacing=dist_win) self.plots[0].setXRange(0, 200e-6) # return self.dist_plot def invalidate_output(self): self.grid.clear() def element_distance(self, element, color, add_to_list=True): if add_to_list is True: self.element_colors.append(color) self.elements.append(element) pre = element['pre_class'][0].name post = element['post_class'][0].name name = ('%s->%s' % (pre, post)) self.element_plot = self.plot_distance(element, color=color, name=name, size=15) def element_distance_reset(self, results, color, name): self.elements = [] self.element_colors = [] self.grid.clear() self.dist_plot = self.plot_distance(results, color=color, name=name, size=10) def update_plot(self): self.invalidate_output() self.plot_distance(self.results, self.color, self.name) if self.element_plot is not None: for element, color in zip(self.elements, self.element_colors): self.element_distance(element, color, add_to_list=False)
rms, pen=None, symbol='o', symbolPen=None, symbolBrush=(255, 255, 255, 50)) plt = pg.plot( labels={ 'left': 'number of sweeps (normalized per rig)', 'bottom': ('baseline rms error', 'V') }) plt.addLegend() grid = PlotGrid() grid.set_shape(3, 1) grid.show() for r, c in ((1, 'r'), (2, 'g'), (3, 'b')): mask = rig == r rig_data = rms[mask] y, x = np.histogram(rig_data, bins=np.linspace(0, 0.002, 1000)) plt.plot(x, y / len(rig_data), stepMode=True, connect='finite', pen=c, name="Rig %d" % r) p = grid[r - 1, 0] p.plot(ts[mask], rig_data,
def distance_plot(connected, distance, plots=None, color=(100, 100, 255), window=40e-6, spacing=None, name=None, fill_alpha=30): """Draw connectivity vs distance profiles with confidence intervals. Parameters ---------- connected : boolean array Whether a synaptic connection was found for each probe distance : array Distance between cells for each probe plots : list of PlotWidget | PlotItem (optional) Two plots used to display distance profile and scatter plot. color : tuple (R, G, B) color values for line and confidence interval. The confidence interval will be drawn with alpha=100 window : float Width of distance window over which proportions are calculated for each point on the profile line. spacing : float Distance spacing between points on the profile line Note: using a spacing value that is smaller than the window size may cause an otherwise smooth decrease over distance to instead look more like a series of downward steps. """ color = pg.colorTuple(pg.mkColor(color))[:3] connected = np.array(connected).astype(float) distance = np.array(distance) pts = np.vstack([distance, connected]).T # scatter points a bit conn = pts[:, 1] == 1 unconn = pts[:, 1] == 0 if np.any(conn): cscat = pg.pseudoScatter(pts[:, 0][conn], spacing=10e-6, bidir=False) mx = abs(cscat).max() if mx != 0: cscat = cscat * 0.2 # / mx pts[:, 1][conn] = -2e-5 - cscat if np.any(unconn): uscat = pg.pseudoScatter(pts[:, 0][unconn], spacing=10e-6, bidir=False) mx = abs(uscat).max() if mx != 0: uscat = uscat * 0.2 # / mx pts[:, 1][unconn] = uscat # scatter plot connections probed if plots is None: grid = PlotGrid() grid.set_shape(2, 1) grid.grid.ci.layout.setRowStretchFactor(0, 5) grid.grid.ci.layout.setRowStretchFactor(1, 10) plots = (grid[1, 0], grid[0, 0]) plots[0].grid = grid plots[0].addLegend() grid.show() plots[0].setLabels(bottom=('distance', 'm'), left='connection probability') if plots[1] is not None: plots[1].setXLink(plots[0]) plots[1].hideAxis('bottom') plots[1].hideAxis('left') color2 = color + (100, ) scatter = plots[1].plot(pts[:, 0], pts[:, 1], pen=None, symbol='o', labels={'bottom': ('distance', 'm')}, symbolBrush=color2, symbolPen=None, name=name) scatter.scatter.opts[ 'compositionMode'] = pg.QtGui.QPainter.CompositionMode_Plus # use a sliding window to plot the proportion of connections found along with a 95% confidence interval # for connection probability if spacing is None: spacing = window / 4.0 xvals = np.arange(window / 2.0, 500e-6, spacing) upper = [] lower = [] prop = [] ci_xvals = [] for x in xvals: minx = x - window / 2.0 maxx = x + window / 2.0 # select points inside this window mask = (distance >= minx) & (distance <= maxx) pts_in_window = connected[mask] # compute stats for window n_probed = pts_in_window.shape[0] n_conn = pts_in_window.sum() if n_probed == 0: prop.append(np.nan) else: prop.append(n_conn / n_probed) ci = binomial_ci(n_conn, n_probed) lower.append(ci[0]) upper.append(ci[1]) ci_xvals.append(x) # plot connection probability and confidence intervals color2 = [c / 3.0 for c in color] mid_curve = plots[0].plot(xvals, prop, pen={ 'color': color, 'width': 3 }, antialias=True, name=name) upper_curve = plots[0].plot(ci_xvals, upper, pen=(0, 0, 0, 0), antialias=True) lower_curve = plots[0].plot(ci_xvals, lower, pen=(0, 0, 0, 0), antialias=True) upper_curve.setVisible(False) lower_curve.setVisible(False) color2 = color + (fill_alpha, ) fill = pg.FillBetweenItem(upper_curve, lower_curve, brush=color2) fill.setZValue(-10) plots[0].addItem(fill, ignoreBounds=True) return plots
pg.dbg() all_expts = cached_experiments() # results = pulse_average_matrix(all_expts, clamp_mode='ic', min_duration=25e-3, pulse_ids=[0, 8]) # 1. collect a list of all experiments containing connections conn_expts = {} for c in all_expts.connection_summary(): conn_expts.setdefault(c['expt'], []).append(c['cells']) # 2. for each experiment, get and cache the full set of first-pulse averages types = ['sim1', 'tlx3', 'pvalb', 'sst', 'vip'] indplots = PlotGrid() indplots.set_shape(len(types), len(types)) indplots.show() cachefile = "first_pulse_average.pkl" cache = pickle.load(open(cachefile, 'rb')) if os.path.isfile(cachefile) else {} for expt, conns in sorted(conn_expts.items()): if expt.source_id not in cache: print "Load:", expt.source_id try: with expt.data: analyzer = MultiPatchExperimentAnalyzer.get(expt.data) responses = [] for pre_cell, post_cell in conns: pre_id, post_id = pre_cell.cell_id-1, post_cell.cell_id-1 resp = analyzer.get_evoked_responses(pre_id, post_id, clamp_mode='ic',
else: c_type = connection.split('-') connection_types = [((c_type[0], 'unknown'), (c_type[1], 'unknown'))] plt = pg.plot() scale_offset = (-20, -20) scale_anchor = (0.4, 1) holding = [-65, -75] qc_plot = pg.plot() grand_response = {} expt_ids = {} feature_plot = None feature2_plot = PlotGrid() feature2_plot.set_shape(5,1) feature2_plot.show() feature3_plot = PlotGrid() feature3_plot.set_shape(1, 3) feature3_plot.show() amp_plot = pg.plot() synapse_plot = PlotGrid() synapse_plot.set_shape(len(connection_types), 1) synapse_plot.show() for c in range(len(connection_types)): cre_type = (connection_types[c][0][1], connection_types[c][1][1]) target_layer = (connection_types[c][0][0], connection_types[c][1][0]) type = connection_types[c] expt_list = all_expts.select(cre_type=cre_type, target_layer=target_layer, calcium=calcium, age=age) color = color_palette[c] grand_response[type[0]] = {'trace': [], 'amp': [], 'latency': [], 'rise': [], 'dist': [], 'decay':[], 'CV': [], 'amp_measured': []} expt_ids[type[0]] = []
# model.Dynamics = {'Dep':1, 'Fac':0, 'UR':0, 'SMR':0, 'DSR':0} # params = {(('2/3', 'unknown'), ('2/3', 'unknown')): [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # ((None,'rorb'), (None,'rorb')): [0, 506.7, 0, 0, 0, 0, 0.22, 0, 0, 0, 0, 0], # ((None,'sim1'), (None,'sim1')): [0, 1213.8, 0, 0, 0, 0, 0.17, 0, 0, 0, 0, 0], # ((None,'tlx3'), (None,'tlx3')): [0, 319.4, 0, 0, 0, 0, 0.16, 0, 0, 0, 0, 0]} ind = data[0] rec = data[1] freq = 10 delta = 250 order = human_connections.keys() delay_order = [250, 500, 1000, 2000, 4000] ind_plt = PlotGrid() ind_plt.set_shape(4, 1) ind_plt.show() rec_plt = PlotGrid() rec_plt.set_shape(1, 2) rec_plt.show() ind_plt_scatter = pg.plot() ind_plt_all = PlotGrid() ind_plt_all.set_shape(1, 2) ind_plt_all.show() ind_50 = {} ninth_pulse_250 = {} gain_plot = pg.plot() for t, type in enumerate(order): if type == (('2', 'unknown'), ('2', 'unknown')): continue if type == (('4', 'unknown'), ('4', 'unknown')):
def save_fit_psp_test_set(): """NOTE THIS CODE DOES NOT WORK BUT IS HERE FOR DOCUMENTATION PURPOSES SO THAT WE CAN TRACE BACK HOW THE TEST DATA WAS CREATED IF NEEDED. Create a test set of data for testing the fit_psp function. Uses Steph's original first_puls_feature.py code to filter out error causing data. Example run statement python save save_fit_psp_test_set.py --organism mouse --connection ee Comment in the code that does the saving at the bottom """ import pyqtgraph as pg import numpy as np import csv import sys import argparse from multipatch_analysis.experiment_list import cached_experiments from manuscript_figures import get_response, get_amplitude, response_filter, feature_anova, write_cache, trace_plot, \ colors_human, colors_mouse, fail_rate, pulse_qc, feature_kw from synapse_comparison import load_cache, summary_plot_pulse from neuroanalysis.data import TraceList, Trace from neuroanalysis.ui.plot_grid import PlotGrid from multipatch_analysis.connection_detection import fit_psp from rep_connections import ee_connections, human_connections, no_include, all_connections, ie_connections, ii_connections, ei_connections from multipatch_analysis.synaptic_dynamics import DynamicsAnalyzer from scipy import stats import time import pandas as pd import json import os app = pg.mkQApp() pg.dbg() pg.setConfigOption('background', 'w') pg.setConfigOption('foreground', 'k') parser = argparse.ArgumentParser(description='Enter organism and type of connection you"d like to analyze ex: mouse ee (all mouse excitatory-' 'excitatory). Alternatively enter a cre-type connection ex: sim1-sim1') parser.add_argument('--organism', dest='organism', help='Select mouse or human') parser.add_argument('--connection', dest='connection', help='Specify connections to analyze') args = vars(parser.parse_args(sys.argv[1:])) all_expts = cached_experiments() manifest = {'Type': [], 'Connection': [], 'amp': [], 'latency': [],'rise':[], 'rise2080': [], 'rise1090': [], 'rise1080': [], 'decay': [], 'nrmse': [], 'CV': []} fit_qc = {'nrmse': 8, 'decay': 499e-3} if args['organism'] == 'mouse': color_palette = colors_mouse calcium = 'high' age = '40-60' sweep_threshold = 3 threshold = 0.03e-3 connection = args['connection'] if connection == 'ee': connection_types = ee_connections.keys() elif connection == 'ii': connection_types = ii_connections.keys() elif connection == 'ei': connection_types = ei_connections.keys() elif connection == 'ie': connection_types == ie_connections.keys() elif connection == 'all': connection_types = all_connections.keys() elif len(connection.split('-')) == 2: c_type = connection.split('-') if c_type[0] == '2/3': pre_type = ('2/3', 'unknown') else: pre_type = (None, c_type[0]) if c_type[1] == '2/3': post_type = ('2/3', 'unknown') else: post_type = (None, c_type[0]) connection_types = [(pre_type, post_type)] elif args['organism'] == 'human': color_palette = colors_human calcium = None age = None sweep_threshold = 5 threshold = None connection = args['connection'] if connection == 'ee': connection_types = human_connections.keys() else: c_type = connection.split('-') connection_types = [((c_type[0], 'unknown'), (c_type[1], 'unknown'))] plt = pg.plot() scale_offset = (-20, -20) scale_anchor = (0.4, 1) holding = [-65, -75] qc_plot = pg.plot() grand_response = {} expt_ids = {} feature_plot = None feature2_plot = PlotGrid() feature2_plot.set_shape(5,1) feature2_plot.show() feature3_plot = PlotGrid() feature3_plot.set_shape(1, 3) feature3_plot.show() amp_plot = pg.plot() synapse_plot = PlotGrid() synapse_plot.set_shape(len(connection_types), 1) synapse_plot.show() for c in range(len(connection_types)): cre_type = (connection_types[c][0][1], connection_types[c][1][1]) target_layer = (connection_types[c][0][0], connection_types[c][1][0]) conn_type = connection_types[c] expt_list = all_expts.select(cre_type=cre_type, target_layer=target_layer, calcium=calcium, age=age) color = color_palette[c] grand_response[conn_type[0]] = {'trace': [], 'amp': [], 'latency': [], 'rise': [], 'dist': [], 'decay':[], 'CV': [], 'amp_measured': []} expt_ids[conn_type[0]] = [] synapse_plot[c, 0].addLegend() for expt in expt_list: for pre, post in expt.connections: if [expt.uid, pre, post] in no_include: continue cre_check = expt.cells[pre].cre_type == cre_type[0] and expt.cells[post].cre_type == cre_type[1] layer_check = expt.cells[pre].target_layer == target_layer[0] and expt.cells[post].target_layer == target_layer[1] if cre_check is True and layer_check is True: pulse_response, artifact = get_response(expt, pre, post, analysis_type='pulse') if threshold is not None and artifact > threshold: continue response_subset, hold = response_filter(pulse_response, freq_range=[0, 50], holding_range=holding, pulse=True) if len(response_subset) >= sweep_threshold: qc_plot.clear() qc_list = pulse_qc(response_subset, baseline=1.5, pulse=None, plot=qc_plot) if len(qc_list) >= sweep_threshold: avg_trace, avg_amp, amp_sign, peak_t = get_amplitude(qc_list) # if amp_sign is '-': # continue # #print ('%s, %0.0f' %((expt.uid, pre, post), hold, )) # all_amps = fail_rate(response_subset, '+', peak_t) # cv = np.std(all_amps)/np.mean(all_amps) # # # weight parts of the trace during fitting dt = avg_trace.dt weight = np.ones(len(avg_trace.data))*10. #set everything to ten initially weight[int(10e-3/dt):int(12e-3/dt)] = 0. #area around stim artifact weight[int(12e-3/dt):int(19e-3/dt)] = 30. #area around steep PSP rise # check if the test data dir is there and if not create it test_data_dir='test_psp_fit' if not os.path.isdir(test_data_dir): os.mkdir(test_data_dir) save_dict={} save_dict['input']={'data': avg_trace.data.tolist(), 'dtype': str(avg_trace.data.dtype), 'dt': float(avg_trace.dt), 'amp_sign': amp_sign, 'yoffset': 0, 'xoffset': 14e-3, 'avg_amp': float(avg_amp), 'method': 'leastsq', 'stacked': False, 'rise_time_mult_factor': 10., 'weight': weight.tolist()} # need to remake trace because different output is created avg_trace_simple=Trace(data=np.array(save_dict['input']['data']), dt=save_dict['input']['dt']) # create Trace object psp_fits_original = fit_psp(avg_trace, sign=save_dict['input']['amp_sign'], yoffset=save_dict['input']['yoffset'], xoffset=save_dict['input']['xoffset'], amp=save_dict['input']['avg_amp'], method=save_dict['input']['method'], stacked=save_dict['input']['stacked'], rise_time_mult_factor=save_dict['input']['rise_time_mult_factor'], fit_kws={'weights': save_dict['input']['weight']}) psp_fits_simple = fit_psp(avg_trace_simple, sign=save_dict['input']['amp_sign'], yoffset=save_dict['input']['yoffset'], xoffset=save_dict['input']['xoffset'], amp=save_dict['input']['avg_amp'], method=save_dict['input']['method'], stacked=save_dict['input']['stacked'], rise_time_mult_factor=save_dict['input']['rise_time_mult_factor'], fit_kws={'weights': save_dict['input']['weight']}) print expt.uid, pre, post if psp_fits_original.nrmse()!=psp_fits_simple.nrmse(): print ' the nrmse values dont match' print '\toriginal', psp_fits_original.nrmse() print '\tsimple', psp_fits_simple.nrmse()
class DistancePlot(object): def __init__(self): self.grid = PlotGrid() self.grid.set_shape(2, 1) self.grid.grid.ci.layout.setRowStretchFactor(0, 5) self.grid.grid.ci.layout.setRowStretchFactor(1, 10) self.plots = (self.grid[1,0], self.grid[0,0]) self.plots[0].grid = self.grid self.plots[0].addLegend() self.grid.show() self.plots[0].setLabels(bottom=('distance', 'm'), left='connection probability') self.plots[0].setXRange(0, 200e-6) self.params = Parameter.create(name='Distance binning window', type='float', value=40.e-6, step=10.e-6, suffix='m', siPrefix=True) self.element_plot = None self.elements = [] self.element_colors = [] self.results = None self.color = None self.name = None self.params.sigTreeStateChanged.connect(self.update_plot) def plot_distance(self, results, color, name, size=10, suppress_scatter=False): """Results needs to be a DataFrame or Series object with 'Synapse' and 'Distance' as columns """ connected = results[~results['Distance'].isnull()]['Connected'] distance = results[~results['Distance'].isnull()]['Distance'] dist_win = self.params.value() if self.results is None: self.name = name self.color = color self.results = results if suppress_scatter is True: #suppress scatter plot for all results (takes forever to plot) plots = list(self.plots) plots[1] = None self.dist_plot = distance_plot(connected, distance, plots=plots, color=color, name=name, size=size, window=dist_win, spacing=dist_win) else: self.dist_plot = distance_plot(connected, distance, plots=self.plots, color=color, name=name, size=size, window=dist_win, spacing=dist_win) return self.dist_plot def invalidate_output(self): self.grid.clear() def element_distance(self, element, color, add_to_list=True): if add_to_list is True: self.element_colors.append(color) self.elements.append(element) pre = element['pre_class'][0].name post = element['post_class'][0].name name = ('%s->%s' % (pre, post)) self.element_plot = self.plot_distance(element, color=color, name=name, size=15) def element_distance_reset(self, results, color, name, suppress_scatter=False): self.elements = [] self.element_colors = [] self.grid.clear() self.dist_plot = self.plot_distance(results, color=color, name=name, size=10, suppress_scatter=suppress_scatter) def update_plot(self): self.invalidate_output() self.plot_distance(self.results, self.color, self.name, suppress_scatter=True) if self.element_plot is not None: for element, color in zip(self.elements, self.element_colors): self.element_distance(element, color, add_to_list=False)
grid1.set_shape(2, 1) for i in range(n_responses): r = responses.responses[i] grid1[0, 0].plot(r.time_values, r.data) filt = bessel_filter(r - np.median(r.time_slice(0, 10e-3).data), 300.) responses.responses[i] = filt dec = exp_deconvolve(r, 15e-3) baseline = np.median(dec.data[:100]) r2 = bessel_filter(dec-baseline, 300.) grid1[1, 0].plot(r2.time_values, r2.data) deconv.append(r2) grid1.show() def measure_amp(trace, baseline=(6e-3, 8e-3), response=(13e-3, 17e-3)): baseline = trace.time_slice(*baseline).data.mean() peak = trace.time_slice(*response).data.max() return peak - baseline # Chop up responses into groups of varying size and plot the average # amplitude as measured from these chunks. We expect to see variance # decrease as a function of the number of responses being averaged together. x = [] y_deconv1 = [] y_deconv2 = [] y_deconv3 = []
class TSeriesPlot(pg.GraphicsLayoutWidget): def __init__(self, title, units): pg.GraphicsLayoutWidget.__init__(self) self.grid = PlotGrid() self.grid.set_shape(4, 1) self.grid.grid.ci.layout.setRowStretchFactor(0, 3) self.grid.grid.ci.layout.setRowStretchFactor(1, 8) self.grid.grid.ci.layout.setRowStretchFactor(2, 5) self.grid.grid.ci.layout.setRowStretchFactor(3, 10) self.grid.show() self.trace_plots = (self.grid[1, 0], self.grid[3, 0]) self.spike_plots = (self.grid[0, 0], self.grid[2, 0]) self.plots = self.spike_plots + self.trace_plots for plot in self.plots[:-1]: plot.hideAxis('bottom') self.plots[-1].setLabel('bottom', text='Time from spike', units='s') self.fit_item_55 = None self.fit_item_70 = None self.fit_color = {True: 'g', False: 'r'} self.qc_color = {'qc_pass': (255, 255, 255, 100), 'qc_fail': (255, 0, 0, 100)} self.plots[0].setTitle(title) for (plot, holding) in zip(self.trace_plots, holdings): plot.setXLink(self.plots[-1]) plot.setLabel('left', text="%d holding" % int(holding), units=units) for plot in self.spike_plots: plot.setXLink(self.plots[-1]) plot.setLabel('left', text="presynaptic spike") plot.addLine(x=0) self.plots[-1].setXRange(-5e-3, 10e-3) self.items = [] def plot_responses(self, pulse_responses): self.plot_traces(pulse_responses) self.plot_spikes(pulse_responses) def plot_traces(self, pulse_responses): for i, holding in enumerate(pulse_responses.keys()): for qc, prs in pulse_responses[holding].items(): if len(prs) == 0: continue prl = PulseResponseList(prs) post_ts = prl.post_tseries(align='spike', bsub=True) for trace in post_ts: item = self.trace_plots[i].plot(trace.time_values, trace.data, pen=self.qc_color[qc]) if qc == 'qc_fail': item.setZValue(-10) self.items.append(item) if qc == 'qc_pass': grand_trace = post_ts.mean() item = self.trace_plots[i].plot(grand_trace.time_values, grand_trace.data, pen={'color': 'b', 'width': 2}) self.items.append(item) self.trace_plots[i].autoRange() self.trace_plots[i].setXRange(-5e-3, 10e-3) # y_range = [grand_trace.data.min(), grand_trace.data.max()] # self.plots[i].setYRange(y_range[0], y_range[1], padding=1) def plot_spikes(self, pulse_responses): for i, holding in enumerate(pulse_responses.keys()): for prs in pulse_responses[holding].values(): if len(prs) == 0: continue prl = PulseResponseList(prs) pre_ts = prl.pre_tseries(align='spike', bsub=True) for pr, spike in zip(prl, pre_ts): qc = 'qc_pass' if pr.stim_pulse.n_spikes == 1 else 'qc_fail' item = self.spike_plots[i].plot(spike.time_values, spike.data, pen=self.qc_color[qc]) if qc == 'qc_fail': item.setZValue(-10) self.items.append(item) def plot_fit(self, trace, holding, fit_pass=False): if holding == -55: if self.fit_item_55 is not None: self.trace_plots[0].removeItem(self.fit_item_55) self.fit_item_55 = pg.PlotDataItem(trace.time_values, trace.data, name='-55 holding', pen={'color': self.fit_color[fit_pass], 'width': 3}) self.trace_plots[0].addItem(self.fit_item_55) elif holding == -70: if self.fit_item_70 is not None: self.trace_plots[1].removeItem(self.fit_item_70) self.fit_item_70 = pg.PlotDataItem(trace.time_values, trace.data, name='-70 holding', pen={'color': self.fit_color[fit_pass], 'width': 3}) self.trace_plots[1].addItem(self.fit_item_70) def color_fit(self, name, value): if '-55' in name: if self.fit_item_55 is not None: self.fit_item_55.setPen({'color': self.fit_color[value], 'width': 3}) if '-70' in name: if self.fit_item_70 is not None: self.fit_item_70.setPen({'color': self.fit_color[value], 'width': 3}) def clear_plots(self): for item in self.items + [self.fit_item_55, self.fit_item_70]: if item is None: continue item.scene().removeItem(item) self.items = [] self.plots[-1].autoRange() self.plots[-1].setXRange(-5e-3, 10e-3) self.fit_item_70 = None self.fit_item_55 = None
class PairView(QtGui.QWidget): """For analyzing pre/post-synaptic pairs. """ def __init__(self, parent=None): self.sweeps = [] self.channels = [] self.current_event_set = None self.event_sets = [] QtGui.QWidget.__init__(self, parent) self.layout = QtGui.QGridLayout() self.setLayout(self.layout) self.layout.setContentsMargins(0, 0, 0, 0) self.vsplit = QtGui.QSplitter(QtCore.Qt.Vertical) self.layout.addWidget(self.vsplit, 0, 0) self.pre_plot = pg.PlotWidget() self.post_plot = pg.PlotWidget() for plt in (self.pre_plot, self.post_plot): plt.setClipToView(True) plt.setDownsampling(True, True, 'peak') self.post_plot.setXLink(self.pre_plot) self.vsplit.addWidget(self.pre_plot) self.vsplit.addWidget(self.post_plot) self.response_plots = PlotGrid() self.vsplit.addWidget(self.response_plots) self.event_splitter = QtGui.QSplitter(QtCore.Qt.Horizontal) self.vsplit.addWidget(self.event_splitter) self.event_table = pg.TableWidget() self.event_splitter.addWidget(self.event_table) self.event_widget = QtGui.QWidget() self.event_widget_layout = QtGui.QGridLayout() self.event_widget.setLayout(self.event_widget_layout) self.event_splitter.addWidget(self.event_widget) self.event_splitter.setSizes([600, 100]) self.event_set_list = QtGui.QListWidget() self.event_widget_layout.addWidget(self.event_set_list, 0, 0, 1, 2) self.event_set_list.addItem("current") self.event_set_list.itemSelectionChanged.connect( self.event_set_selected) self.add_set_btn = QtGui.QPushButton('add') self.event_widget_layout.addWidget(self.add_set_btn, 1, 0, 1, 1) self.add_set_btn.clicked.connect(self.add_set_clicked) self.remove_set_btn = QtGui.QPushButton('del') self.event_widget_layout.addWidget(self.remove_set_btn, 1, 1, 1, 1) self.remove_set_btn.clicked.connect(self.remove_set_clicked) self.fit_btn = QtGui.QPushButton('fit all') self.event_widget_layout.addWidget(self.fit_btn, 2, 0, 1, 2) self.fit_btn.clicked.connect(self.fit_clicked) self.fit_plot = PlotGrid() self.fit_explorer = None self.artifact_remover = ArtifactRemover(user_width=True) self.baseline_remover = BaselineRemover() self.filter = SignalFilter() self.params = pg.parametertree.Parameter( name='params', type='group', children=[{ 'name': 'pre', 'type': 'list', 'values': [] }, { 'name': 'post', 'type': 'list', 'values': [] }, self.artifact_remover.params, self.baseline_remover.params, self.filter.params, { 'name': 'time constant', 'type': 'float', 'suffix': 's', 'siPrefix': True, 'value': 10e-3, 'dec': True, 'minStep': 100e-6 }]) self.params.sigTreeStateChanged.connect(self._update_plots) def data_selected(self, sweeps, channels): self.sweeps = sweeps self.channels = channels self.params.child('pre').setLimits(channels) self.params.child('post').setLimits(channels) self._update_plots() def _update_plots(self): sweeps = self.sweeps self.current_event_set = None self.event_table.clear() # clear all plots self.pre_plot.clear() self.post_plot.clear() pre = self.params['pre'] post = self.params['post'] # If there are no selected sweeps or channels have not been set, return if len( sweeps ) == 0 or pre == post or pre not in self.channels or post not in self.channels: return pre_mode = sweeps[0][pre].clamp_mode post_mode = sweeps[0][post].clamp_mode for ch, mode, plot in [(pre, pre_mode, self.pre_plot), (post, post_mode, self.post_plot)]: units = 'A' if mode == 'vc' else 'V' plot.setLabels(left=("Channel %d" % ch, units), bottom=("Time", 's')) # Iterate over selected channels of all sweeps, plotting traces one at a time # Collect information about pulses and spikes pulses = [] spikes = [] post_traces = [] for i, sweep in enumerate(sweeps): pre_trace = sweep[pre]['primary'] post_trace = sweep[post]['primary'] # Detect pulse times stim = sweep[pre]['command'].data sdiff = np.diff(stim) on_times = np.argwhere(sdiff > 0)[1:, 0] # 1: skips test pulse off_times = np.argwhere(sdiff < 0)[1:, 0] pulses.append(on_times) # filter data post_filt = self.artifact_remover.process( post_trace, list(on_times) + list(off_times)) post_filt = self.baseline_remover.process(post_filt) post_filt = self.filter.process(post_filt) post_traces.append(post_filt) # plot raw data color = pg.intColor(i, hues=len(sweeps) * 1.3, sat=128) color.setAlpha(128) for trace, plot in [(pre_trace, self.pre_plot), (post_filt, self.post_plot)]: plot.plot(trace.time_values, trace.data, pen=color, antialias=False) # detect spike times spike_inds = [] spike_info = [] for on, off in zip(on_times, off_times): spike = detect_evoked_spike(sweep[pre], [on, off]) spike_info.append(spike) if spike is None: spike_inds.append(None) else: spike_inds.append(spike['rise_index']) spikes.append(spike_info) dt = pre_trace.dt vticks = pg.VTickGroup( [x * dt for x in spike_inds if x is not None], yrange=[0.0, 0.2], pen=color) self.pre_plot.addItem(vticks) # Iterate over spikes, plotting average response all_responses = [] avg_responses = [] fits = [] fit = None npulses = max(map(len, pulses)) self.response_plots.clear() self.response_plots.set_shape(1, npulses + 1) # 1 extra for global average self.response_plots.setYLink(self.response_plots[0, 0]) for i in range(1, npulses + 1): self.response_plots[0, i].hideAxis('left') units = 'A' if post_mode == 'vc' else 'V' self.response_plots[0, 0].setLabels(left=("Averaged events (Channel %d)" % post, units)) fit_pen = {'color': (30, 30, 255), 'width': 2, 'dash': [1, 1]} for i in range(npulses): # get the chunk of each sweep between spikes responses = [] all_responses.append(responses) for j, sweep in enumerate(sweeps): # get the current spike if i >= len(spikes[j]): continue spike = spikes[j][i] if spike is None: continue # find next spike next_spike = None for sp in spikes[j][i + 1:]: if sp is not None: next_spike = sp break # determine time range for response max_len = int(40e-3 / dt) # don't take more than 50ms for any response start = spike['rise_index'] if next_spike is not None: stop = min(start + max_len, next_spike['rise_index']) else: stop = start + max_len # collect data from this trace trace = post_traces[j] d = trace.data[start:stop].copy() responses.append(d) if len(responses) == 0: continue # extend all responses to the same length and take nanmean avg = ragged_mean(responses, method='clip') avg -= float_mode(avg[:int(1e-3 / dt)]) avg_responses.append(avg) # plot average response for this pulse start = np.median( [sp[i]['rise_index'] for sp in spikes if sp[i] is not None]) * dt t = np.arange(len(avg)) * dt self.response_plots[0, i].plot(t, avg, pen='w', antialias=True) # fit! fit = self.fit_psp(avg, t, dt, post_mode) fits.append(fit) # let the user mess with this fit curve = self.response_plots[0, i].plot(t, fit.eval(), pen=fit_pen, antialias=True).curve curve.setClickable(True) curve.fit = fit curve.sigClicked.connect(self.fit_curve_clicked) # display global average global_avg = ragged_mean(avg_responses, method='clip') t = np.arange(len(global_avg)) * dt self.response_plots[0, -1].plot(t, global_avg, pen='w', antialias=True) global_fit = self.fit_psp(global_avg, t, dt, post_mode) self.response_plots[0, -1].plot(t, global_fit.eval(), pen=fit_pen, antialias=True) # display fit parameters in table events = [] for i, f in enumerate(fits + [global_fit]): if f is None: continue if i >= len(fits): vals = OrderedDict([('id', 'avg'), ('spike_time', np.nan), ('spike_stdev', np.nan)]) else: spt = [ s[i]['peak_index'] * dt for s in spikes if s[i] is not None ] vals = OrderedDict([('id', i), ('spike_time', np.mean(spt)), ('spike_stdev', np.std(spt))]) vals.update( OrderedDict([(k, f.best_values[k]) for k in f.params.keys()])) events.append(vals) self.current_event_set = (pre, post, events, sweeps) self.event_set_list.setCurrentRow(0) self.event_set_selected() def fit_psp(self, data, t, dt, clamp_mode): mode = float_mode(data[:int(1e-3 / dt)]) sign = -1 if data.mean() - mode < 0 else 1 params = OrderedDict([ ('xoffset', (2e-3, 3e-4, 5e-3)), ('yoffset', data[0]), ('amp', sign * 10e-12), #('k', (2e-3, 50e-6, 10e-3)), ('rise_time', (2e-3, 50e-6, 10e-3)), ('decay_tau', (4e-3, 500e-6, 50e-3)), ('rise_power', (2.0, 'fixed')), ]) if clamp_mode == 'ic': params['amp'] = sign * 10e-3 #params['k'] = (5e-3, 50e-6, 20e-3) params['rise_time'] = (5e-3, 50e-6, 20e-3) params['decay_tau'] = (15e-3, 500e-6, 150e-3) fit_kws = {'xtol': 1e-3, 'maxfev': 100} psp = fitting.Psp() return psp.fit(data, x=t, fit_kws=fit_kws, params=params) def add_set_clicked(self): if self.current_event_set is None: return ces = self.current_event_set self.event_sets.append(ces) item = QtGui.QListWidgetItem("%d -> %d" % (ces[0], ces[1])) self.event_set_list.addItem(item) item.event_set = ces def remove_set_clicked(self): if self.event_set_list.currentRow() == 0: return sel = self.event_set_list.takeItem(self.event_set_list.currentRow()) self.event_sets.remove(sel.event_set) def event_set_selected(self): sel = self.event_set_list.selectedItems()[0] if sel.text() == "current": self.event_table.setData(self.current_event_set[2]) else: self.event_table.setData(sel.event_set[2]) def fit_clicked(self): from synaptic_release import ReleaseModel self.fit_plot.clear() self.fit_plot.show() n_sets = len(self.event_sets) self.fit_plot.set_shape(n_sets, 1) l = self.fit_plot[0, 0].legend if l is not None: l.scene.removeItem(l) self.fit_plot[0, 0].legend = None self.fit_plot[0, 0].addLegend() for i in range(n_sets): self.fit_plot[i, 0].setXLink(self.fit_plot[0, 0]) spike_sets = [] for i, evset in enumerate(self.event_sets): evset = evset[2][:-1] # select events, skip last row (average) x = np.array([ev['spike_time'] for ev in evset]) y = np.array([ev['amp'] for ev in evset]) x -= x[0] x *= 1000 y /= y[0] spike_sets.append((x, y)) self.fit_plot[i, 0].plot(x / 1000., y, pen=None, symbol='o') model = ReleaseModel() dynamics_types = ['Dep', 'Fac', 'UR', 'SMR', 'DSR'] for k in dynamics_types: model.Dynamics[k] = 0 fit_params = [] with pg.ProgressDialog("Fitting release model..", 0, len(dynamics_types)) as dlg: for k in dynamics_types: model.Dynamics[k] = 1 fit_params.append(model.run_fit(spike_sets)) dlg += 1 if dlg.wasCanceled(): return max_color = len(fit_params) * 1.5 for i, params in enumerate(fit_params): for j, spikes in enumerate(spike_sets): x, y = spikes t = np.linspace(0, x.max(), 1000) output = model.eval(x, params.values()) y = output[:, 1] x = output[:, 0] / 1000. self.fit_plot[j, 0].plot(x, y, pen=(i, max_color), name=dynamics_types[i]) def fit_curve_clicked(self, curve): if self.fit_explorer is None: self.fit_explorer = FitExplorer(curve.fit) else: self.fit_explorer.set_fit(curve.fit) self.fit_explorer.show()
holding = [-55, -75] freqs = [10, 20, 50, 100] rec_t = [250, 500, 1000, 2000, 4000] sweep_threshold = 3 deconv = True # cache_file = 'train_response_cache.pkl' # response_cache = load_cache(cache_file) # cache_change = [] # log_rec_plt = pg.plot() # log_rec_plt.setLogMode(x=True) qc_plot = pg.plot() ind_plot = PlotGrid() ind_plot.set_shape(4, len(connection_types)) ind_plot.show() rec_plot = PlotGrid() rec_plot.set_shape(5, len(connection_types)) rec_plot.show() if deconv is True: deconv_ind_plot = PlotGrid() deconv_ind_plot.set_shape(4, len(connection_types)) deconv_ind_plot.show() deconv_rec_plot = PlotGrid() deconv_rec_plot.set_shape(5,len(connection_types)) deconv_rec_plot.show() summary_plot = PlotGrid() summary_plot.set_shape(len(connection_types), 2) summary_plot.show() symbols = ['o', 's', 'd', '+', 't'] trace_color = (0, 0, 0, 5)
if connection == 'ee': connection_types = human_connections.keys() else: c_type = connection.split('-') connection_types = [((c_type[0], 'unknown'), (c_type[1], 'unknown'))] sweep_threshold = 5 threshold = 0.03e-3 scale_offset = (-20, -20) scale_anchor = (0.4, 1) qc_plot = pg.plot() grand_response = {} feature_plot = None synapse_plot = PlotGrid() synapse_plot.set_shape(len(connection_types), 1) synapse_plot.show() for c in range(len(connection_types)): cre_type = (connection_types[c][0][1], connection_types[c][1][1]) target_layer = (connection_types[c][0][0], connection_types[c][1][0]) type = connection_types[c] expt_list = all_expts.select(cre_type=cre_type, target_layer=target_layer, calcium=calcium, age=age) color = color_palette[c] grand_response[type[0]] = { 'trace': [], 'amp': [], 'latency': [], 'rise': [], 'decay': [],
# Load test data data = np.load('test_data/evoked_spikes/vc_evoked_spikes.npz')['arr_0'] dt = 20e-6 # gaussian filtering constant sigma = 20e-6 / dt # Initialize Qt pg.mkQApp() pg.dbg() # Create a window with a grid of plots (N rows, 1 column) win = PlotGrid() win.set_shape(data.shape[0], 1) win.show() # Loop over all 10 channels for i in range(data.shape[0]): # select the data for this channel trace = data[i, :, 0] stim = data[i, :, 1] # select the plot we will use for this trace plot = win[i, 0] # link all x-axes together plot.setXLink(win[0, 0]) xaxis = plot.getAxis('bottom') if i == data.shape[0] - 1: xaxis.setLabel('Time', 's')
def plot_features(organism=None, conn_type=None, calcium=None, age=None, sweep_thresh=None, fit_thresh=None): s = db.Session() filters = { 'organism': organism, 'conn_type': conn_type, 'calcium': calcium, 'age': age } selection = [{}] for key, value in filters.iteritems(): if value is not None: temp_list = [] value_list = value.split(',') for v in value_list: temp = [s1.copy() for s1 in selection] for t in temp: t[key] = v temp_list = temp_list + temp selection = list(temp_list) if len(selection) > 0: response_grid = PlotGrid() response_grid.set_shape(len(selection), 1) response_grid.show() feature_grid = PlotGrid() feature_grid.set_shape(6, 1) feature_grid.show() for i, select in enumerate(selection): pre_cell = db.aliased(db.Cell) post_cell = db.aliased(db.Cell) q_filter = [] if sweep_thresh is not None: q_filter.append(FirstPulseFeatures.n_sweeps>=sweep_thresh) species = select.get('organism') if species is not None: q_filter.append(db.Slice.species==species) c_type = select.get('conn_type') if c_type is not None: pre_type, post_type = c_type.split('-') pre_layer, pre_cre = pre_type.split(';') if pre_layer == 'None': pre_layer = None post_layer, post_cre = post_type.split(';') if post_layer == 'None': post_layer = None q_filter.extend([pre_cell.cre_type==pre_cre, pre_cell.target_layer==pre_layer, post_cell.cre_type==post_cre, post_cell.target_layer==post_layer]) calc_conc = select.get('calcium') if calc_conc is not None: q_filter.append(db.Experiment.acsf.like(calc_conc + '%')) age_range = select.get('age') if age_range is not None: age_lower, age_upper = age_range.split('-') q_filter.append(db.Slice.age.between(int(age_lower), int(age_upper))) q = s.query(FirstPulseFeatures).join(db.Pair, FirstPulseFeatures.pair_id==db.Pair.id)\ .join(pre_cell, db.Pair.pre_cell_id==pre_cell.id)\ .join(post_cell, db.Pair.post_cell_id==post_cell.id)\ .join(db.Experiment, db.Experiment.id==db.Pair.expt_id)\ .join(db.Slice, db.Slice.id==db.Experiment.slice_id) for filter_arg in q_filter: q = q.filter(filter_arg) results = q.all() trace_list = [] for pair in results: #TODO set t0 to latency to align to foot of PSP trace = Trace(data=pair.avg_psp, sample_rate=db.default_sample_rate) trace_list.append(trace) response_grid[i, 0].plot(trace.time_values, trace.data) if len(trace_list) > 0: grand_trace = TraceList(trace_list).mean() response_grid[i, 0].plot(grand_trace.time_values, grand_trace.data, pen='b') response_grid[i, 0].setTitle('layer %s, %s-> layer %s, %s; n_synapses = %d' % (pre_layer, pre_cre, post_layer, post_cre, len(trace_list))) else: print('No synapses for layer %s, %s-> layer %s, %s' % (pre_layer, pre_cre, post_layer, post_cre)) return response_grid, feature_grid
def plot_features(organism=None, conn_type=None, calcium=None, age=None, sweep_thresh=None, fit_thresh=None): s = db.session() filters = { 'organism': organism, 'conn_type': conn_type, 'calcium': calcium, 'age': age } selection = [{}] for key, value in filters.iteritems(): if value is not None: temp_list = [] value_list = value.split(',') for v in value_list: temp = [s1.copy() for s1 in selection] for t in temp: t[key] = v temp_list = temp_list + temp selection = list(temp_list) if len(selection) > 0: response_grid = PlotGrid() response_grid.set_shape(len(selection), 1) response_grid.show() feature_grid = PlotGrid() feature_grid.set_shape(6, 1) feature_grid.show() for i, select in enumerate(selection): pre_cell = aliased(db.Cell) post_cell = aliased(db.Cell) q_filter = [] if sweep_thresh is not None: q_filter.append(FirstPulseFeatures.n_sweeps >= sweep_thresh) species = select.get('organism') if species is not None: q_filter.append(db.Slice.species == species) c_type = select.get('conn_type') if c_type is not None: pre_type, post_type = c_type.split('-') pre_layer, pre_cre = pre_type.split(';') if pre_layer == 'None': pre_layer = None post_layer, post_cre = post_type.split(';') if post_layer == 'None': post_layer = None q_filter.extend([ pre_cell.cre_type == pre_cre, pre_cell.target_layer == pre_layer, post_cell.cre_type == post_cre, post_cell.target_layer == post_layer ]) calc_conc = select.get('calcium') if calc_conc is not None: q_filter.append(db.Experiment.acsf.like(calc_conc + '%')) age_range = select.get('age') if age_range is not None: age_lower, age_upper = age_range.split('-') q_filter.append( db.Slice.age.between(int(age_lower), int(age_upper))) q = s.query(FirstPulseFeatures).join(db.Pair, FirstPulseFeatures.pair_id==db.Pair.id)\ .join(pre_cell, db.Pair.pre_cell_id==pre_cell.id)\ .join(post_cell, db.Pair.post_cell_id==post_cell.id)\ .join(db.Experiment, db.Experiment.id==db.Pair.expt_id)\ .join(db.Slice, db.Slice.id==db.Experiment.slice_id) for filter_arg in q_filter: q = q.filter(filter_arg) results = q.all() trace_list = [] for pair in results: #TODO set t0 to latency to align to foot of PSP trace = TSeries(data=pair.avg_psp, sample_rate=db.default_sample_rate) trace_list.append(trace) response_grid[i, 0].plot(trace.time_values, trace.data) if len(trace_list) > 0: grand_trace = TSeriesList(trace_list).mean() response_grid[i, 0].plot(grand_trace.time_values, grand_trace.data, pen='b') response_grid[i, 0].setTitle( 'layer %s, %s-> layer %s, %s; n_synapses = %d' % (pre_layer, pre_cre, post_layer, post_cre, len(trace_list))) else: print('No synapses for layer %s, %s-> layer %s, %s' % (pre_layer, pre_cre, post_layer, post_cre)) return response_grid, feature_grid
# plot options parser.add_argument('--sweeps', action='store_true', default=False, dest='sweeps', help='plot individual sweeps behing average') parser.add_argument('--trains', action='store_true', default=False, dest='trains', help='plot 50Hz train and deconvolution') parser.add_argument('--link-y-axis', action='store_true', default=False, dest='link-y-axis', help='link all y-axis down a column') args = vars(parser.parse_args(sys.argv[1:])) plot_sweeps = args['sweeps'] plot_trains = args['trains'] link_y_axis = args['link-y-axis'] expt_cache = 'C:/Users/Stephanies/multipatch_analysis/tools/expts_cache.pkl' all_expts = cached_experiments() test = PlotGrid() test.set_shape(len(connection_types.keys()), 1) test.show() grid = PlotGrid() if plot_trains is True: grid.set_shape(len(connection_types.keys()), 3) grid[0, 1].setTitle(title='50 Hz Train') grid[0, 2].setTitle(title='Exponential Deconvolution') tau = 15e-3 lp = 1000 else: grid.set_shape(len(connection_types.keys()), 2) grid.show() row = 0 grid[0, 0].setTitle(title='First Pulse') maxYpulse = [] maxYtrain = []