Example #1
0
    def get_features(self, spike_ids, channel_ids):
        """Return sparse features for given spikes."""
        data = self.features
        _, n_channels_loc, n_pcs = data.shape
        ns = len(spike_ids)
        nc = len(channel_ids)

        # Initialize the output array.
        features = np.empty((ns, n_channels_loc, n_pcs))
        features[:] = np.NAN

        if self.features_rows is not None:
            s = np.intersect1d(spike_ids, self.features_rows)
            # Relative indices of the spikes in the self.features_spike_ids
            # array, necessary to load features from all_features which only
            # contains the subset of the spikes.
            rows = _index_of(s, self.features_rows)
            # Relative indices of the non-null rows in the output features
            # array.
            rows_out = _index_of(s, spike_ids)
        else:
            rows = spike_ids
            rows_out = slice(None, None, None)
        features[rows_out, ...] = data[rows]

        if self.features_cols is not None:
            assert self.features_cols.shape[1] == n_channels_loc
            cols = self.features_cols[self.spike_templates[spike_ids]]
            features = from_sparse(features, cols, channel_ids)

        assert features.shape == (ns, nc, n_pcs)
        return features
Example #2
0
    def get_template_features(self, spike_ids):
        """Return sparse template features for given spikes."""
        data = self.template_features
        _, n_templates_loc = data.shape
        ns = len(spike_ids)

        if self.template_features_rows is not None:
            spike_ids = np.intersect1d(spike_ids, self.features_rows)
            # Relative indices of the spikes in the self.features_spike_ids
            # array, necessary to load features from all_features which only
            # contains the subset of the spikes.
            rows = _index_of(spike_ids, self.template_features_rows)
        else:
            rows = spike_ids
        template_features = data[rows]

        if self.template_features_cols is not None:
            assert self.template_features_cols.shape[1] == n_templates_loc
            cols = self.template_features_cols[self.spike_templates[spike_ids]]
            template_features = from_sparse(template_features,
                                            cols,
                                            np.arange(self.n_templates),
                                            )
        assert template_features.shape[0] == ns
        return template_features
Example #3
0
    def get_features(self, cluster_id, load_all=False):
        # Overriden to take into account the sparse structure.
        # Only keep spikes belonging to the features spike ids.
        if self.features_spike_ids is not None:
            # All spikes
            spike_ids = self._select_spikes(cluster_id)
            spike_ids = np.intersect1d(spike_ids, self.features_spike_ids)
            # Relative indices of the spikes in the self.features_spike_ids
            # array, necessary to load features from all_features which only
            # contains the subset of the spikes.
            spike_ids_rel = _index_of(spike_ids, self.features_spike_ids)
        else:
            spike_ids = self._select_spikes(
                cluster_id, self.n_spikes_features if not load_all else None)
            spike_ids_rel = spike_ids
        st = self.spike_templates[spike_ids]
        nc = self.n_channels
        nfpc = self.n_features_per_channel
        ns = len(spike_ids)
        f = _densify(spike_ids_rel, self.all_features,
                     self.features_ind[st, :], self.n_channels)
        f = np.transpose(f, (0, 2, 1))
        assert f.shape == (ns, nc, nfpc)
        b = Bunch()

        # Normalize features.
        m = self.get_feature_lim()
        f = _normalize(f, -m, m)

        b.data = f
        b.spike_ids = spike_ids
        b.spike_clusters = self.spike_clusters[spike_ids]
        b.masks = self.all_masks[spike_ids]
        return b
