Пример #1
0
    def _plot_waveforms(self, bunchs, bunchs_set, channel_ids, cluster_ids):
        # Initialize the box scaling the first time.
        if self.box_scaling[1] == 1.:
            M = np.max([np.max(np.abs(b.data)) for b in bunchs])
            self.box_scaling[1] = 1. / M
            self._update_boxes()
        clu_offsets = _get_clu_offsets(bunchs)
        max_clu_offsets = max(clu_offsets) + 1
        for i, d in enumerate(bunchs):
            wave = d.data
            alpha = d.get('alpha', .5)
            channel_ids_loc = d.channel_ids

            n_channels = len(channel_ids_loc)
            masks = d.get('masks', np.ones((wave.shape[0], n_channels)))
            # By default, this is 0, 1, 2 for the first 3 clusters.
            # But it can be customized when displaying several sets
            # of waveforms per cluster.
            # i = cluster_ids.index(d.cluster_id)  # 0, 1, 2, ...

            n_spikes_clu, n_samples = wave.shape[:2]
            assert wave.shape[2] == n_channels
            assert masks.shape == (n_spikes_clu, n_channels)

            # Find the x coordinates.
            t = _get_linear_x(n_spikes_clu * n_channels, n_samples)
            if not self.overlap:

                # Determine the cluster offset.
                offset = clu_offsets[i]
                t = t + 2.5 * (offset - (max_clu_offsets - 1) / 2.)
                # The total width should not depend on the number of
                # clusters.
                t /= max_clu_offsets

            # Get the spike masks.
            m = masks
            # HACK: on the GPU, we get the actual masks with fract(masks)
            # since we add the relative cluster index. We need to ensure
            # that the masks is never 1.0, otherwise it is interpreted as
            # 0.
            m *= .99999
            # NOTE: we add the cluster index which is used for the
            # computation of the depth on the GPU.
            m += i

            color = tuple(_colormap(i)) + (alpha,)
            assert len(color) == 4

            # Generate the box index (one number per channel).
            box_index = _index_of(channel_ids_loc, channel_ids)
            box_index = np.repeat(box_index, n_samples)
            box_index = np.tile(box_index, n_spikes_clu)
            assert box_index.shape == (n_spikes_clu *
                                       n_channels *
                                       n_samples,)

            # Generate the waveform array.
            wave = np.transpose(wave, (0, 2, 1))
            wave = wave.reshape((n_spikes_clu * n_channels, n_samples))

            self.uplot(x=t,
                       y=wave,
                       color=color,
                       masks=m,
                       box_index=box_index,
                       data_bounds=None,
                       )
            
        for i, d in enumerate(bunchs_set):
            wave = d.data ##### Equivalent to the data of module CellTypes.py
            clIds = str(cluster_ids).replace(' ', '')
            color = tuple(_colormap(i)) + (alpha,)
            color=color[:3]+(0.3,)
            #np.save('/home/ms047/Desktop/waveform385%s.npy'%clIds, wave)
            # Generate the waveform array.
            plot_wvf_feat_vispy(self, wave, color, i)
Пример #2
0
    def on_select(self, cluster_ids=None):
        super(WaveformView, self).on_select(cluster_ids)
        cluster_ids = self.cluster_ids
        n_clusters = len(cluster_ids)
        if n_clusters == 0:
            return

        # Load the waveform subset.
        data = self.waveforms(cluster_ids)

        # Plot all waveforms.
        with self.building():
            already_shown = set()
            for i, d in enumerate(data):
                if (self.filtered_tags and
                        d.get('tag') not in self.filtered_tags):
                    continue  # pragma: no cover
                alpha = d.get('alpha', .5)
                wave = d.data
                masks = d.masks
                # By default, this is 0, 1, 2 for the first 3 clusters.
                # But it can be customized when displaying several sets
                # of waveforms per cluster.
                pos_idx = cluster_ids.index(d.cluster_id)  # 0, 1, 2, ...

                n_spikes_clu, n_samples, n_unmasked = wave.shape
                assert masks.shape[0] == n_spikes_clu

                # Find the unmasked channels for those spikes.
                unmasked = d.get('channels', np.arange(self.n_channels))
                assert n_unmasked == len(unmasked)
                assert n_unmasked > 0

                # Find the x coordinates.
                t = _get_linear_x(n_spikes_clu * n_unmasked, n_samples)
                if not self.overlap:
                    t = t + 2.5 * (pos_idx - (n_clusters - 1) / 2.)
                    # The total width should not depend on the number of
                    # clusters.
                    t /= n_clusters

                # Get the spike masks.
                m = masks[:, unmasked].reshape((-1, 1))
                # HACK: on the GPU, we get the actual masks with fract(masks)
                # since we add the relative cluster index. We need to ensure
                # that the masks is never 1.0, otherwise it is interpreted as
                # 0.
                m *= .999
                # NOTE: we add the cluster index which is used for the
                # computation of the depth on the GPU.
                m += pos_idx

                color = tuple(_colormap(pos_idx)) + (alpha,)
                assert len(color) == 4

                # Generate the box index (one number per channel).
                box_index = unmasked
                box_index = np.repeat(box_index, n_samples)
                box_index = np.tile(box_index, n_spikes_clu)
                assert box_index.shape == (n_spikes_clu * n_unmasked *
                                           n_samples,)

                # Generate the waveform array.
                wave = np.transpose(wave, (0, 2, 1))
                wave = wave.reshape((n_spikes_clu * n_unmasked, n_samples))

                self.plot(x=t,
                          y=wave,
                          color=color,
                          masks=m,
                          box_index=box_index,
                          data_bounds=None,
                          uniform=True,
                          )
                # Add channel labels.
                if self.do_show_labels:
                    for ch in unmasked:
                        # Skip labels that have already been shown.
                        if ch in already_shown:
                            continue
                        already_shown.add(ch)
                        ch_label = '%d' % self.channel_order[ch]
                        self[ch].text(pos=[t[0, 0], 0.],
                                      text=ch_label,
                                      anchor=[-1.01, -.25],
                                      data_bounds=None,
                                      )

        # Zoom on the best channels when selecting clusters.
        channels = self.best_channels(cluster_ids)
        if channels is not None and self.do_zoom_on_channels:
            self.zoom_on_channels(channels)
