def _initialize(self, **kargs):
        QOscilloscope._initialize(self, **kargs)

        self.inputs['spikes'].set_buffer(size=self.peak_buffer_size,
                                         double=False)

        # poller onpeak
        self._last_peak = 0
        self.poller_peak = ThreadPollInput(input_stream=self.inputs['spikes'],
                                           return_data=True)
        self.poller_peak.new_data.connect(self._on_new_peak)

        self.spikes_array = self.inputs['spikes'].buffer.buffer

        self.scatters = {}
        self.change_catalogue(self.catalogue)

        self.params['xsize'] = 1.
        self.params['decimation_method'] = 'min_max'
        self.params['mode'] = 'scan'
        self.params['scale_mode'] = 'same_for_all'
        self.params['display_labels'] = True

        self.timer_scale = QT.QTimer(singleShot=True, interval=500)
        self.timer_scale.timeout.connect(self.auto_scale)
        self.timer_scale.start()
 def __init__(self, input_stream, output_streams, peeler,in_group_channels,
                     timeout = 200, parent = None):
     
     ThreadPollInput.__init__(self, input_stream,  timeout=timeout, return_data=True, parent = parent)
     self.output_streams = output_streams
     self.peeler = peeler
     self.in_group_channels = in_group_channels
     
     self.sample_rate = input_stream.params['sample_rate']
     self.total_channel = self.input_stream().params['shape'][1]
Exemple #3
0
 def __init__(self, input_stream, output_streams, peeler,in_group_channels,
                     timeout = 200, parent = None):
     
     ThreadPollInput.__init__(self, input_stream,  timeout=timeout, return_data=True, parent = parent)
     self.output_streams = output_streams
     self.peeler = peeler
     self.in_group_channels = in_group_channels
     
     self.sample_rate = input_stream.params['sample_rate']
     self.total_channel = self.input_stream().params['shape'][1]
     
     self.mutex = Mutex()
    def _initialize(self, **kargs):

        self.sample_rate = self.inputs['signals'].params['sample_rate']
        self.wf_dtype = self.inputs['signals'].params['dtype']

        self.inputs['spikes'].set_buffer(size=self.peak_buffer_size,
                                         double=False)
        buffer_sigs_size = int(self.sample_rate * 3.)
        self.inputs['signals'].set_buffer(size=buffer_sigs_size, double=False)

        # poller
        self.poller_sigs = ThreadPollInput(input_stream=self.inputs['signals'],
                                           return_data=True)
        self.poller_spikes = ThreadPollInput(
            input_stream=self.inputs['spikes'], return_data=True)

        self.histogram_2d = {}
        self.last_waveform = {}
        self.change_catalogue(self.catalogue)

        self.timer = QT.QTimer(interval=100)
        self.timer.timeout.connect(self.refresh)
 def _initialize(self, **kargs):
     QOscilloscope._initialize(self, **kargs)
     
     self.inputs['spikes'].set_buffer(size=self.peak_buffer_size, double=False)
     
     # poller onpeak
     self._last_peak = 0
     self.poller_peak = ThreadPollInput(input_stream=self.inputs['spikes'], return_data=True)
     self.poller_peak.new_data.connect(self._on_new_peak)
     
     self.spikes_array = self.inputs['spikes'].buffer.buffer
     
     self._default_color = QtGui.QColor('#FFFFFF')#TODO
     self.scatters = {}
     for k in self.catalogue['cluster_labels']:
         color = self.catalogue['cluster_colors'].get(k, (1,1,1))
         r, g, b = color
         qcolor = QtGui.QColor(r*255, g*255, b*255)
         qcolor.setAlpha(150)
         scatter = pg.ScatterPlotItem(x=[ ], y= [ ], pen=None, brush=qcolor, size=10, pxMode = True)
         self.scatters[k] = scatter
         self.plot.addItem(scatter)