Example #4
0
    def get_features(self, cluster_id, load_all=False):
        # Overriden to take into account the sparse structure.
        # Only keep spikes belonging to the features spike ids.
        if self.features_spike_ids is not None:
            # All spikes
            spike_ids = self._select_spikes(cluster_id)
            spike_ids = np.intersect1d(spike_ids, self.features_spike_ids)
            # Relative indices of the spikes in the self.features_spike_ids
            # array, necessary to load features from all_features which only
            # contains the subset of the spikes.
            spike_ids_rel = _index_of(spike_ids, self.features_spike_ids)
        else:
            spike_ids = self._select_spikes(cluster_id,
                                            self.n_spikes_features
                                            if not load_all else None)
            spike_ids_rel = spike_ids
        st = self.spike_templates[spike_ids]
        nc = self.n_channels
        nfpc = self.n_features_per_channel
        ns = len(spike_ids)
        f = _densify(spike_ids_rel, self.all_features,
                     self.features_ind[st, :], self.n_channels)
        f = np.transpose(f, (0, 2, 1))
        assert f.shape == (ns, nc, nfpc)
        b = Bunch()

        # Normalize features.
        m = self.get_feature_lim()
        f = _normalize(f, -m, m)

        b.data = f
        b.spike_ids = spike_ids
        b.spike_clusters = self.spike_clusters[spike_ids]
        b.masks = self.all_masks[spike_ids]
        return b
Example #5
0
def from_sparse(data, cols, channel_ids):
    """Convert a sparse structure into a dense one.

    Arguments:

    data : array
        A (n_spikes, n_channels_loc, ...) array with the data.
    cols : array
        A (n_spikes, n_channels_loc) array with the channel indices of
        every row in data.
    channel_ids : array
        List of requested channel ids (columns).

    """
    # The axis in the data that contains the channels.
    if len(channel_ids) != len(np.unique(channel_ids)):
        raise NotImplementedError("Multiple identical requested channels "
                                  "in from_sparse().")
    channel_axis = 1
    shape = list(data.shape)
    assert data.ndim >= 2
    assert cols.ndim == 2
    assert data.shape[:2] == cols.shape
    n_spikes, n_channels_loc = shape[:2]
    # NOTE: we ensure here that `col` contains integers.
    c = cols.flatten().astype(np.int32)
    # Remove columns that do not belong to the specified channels.
    c[~np.in1d(c, channel_ids)] = -1
    assert np.all(np.in1d(c, np.r_[channel_ids, -1]))
    # Convert column indices to relative indices given the specified
    # channel_ids.
    cols_loc = _index_of(c, np.r_[channel_ids, -1]).reshape(cols.shape)
    assert cols_loc.shape == (n_spikes, n_channels_loc)
    n_channels = len(channel_ids)
    # Shape of the output array.
    out_shape = shape
    # The channel dimension contains the number of requested channels.
    # The last column contains irrelevant values.
    out_shape[channel_axis] = n_channels + 1
    out = np.zeros(out_shape, dtype=data.dtype)
    x = np.tile(np.arange(n_spikes)[:, np.newaxis],
                (1, n_channels_loc))
    assert x.shape == cols_loc.shape == data.shape[:2]
    out[x, cols_loc, ...] = data
    # Remove the last column with values outside the specified
    # channels.
    out = out[:, :-1, ...]
    return out
Example #6
0
    def on_select(self, cluster_ids=None):
        super(ScatterView, self).on_select(cluster_ids)
        cluster_ids = self.cluster_ids
        n_clusters = len(cluster_ids)
        if n_clusters == 0:
            return

        # Get the spike times and amplitudes
        data = self.coords(cluster_ids)
        if data is None:
            self.clear()
            return
        spike_ids = data.spike_ids
        spike_clusters = data.spike_clusters
        x = data.x
        y = data.y
        n_spikes = len(spike_ids)
        assert n_spikes > 0
        assert spike_clusters.shape == (n_spikes, )
        assert x.shape == (n_spikes, )
        assert y.shape == (n_spikes, )

        # Get the spike clusters.
        sc = _index_of(spike_clusters, cluster_ids)

        # Plot the amplitudes.
        with self.building():
            m = np.ones(n_spikes)
            # Get the color of the markers.
            color = _spike_colors(sc, masks=m)
            assert color.shape == (n_spikes, 4)
            ms = (self._default_marker_size if sc is not None else 1.)

            self.scatter(
                x=x,
                y=y,
                color=color,
                size=ms * np.ones(n_spikes),
            )