Пример #3
0
    def on_select(self, cluster_ids=None):
        super(WaveformView, self).on_select(cluster_ids)
        cluster_ids = self.cluster_ids
        n_clusters = len(cluster_ids)
        if n_clusters == 0:
            return

        # Load the waveform subset.
        data = self.waveforms(cluster_ids)
        # Take one element in the list.
        data = data[self.data_index % len(data)]
        alpha = data.get('alpha', .5)
        spike_ids = data.spike_ids
        spike_clusters = data.spike_clusters
        w = data.data
        masks = data.masks
        n_spikes = len(spike_ids)
        assert w.ndim == 3
        n_samples = w.shape[1]
        assert w.shape == (n_spikes, n_samples, self.n_channels)
        assert masks.shape == (n_spikes, self.n_channels)

        # Relative spike clusters.
        spike_clusters_rel = _index_of(spike_clusters, cluster_ids)
        assert spike_clusters_rel.shape == (n_spikes,)

        # Fetch the waveforms.
        t = _get_linear_x(n_spikes, n_samples)
        # Overlap.
        if not self.overlap:
            t = t + 2.5 * (spike_clusters_rel[:, np.newaxis] -
                           (n_clusters - 1) / 2.)
            # The total width should not depend on the number of clusters.
            t /= n_clusters

        # Plot all waveforms.
        # OPTIM: avoid the loop.
        with self.building():
            for ch in range(self.n_channels):
                m = masks[:, ch]
                depth = _get_depth(m,
                                   spike_clusters_rel=spike_clusters_rel,
                                   n_clusters=n_clusters)
                color = _spike_colors(spike_clusters_rel,
                                      masks=m,
                                      alpha=alpha,
                                      )
                self[ch].plot(x=t, y=w[:, :, ch],
                              color=color,
                              depth=depth,
                              data_bounds=self.data_bounds,
                              )
                # Add channel labels.
                self[ch].text(pos=[[t[0, 0], 0.]],
                              text=str(ch),
                              anchor=[-1.01, -.25],
                              data_bounds=self.data_bounds,
                              )

        # Zoom on the best channels when selecting clusters.
        channels = self.best_channels(cluster_ids)
        if channels is not None and self.do_zoom_on_channels:
            self.zoom_on_channels(channels)
