Exemplo n.º 1
0
    def set_plots(self, plt1=None, plt2=None):
        """Connect this detector to two PlotWidgets where data should be displayed.
        
        The first plot will contain the lowpass-filtered trace and tick marks
        for detected events. The second plot will contain the deconvolved signal
        and a draggable threshold line.
        """
        self.sig_plot = plt1
        if plt1 is not None:
            if self.sig_trace is None:
                self.sig_trace = pg.PlotDataItem()
                self.vticks = pg.VTickGroup(yrange=[0.0, 0.05])
            plt1.addItem(self.sig_trace)
            plt1.addItem(self.vticks)

        self.deconv_plot = plt2
        if plt2 is not None:
            if self.deconv_trace is None:
                self.deconv_trace = pg.PlotDataItem()
                self.threshold_line = pg.InfiniteLine(angle=0,
                                                      movable=True,
                                                      pen='g')
                self.threshold_line.setValue(self.params['threshold'])
                self.threshold_line.sigPositionChanged.connect(
                    self._threshold_line_moved)
            plt2.addItem(self.deconv_trace)
            plt2.addItem(self.threshold_line)
Exemplo n.º 2
0
    def show(self):
        self.win = pg.GraphicsWindow()
        
        p5 = self.win.addPlot(title='stim')
        p5.plot(self.stim['left'][0].time * 1000., self.stim['left'][0].sound)

        p1 = self.win.addPlot(title='Bushy Vm', row=1, col=0)
        for k, ear in enumerate(self.ears.keys()):
            p1.plot(self['t'], self['vm_bu_%s' % ear], pen=(k, 15))
        
        p2 = self.win.addPlot(title='SGC-BU xmtr left', row=0, col=1)
        for i in range(30):
            p2.plot(self['t'], self['xmtr%d_left'%i], pen=(i, 15))
        p2.setXLink(p1)
        p2r = self.win.addPlot(title='SGC-BU xmtr right', row=1, col=1)
        for i in range(30):
            p2r.plot(self['t'], self['xmtr%d_right'%i], pen=(i, 15))
        p2r.setXLink(p1)        
        
        p3 = self.win.addPlot(title='MSO Vm', row=2, col=0)
        p3.plot(self['t'], self['vm_mso'])
        p3.setXLink(p1)
        
        p4 = self.win.addPlot(title='BU-MSO xmtr', row=2, col=1)
        for k, ear in enumerate(self.ears.keys()):
            for i in range(30):
                p2.plot(self['t'], self['mso_xmtr%d_%s'%(i, ear)], pen=(i, 15))
        p4.setXLink(p1)        

        p4 = self.win.addPlot(title='AN spikes', row=3, col=0)
        ntrain = len(self.sgc_cells['left'])
        for k in range(ntrain):
            yr = [k/float(ntrain), (k+0.8)/float(ntrain)]
            vt = pg.VTickGroup(self.sgc_cells['left'][k]._spiketrain, yrange = yr, pen=(k, 15))
            p4.addItem(vt)
        p4.setXLink(p1)
        p5.setXLink(p1)
        
        # phaselocking calculations
        phasewin = [self.stimdelay + 0.2*self.stimdur, self.stimdelay + self.stimdur]
        msospk = PU.findspikes(self['t'], self['vm_mso'], -30.)

        spkin = msospk[np.where(msospk > phasewin[0]*1e3)]
        spikesinwin = spkin[np.where(spkin <= phasewin[1]*1e3)[0]]

        # set freq for VS calculation
        f0 = self.f0
        fb = self.beatfreq
        vs = PU.vector_strength(spikesinwin, f0)

        print 'MSO Vector Strength at %.1f: %7.3f, d=%.2f (us) Rayleigh: %7.3f  p = %.3e  n = %d' % (f0, vs['r'], vs['d']*1e6, vs['R'], vs['p'], vs['n'])
        if fb > 0:
            vsb = PU.vector_strength(spikesinwin, fb)
            print 'MSO Vector Strength to beat at %.1f: %7.3f, d=%.2f (us) Rayleigh: %7.3f  p = %.3e  n = %d' % (fb, vsb['r'], vsb['d']*1e6, vsb['R'], vsb['p'], vsb['n'])
        (hist, binedges) = np.histogram(vs['ph'])
        p6 = self.win.addPlot(title='VS', row=3, col=1)
        p6.plot(binedges, hist, stepMode=True, fillBrush=(100, 100, 255, 150), fillLevel=0)
        p6.setXRange(0., 2*np.pi)
        
        self.win.show()
