def test_model_spike_detection():
    # Need to fill this function up with many more tests, especially
    # measuring against real data.
    dt = 10 * us
    start = 5 * ms
    duration = 2 * ms

    resp = create_test_pulse(start=5 * ms,
                             pamp=100 * pA,
                             pdur=2 * ms,
                             mode='ic',
                             dt=dt)
    pulse_edges = resp['primary'].t0 + start, resp[
        'primary'].t0 + start + duration
    spikes = detect_evoked_spikes(resp, pulse_edges)
    assert len(spikes) == 0

    resp = create_test_pulse(start=5 * ms,
                             pamp=1000 * pA,
                             pdur=2 * ms,
                             mode='ic',
                             dt=dt)
    pulse_edges = resp['primary'].t0 + start, resp[
        'primary'].t0 + start + duration
    spikes = detect_evoked_spikes(resp, pulse_edges)
    assert len(spikes) == 1
    def test_pulse(amp, ra):
        # Simulate pulse response
        resp = create_test_pulse(start=start,
                                 pamp=amp,
                                 pdur=duration,
                                 mode='ic',
                                 r_access=ra)

        # Test spike detection
        pri = resp['primary']
        pri.t0 = 0
        spikes = detect_evoked_spikes(resp, pulse_edges, ui=ui)
        print(spikes)
        pen = ['r', 'y', 'g', 'b'][len(spikes)]

        # plot in green if a spike was detected
        plt.plot(pri.time_values, pri.data, pen=pen)
Ejemplo n.º 3
0
    def evoked_spikes(self):
        """Given presynaptic Recording, detect action potentials
        evoked by current injection or unclamped spikes evoked by a voltage pulse.

        Returns
        -------
        spikes : list
            [{'pulse_n', 'pulse_start', 'pulse_end', 'spikes': [...]}, ...]
        """
        if self._evoked_spikes is None:
            spike_info = []
            for i, chunk in enumerate(self.pulse_chunks()):
                pulse_edges = chunk.meta['pulse_edges']
                spikes = detect_evoked_spikes(chunk, pulse_edges)
                spike_info.append({
                    'pulse_n': chunk.meta['pulse_n'],
                    'pulse_start': pulse_edges[0],
                    'pulse_end': pulse_edges[1],
                    'spikes': spikes
                })
            self._evoked_spikes = spike_info
        return self._evoked_spikes
def load_next():
    global all_pulses, ui, last_result
    try:
        (expt_id, cell_id, sweep, channel, chunk) = next(all_pulses)
    except StopIteration:
        ui.widget.hide()
        return

    # run spike detection on each chunk
    pulse_edges = chunk.meta['pulse_edges']
    spikes = detect_evoked_spikes(chunk, pulse_edges, ui=ui)
    ui.show_result(spikes)

    # copy just the necessary parts of recording data for export to file
    export_chunk = PatchClampRecording(
        channels={
            k: TSeries(chunk[k].data,
                       t0=chunk[k].t0,
                       sample_rate=chunk[k].sample_rate)
            for k in chunk.channels
        })
    export_chunk.meta.update(chunk.meta)

    # construct test case
    tc = SpikeDetectTestCase()
    tc._meta = {
        'expt_id': expt_id,
        'cell_id': cell_id,
        'device_id': channel,
        'sweep_id': sweep.key,
    }
    tc._input_args = {
        'data': export_chunk,
        'pulse_edges': chunk.meta['pulse_edges'],
    }
    last_result = tc