class OnlineWaveformHistViewer(WidgetNode):

    _input_specs = {
        'signals': dict(streamtype='signals'),
        'spikes': dict(streamtype='events', shape=(-1, ), dtype=_dtype_spike),
    }

    _params = [
        {
            'name': 'colormap',
            'type': 'list',
            'values': [
                'hot',
                'viridis',
                'jet',
                'gray',
            ]
        },
        {
            'name': 'bin_min',
            'type': 'float',
            'value': -20.
        },
        {
            'name': 'bin_max',
            'type': 'float',
            'value': 8.
        },
        {
            'name': 'bin_size',
            'type': 'float',
            'value': .1
        },
        {
            'name': 'refresh_interval',
            'type': 'int',
            'value': 100,
            'limits': [5, 1000]
        },
    ]

    def __init__(self, **kargs):
        WidgetNode.__init__(self, **kargs)

        self.layout = QT.QVBoxLayout()
        self.setLayout(self.layout)

        h = QT.QHBoxLayout()
        self.layout.addLayout(h)

        self.combobox = QT.QComboBox()
        h.addWidget(self.combobox)

        but = QT.QPushButton('clear')
        h.addWidget(but)
        but.clicked.connect(self.on_clear)

        self.label = QT.QLabel('')
        h.addWidget(self.label)

        self.graphicsview = pg.GraphicsView()
        self.layout.addWidget(self.graphicsview)

        self.params = pg.parametertree.Parameter.create(name='settings',
                                                        type='group',
                                                        children=self._params)
        self.tree_params = pg.parametertree.ParameterTree(parent=self)
        self.tree_params.header().hide()
        self.tree_params.setParameters(self.params, showTop=True)
        self.tree_params.setWindowTitle('Options for waveforms hist viewer')
        self.tree_params.setWindowFlags(QT.Qt.Window)
        self.params.sigTreeStateChanged.connect(self.on_params_changed)

        self.initialize_plot()

        self.mutex = Mutex()

    def _configure(self, peak_buffer_size=100000, catalogue=None, **kargs):
        self.peak_buffer_size = peak_buffer_size
        self.catalogue = catalogue

    def _initialize(self, **kargs):

        self.sample_rate = self.inputs['signals'].params['sample_rate']
        self.wf_dtype = self.inputs['signals'].params['dtype']

        self.inputs['spikes'].set_buffer(size=self.peak_buffer_size,
                                         double=False)
        buffer_sigs_size = int(self.sample_rate * 3.)
        self.inputs['signals'].set_buffer(size=buffer_sigs_size, double=False)

        # poller
        self.poller_sigs = ThreadPollInput(input_stream=self.inputs['signals'],
                                           return_data=True)
        self.poller_spikes = ThreadPollInput(
            input_stream=self.inputs['spikes'], return_data=True)

        self.histogram_2d = {}
        self.last_waveform = {}
        self.change_catalogue(self.catalogue)

        self.timer = QT.QTimer(interval=100)
        self.timer.timeout.connect(self.refresh)

    def _start(self, **kargs):
        self.last_head_sigs = None
        self.last_head_spikes = None
        self.timer.start()
        self.inputs['signals'].empty_queue()
        self.inputs['spikes'].empty_queue()
        self.poller_sigs.start()
        self.poller_spikes.start()

    def _stop(self, **kargs):
        self.timer.stop()
        self.poller_sigs.stop()
        self.poller_sigs.wait()
        self.poller_spikes.stop()
        self.poller_spikes.wait()

    def _close(self, **kargs):
        pass

    def open_settings(self):
        if not self.tree_params.isVisible():
            self.tree_params.show()
        else:
            self.tree_params.hide()

    def on_params_changed(self, params, changes):
        self.change_lut()
        self.change_catalogue(self.catalogue)
        self.timer.setInterval(self.params['refresh_interval'])

    def initialize_plot(self):

        self.viewBox = MyViewBox()
        self.viewBox.doubleclicked.connect(self.open_settings)
        self.viewBox.gain_zoom.connect(self.gain_zoom)
        self.viewBox.disableAutoRange()

        self.plot = pg.PlotItem(viewBox=self.viewBox)
        self.graphicsview.setCentralItem(self.plot)
        self.plot.hideButtons()

        self.image = pg.ImageItem()
        self.plot.addItem(self.image)

        self.curve_spike = pg.PlotCurveItem()
        self.plot.addItem(self.curve_spike)

        self.curve_limit = pg.PlotCurveItem()
        self.plot.addItem(self.curve_limit)

        self.change_lut()

    def change_lut(self):
        N = 512
        cmap_name = self.params['colormap']
        cmap = matplotlib.cm.get_cmap(cmap_name, N)
        lut = []
        for i in range(N):
            r, g, b, _ = matplotlib.colors.ColorConverter().to_rgba(cmap(i))
            lut.append([r * 255, g * 255, b * 255])
        self.lut = np.array(lut, dtype='uint8')

    def change_catalogue(self, catalogue):
        self.params.blockSignals(True)
        with self.mutex:

            self.catalogue = catalogue

            colors = make_color_dict(self.catalogue['clusters'])
            self.qcolors = {}
            for k, color in colors.items():
                r, g, b = color
                self.qcolors[k] = QT.QColor(r * 255, g * 255, b * 255)

            self.all_plotted_labels = self.catalogue['cluster_labels'].tolist(
            ) + [LABEL_UNCLASSIFIED]

            centers0 = self.catalogue['centers0']
            if centers0.shape[0] > 0:
                self.params['bin_min'] = np.min(centers0) * 1.5
                self.params['bin_max'] = np.max(centers0) * 1.5

            bin_min, bin_max = self.params['bin_min'], self.params['bin_max']
            bin_size = self.params['bin_size']
            self.bins = np.arange(bin_min, bin_max, self.params['bin_size'])

            self.combobox.clear()
            self.combobox.addItems([str(k) for k in self.all_plotted_labels])

        self.on_clear()
        self._max = 10

        _, peak_width, nb_chan = self.catalogue['centers0'].shape
        x, y = [], []
        for c in range(1, nb_chan):
            x.extend([c * peak_width, c * peak_width, np.nan])
            y.extend([-1000, 1000, np.nan])

        self.curve_limit.setData(x=x, y=y, connect='finite')
        self.params.blockSignals(False)

    def on_clear(self):
        with self.mutex:
            shape = self.catalogue['centers0'].shape

            self.indexes0 = np.arange(shape[1] * shape[2], dtype='int64')

            self.histogram_2d = {}
            self.last_waveform = {}
            self.nb_spikes = {}
            for k in self.all_plotted_labels:
                self.histogram_2d[k] = np.zeros(
                    (shape[1] * shape[2], self.bins.size), dtype='int64')
                self.last_waveform[k] = np.zeros((shape[1] * shape[2], ),
                                                 dtype=self.wf_dtype)
                self.nb_spikes[k] = 0

        self.plot.setXRange(0, self.indexes0[-1] + 1)
        self.plot.setYRange(self.params['bin_min'], self.params['bin_max'])

    def auto_scale(self):
        pass

    def gain_zoom(self, v):
        self._max *= v
        self.image.setLevels([0, self._max], update=True)

    def refresh(self):
        #~ print('refresh')
        #~ t0 = time.perf_counter()

        head_sigs = self.poller_sigs.pos()
        head_spikes = self.poller_spikes.pos()

        if self.last_head_sigs is None:
            self.last_head_sigs = head_sigs

        if self.last_head_spikes is None:
            self.last_head_spikes = head_spikes

        if self.last_head_spikes is None or self.last_head_sigs is None:
            return

        # update image
        n_right, n_left = self.catalogue['n_right'], self.catalogue['n_left']
        bin_min, bin_max, bin_size = self.params['bin_min'], self.params[
            'bin_max'], self.params['bin_size']

        # check peak_buffer_size here
        if (head_spikes - self.last_head_spikes) >= (0.9 *
                                                     self.peak_buffer_size):
            self.last_head_spikes = head_spikes - int(
                0.9 * self.peak_buffer_size)
        new_spikes = self.inputs['spikes'].get_data(self.last_head_spikes,
                                                    head_spikes)

        right_indexes = new_spikes['index'] + n_right
        if np.any(right_indexes > head_sigs):
            # the buffer of signals is available for some spikes yet
            # so remove then for this loop and get back on head_spikes
            first_out = np.nonzero(right_indexes)[0][0]
            head_spikes = head_spikes - (new_spikes.size - first_out)
            new_spikes = new_spikes[:first_out]

        for k in self.all_plotted_labels:
            mask = new_spikes['cluster_label'] == k
            indexes = new_spikes[mask]['index']
            for ind in indexes:
                wf = self.inputs['signals'].get_data(ind + n_left,
                                                     ind + n_right)
                wf = wf.T.reshape(-1)
                wf_bined = np.floor((wf - bin_min) / bin_size).astype('int32')
                wf_bined = wf_bined.clip(0, self.bins.size - 1)

                with self.mutex:
                    self.histogram_2d[k][self.indexes0, wf_bined] += 1
                    self.last_waveform[k] = wf
                    self.nb_spikes[k] += 1

        self.last_head_sigs = head_sigs
        self.last_head_spikes = head_spikes

        if self.combobox.currentIndex() == -1:
            return

        if self.visibleRegion().isEmpty():
            # when several tabs not need to refresh
            return

        # refresh plot , update image
        k = self.all_plotted_labels[self.combobox.currentIndex()]
        hist2d = self.histogram_2d[k]

        self.image.setImage(hist2d, lut=self.lut, levels=[0, self._max])
        self.image.setRect(
            QT.QRectF(-0.5, bin_min, hist2d.shape[0], bin_max - bin_min))
        self.image.show()

        self.curve_spike.setData(x=self.indexes0,
                                 y=self.last_waveform[k],
                                 pen=pg.mkPen(self.qcolors[k], width=1.5))

        txt = 'nbs_pike = {}'.format(self.nb_spikes[k])
        self.label.setText(txt)