Example #7
0
    def on_select(self, cluster_ids=None):
        super(ScatterView, self).on_select(cluster_ids)
        cluster_ids = self.cluster_ids
        n_clusters = len(cluster_ids)
        if n_clusters == 0:
            return

        # Get the spike times and amplitudes
        data = self.coords(cluster_ids)
        if data is None:
            self.clear()
            return
        spike_ids = data.spike_ids
        spike_clusters = data.spike_clusters
        x = data.x
        y = data.y
        n_spikes = len(spike_ids)
        assert n_spikes > 0
        assert spike_clusters.shape == (n_spikes,)
        assert x.shape == (n_spikes,)
        assert y.shape == (n_spikes,)

        # Get the spike clusters.
        sc = _index_of(spike_clusters, cluster_ids)

        # Plot the amplitudes.
        with self.building():
            m = np.ones(n_spikes)
            # Get the color of the markers.
            color = _spike_colors(sc, masks=m)
            assert color.shape == (n_spikes, 4)
            ms = (self._default_marker_size if sc is not None else 1.)

            self.scatter(x=x,
                         y=y,
                         color=color,
                         data_bounds=self.data_bounds,
                         size=ms * np.ones(n_spikes),
                         )
Example #8
0
    def _plot_waveforms(self, bunchs, bunchs_set, channel_ids, cluster_ids):
        # Initialize the box scaling the first time.
        if self.box_scaling[1] == 1.:
            M = np.max([np.max(np.abs(b.data)) for b in bunchs])
            self.box_scaling[1] = 1. / M
            self._update_boxes()
        clu_offsets = _get_clu_offsets(bunchs)
        max_clu_offsets = max(clu_offsets) + 1
        for i, d in enumerate(bunchs):
            wave = d.data
            alpha = d.get('alpha', .5)
            channel_ids_loc = d.channel_ids

            n_channels = len(channel_ids_loc)
            masks = d.get('masks', np.ones((wave.shape[0], n_channels)))
            # By default, this is 0, 1, 2 for the first 3 clusters.
            # But it can be customized when displaying several sets
            # of waveforms per cluster.
            # i = cluster_ids.index(d.cluster_id)  # 0, 1, 2, ...

            n_spikes_clu, n_samples = wave.shape[:2]
            assert wave.shape[2] == n_channels
            assert masks.shape == (n_spikes_clu, n_channels)

            # Find the x coordinates.
            t = _get_linear_x(n_spikes_clu * n_channels, n_samples)
            if not self.overlap:

                # Determine the cluster offset.
                offset = clu_offsets[i]
                t = t + 2.5 * (offset - (max_clu_offsets - 1) / 2.)
                # The total width should not depend on the number of
                # clusters.
                t /= max_clu_offsets

            # Get the spike masks.
            m = masks
            # HACK: on the GPU, we get the actual masks with fract(masks)
            # since we add the relative cluster index. We need to ensure
            # that the masks is never 1.0, otherwise it is interpreted as
            # 0.
            m *= .99999
            # NOTE: we add the cluster index which is used for the
            # computation of the depth on the GPU.
            m += i

            color = tuple(_colormap(i)) + (alpha,)
            assert len(color) == 4

            # Generate the box index (one number per channel).
            box_index = _index_of(channel_ids_loc, channel_ids)
            box_index = np.repeat(box_index, n_samples)
            box_index = np.tile(box_index, n_spikes_clu)
            assert box_index.shape == (n_spikes_clu *
                                       n_channels *
                                       n_samples,)

            # Generate the waveform array.
            wave = np.transpose(wave, (0, 2, 1))
            wave = wave.reshape((n_spikes_clu * n_channels, n_samples))

            self.uplot(x=t,
                       y=wave,
                       color=color,
                       masks=m,
                       box_index=box_index,
                       data_bounds=None,
                       )
            
        for i, d in enumerate(bunchs_set):
            wave = d.data ##### Equivalent to the data of module CellTypes.py
            clIds = str(cluster_ids).replace(' ', '')
            color = tuple(_colormap(i)) + (alpha,)
            color=color[:3]+(0.3,)
            #np.save('/home/ms047/Desktop/waveform385%s.npy'%clIds, wave)
            # Generate the waveform array.
            plot_wvf_feat_vispy(self, wave, color, i)