Exemplo n.º 3
0
    def show(self):
        self.win = pg.GraphicsWindow()

        p1 = self.win.addPlot(title='Bushy Vm')
        p1.plot(self['t'], self['vm'])
        p2 = self.win.addPlot(title='xmtr', row=1, col=0)
        for i in range(30):
            p2.plot(self['t'], self['xmtr%d' % i], pen=(i, 15))
        p2.setXLink(p1)

        p3 = self.win.addPlot(title='AN spikes', row=2, col=0)
        vt = pg.VTickGroup(self.pre_cell._spiketrain)
        p3.addItem(vt)
        p3.setXLink(p1)

        p4 = self.win.addPlot(title='stim', row=3, col=0)
        p4.plot(self.stim.time * 1000, self.stim.sound)
        p4.setXLink(p1)
        self.win.show()
Exemplo n.º 4
0
    on_times = np.argwhere(diff > 0)[:, 0]
    off_times = np.argwhere(diff < 0)[:, 0]

    # decide on the region of the trace to focus on
    start = on_times[1] - 1000
    stop = off_times[8] + 1000
    chunk = trace[start:stop]

    # plot the selected chunk
    t = np.arange(chunk.shape[0]) * dt
    plot.plot(t[:-1], np.diff(ndi.gaussian_filter(chunk, sigma)), pen=0.5)
    plot.plot(t, chunk)

    # detect spike times
    peak_inds = []
    rise_inds = []
    for j in range(8):  # loop over pulses
        pstart = on_times[j + 1] - start
        pstop = off_times[j + 1] - start
        spike_info = detect_vc_evoked_spike(Trace(chunk, dt=dt),
                                            pulse_edges=(pstart, pstop))
        if spike_info is not None:
            peak_inds.append(spike_info['peak_index'])
            rise_inds.append(spike_info['rise_index'])

    # display spike rise and peak times as ticks
    pticks = pg.VTickGroup(np.array(peak_inds) * dt, yrange=[0, 0.3], pen='r')
    rticks = pg.VTickGroup(np.array(rise_inds) * dt, yrange=[0, 0.3], pen='y')
    plot.addItem(pticks)
    plot.addItem(rticks)