class OnlineTraceViewer(QOscilloscope):
    
    _input_specs = {'signals': dict(streamtype='signals'),
                                'spikes': dict(streamtype='events', shape = (-1, ),  dtype=_dtype_spike),
                                    }
    
    _default_params = QOscilloscope._default_params

    
    def __init__(self, **kargs):
        QOscilloscope.__init__(self, **kargs)

    def _configure(self, peak_buffer_size = 10000, catalogue=None, **kargs):
        QOscilloscope._configure(self, **kargs)
        self.peak_buffer_size = peak_buffer_size
        self.catalogue = catalogue
        assert catalogue is not None
    
    def _initialize(self, **kargs):
        QOscilloscope._initialize(self, **kargs)
        
        self.inputs['spikes'].set_buffer(size=self.peak_buffer_size, double=False)
        
        # poller onpeak
        self._last_peak = 0
        self.poller_peak = ThreadPollInput(input_stream=self.inputs['spikes'], return_data=True)
        self.poller_peak.new_data.connect(self._on_new_peak)
        
        self.spikes_array = self.inputs['spikes'].buffer.buffer
        
        self._default_color = QtGui.QColor('#FFFFFF')#TODO
        self.scatters = {}
        for k in self.catalogue['cluster_labels']:
            color = self.catalogue['cluster_colors'].get(k, (1,1,1))
            r, g, b = color
            qcolor = QtGui.QColor(r*255, g*255, b*255)
            qcolor.setAlpha(150)
            scatter = pg.ScatterPlotItem(x=[ ], y= [ ], pen=None, brush=qcolor, size=10, pxMode = True)
            self.scatters[k] = scatter
            self.plot.addItem(scatter)

        #~ for i in range(self.nb_channel):
            #~ color = self._default_color
            #~ color.setAlpha(150)
            #~ scatter = pg.ScatterPlotItem(x=[ ], y= [ ], pen=None, brush=color, size=10, pxMode = True)
            #~ self.scatters.append(scatter)
            #~ self.plot.addItem(scatter)
        

    def _start(self, **kargs):
        QOscilloscope._start(self, **kargs)
        self._last_peak = 0
        self.poller_peak.start()

    def _stop(self, **kargs):
        QOscilloscope._stop(self, **kargs)
        self.poller_peak.stop()
        self.poller_peak.wait()

    def _close(self, **kargs):
        QOscilloscope._close(self, **kargs)
    
    def reset_curves_data(self):
        QOscilloscope.reset_curves_data(self)
        self.t_vect_full = np.arange(0,self.full_size, dtype=float)/self.sample_rate
        self.t_vect_full -= self.t_vect_full[-1]
    
    def _on_new_peak(self, pos, data):
        self._last_peak = pos
    
    def autoestimate_scales(self):
        # in our case preprocesssed signal is supposed to be normalized
        self.all_mean = np.zeros(self.nb_channel,)
        self.all_sd = np.ones(self.nb_channel,)
        return self.all_mean, self.all_sd

    
    def _refresh(self, **kargs):
        QOscilloscope._refresh(self, **kargs)
        
        mode = self.params['mode']
        gains = np.array([p['gain'] for p in self.by_channel_params.children()])
        offsets = np.array([p['offset'] for p in self.by_channel_params.children()])
        visibles = np.array([p['visible'] for p in self.by_channel_params.children()], dtype=bool)
        
        head = self._head
        full_arr = self.inputs['signals'].get_data(head-self.full_size, head)
        if self._last_peak==0:
            return

        keep = (self.spikes_array['index']>head - self.full_size) & (self.spikes_array['index']<head)
        spikes = self.spikes_array[keep]
        
        #~ spikes = self.spikes_array['index'][keep]
        
        spikes_ind = spikes['index'] - (head - self.full_size)
        spikes_amplitude = full_arr[spikes_ind, :]
        spikes_amplitude[:, visibles] *= gains[visibles]
        spikes_amplitude[:, visibles] += offsets[visibles]
        
        if mode=='scroll':
            peak_times = self.t_vect_full[spikes_ind]
        elif mode =='scan':
            #some trick to play with fake time
            front = head % self.full_size
            ind1 = (spikes['index']%self.full_size)<front
            ind2 = (spikes['index']%self.full_size)>front
            peak_times = self.t_vect_full[spikes_ind]
            peak_times[ind1] += (self.t_vect_full[front] - self.t_vect_full[-1])
            peak_times[ind2] += (self.t_vect_full[front] - self.t_vect_full[0])
        
        for i, k in enumerate(self.catalogue['cluster_labels']):
            keep = k==spikes['label']
            if np.sum(keep)>0:
                chan = self.catalogue['max_on_channel'][i]
                if visibles[chan]:
                    self.scatters[k].setData(peak_times[keep], spikes_amplitude[keep, chan])
                else:
                    self.scatters[k].setData([], [])
            else:
                self.scatters[k].setData([], [])