Ejemplo n.º 5
0
    def _update_plots(self):
        sweeps = self.sweeps
        self.current_event_set = None
        self.event_table.clear()
        
        # clear all plots
        self.pre_plot.clear()
        self.post_plot.clear()

        pre = self.params['pre']
        post = self.params['post']
        
        # If there are no selected sweeps or channels have not been set, return
        if len(sweeps) == 0 or pre == post or pre not in self.channels or post not in self.channels:
            return

        pre_mode = sweeps[0][pre].clamp_mode
        post_mode = sweeps[0][post].clamp_mode
        for ch, mode, plot in [(pre, pre_mode, self.pre_plot), (post, post_mode, self.post_plot)]:
            units = 'A' if mode == 'vc' else 'V'
            plot.setLabels(left=("Channel %d" % ch, units), bottom=("Time", 's'))
        
        # Iterate over selected channels of all sweeps, plotting traces one at a time
        # Collect information about pulses and spikes
        pulses = []
        spikes = []
        post_traces = []
        for i,sweep in enumerate(sweeps):
            pre_trace = sweep[pre]['primary']
            post_trace = sweep[post]['primary']
            
            # Detect pulse times
            stim = sweep[pre]['command'].data
            sdiff = np.diff(stim)
            on_times = np.argwhere(sdiff > 0)[1:, 0]  # 1: skips test pulse
            off_times = np.argwhere(sdiff < 0)[1:, 0]
            pulses.append(on_times)

            # filter data
            post_filt = self.artifact_remover.process(post_trace, list(on_times) + list(off_times))
            post_filt = self.baseline_remover.process(post_filt)
            post_filt = self.filter.process(post_filt)
            post_traces.append(post_filt)
            
            # plot raw data
            color = pg.intColor(i, hues=len(sweeps)*1.3, sat=128)
            color.setAlpha(128)
            for trace, plot in [(pre_trace, self.pre_plot), (post_filt, self.post_plot)]:
                plot.plot(trace.time_values, trace.data, pen=color, antialias=False)

            # detect spike times
            spike_times = []
            spike_info = []
            for on, off in zip(on_times, off_times):
                spikes = detect_evoked_spikes(sweep[pre], [on, off])
                spike_info.append(spikes)
                for spike in spikes:
                    if spike['max_slope_time'] is None:
                        continue
                    spike_times.append(spike['max_slope_time'])
            spikes.append(spike_info)
                    
            dt = pre_trace.dt
            vticks = pg.VTickGroup(spike_times, yrange=[0.0, 0.2], pen=color)
            self.pre_plot.addItem(vticks)

        # Iterate over spikes, plotting average response
        all_responses = []
        avg_responses = []
        fits = []
        fit = None
        
        npulses = max(map(len, pulses))
        self.response_plots.clear()
        self.response_plots.set_shape(1, npulses+1) # 1 extra for global average
        self.response_plots.setYLink(self.response_plots[0,0])
        for i in range(1, npulses+1):
            self.response_plots[0,i].hideAxis('left')
        units = 'A' if post_mode == 'vc' else 'V'
        self.response_plots[0, 0].setLabels(left=("Averaged events (Channel %d)" % post, units))
        
        fit_pen = {'color':(30, 30, 255), 'width':2, 'dash': [1, 1]}
        for i in range(npulses):
            # get the chunk of each sweep between spikes
            responses = []
            all_responses.append(responses)
            for j, sweep in enumerate(sweeps):
                # get the current spike
                if i >= len(spikes[j]):
                    continue
                spike = spikes[j][i]
                if spike is None:
                    continue
                
                # find next spike
                next_spike = None
                for sp in spikes[j][i+1:]:
                    if sp is not None:
                        next_spike = sp
                        break
                    
                # determine time range for response
                max_len = int(40e-3 / dt)  # don't take more than 50ms for any response
                start = spike['rise_index']
                if next_spike is not None:
                    stop = min(start + max_len, next_spike['rise_index'])
                else:
                    stop = start + max_len
                    
                # collect data from this trace
                trace = post_traces[j]
                d = trace.data[start:stop].copy()
                responses.append(d)

            if len(responses) == 0:
                continue
                
            # extend all responses to the same length and take nanmean
            avg = ragged_mean(responses, method='clip')
            avg -= float_mode(avg[:int(1e-3/dt)])
            avg_responses.append(avg)
            
            # plot average response for this pulse
            start = np.median([sp[i]['rise_index'] for sp in spikes if sp[i] is not None]) * dt
            t = np.arange(len(avg)) * dt
            self.response_plots[0,i].plot(t, avg, pen='w', antialias=True)

            # fit!
            fit = self.fit_psp(avg, t, dt, post_mode)
            fits.append(fit)
            
            # let the user mess with this fit
            curve = self.response_plots[0,i].plot(t, fit.eval(), pen=fit_pen, antialias=True).curve
            curve.setClickable(True)
            curve.fit = fit
            curve.sigClicked.connect(self.fit_curve_clicked)
            
        # display global average
        global_avg = ragged_mean(avg_responses, method='clip')
        t = np.arange(len(global_avg)) * dt
        self.response_plots[0,-1].plot(t, global_avg, pen='w', antialias=True)
        global_fit = self.fit_psp(global_avg, t, dt, post_mode)
        self.response_plots[0,-1].plot(t, global_fit.eval(), pen=fit_pen, antialias=True)
            
        # display fit parameters in table
        events = []
        for i,f in enumerate(fits + [global_fit]):
            if f is None:
                continue
            if i >= len(fits):
                vals = OrderedDict([('id', 'avg'), ('spike_time', np.nan), ('spike_stdev', np.nan)])
            else:
                spt = [s[i]['peak_index'] * dt for s in spikes if s[i] is not None]
                vals = OrderedDict([('id', i), ('spike_time', np.mean(spt)), ('spike_stdev', np.std(spt))])
            vals.update(OrderedDict([(k,f.best_values[k]) for k in f.params.keys()]))
            events.append(vals)
            
        self.current_event_set = (pre, post, events, sweeps)
        self.event_set_list.setCurrentRow(0)
        self.event_set_selected()