Exemplo n.º 5
0
    def _update_plots(self):
        sweeps = self.sweeps
        self.current_event_set = None
        self.event_table.clear()

        # clear all plots
        self.pre_plot.clear()
        self.post_plot.clear()

        pre = self.params['pre']
        post = self.params['post']

        # If there are no selected sweeps or channels have not been set, return
        if len(
                sweeps
        ) == 0 or pre == post or pre not in self.channels or post not in self.channels:
            return

        pre_mode = sweeps[0][pre].clamp_mode
        post_mode = sweeps[0][post].clamp_mode
        for ch, mode, plot in [(pre, pre_mode, self.pre_plot),
                               (post, post_mode, self.post_plot)]:
            units = 'A' if mode == 'vc' else 'V'
            plot.setLabels(left=("Channel %d" % ch, units),
                           bottom=("Time", 's'))

        # Iterate over selected channels of all sweeps, plotting traces one at a time
        # Collect information about pulses and spikes
        pulses = []
        spikes = []
        post_traces = []
        for i, sweep in enumerate(sweeps):
            pre_trace = sweep[pre]['primary']
            post_trace = sweep[post]['primary']

            # Detect pulse times
            stim = sweep[pre]['command'].data
            sdiff = np.diff(stim)
            on_times = np.argwhere(sdiff > 0)[1:, 0]  # 1: skips test pulse
            off_times = np.argwhere(sdiff < 0)[1:, 0]
            pulses.append(on_times)

            # filter data
            post_filt = self.artifact_remover.process(
                post_trace,
                list(on_times) + list(off_times))
            post_filt = self.baseline_remover.process(post_filt)
            post_filt = self.filter.process(post_filt)
            post_traces.append(post_filt)

            # plot raw data
            color = pg.intColor(i, hues=len(sweeps) * 1.3, sat=128)
            color.setAlpha(128)
            for trace, plot in [(pre_trace, self.pre_plot),
                                (post_filt, self.post_plot)]:
                plot.plot(trace.time_values,
                          trace.data,
                          pen=color,
                          antialias=False)

            # detect spike times
            spike_inds = []
            spike_info = []
            for on, off in zip(on_times, off_times):
                spike = detect_evoked_spike(sweep[pre], [on, off])
                spike_info.append(spike)
                if spike is None:
                    spike_inds.append(None)
                else:
                    spike_inds.append(spike['rise_index'])
            spikes.append(spike_info)

            dt = pre_trace.dt
            vticks = pg.VTickGroup(
                [x * dt for x in spike_inds if x is not None],
                yrange=[0.0, 0.2],
                pen=color)
            self.pre_plot.addItem(vticks)

        # Iterate over spikes, plotting average response
        all_responses = []
        avg_responses = []
        fits = []
        fit = None

        npulses = max(map(len, pulses))
        self.response_plots.clear()
        self.response_plots.set_shape(1, npulses +
                                      1)  # 1 extra for global average
        self.response_plots.setYLink(self.response_plots[0, 0])
        for i in range(1, npulses + 1):
            self.response_plots[0, i].hideAxis('left')
        units = 'A' if post_mode == 'vc' else 'V'
        self.response_plots[0,
                            0].setLabels(left=("Averaged events (Channel %d)" %
                                               post, units))

        fit_pen = {'color': (30, 30, 255), 'width': 2, 'dash': [1, 1]}
        for i in range(npulses):
            # get the chunk of each sweep between spikes
            responses = []
            all_responses.append(responses)
            for j, sweep in enumerate(sweeps):
                # get the current spike
                if i >= len(spikes[j]):
                    continue
                spike = spikes[j][i]
                if spike is None:
                    continue

                # find next spike
                next_spike = None
                for sp in spikes[j][i + 1:]:
                    if sp is not None:
                        next_spike = sp
                        break

                # determine time range for response
                max_len = int(40e-3 /
                              dt)  # don't take more than 50ms for any response
                start = spike['rise_index']
                if next_spike is not None:
                    stop = min(start + max_len, next_spike['rise_index'])
                else:
                    stop = start + max_len

                # collect data from this trace
                trace = post_traces[j]
                d = trace.data[start:stop].copy()
                responses.append(d)

            if len(responses) == 0:
                continue

            # extend all responses to the same length and take nanmean
            avg = ragged_mean(responses, method='clip')
            avg -= float_mode(avg[:int(1e-3 / dt)])
            avg_responses.append(avg)

            # plot average response for this pulse
            start = np.median(
                [sp[i]['rise_index']
                 for sp in spikes if sp[i] is not None]) * dt
            t = np.arange(len(avg)) * dt
            self.response_plots[0, i].plot(t, avg, pen='w', antialias=True)

            # fit!
            fit = self.fit_psp(avg, t, dt, post_mode)
            fits.append(fit)

            # let the user mess with this fit
            curve = self.response_plots[0, i].plot(t,
                                                   fit.eval(),
                                                   pen=fit_pen,
                                                   antialias=True).curve
            curve.setClickable(True)
            curve.fit = fit
            curve.sigClicked.connect(self.fit_curve_clicked)

        # display global average
        global_avg = ragged_mean(avg_responses, method='clip')
        t = np.arange(len(global_avg)) * dt
        self.response_plots[0, -1].plot(t, global_avg, pen='w', antialias=True)
        global_fit = self.fit_psp(global_avg, t, dt, post_mode)
        self.response_plots[0, -1].plot(t,
                                        global_fit.eval(),
                                        pen=fit_pen,
                                        antialias=True)

        # display fit parameters in table
        events = []
        for i, f in enumerate(fits + [global_fit]):
            if f is None:
                continue
            if i >= len(fits):
                vals = OrderedDict([('id', 'avg'), ('spike_time', np.nan),
                                    ('spike_stdev', np.nan)])
            else:
                spt = [
                    s[i]['peak_index'] * dt for s in spikes if s[i] is not None
                ]
                vals = OrderedDict([('id', i), ('spike_time', np.mean(spt)),
                                    ('spike_stdev', np.std(spt))])
            vals.update(
                OrderedDict([(k, f.best_values[k]) for k in f.params.keys()]))
            events.append(vals)

        self.current_event_set = (pre, post, events, sweeps)
        self.event_set_list.setCurrentRow(0)
        self.event_set_selected()