class OnlineTraceViewer(QOscilloscope):

    _input_specs = {
        'signals': dict(streamtype='signals'),
        'spikes': dict(streamtype='events', shape=(-1, ), dtype=_dtype_spike),
    }

    _default_params = QOscilloscope._default_params

    def __init__(self, **kargs):
        QOscilloscope.__init__(self, **kargs)
        self.mutex = Mutex()

    def _configure(self, peak_buffer_size=100000, catalogue=None, **kargs):
        QOscilloscope._configure(self, **kargs)
        self.peak_buffer_size = peak_buffer_size
        self.catalogue = catalogue
        assert catalogue is not None

    def _initialize(self, **kargs):
        QOscilloscope._initialize(self, **kargs)

        self.inputs['spikes'].set_buffer(size=self.peak_buffer_size,
                                         double=False)

        # poller onpeak
        self._last_peak = 0
        self.poller_peak = ThreadPollInput(input_stream=self.inputs['spikes'],
                                           return_data=True)
        self.poller_peak.new_data.connect(self._on_new_peak)

        self.spikes_array = self.inputs['spikes'].buffer.buffer

        self.scatters = {}
        self.change_catalogue(self.catalogue)

        self.params['xsize'] = 1.
        self.params['decimation_method'] = 'min_max'
        self.params['mode'] = 'scan'
        self.params['scale_mode'] = 'same_for_all'
        self.params['display_labels'] = True

        self.timer_scale = QT.QTimer(singleShot=True, interval=500)
        self.timer_scale.timeout.connect(self.auto_scale)
        self.timer_scale.start()

    def _start(self, **kargs):
        QOscilloscope._start(self, **kargs)
        self._last_peak = 0
        self.poller_peak.start()

    def _stop(self, **kargs):
        QOscilloscope._stop(self, **kargs)
        self.poller_peak.stop()
        self.poller_peak.wait()

    def _close(self, **kargs):
        QOscilloscope._close(self, **kargs)

    def reset_curves_data(self):
        QOscilloscope.reset_curves_data(self)
        self.t_vect_full = np.arange(0, self.full_size,
                                     dtype=float) / self.sample_rate
        self.t_vect_full -= self.t_vect_full[-1]

    def _on_new_peak(self, pos, data):
        self._last_peak = pos

    def autoestimate_scales(self):
        # in our case preprocesssed signal is supposed to be normalized
        self.all_mean = np.zeros(self.nb_channel, )
        self.all_sd = np.ones(self.nb_channel, )
        return self.all_mean, self.all_sd

    def change_catalogue(self, catalogue):
        with self.mutex:

            for k, v in self.scatters.items():
                self.plot.removeItem(v)
            self.scatters = {}

            self.catalogue = catalogue

            colors = make_color_dict(self.catalogue['clusters'])

            self.qcolors = {}
            for k, color in colors.items():
                r, g, b = color
                self.qcolors[k] = QT.QColor(r * 255, g * 255, b * 255)

            self.all_plotted_labels = self.catalogue['cluster_labels'].tolist(
            ) + [LABEL_UNCLASSIFIED]

            for k in self.all_plotted_labels:
                qcolor = self.qcolors[k]
                qcolor.setAlpha(150)
                scatter = pg.ScatterPlotItem(x=[],
                                             y=[],
                                             pen=None,
                                             brush=qcolor,
                                             size=10,
                                             pxMode=True)
                self.scatters[k] = scatter
                self.plot.addItem(scatter)

    def _refresh(self, **kargs):
        if self.visibleRegion().isEmpty():
            # when several tabs not need to refresh
            return

        with self.mutex:
            QOscilloscope._refresh(self, **kargs)

            mode = self.params['mode']
            gains = np.array(
                [p['gain'] for p in self.by_channel_params.children()])
            offsets = np.array(
                [p['offset'] for p in self.by_channel_params.children()])
            visibles = np.array(
                [p['visible'] for p in self.by_channel_params.children()],
                dtype=bool)

            head = self._head
            full_arr = self.inputs['signals'].get_data(head - self.full_size,
                                                       head)
            if self._last_peak == 0:
                return

            keep = (self.spikes_array['index'] > head - self.full_size) & (
                self.spikes_array['index'] < head)
            spikes = self.spikes_array[keep]

            spikes_ind = spikes['index'] - (head - self.full_size)
            spikes_ind = spikes_ind[spikes_ind < full_arr.shape[
                0]]  # to avoid bug if last peak is great than head
            real_spikes_amplitude = full_arr[spikes_ind, :]
            spikes_amplitude = real_spikes_amplitude.copy()
            spikes_amplitude[:, visibles] *= gains[visibles]
            spikes_amplitude[:, visibles] += offsets[visibles]

            if mode == 'scroll':
                peak_times = self.t_vect_full[spikes_ind]
            elif mode == 'scan':
                #some trick to play with fake time
                front = head % self.full_size
                ind1 = (spikes['index'] % self.full_size) < front
                ind2 = (spikes['index'] % self.full_size) > front
                peak_times = self.t_vect_full[spikes_ind]
                peak_times[ind1] += (self.t_vect_full[front] -
                                     self.t_vect_full[-1])
                peak_times[ind2] += (self.t_vect_full[front] -
                                     self.t_vect_full[0])

            for i, k in enumerate(self.all_plotted_labels):
                keep = k == spikes['cluster_label']
                if np.sum(keep) > 0:
                    if k >= 0:
                        chan = self.catalogue['extremum_channel'][i]
                        if visibles[chan]:
                            times, amps = peak_times[keep], spikes_amplitude[
                                keep, chan]
                        else:
                            times, amps = [], []

                    else:
                        chan_max = np.argmax(np.abs(
                            real_spikes_amplitude[keep, :]),
                                             axis=1)
                        keep2 = visibles[chan_max]
                        chan_max = chan_max[keep2]
                        keep[keep] &= keep2
                        times, amps = peak_times[keep], spikes_amplitude[
                            keep, chan_max]

                    self.scatters[k].setData(times, amps)

                else:
                    self.scatters[k].setData([], [])

    def auto_scale(self, spacing_factor=25.):
        self.params_controller.compute_rescale(spacing_factor=spacing_factor)
        self.refresh()