Example #9
0
File: ccg.py Project: ablot/phy
def correlograms(spike_times,
                 spike_clusters,
                 cluster_ids=None,
                 sample_rate=1.,
                 bin_size=None,
                 window_size=None,
                 symmetrize=True,
                 ):
    """Compute all pairwise cross-correlograms among the clusters appearing
    in `spike_clusters`.

    Parameters
    ----------

    spike_times : array-like
        Spike times in seconds.
    spike_clusters : array-like
        Spike-cluster mapping.
    cluster_ids : array-like
        The list of unique clusters, in any order. That order will be used
        in the output array.
    bin_size : float
        Size of the bin, in seconds.
    window_size : float
        Size of the window, in seconds.

    Returns
    -------

    correlograms : array
        A `(n_clusters, n_clusters, winsize_samples)` array with all pairwise
        CCGs.

    """
    assert sample_rate > 0.
    assert np.all(np.diff(spike_times) >= 0), ("The spike times must be "
                                               "increasing.")

    # Get the spike samples.
    spike_times = np.asarray(spike_times, dtype=np.float64)
    spike_samples = (spike_times * sample_rate).astype(np.int64)

    spike_clusters = _as_array(spike_clusters)

    assert spike_samples.ndim == 1
    assert spike_samples.shape == spike_clusters.shape

    # Find `binsize`.
    bin_size = np.clip(bin_size, 1e-5, 1e5)  # in seconds
    binsize = int(sample_rate * bin_size)  # in samples
    assert binsize >= 1

    # Find `winsize_bins`.
    window_size = np.clip(window_size, 1e-5, 1e5)  # in seconds
    winsize_bins = 2 * int(.5 * window_size / bin_size) + 1

    assert winsize_bins >= 1
    assert winsize_bins % 2 == 1

    # Take the cluster oder into account.
    if cluster_ids is None:
        clusters = _unique(spike_clusters)
    else:
        clusters = _as_array(cluster_ids)
    n_clusters = len(clusters)

    # Like spike_clusters, but with 0..n_clusters-1 indices.
    spike_clusters_i = _index_of(spike_clusters, clusters)

    # Shift between the two copies of the spike trains.
    shift = 1

    # At a given shift, the mask precises which spikes have matching spikes
    # within the correlogram time window.
    mask = np.ones_like(spike_samples, dtype=np.bool)

    correlograms = _create_correlograms_array(n_clusters, winsize_bins)

    # The loop continues as long as there is at least one spike with
    # a matching spike.
    while mask[:-shift].any():
        # Number of time samples between spike i and spike i+shift.
        spike_diff = _diff_shifted(spike_samples, shift)

        # Binarize the delays between spike i and spike i+shift.
        spike_diff_b = spike_diff // binsize

        # Spikes with no matching spikes are masked.
        mask[:-shift][spike_diff_b > (winsize_bins // 2)] = False

        # Cache the masked spike delays.
        m = mask[:-shift].copy()
        d = spike_diff_b[m]

        # # Update the masks given the clusters to update.
        # m0 = np.in1d(spike_clusters[:-shift], clusters)
        # m = m & m0
        # d = spike_diff_b[m]
        d = spike_diff_b[m]

        # Find the indices in the raveled correlograms array that need
        # to be incremented, taking into account the spike clusters.
        indices = np.ravel_multi_index((spike_clusters_i[:-shift][m],
                                        spike_clusters_i[+shift:][m],
                                        d),
                                       correlograms.shape)

        # Increment the matching spikes in the correlograms array.
        _increment(correlograms.ravel(), indices)

        shift += 1

    # Remove ACG peaks.
    correlograms[np.arange(n_clusters),
                 np.arange(n_clusters),
                 0] = 0

    if symmetrize:
        return _symmetrize_correlograms(correlograms)
    else:
        return correlograms
Example #10
0
def correlograms(
    spike_times,
    spike_clusters,
    cluster_ids=None,
    sample_rate=1.,
    bin_size=None,
    window_size=None,
    symmetrize=True,
):
    """Compute all pairwise cross-correlograms among the clusters appearing
    in `spike_clusters`.

    Parameters
    ----------

    spike_times : array-like
        Spike times in seconds.
    spike_clusters : array-like
        Spike-cluster mapping.
    cluster_ids : array-like
        The list of unique clusters, in any order. That order will be used
        in the output array.
    bin_size : float
        Size of the bin, in seconds.
    window_size : float
        Size of the window, in seconds.

    Returns
    -------

    correlograms : array
        A `(n_clusters, n_clusters, winsize_samples)` array with all pairwise
        CCGs.

    """
    assert sample_rate > 0.
    assert np.all(np.diff(spike_times) >= 0), ("The spike times must be "
                                               "increasing.")

    # Get the spike samples.
    spike_times = np.asarray(spike_times, dtype=np.float64)
    spike_samples = (spike_times * sample_rate).astype(np.int64)

    spike_clusters = _as_array(spike_clusters)

    assert spike_samples.ndim == 1
    assert spike_samples.shape == spike_clusters.shape

    # Find `binsize`.
    bin_size = np.clip(bin_size, 1e-5, 1e5)  # in seconds
    binsize = int(sample_rate * bin_size)  # in samples
    assert binsize >= 1

    # Find `winsize_bins`.
    window_size = np.clip(window_size, 1e-5, 1e5)  # in seconds
    winsize_bins = 2 * int(.5 * window_size / bin_size) + 1

    assert winsize_bins >= 1
    assert winsize_bins % 2 == 1

    # Take the cluster oder into account.
    if cluster_ids is None:
        clusters = _unique(spike_clusters)
    else:
        clusters = _as_array(cluster_ids)
    n_clusters = len(clusters)

    # Like spike_clusters, but with 0..n_clusters-1 indices.
    spike_clusters_i = _index_of(spike_clusters, clusters)

    # Shift between the two copies of the spike trains.
    shift = 1

    # At a given shift, the mask precises which spikes have matching spikes
    # within the correlogram time window.
    mask = np.ones_like(spike_samples, dtype=np.bool)

    correlograms = _create_correlograms_array(n_clusters, winsize_bins)

    # The loop continues as long as there is at least one spike with
    # a matching spike.
    while mask[:-shift].any():
        # Number of time samples between spike i and spike i+shift.
        spike_diff = _diff_shifted(spike_samples, shift)

        # Binarize the delays between spike i and spike i+shift.
        spike_diff_b = spike_diff // binsize

        # Spikes with no matching spikes are masked.
        mask[:-shift][spike_diff_b > (winsize_bins // 2)] = False

        # Cache the masked spike delays.
        m = mask[:-shift].copy()
        d = spike_diff_b[m]

        # # Update the masks given the clusters to update.
        # m0 = np.in1d(spike_clusters[:-shift], clusters)
        # m = m & m0
        # d = spike_diff_b[m]
        d = spike_diff_b[m]

        # Find the indices in the raveled correlograms array that need
        # to be incremented, taking into account the spike clusters.
        indices = np.ravel_multi_index(
            (spike_clusters_i[:-shift][m], spike_clusters_i[+shift:][m], d),
            correlograms.shape)

        # Increment the matching spikes in the correlograms array.
        _increment(correlograms.ravel(), indices)

        shift += 1

    # Remove ACG peaks.
    correlograms[np.arange(n_clusters), np.arange(n_clusters), 0] = 0

    if symmetrize:
        return _symmetrize_correlograms(correlograms)
    else:
        return correlograms
Example #11
0
    def on_select(self, cluster_ids=None):
        super(WaveformView, self).on_select(cluster_ids)
        cluster_ids = self.cluster_ids
        n_clusters = len(cluster_ids)
        if n_clusters == 0:
            return

        # Load the waveform subset.
        data = self.waveforms(cluster_ids)
        # Take one element in the list.
        data = data[self.data_index % len(data)]
        alpha = data.get('alpha', .5)
        spike_ids = data.spike_ids
        spike_clusters = data.spike_clusters
        w = data.data
        masks = data.masks
        n_spikes = len(spike_ids)
        assert w.ndim == 3
        n_samples = w.shape[1]
        assert w.shape == (n_spikes, n_samples, self.n_channels)
        assert masks.shape == (n_spikes, self.n_channels)

        # Relative spike clusters.
        spike_clusters_rel = _index_of(spike_clusters, cluster_ids)
        assert spike_clusters_rel.shape == (n_spikes,)

        # Fetch the waveforms.
        t = _get_linear_x(n_spikes, n_samples)
        # Overlap.
        if not self.overlap:
            t = t + 2.5 * (spike_clusters_rel[:, np.newaxis] -
                           (n_clusters - 1) / 2.)
            # The total width should not depend on the number of clusters.
            t /= n_clusters

        # Plot all waveforms.
        # OPTIM: avoid the loop.
        with self.building():
            for ch in range(self.n_channels):
                m = masks[:, ch]
                depth = _get_depth(m,
                                   spike_clusters_rel=spike_clusters_rel,
                                   n_clusters=n_clusters)
                color = _spike_colors(spike_clusters_rel,
                                      masks=m,
                                      alpha=alpha,
                                      )
                self[ch].plot(x=t, y=w[:, :, ch],
                              color=color,
                              depth=depth,
                              data_bounds=self.data_bounds,
                              )
                # Add channel labels.
                self[ch].text(pos=[[t[0, 0], 0.]],
                              text=str(ch),
                              anchor=[-1.01, -.25],
                              data_bounds=self.data_bounds,
                              )

        # Zoom on the best channels when selecting clusters.
        channels = self.best_channels(cluster_ids)
        if channels is not None and self.do_zoom_on_channels:
            self.zoom_on_channels(channels)
Example #12
0
    def on_select(self, cluster_ids=None):
        super(FeatureView, self).on_select(cluster_ids)
        cluster_ids = self.cluster_ids
        n_clusters = len(cluster_ids)
        if n_clusters == 0:
            return

        # Get the spikes, features, masks.
        data = self.features(cluster_ids)
        spike_ids = data.spike_ids
        spike_clusters = data.spike_clusters
        f = data.data
        masks = data.masks
        assert f.ndim == 3
        assert masks.ndim == 2
        assert spike_ids.shape[0] == f.shape[0] == masks.shape[0]

        # Get the spike clusters.
        sc = _index_of(spike_clusters, cluster_ids)

        # Get the background features.
        data_bg = self.background_features
        if data_bg is not None:
            spike_ids_bg = data_bg.spike_ids
            features_bg = data_bg.data
            masks_bg = data_bg.masks

        # Select the dimensions.
        # Choose the channels automatically unless fixed_channels is set.
        if (not self.fixed_channels or self.channels is None):
            self.channels = self._get_channel_dims(cluster_ids)
        tla = self.top_left_attribute
        assert self.channels
        x_dim, y_dim = _dimensions_matrix(self.channels,
                                          n_cols=self.n_cols,
                                          top_left_attribute=tla)

        # Set the status message.
        ch = ', '.join(map(str, self.channels))
        self.set_status('Channels: {}'.format(ch))

        # Set a non-time attribute as y coordinate in the top-left subplot.
        attrs = sorted(self.attributes)
        attrs.remove('time')
        if attrs:
            y_dim[0, 0] = attrs[0]

        # Plot all features.
        with self.building():
            for i in range(self.n_cols):
                for j in range(self.n_cols):

                    # Retrieve the x and y values for the subplot.
                    x = self._get_feature(x_dim[i, j], spike_ids, f)
                    y = self._get_feature(y_dim[i, j], spike_ids, f)

                    if data_bg is not None:
                        # Retrieve the x and y values for the background
                        # spikes.
                        x_bg = self._get_feature(x_dim[i, j], spike_ids_bg,
                                                 features_bg)
                        y_bg = self._get_feature(y_dim[i, j], spike_ids_bg,
                                                 features_bg)

                        # Background features.
                        self._plot_features(i, j, x_dim, y_dim, x_bg, y_bg,
                                            masks=masks_bg)

                    # Cluster features.
                    self._plot_features(i, j, x_dim, y_dim, x, y,
                                        masks=masks,
                                        spike_clusters_rel=sc)

                    # Add axes.
                    self[i, j].lines(pos=[[-1., 0., +1., 0.],
                                          [0., -1., 0., +1.]],
                                     color=(.25, .25, .25, .5))

            # Add the boxes.
            self.grid.add_boxes(self, self.shape)
Example #13
0
    def _plot_waveforms(self, bunchs, channel_ids):
        # Initialize the box scaling the first time.
        if self.box_scaling[1] == 1.:
            M = np.max([np.max(np.abs(b.data)) for b in bunchs])
            self.box_scaling[1] = 1. / M
            self._update_boxes()
        clu_offsets = _get_clu_offsets(bunchs)
        max_clu_offsets = max(clu_offsets) + 1
        for i, d in enumerate(bunchs):
            wave = d.data
            alpha = d.get('alpha', .5)
            channel_ids_loc = d.channel_ids

            n_channels = len(channel_ids_loc)
            masks = d.get('masks', np.ones((wave.shape[0], n_channels)))
            # By default, this is 0, 1, 2 for the first 3 clusters.
            # But it can be customized when displaying several sets
            # of waveforms per cluster.
            # i = cluster_ids.index(d.cluster_id)  # 0, 1, 2, ...

            n_spikes_clu, n_samples = wave.shape[:2]
            assert wave.shape[2] == n_channels
            assert masks.shape == (n_spikes_clu, n_channels)

            # Find the x coordinates.
            t = _get_linear_x(n_spikes_clu * n_channels, n_samples)
            if not self.overlap:

                # Determine the cluster offset.
                offset = clu_offsets[i]
                t = t + 2.5 * (offset - (max_clu_offsets - 1) / 2.)
                # The total width should not depend on the number of
                # clusters.
                t /= max_clu_offsets

            # Get the spike masks.
            m = masks
            # HACK: on the GPU, we get the actual masks with fract(masks)
            # since we add the relative cluster index. We need to ensure
            # that the masks is never 1.0, otherwise it is interpreted as
            # 0.
            m *= .99999
            # NOTE: we add the cluster index which is used for the
            # computation of the depth on the GPU.
            m += i

            color = tuple(_colormap(i)) + (alpha,)
            assert len(color) == 4

            # Generate the box index (one number per channel).
            box_index = _index_of(channel_ids_loc, channel_ids)
            box_index = np.repeat(box_index, n_samples)
            box_index = np.tile(box_index, n_spikes_clu)
            assert box_index.shape == (n_spikes_clu *
                                       n_channels *
                                       n_samples,)

            # Generate the waveform array.
            wave = np.transpose(wave, (0, 2, 1))
            wave = wave.reshape((n_spikes_clu * n_channels, n_samples))

            self.uplot(x=t,
                       y=wave,
                       color=color,
                       masks=m,
                       box_index=box_index,
                       data_bounds=None,
                       )