def simulate_response(fg_recs, bg_results, amp, rtime, seed=None): if seed is not None: np.random.seed(seed) dt = 1.0 / db.default_sample_rate t = np.arange(0, 15e-3, dt) template = Psp.psp_func(t, xoffset=0, yoffset=0, rise_time=rtime, decay_tau=15e-3, amp=1, rise_power=2) r_amps = scipy.stats.binom.rvs(p=0.2, n=24, size=len(fg_recs)) * scipy.stats.norm.rvs(scale=0.3, loc=1, size=len(fg_recs)) r_amps *= amp / r_amps.mean() r_latency = np.random.normal(size=len(fg_recs), scale=200e-6, loc=13e-3) fg_results = [] traces = [] fg_recs = [RecordWrapper(rec) for rec in fg_recs] # can't modify fg_recs, so we wrap records with a mutable shell for k,rec in enumerate(fg_recs): rec.data = rec.data.copy() start = int(r_latency[k] * db.default_sample_rate) length = len(rec.data) - start rec.data[start:] += template[:length] * r_amps[k] fg_result = analyze_response_strength(rec, 'baseline') fg_results.append(fg_result) traces.append(Trace(rec.data, sample_rate=db.default_sample_rate)) traces[-1].amp = r_amps[k] fg_results = str_analysis_result_table(fg_results, fg_recs) conn_result = analyze_pair_connectivity({('ic', 'fg'): fg_results, ('ic', 'bg'): bg_results, ('vc', 'fg'): [], ('vc', 'bg'): []}, sign=1) return conn_result, traces
def add_connection_plots(i, name, timestamp, pre_id, post_id): global session, win, filtered p = pg.debug.Profiler(disabled=True, delayed=False) trace_plot = win.addPlot(i, 1) trace_plots.append(trace_plot) deconv_plot = win.addPlot(i, 2) deconv_plots.append(deconv_plot) hist_plot = win.addPlot(i, 3) hist_plots.append(hist_plot) limit_plot = win.addPlot(i, 4) limit_plot.addLegend() limit_plot.setLogMode(True, True) # Find this connection in the pair list idx = np.argwhere((abs(filtered['acq_timestamp'] - timestamp) < 1) & (filtered['pre_cell_id'] == pre_id) & (filtered['post_cell_id'] == post_id)) if idx.size == 0: print("not in filtered connections") return idx = idx[0,0] p() # Mark the point in scatter plot scatter_plot.plot([background[idx]], [signal[idx]], pen='k', symbol='o', size=10, symbolBrush='r', symbolPen=None) # Plot example traces and histograms for plts in [trace_plots, deconv_plots]: plt = plts[-1] plt.setXLink(plts[0]) plt.setYLink(plts[0]) plt.setXRange(-10e-3, 17e-3, padding=0) plt.hideAxis('left') plt.hideAxis('bottom') plt.addLine(x=0) plt.setDownsampling(auto=True, mode='peak') plt.setClipToView(True) hbar = pg.QtGui.QGraphicsLineItem(0, 0, 2e-3, 0) hbar.setPen(pg.mkPen(color='k', width=5)) plt.addItem(hbar) vbar = pg.QtGui.QGraphicsLineItem(0, 0, 0, 100e-6) vbar.setPen(pg.mkPen(color='k', width=5)) plt.addItem(vbar) hist_plot.setXLink(hist_plots[0]) pair = session.query(db.Pair).filter(db.Pair.id==filtered[idx]['pair_id']).all()[0] p() amps = strength_analysis.get_amps(session, pair) p() base_amps = strength_analysis.get_baseline_amps(session, pair) p() q = strength_analysis.response_query(session) p() q = q.join(strength_analysis.PulseResponseStrength) q = q.filter(strength_analysis.PulseResponseStrength.id.in_(amps['id'])) q = q.join(db.Recording, db.Recording.id==db.PulseResponse.recording_id).join(db.PatchClampRecording).join(db.MultiPatchProbe) q = q.filter(db.MultiPatchProbe.induction_frequency < 100) # pre_cell = db.aliased(db.Cell) # post_cell = db.aliased(db.Cell) # q = q.join(db.Pair).join(db.Experiment).join(pre_cell, db.Pair.pre_cell_id==pre_cell.id).join(post_cell, db.Pair.post_cell_id==post_cell.id) # q = q.filter(db.Experiment.id==filtered[idx]['experiment_id']) # q = q.filter(pre_cell.ext_id==pre_id) # q = q.filter(post_cell.ext_id==post_id) fg_recs = q.all() p() traces = [] deconvs = [] for rec in fg_recs[:100]: result = strength_analysis.analyze_response_strength(rec, source='pulse_response', lpf=True, lowpass=2000, remove_artifacts=False, bsub=True) trace = result['raw_trace'] trace.t0 = -result['spike_time'] trace = trace - np.median(trace.time_slice(-0.5e-3, 0.5e-3).data) traces.append(trace) trace_plot.plot(trace.time_values, trace.data, pen=(0, 0, 0, 20)) trace = result['dec_trace'] trace.t0 = -result['spike_time'] trace = trace - np.median(trace.time_slice(-0.5e-3, 0.5e-3).data) deconvs.append(trace) deconv_plot.plot(trace.time_values, trace.data, pen=(0, 0, 0, 20)) # plot average trace mean = TraceList(traces).mean() trace_plot.plot(mean.time_values, mean.data, pen={'color':'g', 'width': 2}, shadowPen={'color':'k', 'width': 3}, antialias=True) mean = TraceList(deconvs).mean() deconv_plot.plot(mean.time_values, mean.data, pen={'color':'g', 'width': 2}, shadowPen={'color':'k', 'width': 3}, antialias=True) # add label label = pg.LabelItem(name) label.setParentItem(trace_plot) p("analyze_response_strength") # bins = np.arange(-0.0005, 0.002, 0.0001) # field = 'pos_amp' bins = np.arange(-0.001, 0.015, 0.0005) field = 'pos_dec_amp' n = min(len(amps), len(base_amps)) hist_y, hist_bins = np.histogram(base_amps[:n][field], bins=bins) hist_plot.plot(hist_bins, hist_y, stepMode=True, pen=None, brush=(200, 0, 0, 150), fillLevel=0) hist_y, hist_bins = np.histogram(amps[:n][field], bins=bins) hist_plot.plot(hist_bins, hist_y, stepMode=True, pen='k', brush=(0, 150, 150, 100), fillLevel=0) p() pg.QtGui.QApplication.processEvents() # Plot detectability analysis q = strength_analysis.baseline_query(session) q = q.join(strength_analysis.BaselineResponseStrength) q = q.filter(strength_analysis.BaselineResponseStrength.id.in_(base_amps['id'])) # q = q.limit(100) bg_recs = q.all() def clicked(sp, pts): traces = pts[0].data()['traces'] print([t.amp for t in traces]) plt = pg.plot() bsub = [t.copy(data=t.data - np.median(t.time_slice(0, 1e-3).data)) for t in traces] for t in bsub: plt.plot(t.time_values, t.data, pen=(0, 0, 0, 50)) mean = TraceList(bsub).mean() plt.plot(mean.time_values, mean.data, pen='g') # first measure background a few times N = len(fg_recs) N = 50 # temporary for testing print("Testing %d trials" % N) bg_results = [] M = 500 print(" Grinding on %d background trials" % len(bg_recs)) for i in range(M): amps = base_amps.copy() np.random.shuffle(amps) bg_results.append(np.median(amps[:N]['pos_dec_amp']) / np.std(amps[:N]['pos_dec_latency'])) print(" %d/%d \r" % (i, M),) print(" done. ") print(" ", bg_results) # now measure foreground simulated under different conditions amps = 5e-6 * 2**np.arange(6) amps[0] = 0 rtimes = 1e-3 * 1.71**np.arange(4) dt = 1 / db.default_sample_rate results = np.empty((len(amps), len(rtimes)), dtype=[('pos_dec_amp', float), ('latency_stdev', float), ('result', float), ('percentile', float), ('traces', object)]) print(" Simulating synaptic events..") for j,rtime in enumerate(rtimes): for i,amp in enumerate(amps): trial_results = [] t = np.arange(0, 15e-3, dt) template = Psp.psp_func(t, xoffset=0, yoffset=0, rise_time=rtime, decay_tau=15e-3, amp=1, rise_power=2) for l in range(20): print(" %d/%d %d/%d \r" % (i,len(amps),j,len(rtimes)),) r_amps = amp * 2**np.random.normal(size=N, scale=0.5) r_latency = np.random.normal(size=N, scale=600e-6, loc=12.5e-3) fg_results = [] traces = [] np.random.shuffle(bg_recs) for k,rec in enumerate(bg_recs[:N]): data = rec.data.copy() start = int(r_latency[k] / dt) length = len(rec.data) - start rec.data[start:] += template[:length] * r_amps[k] fg_result = strength_analysis.analyze_response_strength(rec, 'baseline') fg_results.append((fg_result['pos_dec_amp'], fg_result['pos_dec_latency'])) traces.append(Trace(rec.data.copy(), dt=dt)) traces[-1].amp = r_amps[k] rec.data[:] = data # can't modify rec, so we have to muck with the array (and clean up afterward) instead fg_amp = np.array([r[0] for r in fg_results]) fg_latency = np.array([r[1] for r in fg_results]) trial_results.append(np.median(fg_amp) / np.std(fg_latency)) results[i,j]['result'] = np.median(trial_results) / np.median(bg_results) results[i,j]['percentile'] = stats.percentileofscore(bg_results, results[i,j]['result']) results[i,j]['traces'] = traces assert all(np.isfinite(results[i]['pos_dec_amp'])) print(i, results[i]['result']) print(i, results[i]['percentile']) # c = limit_plot.plot(rtimes, results[i]['result'], pen=(i, len(amps)*1.3), symbol='o', antialias=True, name="%duV"%(amp*1e6), data=results[i], symbolSize=4) # c.scatter.sigClicked.connect(clicked) # pg.QtGui.QApplication.processEvents() c = limit_plot.plot(amps, results[:,j]['result'], pen=(j, len(rtimes)*1.3), symbol='o', antialias=True, name="%dus"%(rtime*1e6), data=results[:,j], symbolSize=4) c.scatter.sigClicked.connect(clicked) pg.QtGui.QApplication.processEvents() pg.QtGui.QApplication.processEvents()