Exemplo n.º 6
0
# test_model_class = SphereIntersectionModel   # doesn't work well with the current minimization algo
# test_model_class = LinearModel   # doesn't work well with the current minimization algo
# test_model_class = GaussianModel

# How many connections probed per experiment
n_probes = 100

plt = pg.plot(labels={'bottom': ('distance', 'm')})

# model distance sampling as lognormal
x_probed = np.random.lognormal(size=n_probes, sigma=.6, mean=np.log(150e-6))
x_bins = np.arange(0, 500e-6, 40e-6)
x_vals = 0.5 * (x_bins[1:] + x_bins[:-1])

# plot connections probed
probed_ticks = pg.VTickGroup(x_probed, [0, 0.05], pen=(255, 255, 255, 128))
plt.addItem(probed_ticks)

# plot the ground-truth probability distribution (solid green)
plt.plot(x_vals,
         true_model.pdf(x_vals),
         pen={
             'width': 2,
             'color': (0, 255, 0, 100)
         })

# run the experiment (measure connectivity at the chosen distances)
conn = true_model.generate(x_probed)

# plot ticks for connected pairs
conn_ticks = pg.VTickGroup(x_probed[conn], [0, 0.1], pen='w')
Exemplo n.º 7
0
    def _update_plots(self, autoRange=False):
        sweeps = self.sweeps
        chans = self.channels
        self.plots.clear()
        if len(sweeps) == 0 or len(chans) == 0:
            return

        # collect data
        data = MiesNwb.pack_sweep_data(sweeps)
        data, stim = data[..., 0], data[..., 1]  # unpack stim and recordings
        dt = sweeps[0].traces().values()[0].sample_rate / 1000.

        # mask for selected channels
        mask = np.array([ch in chans for ch in sweeps[0].channels()])
        data = data[:, mask]
        stim = stim[:, mask]
        chans = np.array(sweeps[0].channels())[mask]

        modes = [sweeps[0].traces()[ch].meta()['Clamp Mode'] for ch in chans]
        headstages = [sweeps[0].traces()[ch].headstage_id for ch in chans]

        # get pulse times for each channel
        stim = stim[0]
        diff = stim[:, 1:] - stim[:, :-1]
        # note: the [1:] here skips the test pulse
        on_times = [
            np.argwhere(diff[i] > 0)[1:, 0] for i in range(diff.shape[0])
        ]
        off_times = [
            np.argwhere(diff[i] < 0)[1:, 0] for i in range(diff.shape[0])
        ]

        # remove capacitive artifacts from adjacent electrodes
        if self.params['remove artifacts']:
            npts = int(self.params['remove artifacts', 'window'] / dt)
            for i in range(stim.shape[0]):
                for j in range(stim.shape[0]):
                    if i == j:
                        continue

                    # are these headstages adjacent?
                    hs1, hs2 = headstages[i], headstages[j]
                    if abs(hs2 - hs1) > 3:
                        continue

                    # remove artifacts
                    for k in range(len(on_times[i])):
                        on = on_times[i][k]
                        off = off_times[i][k]
                        data[:, j,
                             on:on + npts] = data[:, j,
                                                  max(0, on - npts):on].mean(
                                                      axis=1)[:, None]
                        data[:, j, off:off +
                             npts] = data[:, j,
                                          max(0, off -
                                              npts):off].mean(axis=1)[:, None]

        # lowpass filter
        if self.params['lowpass']:
            data = gaussian_filter(
                data, (0, 0, self.params['lowpass', 'sigma'] / dt))

        # prepare to plot
        window = int(self.params['window'] / dt)
        n_sweeps = data.shape[0]
        n_channels = data.shape[1]
        self.plots.set_shape(n_channels, n_channels)
        self.plots.setClipToView(True)
        self.plots.setDownsampling(True, True, 'peak')
        self.plots.enableAutoRange(False, False)

        show_sweeps = 'sweeps' in self.params['show']
        show_sweep_avg = 'sweep avg' in self.params['show']
        show_pulse_avg = self.params['show'] == 'pulse avg'

        for i in range(n_channels):
            for j in range(n_channels):
                plt = self.plots[i, j]
                start = on_times[j][self.params['first pulse']] - window
                if start < 0:
                    frontpad = -start
                    start = 0
                else:
                    frontpad = 0
                stop = on_times[j][self.params['last pulse']] + window

                # select the data segment to be displayed in this matrix cell
                # add padding if necessary
                if frontpad == 0:
                    seg = data[:, i, start:stop].copy()
                else:
                    seg = np.empty((data.shape[0], stop + frontpad),
                                   data.dtype)
                    seg[:, frontpad:] = data[:, i, start:stop]
                    seg[:, :frontpad] = seg[:, frontpad:frontpad + 1]

                # subtract off baseline for each sweep
                if self.params['remove baseline']:
                    seg -= seg[:, :window].mean(axis=1)[:, None]

                if show_sweeps:
                    alpha = 100 if show_sweep_avg else 200
                    color = (255, 255, 255, alpha)
                    t = np.arange(seg.shape[1]) * dt
                    for k in range(n_sweeps):
                        plt.plot(t,
                                 seg[k],
                                 pen={
                                     'color': color,
                                     'width': 1
                                 },
                                 antialias=True)

                if show_sweep_avg or show_pulse_avg:
                    # average selected segments over all sweeps
                    segm = seg.mean(axis=0)

                    if show_pulse_avg:
                        # average over all selected pulses
                        pulses = []
                        for k in range(self.params['first pulse'],
                                       self.params['last pulse'] + 1):
                            pstart = on_times[j][k] - on_times[j][
                                self.params['first pulse']]
                            pstop = pstart + (window * 2)
                            pulses.append(segm[pstart:pstop])
                        # for p in pulses:
                        #     t = np.arange(p.shape[0]) * dt
                        #     plt.plot(t, p)
                        segm = np.vstack(pulses).mean(axis=0)

                    t = np.arange(segm.shape[0]) * dt

                    if i == j:
                        color = (80, 80, 80)
                    else:
                        dif = segm - segm[:window].mean()
                        qe = 30 * np.clip(dif, 0,
                                          1e20).mean() / segm[:window].std()
                        qi = 30 * np.clip(-dif, 0,
                                          1e20).mean() / segm[:window].std()
                        if modes[i] == 1:
                            qi, qe = qe, qi  # invert color metric for current clamp
                        g = 100
                        r = np.clip(g + max(qi, 0), 0, 255)
                        b = np.clip(g + max(qe, 0), 0, 255)
                        color = (r, g, b)

                    plt.plot(t,
                             segm,
                             pen={
                                 'color': color,
                                 'width': 1
                             },
                             antialias=True)

                if self.params['show ticks']:
                    vt = pg.VTickGroup((on_times[j] - start) * dt, [0, 0.15],
                                       pen=0.4)
                    plt.addItem(vt)

                # Link all plots along x axis
                plt.setXLink(self.plots[0, 0])

                if i == j:
                    # link y axes of all diagonal plots
                    plt.setYLink(self.plots[0, 0])
                else:
                    # link y axes of all plots within a row
                    plt.setYLink(
                        self.plots[i, (i + 1) %
                                   2])  # (i+1)%2 just avoids linking to 0,0

                if i < n_channels - 1:
                    plt.getAxis('bottom').setVisible(False)
                if j > 0:
                    plt.getAxis('left').setVisible(False)

                if i == n_channels - 1:
                    plt.setLabels(
                        bottom=('CH%d' %
                                sweeps[0].traces()[chans[j]].headstage_id,
                                's'))
                if j == 0:
                    plt.setLabels(
                        left=('CH%d' %
                              sweeps[0].traces()[chans[i]].headstage_id,
                              'A' if modes[i] == 0 else 'V'))

        if autoRange:
            r = 14e-12 if modes[i] == 0 else 5e-3
            self.plots[0, 1].setYRange(-r, r)
            r = 2e-9 if modes[i] == 0 else 100e-3
            self.plots[0, 0].setYRange(-r, r)

            self.plots[0, 0].setXRange(t[0], t[-1])
    def show(self):
        self.win = pg.GraphicsWindow()
        p1 = self.win.addPlot(title='stim', row=0, col=0)
        p1.plot(self.stim.time * 1000, self.stim.sound)
        p1.setXLink(p1)

        p2 = self.win.addPlot(title='AN spikes', row=1, col=0)
        vt = pg.VTickGroup(self.pre_cells[0]._spiketrain)
        p2.addItem(vt)
        p2.setXLink(p1)

        p3 = self.win.addPlot(title='%s Spikes' % self.cell, row=2, col=0)
        bspk = PU.findspikes(self['t'], self['vm'], -30.)
        bspktick = pg.VTickGroup(bspk)
        p3.addItem(bspktick)
        p3.setXLink(p1)

        p4 = self.win.addPlot(title='%s Vm' % self.cell, row=3, col=0)
        p4.plot(self['t'], self['vm'])
        p4.setXLink(p1)

        p5 = self.win.addPlot(title='xmtr', row=0, col=1)
        j = 0
        for k in range(self.n_sgc):
            synapse = self.synapses[k]
            for i in range(synapse.terminal.n_rzones):
                p5.plot(self['t'], self['xmtr%03d'%j], pen=(i, 15))
                j = j + 1
        p5.setXLink(p1)
        
        p6 = self.win.addPlot(title='AN phase', row=1, col=1)
        phasewin = [self.pip_start[0] + 0.25*self.pip_duration, self.pip_start[0] + self.pip_duration]
        prespk = self.pre_cells[0]._spiketrain  # just sample one...
        spkin = prespk[np.where(prespk > phasewin[0]*1e3)]
        spikesinwin = spkin[np.where(spkin <= phasewin[1]*1e3)]
        print "\nCell type: %s" % self.cell
        print "Stimulus: "

        # set freq for VS calculation
        if self.stimulus == 'tone':
            f0 = self.f0
            print "Tone: f0=%.3f at %3.1f dbSPL, cell CF=%.3f" % (self.f0, self.dbspl, self.cf)
        if self.stimulus == 'SAM':
            f0 = self.fMod
            print ("SAM Tone: f0=%.3f at %3.1f dbSPL, fMod=%3.1f  dMod=%5.2f, cell CF=%.3f" %
                 (self.f0, self.dbspl, self.fMod, self.dMod, self.cf))
        if self.stimulus == 'clicks':
            f0 = 1./self.click_rate
            print "Clicks: interval %.3f at %3.1f dbSPL, cell CF=%.3f " % (self.click_rate, self.dbspl, self.cf)
        vs = PU.vector_strength(spikesinwin, f0)
        
        print 'AN Vector Strength at %.1f: %7.3f, d=%.2f (us) Rayleigh: %7.3f  p = %.3e  n = %d' % (f0, vs['r'], vs['d']*1e6, vs['R'], vs['p'], vs['n'])
        (hist, binedges) = np.histogram(vs['ph'])
        p6.plot(binedges, hist, stepMode=True, fillBrush=(100, 100, 255, 150), fillLevel=0)
        p6.setXRange(0., 2*np.pi)

        p7 = self.win.addPlot(title='%s phase' % self.cell, row=2, col=1)
        spkin = bspk[np.where(bspk > phasewin[0]*1e3)]
        spikesinwin = spkin[np.where(spkin <= phasewin[1]*1e3)]
        vs = PU.vector_strength(spikesinwin, f0)
        print '%s Vector Strength: %7.3f, d=%.2f (us) Rayleigh: %7.3f  p = %.3e  n = %d' % (self.cell, vs['r'], vs['d']*1e6, vs['R'], vs['p'], vs['n'])
        (hist, binedges) = np.histogram(vs['ph'])
        p7.plot(binedges, hist, stepMode=True, fillBrush=(100, 100, 255, 150), fillLevel=0)
        p7.setXRange(0., 2*np.pi)
        p7.setXLink(p6)

        self.win.show()