def _get_tserieslist(self, ts_name, align, bsub): tsl = [] for pr in self.prs: ts = getattr(pr, ts_name) stim_time = pr.stim_pulse.onset_time if bsub is True: start_time = max(ts.t0, stim_time - 5e-3) baseline_data = ts.time_slice(start_time, stim_time).data if len(baseline_data) == 0: baseline = ts.data[0] else: baseline = float_mode(baseline_data) ts = ts - baseline if align is not None: if align == 'spike': align_t = pr.stim_pulse.first_spike_time # ignore PRs with no known spike time if align_t is None: continue elif align == 'pulse': align_t = stim_time else: raise ValueError( "align must be None, 'spike', or 'pulse'.") ts = ts.copy(t0=ts.t0 - align_t) tsl.append(ts) return TSeriesList(tsl)
def first_pulse_features(pair, pulse_responses, pulse_response_amps): avg_psp = TSeriesList(pulse_responses).mean() dt = avg_psp.dt avg_psp_baseline = float_mode(avg_psp.data[:int(10e-3 / dt)]) avg_psp_bsub = avg_psp.copy(data=avg_psp.data - avg_psp_baseline) lower_bound = -float('inf') upper_bound = float('inf') xoffset = pair.synapse_prediction.ic_fit_xoffset if xoffset is None: xoffset = 14 * 10e-3 synapse_type = pair.synapse_prediction.synapse_type if synapse_type == 'ex': amp_sign = '+' elif synapse_type == 'in': amp_sign = '-' else: raise Exception( 'Synapse type is not defined, reconsider fitting this pair %s %d->%d' % (pair.expt_id, pair.pre_cell_id, pair.post_cell_id)) weight = np.ones(len( avg_psp.data)) * 10. # set everything to ten initially weight[int(10e-3 / dt):int(12e-3 / dt)] = 0. # area around stim artifact weight[int(12e-3 / dt):int(19e-3 / dt)] = 30. # area around steep PSP rise psp_fits = fit_psp(avg_psp, xoffset=(xoffset, lower_bound, upper_bound), yoffset=(avg_psp_baseline, lower_bound, upper_bound), sign=amp_sign, weight=weight) amp_cv = np.std(pulse_response_amps) / np.mean(pulse_response_amps) features = { 'ic_fit_amp': psp_fits.best_values['amp'], 'ic_fit_latency': psp_fits.best_values['xoffset'] - 10e-3, 'ic_fit_rise_time': psp_fits.best_values['rise_time'], 'ic_fit_decay_tau': psp_fits.best_values['decay_tau'], 'ic_amp_cv': amp_cv, 'avg_psp': avg_psp_bsub.data } #'ic_fit_NRMSE': psp_fits.nrmse()} TODO: nrmse not returned from psp_fits? return features
def clicked(sp, pts): data = pts[0].data() print("-----------------------\nclicked:", data['rise_time'], data['amp'], data['prediction'], data['confidence']) for r in data['results']: print({k: r[k] for k in classifier.features}) traces = data['traces'] plt = pg.plot() bsub = [ t.copy(data=t.data - np.median(t.time_slice(0, 1e-3).data)) for t in traces ] for t in bsub: plt.plot(t.time_values, t.data, pen=(0, 0, 0, 50)) mean = TSeriesList(bsub).mean() plt.plot(mean.time_values, mean.data, pen='g')
def get_average_pulse_response(pair, desired_clamp='ic'): """ Inputs ------ pair: aisynphys.database.database.Pair object desired_clamp: string Specifies whether current or voltage clamp sweeps are desired. Options are: 'ic': current clamp 'vc': voltage clamp Returns ------- Note that all returned variables are set to None if there are no acceptable (qc pasing) sweeps pulse_responses: TSeriesList traces where the start of each trace is 10 ms before the spike pulse_ids: list of ints pulse ids of *pulse_responses* psp_amps_measured: list of floats amplitude of *pulse_responses* from the *pulse_response* table freq: list of floats the stimulation frequency corresponding to the *pulse_responses* avg_psp: TSeries average of the pulse_responses measured_relative_amp: float measured amplitude relative to baseline measured_baseline: float value of baseline """ # get pulses that pass qc pulse_responses, pulse_ids, psp_amps_measured, freq = extract_first_pulse_info_from_Pair_object( pair, desired_clamp=desired_clamp) # if pulses are returned take the average if len(pulse_responses) > 0: avg_psp = TSeriesList(pulse_responses).mean() else: return None, None, None, None, None, None, None # get the measured baseline and amplitude of psp measured_relative_amp, measured_baseline = measure_amp( avg_psp.data, [0, int((time_before_spike - 1.e-3) / avg_psp.dt)], [int((time_before_spike + .5e-3) / avg_psp.dt), -1]) return pulse_responses, pulse_ids, psp_amps_measured, freq, avg_psp, measured_relative_amp, measured_baseline
def plot_element_data(self, pre_class, post_class, element, field_name, color='g', trace_plt=None): trace_plt = None val = element[field_name].mean() line = pg.InfiniteLine(val, pen={'color': color, 'width': 2}, movable=False) scatter = None baseline_window = int(db.default_sample_rate * 5e-3) values = [] traces = [] point_data = [] for pair, value in element[field_name].iteritems(): if np.isnan(value): continue traces = [] if trace_plt is not None: if rsf is not None: trace = rsf.ic_avg_data start_time = rsf.ic_avg_data_start_time latency = pair.synapse.latency if latency is not None and start_time is not None: xoffset = start_time - latency trace = format_trace(trace, baseline_window, x_offset=xoffset, align='psp') trace_plt.plot(trace.time_values, trace.data) traces.append(trace) values.append(value) point_data.append(pair) y_values = pg.pseudoScatter(np.asarray(values, dtype=float), spacing=1) scatter = pg.ScatterPlotItem(symbol='o', brush=(color + (150,)), pen='w', size=12) scatter.setData(values, y_values + 10., data=point_data) for point in scatter.points(): pair_id = point.data().id self.pair_items[pair_id] = [point, color] scatter.sigClicked.connect(self.scatter_plot_clicked) if len(traces) > 0: grand_trace = TSeriesList(traces).mean() trace_plt.plot(grand_trace.time_values, grand_trace.data, pen={'color': color, 'width': 3}) units = 'V' if field_name.startswith('ic') else 'A' trace_plt.setXRange(0, 20e-3) trace_plt.setLabels(left=('', units), bottom=('Time from stimulus', 's')) return line, scatter
def trace_avg(response_list): # doc string commented out to discourage code reuse given the change of values of t0 # """ # Parameters # ---------- # response_list : list of neuroanalysis.data.TSeriesView objects # neuroanalysis.data.TSeriesView object contains waveform data. # # Returns # ------- # bsub_mean : neuroanalysis.data.TSeries object # averages and baseline subtracts the ephys waveform data in the # input response_list TSeriesView objects and replaces the .t0 value with 0. # # """ for trace in response_list: trace.t0 = 0 #align traces for the use of TSeriesList().mean() funtion avg_trace = TSeriesList(response_list).mean( ) #returns the average of the wave form in a of a neuroanalysis.data.TSeries object bsub_mean = bsub( avg_trace ) #returns a copy of avg_trace but replaces the ephys waveform in .data with the base_line subtracted wave_form return bsub_mean
def get_tseries(self, series, bsub=True, align='stim', bsub_window=(-3e-3, 0)): """Return a TSeriesList of timeseries, optionally baseline-subtracted and time-aligned. Parameters ---------- series : str "stim", "pre", or "post" """ assert series in ( 'stim', 'pre', 'post'), "series must be one of 'stim', 'pre', or 'post'" tseries = [] for i, sr in enumerate(self.srs): ts = getattr(sr, series + '_tseries') if bsub: bstart = sr.stim_pulse.onset_time + bsub_window[0] bstop = sr.stim_pulse.onset_time + bsub_window[1] baseline = np.median(ts.time_slice(bstart, bstop).data) ts = ts - baseline if align is not None: if align == 'stim': t_align = sr.stim_pulse.onset_time elif align == 'pre': t_align = sr.stim_pulse.spikes[0].max_dvdt_time elif align == 'post': raise NotImplementedError() else: raise ValueError("invalid time alignment mode %r" % align) t_align = t_align or 0 ts = ts.copy(t0=ts.t0 - t_align) tseries.append(ts) return TSeriesList(tseries)
train_response, freqs, holding, thresh=sweep_threshold, ind_dict=grand_induction, offset_dict=offset_ind) grand_recovery, offset_rec = recovery_summary( train_response, t_rec, holding, thresh=sweep_threshold, rec_dict=grand_recovery, offset_dict=offset_rec) if len(grand_pulse_response) > 0: grand_pulse_trace = TSeriesList(grand_pulse_response).mean() p2 = trace_plot(grand_pulse_trace, color=avg_color, plot=p2, x_range=[0, 27e-3], name=('n = %d' % len(grand_pulse_response))) if len(grand_induction) > 0: for f, freq in enumerate(freqs): if freq in grand_induction: offset = offset_ind[freq] ind_pass_qc = train_qc(grand_induction[freq], offset, amp=amp_thresh, sign=sign) n = len(ind_pass_qc[0]) if n > 0:
def analyze_pair_connectivity(amps, sign=None): """Given response strength records for a single pair, generate summary statistics characterizing strength, latency, and connectivity. Parameters ---------- amps : dict Contains foreground and background strength analysis records (see input format below) sign : None, -1, or +1 If None, then automatically determine whether to treat this connection as inhibitory or excitatory. Input must have the following structure:: amps = { ('ic'): recs, ('vc'): recs, } Where each *recs* must be a structured array containing fields as returned by get_amps(). The overall strategy here is: 1. Make an initial decision on whether to treat this pair as excitatory or inhibitory, based on differences between foreground and background amplitude measurements 2. Generate mean and stdev for amplitudes, deconvolved amplitudes, and deconvolved latencies 3. Generate KS test p values describing the differences between foreground and background distributions for amplitude, deconvolved amplitude, and deconvolved latency """ # Filter by QC for k, v in amps.items(): mask = v['qc_pass'].astype(bool) amps[k] = v[mask] # See if any data remains if all([len(a) == 0 for a in amps]): return None requested_sign = sign fields = {} # used to fill the new DB record # Use KS p value to check for differences between foreground and background qc_amps = {} ks_pvals = {} amp_means = {} amp_diffs = {} for clamp_mode in ('ic', 'vc'): clamp_mode_amps = amps[clamp_mode] if len(clamp_mode_amps) == 0: continue for sign in ('pos', 'neg'): # Separate into positive/negative tests and filter out responses that failed qc qc_field = { 'vc': { 'pos': 'in_qc_pass', 'neg': 'ex_qc_pass' }, 'ic': { 'pos': 'ex_qc_pass', 'neg': 'in_qc_pass' } }[clamp_mode][sign] fg = clamp_mode_amps[clamp_mode_amps[qc_field]] qc_amps[sign, clamp_mode] = fg if len(fg) == 0: continue # Measure some statistics from these records bg = fg['baseline_' + sign + '_dec_amp'] fg = fg[sign + '_dec_amp'] pval = scipy.stats.ks_2samp(fg, bg).pvalue ks_pvals[(sign, clamp_mode)] = pval # we could ensure that the average amplitude is in the right direction: fg_mean = np.mean(fg) bg_mean = np.mean(bg) amp_means[sign, clamp_mode] = {'fg': fg_mean, 'bg': bg_mean} amp_diffs[sign, clamp_mode] = fg_mean - bg_mean if requested_sign is None: # Decide whether to treat this connection as excitatory or inhibitory. # strategy: accumulate evidence for either possibility by checking # the ks p-values for each sign/clamp mode and the direction of the deflection is_exc = 0 # print(expt.acq_timestamp, pair.pre_cell.ext_id, pair.post_cell.ext_id) for sign in ('pos', 'neg'): for mode in ('ic', 'vc'): ks = ks_pvals.get((sign, mode), None) if ks is None: continue # turn p value into a reasonable scale factor ks = norm_pvalue(ks) dif_sign = 1 if amp_diffs[sign, mode] > 0 else -1 if mode == 'vc': dif_sign *= -1 is_exc += dif_sign * ks # print(" ", sign, mode, is_exc, dif_sign * ks) else: is_exc = requested_sign if is_exc > 0: fields['synapse_type'] = 'ex' signs = {'ic': 'pos', 'vc': 'neg'} else: fields['synapse_type'] = 'in' signs = {'ic': 'neg', 'vc': 'pos'} # compute the rest of statistics for only positive or negative deflections for clamp_mode in ('ic', 'vc'): sign = signs[clamp_mode] fg = qc_amps.get((sign, clamp_mode)) if fg is None or len(fg) == 0: fields[clamp_mode + '_n_samples'] = 0 continue fields[clamp_mode + '_n_samples'] = len(fg) fields[clamp_mode + '_crosstalk_mean'] = np.mean(fg['crosstalk']) fields[clamp_mode + '_base_crosstalk_mean'] = np.mean( fg['baseline_crosstalk']) # measure mean, stdev, and statistical differences between # fg and bg for each measurement for val, field in [('amp', 'amp'), ('deconv_amp', 'dec_amp'), ('latency', 'dec_latency')]: f = fg[sign + '_' + field] b = fg['baseline_' + sign + '_' + field] fields[clamp_mode + '_' + val + '_mean'] = np.mean(f) fields[clamp_mode + '_' + val + '_stdev'] = np.std(f) fields[clamp_mode + '_base_' + val + '_mean'] = np.mean(b) fields[clamp_mode + '_base_' + val + '_stdev'] = np.std(b) # statistical tests comparing fg vs bg # Note: we use log(1-log(pval)) because it's nicer to plot and easier to # use as a classifier input tt_pval = scipy.stats.ttest_ind(f, b, equal_var=False).pvalue ks_pval = scipy.stats.ks_2samp(f, b).pvalue fields[clamp_mode + '_' + val + '_ttest'] = norm_pvalue(tt_pval) fields[clamp_mode + '_' + val + '_ks2samp'] = norm_pvalue(ks_pval) ### generate the average response and psp fit # collect all bg and fg traces fg_traces = TSeriesList() for rec in fg: if not np.isfinite( rec['max_slope_time']) or rec['max_slope_time'] is None: continue t0 = rec['response_start_time'] - rec[ 'max_slope_time'] # time-align to presynaptic spike trace = TSeries(rec['data'], sample_rate=db.default_sample_rate, t0=t0) fg_traces.append(trace) # get averages if len(fg_traces) == 0: continue # bg_avg = bg_traces.mean() fg_avg = fg_traces.mean() base_rgn = fg_avg.time_slice(-6e-3, 0) base = float_mode(base_rgn.data) fields[clamp_mode + '_average_response'] = fg_avg.data fields[clamp_mode + '_average_response_t0'] = fg_avg.t0 fields[clamp_mode + '_average_base_stdev'] = base_rgn.std() sign = {'pos': 1, 'neg': -1}[signs[clamp_mode]] fg_bsub = fg_avg.copy(data=fg_avg.data - base) # remove base to help fitting try: fit = fit_psp(fg_bsub, clamp_mode=clamp_mode, sign=sign, search_window=[0, 6e-3]) for param, val in fit.best_values.items(): fields['%s_fit_%s' % (clamp_mode, param)] = val fields[clamp_mode + '_fit_yoffset'] = fit.best_values['yoffset'] + base fields[clamp_mode + '_fit_nrmse'] = fit.nrmse() except: print("Error in PSP fit:") sys.excepthook(*sys.exc_info()) continue #global fit_plot #if fit_plot is None: #fit_plot = FitExplorer(fit) #fit_plot.show() #else: #fit_plot.set_fit(fit) #raw_input("Waiting to continue..") return fields
def add_connection_plots(i, name, timestamp, pre_id, post_id): global session, win, filtered p = pg.debug.Profiler(disabled=True, delayed=False) trace_plot = win.addPlot(i, 1) trace_plots.append(trace_plot) trace_plot.setYRange(-1.4e-3, 2.1e-3) # deconv_plot = win.addPlot(i, 2) # deconv_plots.append(deconv_plot) # deconv_plot.hide() hist_plot = win.addPlot(i, 2) hist_plots.append(hist_plot) limit_plot = win.addPlot(i, 3) limit_plot.addLegend() limit_plot.setLogMode(True, False) limit_plot.addLine(y=classifier.prob_threshold) # Find this connection in the pair list idx = np.argwhere((abs(filtered['acq_timestamp'] - timestamp) < 1) & (filtered['pre_cell_id'] == pre_id) & (filtered['post_cell_id'] == post_id)) if idx.size == 0: print("not in filtered connections") return idx = idx[0, 0] p() # Mark the point in scatter plot scatter_plot.plot([background[idx]], [signal[idx]], pen='k', symbol='o', size=10, symbolBrush='r', symbolPen=None) # Plot example traces and histograms for plts in [trace_plots]: #, deconv_plots]: plt = plts[-1] plt.setXLink(plts[0]) plt.setYLink(plts[0]) plt.setXRange(-10e-3, 17e-3, padding=0) plt.hideAxis('left') plt.hideAxis('bottom') plt.addLine(x=0) plt.setDownsampling(auto=True, mode='peak') plt.setClipToView(True) hbar = pg.QtGui.QGraphicsLineItem(0, 0, 2e-3, 0) hbar.setPen(pg.mkPen(color='k', width=5)) plt.addItem(hbar) vbar = pg.QtGui.QGraphicsLineItem(0, 0, 0, 100e-6) vbar.setPen(pg.mkPen(color='k', width=5)) plt.addItem(vbar) hist_plot.setXLink(hist_plots[0]) pair = session.query( db.Pair).filter(db.Pair.id == filtered[idx]['pair_id']).all()[0] p() amps = strength_analysis.get_amps(session, pair) p() base_amps = strength_analysis.get_baseline_amps(session, pair, amps=amps, clamp_mode='ic') p() q = strength_analysis.response_query(session) p() q = q.join(strength_analysis.PulseResponseStrength) q = q.filter(strength_analysis.PulseResponseStrength.id.in_( amps['id'])) q = q.join(db.MultiPatchProbe) q = q.filter(db.MultiPatchProbe.induction_frequency < 100) # pre_cell = db.aliased(db.Cell) # post_cell = db.aliased(db.Cell) # q = q.join(db.Pair).join(db.Experiment).join(pre_cell, db.Pair.pre_cell_id==pre_cell.id).join(post_cell, db.Pair.post_cell_id==post_cell.id) # q = q.filter(db.Experiment.id==filtered[idx]['experiment_id']) # q = q.filter(pre_cell.ext_id==pre_id) # q = q.filter(post_cell.ext_id==post_id) fg_recs = q.all() p() traces = [] deconvs = [] for i, rec in enumerate(fg_recs[:100]): result = strength_analysis.analyze_response_strength( rec, source='pulse_response', lpf=True, lowpass=2000, remove_artifacts=False, bsub=True) trace = result['raw_trace'] trace.t0 = -result['spike_time'] trace = trace - np.median(trace.time_slice(-0.5e-3, 0.5e-3).data) traces.append(trace) trace_plot.plot(trace.time_values, trace.data, pen=(0, 0, 0, 20)) write_csv( csv_file, trace, "Figure 3B; {name}; trace {trace_n}".format(name=name, trace_n=i)) # trace = result['dec_trace'] # trace.t0 = -result['spike_time'] # trace = trace - np.median(trace.time_slice(-0.5e-3, 0.5e-3).data) # deconvs.append(trace) # # deconv_plot.plot(trace.time_values, trace.data, pen=(0, 0, 0, 20)) # plot average trace mean = TSeriesList(traces).mean() trace_plot.plot(mean.time_values, mean.data, pen={ 'color': 'g', 'width': 2 }, shadowPen={ 'color': 'k', 'width': 3 }, antialias=True) write_csv(csv_file, mean, "Figure 3B; {name}; average".format(name=name)) # mean = TSeriesList(deconvs).mean() # # deconv_plot.plot(mean.time_values, mean.data, pen={'color':'g', 'width': 2}, shadowPen={'color':'k', 'width': 3}, antialias=True) # add label label = pg.LabelItem(name) label.setParentItem(trace_plot) p("analyze_response_strength") # bins = np.arange(-0.0005, 0.002, 0.0001) # field = 'pos_amp' bins = np.arange(-0.001, 0.015, 0.0005) field = 'pos_dec_amp' n = min(len(amps), len(base_amps)) hist_y, hist_bins = np.histogram(base_amps[:n][field], bins=bins) hist_plot.plot(hist_bins, hist_y, stepMode=True, pen=None, brush=(200, 0, 0, 150), fillLevel=0) write_csv( csv_file, hist_bins, "Figure 3C; {name}; background noise amplitude distribution bin edges (V)" .format(name=name)) write_csv( csv_file, hist_y, "Figure 3C; {name}; background noise amplitude distribution counts per bin" .format(name=name)) hist_y, hist_bins = np.histogram(amps[:n][field], bins=bins) hist_plot.plot(hist_bins, hist_y, stepMode=True, pen='k', brush=(0, 150, 150, 100), fillLevel=0) write_csv( csv_file, hist_bins, "Figure 3C; {name}; PSP amplitude distribution bin edges (V)". format(name=name)) write_csv( csv_file, hist_y, "Figure 3C; {name}; PSP amplitude distribution counts per bin". format(name=name)) p() pg.QtGui.QApplication.processEvents() # Plot detectability analysis q = strength_analysis.baseline_query(session) q = q.join(strength_analysis.BaselineResponseStrength) q = q.filter( strength_analysis.BaselineResponseStrength.id.in_(base_amps['id'])) # q = q.limit(100) bg_recs = q.all() def clicked(sp, pts): data = pts[0].data() print("-----------------------\nclicked:", data['rise_time'], data['amp'], data['prediction'], data['confidence']) for r in data['results']: print({k: r[k] for k in classifier.features}) traces = data['traces'] plt = pg.plot() bsub = [ t.copy(data=t.data - np.median(t.time_slice(0, 1e-3).data)) for t in traces ] for t in bsub: plt.plot(t.time_values, t.data, pen=(0, 0, 0, 50)) mean = TSeriesList(bsub).mean() plt.plot(mean.time_values, mean.data, pen='g') # def analyze_response_strength(recs, source, dtype): # results = [] # for i,rec in enumerate(recs): # result = strength_analysis.analyze_response_strength(rec, source) # results.append(result) # return str_analysis_result_table(results) # measure background connection strength bg_results = [ strength_analysis.analyze_response_strength(rec, 'baseline') for rec in bg_recs ] bg_results = strength_analysis.str_analysis_result_table( bg_results, bg_recs) # for this example, we use background data to simulate foreground # (but this will be biased due to lack of crosstalk in background data) fg_recs = bg_recs # now measure foreground simulated under different conditions amps = 2e-6 * 2**np.arange(9) amps[0] = 0 rtimes = [1e-3, 2e-3, 4e-3, 6e-3] dt = 1 / db.default_sample_rate results = np.empty((len(amps), len(rtimes)), dtype=[('results', object), ('predictions', object), ('confidence', object), ('traces', object), ('rise_time', float), ('amp', float)]) print(" Simulating synaptic events..") cachefile = 'fig_3_cache.pkl' if os.path.exists(cachefile): cache = pickle.load(open(cachefile, 'rb')) else: cache = {} pair_key = (timestamp, pre_id, post_id) pair_cache = cache.setdefault(pair_key, {}) for j, rtime in enumerate(rtimes): new_results = False for i, amp in enumerate(amps): print( "--------------------------------------- %d/%d %d/%d \r" % (i, len(amps), j, len(rtimes)), ) result = pair_cache.get((rtime, amp)) if result is None: result = strength_analysis.simulate_connection( fg_recs, bg_results, classifier, amp, rtime) pair_cache[rtime, amp] = result new_results = True for k, v in result.items(): results[i, j][k] = v x, y = amps, [np.mean(x) for x in results[:, j]['confidence']] c = limit_plot.plot(x, y, pen=pg.intColor(j, len(rtimes) * 1.3, maxValue=150), symbol='o', antialias=True, name="%dus" % (rtime * 1e6), data=results[:, j], symbolSize=4) write_csv( csv_file, x, "Figure 3D; {name}; {rise_time:0.3g} ms rise time; simulated PSP amplitude (V)" .format(name=name, rise_time=rtime * 1000)) write_csv( csv_file, y, "Figure 3D; {name}; {rise_time:0.3g} ms rise time; classifier decision probability" .format(name=name, rise_time=rtime * 1000)) c.scatter.sigClicked.connect(clicked) pg.QtGui.QApplication.processEvents() if new_results: pickle.dump(cache, open(cachefile, 'wb')) pg.QtGui.QApplication.processEvents()
def plot_element_data(self, pre_class, post_class, element, field_name, color='g', trace_plt=None): val = element[field_name].mean() line = pg.InfiniteLine(val, pen={'color': color, 'width': 2}, movable=False) scatter = None baseline_window = int(db.default_sample_rate * 5e-3) values = [] tracesA = [] tracesB = [] point_data = [] for pair, value in element[field_name].iteritems(): latency = self.results.loc[pair]['Latency'] trace_itemA = None trace_itemB = None if pair.has_synapse is not True: continue if np.isnan(value): continue syn_typ = pair.synapse.synapse_type rsf = pair.resting_state_fit if rsf is not None: nrmse = rsf.vc_nrmse if field_name.startswith('PSC') else rsf.ic_nrmse # if nrmse is None or nrmse > 0.8: # continue data = rsf.vc_avg_data if field_name.startswith('PSC') else rsf.ic_avg_data traceA = TSeries(data=data, sample_rate=db.default_sample_rate) if field_name.startswith('PSC'): traceA = bessel_filter(traceA, 5000, btype='low', bidir=True) bessel_filter(traceA, 5000, btype='low', bidir=True) start_time = rsf.vc_avg_data_start_time if field_name.startswith('PSC') else rsf.ic_avg_data_start_time if latency is not None and start_time is not None: if field_name == 'Latency': xoffset = start_time + latency else: xoffset = start_time - latency baseline_window = [abs(xoffset)-1e-3, abs(xoffset)] traceA = format_trace(traceA, baseline_window, x_offset=xoffset, align='psp') trace_itemA = trace_plt[1].plot(traceA.time_values, traceA.data) trace_itemA.pair = pair trace_itemA.curve.setClickable(True) trace_itemA.sigClicked.connect(self.trace_plot_clicked) tracesA.append(traceA) if field_name == 'Latency' and rsf.vc_nrmse is not None: #and rsf.vc_nrmse < 0.8: traceB = TSeries(data=rsf.vc_avg_data, sample_rate=db.default_sample_rate) traceB = bessel_filter(traceB, 5000, btype='low', bidir=True) start_time = rsf.vc_avg_data_start_time if latency is not None and start_time is not None: xoffset = start_time + latency baseline_window = [abs(xoffset)-1e-3, abs(xoffset)] traceB = format_trace(traceB, baseline_window, x_offset=xoffset, align='psp') trace_itemB = trace_plt[0].plot(traceB.time_values, traceB.data) trace_itemB.pair = pair trace_itemB.curve.setClickable(True) trace_itemB.sigClicked.connect(self.trace_plot_clicked) tracesB.append(traceB) self.pair_items[pair.id] = [trace_itemA, trace_itemB] if trace_itemA is not None: values.append(value) point_data.append(pair) y_values = pg.pseudoScatter(np.asarray(values, dtype=float), spacing=1) scatter = pg.ScatterPlotItem(symbol='o', brush=(color + (150,)), pen='w', size=12) scatter.setData(values, y_values + 10., data=point_data) for point in scatter.points(): pair_id = point.data().id self.pair_items[pair_id].extend([point, color]) scatter.sigClicked.connect(self.scatter_plot_clicked) if len(tracesA) > 0: if field_name == 'Latency': spike_line = pg.InfiniteLine(0, pen={'color': 'w', 'width': 1, 'style': pg.QtCore.Qt.DotLine}, movable=False) trace_plt[0].addItem(spike_line) x_label = 'Time from presynaptic spike' else: x_label = 'Response Onset' grand_trace = TSeriesList(tracesA).mean() name = ('%s->%s, n=%d' % (pre_class, post_class, len(tracesA))) trace_plt[1].plot(grand_trace.time_values, grand_trace.data, pen={'color': color, 'width': 3}, name=name) units = 'A' if field_name.startswith('PSC') else 'V' title = 'Voltage Clamp' if field_name.startswith('PSC') else 'Current Clamp' trace_plt[1].setXRange(-5e-3, 20e-3) trace_plt[1].setLabels(left=('', units), bottom=(x_label, 's')) trace_plt[1].setTitle(title) if len(tracesB) > 0: trace_plt[1].setLabels(right=('', units)) trace_plt[1].hideAxis('left') spike_line = pg.InfiniteLine(0, pen={'color': 'w', 'width': 1, 'style': pg.QtCore.Qt.DotLine}, movable=False) trace_plt[0].addItem(spike_line) grand_trace = TSeriesList(tracesB).mean() trace_plt[0].plot(grand_trace.time_values, grand_trace.data, pen={'color': color, 'width': 3}) trace_plt[0].setXRange(-5e-3, 20e-3) trace_plt[0].setLabels(left=('', 'A'), bottom=('Time from presynaptic spike', 's')) trace_plt[0].setTitle('Voltage Clamp') return line, scatter
offset_dict=pulse_offset_rec, uid=(expt.uid, pre, post)) for f, freq in enumerate(freqs): if freq not in induction_grand.keys(): print("%d Hz not represented in data set for %s" % (freq, c_type)) continue ind_offsets = pulse_offset_ind[freq] qc_plot.clear() ind_pass_qc = train_qc(induction_grand[freq], ind_offsets, amp=qc_params[1][c], sign=qc_params[0], plot=qc_plot) n_synapses = len(ind_pass_qc[0]) if n_synapses > 0: induction_grand_trace = TSeriesList(ind_pass_qc[0]).mean() ind_rec_grand_trace = TSeriesList(ind_pass_qc[1]).mean() ind_amp = train_amp(ind_pass_qc, ind_offsets, '+') ind_amp_grand = np.nanmean(ind_amp, 0) if f == 0: ind_plot[f, c].setTitle(connection_types[c]) type = pg.LabelItem('%s -> %s' % connection_types[c]) type.setParentItem(summary_plot[c, 0]) type.setPos(50, 0) if c == 0: label = pg.LabelItem('%d Hz Induction' % freq) label.setParentItem(ind_plot[f, c].vb) label.setPos(50, 0) summary_plot[c, 0].setTitle('Induction') ind_plot[f, c].addLegend()
def plot_prd_ids(self, ids, source, pen=None, trace_list=None, avg=False, qc_filter=None): """Plot raw or decolvolved PulseResponse data, given IDs of records in a db.PulseResponseStrength table. """ if qc_filter is None: qc_filter = self.ui.qc_check.isChecked() with pg.BusyCursor(): recs = self.get_pulse_recs(ids, source) if len(recs) == 0: return if source == 'fg': traces = self.selected_fg_traces plot = self.fg_trace_plot else: traces = self.selected_bg_traces plot = self.bg_trace_plot for i in trace_list[:]: plot.removeItem(i) trace_list.remove(i) if pen is None: alpha = np.clip(1000 / len(recs), 30, 255) pen = (255, 255, 255, alpha) pen = pg.mkPen(pen) # qc-failed traces are tinted red fail_color = pen.color() fail_color.setBlue(fail_color.blue() // 2) fail_color.setGreen(fail_color.green() // 2) qc_fail_pen = pg.mkPen(fail_color) traces = [] spike_times = [] spike_values = [] for rec in recs: # Filter by QC unless we selected just a single record qc_pass = getattr(rec, self.qc_field) is True if qc_filter is True and not qc_pass: continue s = {'fg': 'pulse_response', 'bg': 'baseline'}[source] filter_opts = dict( deconvolve=self.ui.deconv_check.isChecked(), lpf=self.ui.lpf_check.isChecked(), remove_artifacts=self.ui.ar_check.isChecked(), bsub=self.ui.bsub_check.isChecked(), ) result = analyze_response_strength(rec, source=s, **filter_opts) trace = result['dec_trace'] spike_values.append(trace.value_at([result['spike_time']])[0]) if self.ui.align_check.isChecked(): trace.t0 = -result['spike_time'] spike_times.append(0) else: spike_times.append(result['spike_time']) traces.append(trace) trace_list.append( plot.plot(trace.time_values, trace.data, pen=(pen if qc_pass else qc_fail_pen))) if avg and len(traces) > 0: mean = TSeriesList(traces).mean() trace_list.append( plot.plot(mean.time_values, mean.data, pen='g')) trace_list[-1].setZValue(10) spike_scatter = pg.ScatterPlotItem(spike_times, spike_values, size=4, pen=None, brush=(200, 200, 0)) spike_scatter.setZValue(-100) plot.addItem(spike_scatter) trace_list.append(spike_scatter)
def first_pulse_plot(expt_list, name=None, summary_plot=None, color=None, scatter=0, features=False): amp_plots = pg.plot() amp_plots.setLabels(left=('Vm', 'V')) grand_response = [] avg_amps = {'amp': [], 'latency': [], 'rise': []} for expt in expt_list: if expt.connections is not None: for pre, post in expt.connections: if expt.cells[pre].cre_type == cre_type[0] and expt.cells[ post].cre_type == cre_type[1]: all_responses, artifact = get_response( expt, pre, post, analysis_type='pulse') if artifact > 0.03e-3: continue filtered_responses = response_filter( all_responses, freq_range=[0, 50], holding_range=[-68, -72], pulse=True) n_sweeps = len(filtered_responses) if n_sweeps >= 10: avg_trace, avg_amp, amp_sign, _ = get_amplitude( filtered_responses) if expt.cells[ pre].cre_type in EXCITATORY_CRE_TYPES and avg_amp < 0: continue elif expt.cells[ pre].cre_type in INHIBITORY_CRE_TYPES and avg_amp > 0: continue avg_trace.t0 = 0 avg_amps['amp'].append(avg_amp) grand_response.append(avg_trace) if features is True: psp_fits = fit_psp(avg_trace, sign=amp_sign, yoffset=0, amp=avg_amp, method='leastsq', fit_kws={}) avg_amps['latency'].append( psp_fits.best_values['xoffset'] - 10e-3) avg_amps['rise'].append( psp_fits.best_values['rise_time']) current_connection_HS = post, pre if len(expt.connections) > 1 and args.recip is True: for i, x in enumerate(expt.connections): if x == current_connection_HS: # determine if a reciprocal connection amp_plots.plot(avg_trace.time_values, avg_trace.data, pen={ 'color': 'r', 'width': 1 }) break elif x != current_connection_HS and i == len( expt.connections ) - 1: # reciprocal connection was not found amp_plots.plot(avg_trace.time_values, avg_trace.data) else: amp_plots.plot(avg_trace.time_values, avg_trace.data) app.processEvents() if len(grand_response) != 0: print(name + ' n = %d' % len(grand_response)) grand_mean = TSeriesList(grand_response).mean() grand_amp = np.mean(np.array(avg_amps['amp'])) grand_amp_sem = stats.sem(np.array(avg_amps['amp'])) amp_plots.addLegend() amp_plots.plot(grand_mean.time_values, grand_mean.data, pen={ 'color': 'g', 'width': 3 }, name=name) amp_plots.addLine(y=grand_amp, pen={'color': 'g'}) if grand_mean is not None: print(legend + ' Grand mean amplitude = %f +- %f' % (grand_amp, grand_amp_sem)) if features is True: feature_list = (avg_amps['amp'], avg_amps['latency'], avg_amps['rise']) labels = (['Vm', 'V'], ['t', 's'], ['t', 's']) titles = ('Amplitude', 'Latency', 'Rise time') else: feature_list = [avg_amps['amp']] labels = (['Vm', 'V']) titles = 'Amplitude' summary_plots = summary_plot_pulse(feature_list[0], labels=labels, titles=titles, i=scatter, grand_trace=grand_mean, plot=summary_plot, color=color, name=legend) return avg_amps, summary_plots else: print("No TSeries") return avg_amps, None
def train_response_plot(expt_list, name=None, summary_plots=[None, None], color=None): grand_train = [[], []] train_plots = pg.plot() train_plots.setLabels(left=('Vm', 'V')) tau = 15e-3 lp = 1000 for expt in expt_list: for pre, post in expt.connections: if expt.cells[pre].cre_type == cre_type[0] and expt.cells[ post].cre_type == cre_type[1]: print('Processing experiment: %s' % (expt.nwb_file)) train_responses, artifact = get_response(expt, pre, post, analysis_type='train') if artifact > 0.03e-3: continue train_filter = response_filter(train_responses['responses'], freq_range=[50, 50], train=0, delta_t=250) pulse_offsets = response_filter( train_responses['pulse_offsets'], freq_range=[50, 50], train=0, delta_t=250) if len(train_filter[0]) > 5: ind_avg = TSeriesList(train_filter[0]).mean() rec_avg = TSeriesList(train_filter[1]).mean() rec_avg.t0 = 0.3 grand_train[0].append(ind_avg) grand_train[1].append(rec_avg) train_plots.plot(ind_avg.time_values, ind_avg.data) train_plots.plot(rec_avg.time_values, rec_avg.data) app.processEvents() if len(grand_train[0]) != 0: print(name + ' n = %d' % len(grand_train[0])) ind_grand_mean = TSeriesList(grand_train[0]).mean() rec_grand_mean = TSeriesList(grand_train[1]).mean() ind_grand_mean_dec = bessel_filter(exp_deconvolve(ind_grand_mean, tau), lp) train_plots.addLegend() train_plots.plot(ind_grand_mean.time_values, ind_grand_mean.data, pen={ 'color': 'g', 'width': 3 }, name=name) train_plots.plot(rec_grand_mean.time_values, rec_grand_mean.data, pen={ 'color': 'g', 'width': 3 }, name=name) train_amps = train_amp([grand_train[0], grand_train[1]], pulse_offsets, '+') if ind_grand_mean is not None: train_plots = summary_plot_train(ind_grand_mean, plot=summary_plots[0], color=color, name=(legend + ' 50 Hz induction')) train_plots = summary_plot_train(rec_grand_mean, plot=summary_plots[0], color=color) train_plots2 = summary_plot_train(ind_grand_mean_dec, plot=summary_plots[1], color=color, name=(legend + ' 50 Hz induction')) return train_plots, train_plots2, train_amps else: print("No TSeries") return None
def _get_tserieslist(self, ts_name, align, bsub, bsub_win=5e-3, alignment_failure_mode='ignore'): tsl = [] if align is not None and alignment_failure_mode == 'average': if align == 'spike': average_align_t = np.mean([ p.stim_pulse.first_spike_time for p in self.prs if p.stim_pulse.first_spike_time is not None ]) elif align == 'peak': average_align_t = np.mean([ p.stim_pulse.spikes[0].peak_time for p in self.prs if p.stim_pulse.n_spikes == 1 and p.stim_pulse.spikes[0].peak_time is not None ]) elif align == 'stim': average_align_t = np.mean([ p.stim_pulse.onset_time for p in self.prs if p.stim_pulse.onset_time is not None ]) else: raise ValueError( "align must be None, 'spike', 'peak', or 'pulse'.") for pr in self.prs: ts = getattr(pr, ts_name) stim_time = pr.stim_pulse.onset_time if bsub is True: start_time = max(ts.t0, stim_time - bsub_win) baseline_data = ts.time_slice(start_time, stim_time).data if len(baseline_data) == 0: baseline = ts.data[0] else: baseline = float_mode(baseline_data) ts = ts - baseline if align is not None: if align == 'spike': # first_spike_time is the max dv/dt of the spike align_t = pr.stim_pulse.first_spike_time elif align == 'pulse': align_t = stim_time elif align == 'peak': # peak of the first spike align_t = pr.stim_pulse.spikes[ 0].peak_time if pr.stim_pulse.n_spikes == 1 else None else: raise ValueError( "align must be None, 'spike', 'peak', or 'pulse'.") if align_t is None: if alignment_failure_mode == 'ignore': # ignore PRs with no known timing continue elif alignment_failure_mode == 'average': align_t = average_align_t if np.isnan(align_t): raise Exception( "average %s time is None, try another mode" % align) elif alignment_failure_mode == 'raise': raise Exception( "%s time is not available for pulse %s and can't be aligned" % (align, pr)) ts = ts.copy(t0=ts.t0 - align_t) tsl.append(ts) return TSeriesList(tsl)
def plot_features(organism=None, conn_type=None, calcium=None, age=None, sweep_thresh=None, fit_thresh=None): s = db.session() filters = { 'organism': organism, 'conn_type': conn_type, 'calcium': calcium, 'age': age } selection = [{}] for key, value in filters.iteritems(): if value is not None: temp_list = [] value_list = value.split(',') for v in value_list: temp = [s1.copy() for s1 in selection] for t in temp: t[key] = v temp_list = temp_list + temp selection = list(temp_list) if len(selection) > 0: response_grid = PlotGrid() response_grid.set_shape(len(selection), 1) response_grid.show() feature_grid = PlotGrid() feature_grid.set_shape(6, 1) feature_grid.show() for i, select in enumerate(selection): pre_cell = aliased(db.Cell) post_cell = aliased(db.Cell) q_filter = [] if sweep_thresh is not None: q_filter.append(FirstPulseFeatures.n_sweeps >= sweep_thresh) species = select.get('organism') if species is not None: q_filter.append(db.Slice.species == species) c_type = select.get('conn_type') if c_type is not None: pre_type, post_type = c_type.split('-') pre_layer, pre_cre = pre_type.split(';') if pre_layer == 'None': pre_layer = None post_layer, post_cre = post_type.split(';') if post_layer == 'None': post_layer = None q_filter.extend([ pre_cell.cre_type == pre_cre, pre_cell.target_layer == pre_layer, post_cell.cre_type == post_cre, post_cell.target_layer == post_layer ]) calc_conc = select.get('calcium') if calc_conc is not None: q_filter.append(db.Experiment.acsf.like(calc_conc + '%')) age_range = select.get('age') if age_range is not None: age_lower, age_upper = age_range.split('-') q_filter.append( db.Slice.age.between(int(age_lower), int(age_upper))) q = s.query(FirstPulseFeatures).join(db.Pair, FirstPulseFeatures.pair_id==db.Pair.id)\ .join(pre_cell, db.Pair.pre_cell_id==pre_cell.id)\ .join(post_cell, db.Pair.post_cell_id==post_cell.id)\ .join(db.Experiment, db.Experiment.id==db.Pair.expt_id)\ .join(db.Slice, db.Slice.id==db.Experiment.slice_id) for filter_arg in q_filter: q = q.filter(filter_arg) results = q.all() trace_list = [] for pair in results: #TODO set t0 to latency to align to foot of PSP trace = TSeries(data=pair.avg_psp, sample_rate=db.default_sample_rate) trace_list.append(trace) response_grid[i, 0].plot(trace.time_values, trace.data) if len(trace_list) > 0: grand_trace = TSeriesList(trace_list).mean() response_grid[i, 0].plot(grand_trace.time_values, grand_trace.data, pen='b') response_grid[i, 0].setTitle( 'layer %s, %s-> layer %s, %s; n_synapses = %d' % (pre_layer, pre_cre, post_layer, post_cre, len(trace_list))) else: print('No synapses for layer %s, %s-> layer %s, %s' % (pre_layer, pre_cre, post_layer, post_cre)) return response_grid, feature_grid
def plot_element_data(self, pre_class, post_class, element, field_name, color='g', trace_plt=None): summary = element.agg(self.summary_stat) val = summary[field_name]['metric_summary'] line = pg.InfiniteLine(val, pen={'color': color, 'width': 2}, movable=False) scatter = None tracesA = [] tracesB = [] connections = element[element['Connected'] == True].index.tolist() for pair in connections: # rsf = pair.resting_state_fit synapse = pair.synapse if synapse is None: continue arfs = pair.avg_response_fits latency = pair.synapse.latency syn_typ = pair.synapse.synapse_type self.pair_items[pair.id] = [] trace_itemA = None trace_itemB = None # if rsf is not None: # traceA = TSeries(data=rsf.ic_avg_data, sample_rate=db.default_sample_rate) # start_time = rsf.ic_avg_data_start_time # if latency is not None and start_time is not None: # xoffset = start_time - latency # baseline_window = [abs(xoffset)-1e-3, abs(xoffset)] # traceA = format_trace(traceA, baseline_window, x_offset=xoffset, align='psp') # trace_itemA = trace_plt[0].plot(traceA.time_values, traceA.data) # trace_itemA.pair = pair # trace_itemA.curve.setClickable(True) # trace_itemA.sigClicked.connect(self.trace_plot_clicked) # self.pair_items[pair.id].append(trace_itemA) # tracesA.append(traceA) if arfs is not None: for arf in arfs: if arf.holding in syn_typ_holding[syn_typ] and arf.manual_qc_pass is True and latency is not None: if arf.clamp_mode == 'vc' and trace_itemA is None: traceA = TSeries(data=arf.avg_data, sample_rate=db.default_sample_rate) traceA = bessel_filter(traceA, 5000, btype='low', bidir=True) start_time = arf.avg_data_start_time if start_time is not None: xoffset = start_time - latency baseline_window = [abs(xoffset)-1e-3, abs(xoffset)] traceA = format_trace(traceA, baseline_window, x_offset=xoffset, align='psp') trace_itemA = trace_plt[0].plot(traceA.time_values, traceA.data) trace_itemA.pair = pair trace_itemA.curve.setClickable(True) trace_itemA.sigClicked.connect(self.trace_plot_clicked) self.pair_items[pair.id].append(trace_itemA) tracesA.append(traceA) if arf.clamp_mode == 'ic' and trace_itemB is None: traceB = TSeries(data=arf.avg_data, sample_rate=db.default_sample_rate) start_time = arf.avg_data_start_time if latency is not None and start_time is not None: xoffset = start_time - latency baseline_window = [abs(xoffset)-1e-3, abs(xoffset)] traceB = format_trace(traceB, baseline_window, x_offset=xoffset, align='psp') trace_itemB = trace_plt[1].plot(traceB.time_values, traceB.data) trace_itemB.pair = pair trace_itemB.curve.setClickable(True) trace_itemB.sigClicked.connect(self.trace_plot_clicked) tracesB.append(traceB) self.pair_items[pair.id] = [trace_itemA, trace_itemB] if len(tracesA) > 0: grand_trace = TSeriesList(tracesA).mean() name = ('%s->%s' % (pre_class, post_class)) # trace_plt[0].addLegend() trace_plt[0].plot(grand_trace.time_values, grand_trace.data, pen={'color': color, 'width': 3}, name=name) trace_plt[0].setXRange(-5e-3, 20e-3) trace_plt[0].setLabels(left=('', 'A'), bottom=('Response Onset', 's')) trace_plt[0].setTitle('Voltage Clamp') if len(tracesB) > 0: grand_trace = TSeriesList(tracesB).mean() trace_plt[1].plot(grand_trace.time_values, grand_trace.data, pen={'color': color, 'width': 3}) trace_plt[1].setLabels(right=('', 'V'), bottom=('Response Onset', 's')) trace_plt[1].setTitle('Current Clamp') return line, scatter