Пример #4
0
    def on_select(self, cluster_ids=None):
        super(WaveformView, self).on_select(cluster_ids)
        cluster_ids = self.cluster_ids
        n_clusters = len(cluster_ids)
        if n_clusters == 0:
            return

        # Load the waveform subset.
        data = self.waveforms(cluster_ids)
        # Take one element in the list.
        data = data[self.data_index % len(data)]
        alpha = data.get('alpha', .5)
        spike_ids = data.spike_ids
        spike_clusters = data.spike_clusters
        w = data.data
        masks = data.masks
        n_spikes = len(spike_ids)
        assert w.ndim == 3
        n_samples = w.shape[1]
        assert w.shape == (n_spikes, n_samples, self.n_channels)
        assert masks.shape == (n_spikes, self.n_channels)

        # Plot all waveforms.
        # OPTIM: avoid the loop.
        with self.building():
            for i, cl in enumerate(cluster_ids):

                # Select the spikes corresponding to a given cluster.
                idx = spike_clusters == cl
                n_spikes_clu = idx.sum()  # number of spikes in the cluster.

                # Find the x coordinates.
                t = _get_linear_x(n_spikes_clu * self.n_channels, n_samples)
                if not self.overlap:
                    t = t + 2.5 * (i - (n_clusters - 1) / 2.)
                    # The total width should not depend on the number of
                    # clusters.
                    t /= n_clusters

                # Get the spike masks.
                m = masks[idx, :].reshape((n_spikes_clu * self.n_channels, 1))
                # HACK: on the GPU, we get the actual masks with fract(masks)
                # since we add the relative cluster index. We need to ensure
                # that the masks is never 1.0, otherwise it is interpreted as
                # 0.
                m *= .999
                # NOTE: we add the cluster index which is used for the
                # computation of the depth on the GPU.
                m += i

                color = tuple(_colormap(i)) + (alpha, )
                assert len(color) == 4

                # Generate the box index (one number per channel).
                box_index = np.arange(self.n_channels)
                box_index = np.repeat(box_index, n_samples)
                box_index = np.tile(box_index, n_spikes_clu)
                assert box_index.shape == (n_spikes_clu * self.n_channels *
                                           n_samples)

                # Generate the waveform array.
                wave = w[idx, :, :]
                wave = np.transpose(wave, (0, 2, 1))
                wave = wave.reshape(
                    (n_spikes_clu * self.n_channels, n_samples))

                self.plot(
                    x=t,
                    y=wave,
                    color=color,
                    masks=m,
                    box_index=box_index,
                    data_bounds=None,
                    uniform=True,
                )
                # Add channel labels.
                if self.do_show_labels and i == 0:
                    for ch in range(self.n_channels):
                        self[ch].text(
                            pos=[t[0, 0], 0.],
                            # TODO: use real channel labels.
                            text=str(ch),
                            anchor=[-1.01, -.25],
                            data_bounds=None,
                        )

        # Zoom on the best channels when selecting clusters.
        channels = self.best_channels(cluster_ids)
        if channels is not None and self.do_zoom_on_channels:
            self.zoom_on_channels(channels)
Пример #5
0
    def _plot_waveforms(self, bunchs, channel_ids):
        # Initialize the box scaling the first time.
        if self.box_scaling[1] == 1.:
            M = np.max([np.max(np.abs(b.data)) for b in bunchs])
            self.box_scaling[1] = 1. / M
            self._update_boxes()
        clu_offsets = _get_clu_offsets(bunchs)
        max_clu_offsets = max(clu_offsets) + 1
        for i, d in enumerate(bunchs):
            wave = d.data
            alpha = d.get('alpha', .5)
            channel_ids_loc = d.channel_ids

            n_channels = len(channel_ids_loc)
            masks = d.get('masks', np.ones((wave.shape[0], n_channels)))
            # By default, this is 0, 1, 2 for the first 3 clusters.
            # But it can be customized when displaying several sets
            # of waveforms per cluster.
            # i = cluster_ids.index(d.cluster_id)  # 0, 1, 2, ...

            n_spikes_clu, n_samples = wave.shape[:2]
            assert wave.shape[2] == n_channels
            assert masks.shape == (n_spikes_clu, n_channels)

            # Find the x coordinates.
            t = _get_linear_x(n_spikes_clu * n_channels, n_samples)
            if not self.overlap:

                # Determine the cluster offset.
                offset = clu_offsets[i]
                t = t + 2.5 * (offset - (max_clu_offsets - 1) / 2.)
                # The total width should not depend on the number of
                # clusters.
                t /= max_clu_offsets

            # Get the spike masks.
            m = masks
            # HACK: on the GPU, we get the actual masks with fract(masks)
            # since we add the relative cluster index. We need to ensure
            # that the masks is never 1.0, otherwise it is interpreted as
            # 0.
            m *= .99999
            # NOTE: we add the cluster index which is used for the
            # computation of the depth on the GPU.
            m += i

            color = tuple(_colormap(i)) + (alpha,)
            assert len(color) == 4

            # Generate the box index (one number per channel).
            box_index = _index_of(channel_ids_loc, channel_ids)
            box_index = np.repeat(box_index, n_samples)
            box_index = np.tile(box_index, n_spikes_clu)
            assert box_index.shape == (n_spikes_clu *
                                       n_channels *
                                       n_samples,)

            # Generate the waveform array.
            wave = np.transpose(wave, (0, 2, 1))
            wave = wave.reshape((n_spikes_clu * n_channels, n_samples))

            self.uplot(x=t,
                       y=wave,
                       color=color,
                       masks=m,
                       box_index=box_index,
                       data_bounds=None,
                       )