def evoked_spikes(self): """Given presynaptic Recording, detect action potentials evoked by current injection or unclamped spikes evoked by a voltage pulse. """ if self._evoked_spikes is None: pre_trace = self.rec['primary'] # Detect pulse times pulses = self.pulses() # detect spike times spike_info = [] for i, pulse in enumerate(pulses): on, off, amp = pulse if amp < 0: # assume negative pulses do not evoke spikes # (todo: should be watching for rebound spikes as well) continue spike = detect_evoked_spike(self.rec, [on, off]) spike_info.append({ 'pulse_n': i, 'pulse_ind': on, 'spike': spike }) self._evoked_spikes = spike_info return self._evoked_spikes
def test_spike_detection(): # Need to fill this function up with many more tests, especially # measuring against real data. dt = 10 * us start = 5 * ms duration = 2 * ms pulse_edges = int(start / dt), int((start + duration) / dt) resp = create_test_pulse(start=5 * ms, pamp=100 * pA, pdur=2 * ms, mode='ic', dt=dt) spike = detect_evoked_spike(resp, pulse_edges) assert spike is None resp = create_test_pulse(start=5 * ms, pamp=1000 * pA, pdur=2 * ms, mode='ic', dt=dt) spike = detect_evoked_spike(resp, pulse_edges) assert spike is not None
def evoked_spikes(self): """Given presynaptic Recording, detect action potentials evoked by current injection or unclamped spikes evoked by a voltage pulse. """ if self._evoked_spikes is None: pre_trace = self.rec['primary'] # Detect pulse times pulses = self.pulses() # detect spike times spike_info = [] for i,pulse in enumerate(pulses): on, off, amp = pulse if amp < 0: # assume negative pulses do not evoke spikes # (todo: should be watching for rebound spikes as well) continue spike = detect_evoked_spike(self.rec, [on, off]) spike_info.append({'pulse_n': i, 'pulse_ind': on, 'pulse_len': off-on, 'spike': spike}) self._evoked_spikes = spike_info return self._evoked_spikes
def _update_plots(self): sweeps = self.sweeps self.current_event_set = None self.event_table.clear() # clear all plots self.pre_plot.clear() self.post_plot.clear() pre = self.params['pre'] post = self.params['post'] # If there are no selected sweeps or channels have not been set, return if len( sweeps ) == 0 or pre == post or pre not in self.channels or post not in self.channels: return pre_mode = sweeps[0][pre].clamp_mode post_mode = sweeps[0][post].clamp_mode for ch, mode, plot in [(pre, pre_mode, self.pre_plot), (post, post_mode, self.post_plot)]: units = 'A' if mode == 'vc' else 'V' plot.setLabels(left=("Channel %d" % ch, units), bottom=("Time", 's')) # Iterate over selected channels of all sweeps, plotting traces one at a time # Collect information about pulses and spikes pulses = [] spikes = [] post_traces = [] for i, sweep in enumerate(sweeps): pre_trace = sweep[pre]['primary'] post_trace = sweep[post]['primary'] # Detect pulse times stim = sweep[pre]['command'].data sdiff = np.diff(stim) on_times = np.argwhere(sdiff > 0)[1:, 0] # 1: skips test pulse off_times = np.argwhere(sdiff < 0)[1:, 0] pulses.append(on_times) # filter data post_filt = self.artifact_remover.process( post_trace, list(on_times) + list(off_times)) post_filt = self.baseline_remover.process(post_filt) post_filt = self.filter.process(post_filt) post_traces.append(post_filt) # plot raw data color = pg.intColor(i, hues=len(sweeps) * 1.3, sat=128) color.setAlpha(128) for trace, plot in [(pre_trace, self.pre_plot), (post_filt, self.post_plot)]: plot.plot(trace.time_values, trace.data, pen=color, antialias=False) # detect spike times spike_inds = [] spike_info = [] for on, off in zip(on_times, off_times): spike = detect_evoked_spike(sweep[pre], [on, off]) spike_info.append(spike) if spike is None: spike_inds.append(None) else: spike_inds.append(spike['rise_index']) spikes.append(spike_info) dt = pre_trace.dt vticks = pg.VTickGroup( [x * dt for x in spike_inds if x is not None], yrange=[0.0, 0.2], pen=color) self.pre_plot.addItem(vticks) # Iterate over spikes, plotting average response all_responses = [] avg_responses = [] fits = [] fit = None npulses = max(map(len, pulses)) self.response_plots.clear() self.response_plots.set_shape(1, npulses + 1) # 1 extra for global average self.response_plots.setYLink(self.response_plots[0, 0]) for i in range(1, npulses + 1): self.response_plots[0, i].hideAxis('left') units = 'A' if post_mode == 'vc' else 'V' self.response_plots[0, 0].setLabels(left=("Averaged events (Channel %d)" % post, units)) fit_pen = {'color': (30, 30, 255), 'width': 2, 'dash': [1, 1]} for i in range(npulses): # get the chunk of each sweep between spikes responses = [] all_responses.append(responses) for j, sweep in enumerate(sweeps): # get the current spike if i >= len(spikes[j]): continue spike = spikes[j][i] if spike is None: continue # find next spike next_spike = None for sp in spikes[j][i + 1:]: if sp is not None: next_spike = sp break # determine time range for response max_len = int(40e-3 / dt) # don't take more than 50ms for any response start = spike['rise_index'] if next_spike is not None: stop = min(start + max_len, next_spike['rise_index']) else: stop = start + max_len # collect data from this trace trace = post_traces[j] d = trace.data[start:stop].copy() responses.append(d) if len(responses) == 0: continue # extend all responses to the same length and take nanmean avg = ragged_mean(responses, method='clip') avg -= float_mode(avg[:int(1e-3 / dt)]) avg_responses.append(avg) # plot average response for this pulse start = np.median( [sp[i]['rise_index'] for sp in spikes if sp[i] is not None]) * dt t = np.arange(len(avg)) * dt self.response_plots[0, i].plot(t, avg, pen='w', antialias=True) # fit! fit = self.fit_psp(avg, t, dt, post_mode) fits.append(fit) # let the user mess with this fit curve = self.response_plots[0, i].plot(t, fit.eval(), pen=fit_pen, antialias=True).curve curve.setClickable(True) curve.fit = fit curve.sigClicked.connect(self.fit_curve_clicked) # display global average global_avg = ragged_mean(avg_responses, method='clip') t = np.arange(len(global_avg)) * dt self.response_plots[0, -1].plot(t, global_avg, pen='w', antialias=True) global_fit = self.fit_psp(global_avg, t, dt, post_mode) self.response_plots[0, -1].plot(t, global_fit.eval(), pen=fit_pen, antialias=True) # display fit parameters in table events = [] for i, f in enumerate(fits + [global_fit]): if f is None: continue if i >= len(fits): vals = OrderedDict([('id', 'avg'), ('spike_time', np.nan), ('spike_stdev', np.nan)]) else: spt = [ s[i]['peak_index'] * dt for s in spikes if s[i] is not None ] vals = OrderedDict([('id', i), ('spike_time', np.mean(spt)), ('spike_stdev', np.std(spt))]) vals.update( OrderedDict([(k, f.best_values[k]) for k in f.params.keys()])) events.append(vals) self.current_event_set = (pre, post, events, sweeps) self.event_set_list.setCurrentRow(0) self.event_set_selected()
def _update_plots(self): sweeps = self.sweeps self.current_event_set = None self.event_table.clear() # clear all plots self.pre_plot.clear() self.post_plot.clear() pre = self.params['pre'] post = self.params['post'] # If there are no selected sweeps or channels have not been set, return if len(sweeps) == 0 or pre == post or pre not in self.channels or post not in self.channels: return pre_mode = sweeps[0][pre].clamp_mode post_mode = sweeps[0][post].clamp_mode for ch, mode, plot in [(pre, pre_mode, self.pre_plot), (post, post_mode, self.post_plot)]: units = 'A' if mode == 'vc' else 'V' plot.setLabels(left=("Channel %d" % ch, units), bottom=("Time", 's')) # Iterate over selected channels of all sweeps, plotting traces one at a time # Collect information about pulses and spikes pulses = [] spikes = [] post_traces = [] for i,sweep in enumerate(sweeps): pre_trace = sweep[pre]['primary'] post_trace = sweep[post]['primary'] # Detect pulse times stim = sweep[pre]['command'].data sdiff = np.diff(stim) on_times = np.argwhere(sdiff > 0)[1:, 0] # 1: skips test pulse off_times = np.argwhere(sdiff < 0)[1:, 0] pulses.append(on_times) # filter data post_filt = self.artifact_remover.process(post_trace, list(on_times) + list(off_times)) post_filt = self.baseline_remover.process(post_filt) post_filt = self.filter.process(post_filt) post_traces.append(post_filt) # plot raw data color = pg.intColor(i, hues=len(sweeps)*1.3, sat=128) color.setAlpha(128) for trace, plot in [(pre_trace, self.pre_plot), (post_filt, self.post_plot)]: plot.plot(trace.time_values, trace.data, pen=color, antialias=False) # detect spike times spike_inds = [] spike_info = [] for on, off in zip(on_times, off_times): spike = detect_evoked_spike(sweep[pre], [on, off]) spike_info.append(spike) if spike is None: spike_inds.append(None) else: spike_inds.append(spike['rise_index']) spikes.append(spike_info) dt = pre_trace.dt vticks = pg.VTickGroup([x * dt for x in spike_inds if x is not None], yrange=[0.0, 0.2], pen=color) self.pre_plot.addItem(vticks) # Iterate over spikes, plotting average response all_responses = [] avg_responses = [] fits = [] fit = None npulses = max(map(len, pulses)) self.response_plots.clear() self.response_plots.set_shape(1, npulses+1) # 1 extra for global average self.response_plots.setYLink(self.response_plots[0,0]) for i in range(1, npulses+1): self.response_plots[0,i].hideAxis('left') units = 'A' if post_mode == 'vc' else 'V' self.response_plots[0, 0].setLabels(left=("Averaged events (Channel %d)" % post, units)) fit_pen = {'color':(30, 30, 255), 'width':2, 'dash': [1, 1]} for i in range(npulses): # get the chunk of each sweep between spikes responses = [] all_responses.append(responses) for j, sweep in enumerate(sweeps): # get the current spike if i >= len(spikes[j]): continue spike = spikes[j][i] if spike is None: continue # find next spike next_spike = None for sp in spikes[j][i+1:]: if sp is not None: next_spike = sp break # determine time range for response max_len = int(40e-3 / dt) # don't take more than 50ms for any response start = spike['rise_index'] if next_spike is not None: stop = min(start + max_len, next_spike['rise_index']) else: stop = start + max_len # collect data from this trace trace = post_traces[j] d = trace.data[start:stop].copy() responses.append(d) if len(responses) == 0: continue # extend all responses to the same length and take nanmean avg = ragged_mean(responses, method='clip') avg -= float_mode(avg[:int(1e-3/dt)]) avg_responses.append(avg) # plot average response for this pulse start = np.median([sp[i]['rise_index'] for sp in spikes if sp[i] is not None]) * dt t = np.arange(len(avg)) * dt self.response_plots[0,i].plot(t, avg, pen='w', antialias=True) # fit! fit = self.fit_psp(avg, t, dt, post_mode) fits.append(fit) # let the user mess with this fit curve = self.response_plots[0,i].plot(t, fit.eval(), pen=fit_pen, antialias=True).curve curve.setClickable(True) curve.fit = fit curve.sigClicked.connect(self.fit_curve_clicked) # display global average global_avg = ragged_mean(avg_responses, method='clip') t = np.arange(len(global_avg)) * dt self.response_plots[0,-1].plot(t, global_avg, pen='w', antialias=True) global_fit = self.fit_psp(global_avg, t, dt, post_mode) self.response_plots[0,-1].plot(t, global_fit.eval(), pen=fit_pen, antialias=True) # display fit parameters in table events = [] for i,f in enumerate(fits + [global_fit]): if f is None: continue if i >= len(fits): vals = OrderedDict([('id', 'avg'), ('spike_time', np.nan), ('spike_stdev', np.nan)]) else: spt = [s[i]['peak_index'] * dt for s in spikes if s[i] is not None] vals = OrderedDict([('id', i), ('spike_time', np.mean(spt)), ('spike_stdev', np.std(spt))]) vals.update(OrderedDict([(k,f.best_values[k]) for k in f.params.keys()])) events.append(vals) self.current_event_set = (pre, post, events, sweeps) self.event_set_list.setCurrentRow(0) self.event_set_selected()
if __name__ == '__main__': import pyqtgraph as pg plt = pg.plot(labels={'left': ('Vm', 'V'), 'bottom': ('time', 's')}) dt = 10 * us start = 5 * ms duration = 2 * ms pulse_edges = int(start / dt), int((start + duration) / dt) # Iterate over a series of increasing pulse amplitudes for amp in np.arange(50 * pA, 500 * pA, 50 * pA): # Simulate pulse response resp = create_test_pulse(start=start, pamp=amp, pdur=duration, mode='ic', r_access=100 * MOhm) # Test spike detection spike = detect_evoked_spike(resp, pulse_edges) print(spike) pen = 'r' if spike is None else 'g' # plot in green if a spike was detected pri = resp['primary'] pri.t0 = 0 plt.plot(pri.time_values, pri.data, pen=pen) # redraw after every new test pg.QtGui.QApplication.processEvents()