Exemple #1
0
 def get_clusters_data(self, load_all=None):
     """Return a list of Bunch instances, with attributes pos and spike_ids."""
     if not len(self.cluster_ids):
         return
     cluster_ids = list(self.cluster_ids)
     # Don't need the background when splitting.
     if not load_all:
         # Add None cluster which means background spikes.
         cluster_ids = [None] + cluster_ids
     bunchs = self.amplitudes[self.amplitude_name](cluster_ids, load_all=load_all) or ()
     # Add a pos attribute in bunchs in addition to x and y.
     for i, (cluster_id, bunch) in enumerate(zip(cluster_ids, bunchs)):
         spike_ids = _as_array(bunch.spike_ids)
         spike_times = _as_array(bunch.spike_times)
         amplitudes = _as_array(bunch.amplitudes)
         assert spike_ids.shape == spike_times.shape == amplitudes.shape
         # Ensure that bunch.pos exists, as it used by the LassoMixin.
         bunch.pos = np.c_[spike_times, amplitudes]
         assert bunch.pos.ndim == 2
         bunch.cluster_id = cluster_id
         bunch.color = (
             selected_cluster_color(i - 1, self.marker_alpha)
             # Background amplitude color.
             if cluster_id is not None else (.5, .5, .5, .5))
     return bunchs
Exemple #2
0
def _extend_assignment(spike_ids, old_spike_clusters, spike_clusters_rel,
                       new_cluster_id):
    # 1. Add spikes that belong to modified clusters.
    # 2. Find new cluster ids for all changed clusters.

    old_spike_clusters = _as_array(old_spike_clusters)
    spike_ids = _as_array(spike_ids)

    assert isinstance(spike_clusters_rel, (list, np.ndarray))
    spike_clusters_rel = _as_array(spike_clusters_rel)
    assert spike_clusters_rel.min() >= 0

    # We renumber the new cluster indices.
    new_spike_clusters = (spike_clusters_rel +
                          (new_cluster_id - spike_clusters_rel.min()))

    # We find the spikes belonging to modified clusters.
    extended_spike_ids = _extend_spikes(spike_ids, old_spike_clusters)
    if len(extended_spike_ids) == 0:
        return spike_ids, new_spike_clusters

    # We take their clusters.
    extended_spike_clusters = old_spike_clusters[extended_spike_ids]
    # Use relative numbers in extended_spike_clusters.
    _, extended_spike_clusters = np.unique(extended_spike_clusters,
                                           return_inverse=True)
    # Generate new cluster numbers.
    k = new_spike_clusters.max() + 1
    extended_spike_clusters += (k - extended_spike_clusters.min())

    # Finally, we concatenate spike_ids and extended_spike_ids.
    return _concatenate_spike_clusters(
        (spike_ids, new_spike_clusters),
        (extended_spike_ids, extended_spike_clusters))
Exemple #3
0
def _concatenate_spike_clusters(*pairs):
    """Concatenate a list of pairs (spike_ids, spike_clusters)."""
    pairs = [(_as_array(x), _as_array(y)) for (x, y) in pairs]
    concat = np.vstack([np.hstack((x[:, None], y[:, None])) for x, y in pairs])
    reorder = np.argsort(concat[:, 0])
    concat = concat[reorder, :]
    return concat[:, 0].astype(np.int64), concat[:, 1].astype(np.int64)
Exemple #4
0
def select_spikes_from_chunked(spike_times,
                               chunk_bounds,
                               max_n_spikes,
                               skip_chunks=0):
    """Select a maximum number of spikes among the specified ones so as to minimize the
    number of chunks that contain those spikes."""
    if len(spike_times) <= max_n_spikes:
        return spike_times
    spike_times = _as_array(spike_times)
    chunk_bounds = _as_array(chunk_bounds)
    spike_chunks = np.searchsorted(chunk_bounds, spike_times, side='right') - 1
    chunk_sizes = np.bincount(spike_chunks)
    best_chunks = np.argsort(chunk_sizes)[::-1][skip_chunks:]
    keep = np.zeros(len(spike_times), dtype=np.bool)
    total = 0
    for chunk_idx in best_chunks:
        in_chunks = np.isin(spike_chunks, chunk_idx)
        n_spikes_chunk = np.sum(in_chunks)
        if total + n_spikes_chunk > max_n_spikes:
            # Truncate to get the exact number of requested spikes.
            last = np.nonzero(
                np.cumsum(in_chunks) <= max_n_spikes - total)[0][-1]
            in_chunks[last + 1:] = False
        keep[in_chunks] = True
        total += n_spikes_chunk
        if total >= max_n_spikes:
            break
    return spike_times[keep]
Exemple #5
0
def _increment(arr, indices):
    """Increment some indices in a 1D vector of non-negative integers.
    Repeated indices are taken into account."""
    arr = _as_array(arr)
    indices = _as_array(indices)
    bbins = np.bincount(indices)
    arr[:len(bbins)] += bbins
    return arr
Exemple #6
0
def _in_polygon(points, polygon):
    """Return the points that are inside a polygon."""
    from matplotlib.path import Path
    points = _as_array(points)
    polygon = _as_array(polygon)
    assert points.ndim == 2
    assert polygon.ndim == 2
    if len(polygon):
        polygon = np.vstack((polygon, polygon[0]))
    path = Path(polygon, closed=True)
    return path.contains_points(points)
Exemple #7
0
def apply_filter(x, filter=None, axis=0):
    """Apply a filter to an array."""
    x = _as_array(x)
    if x.shape[axis] == 0:
        return x
    b, a = filter
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        return signal.filtfilt(b, a, x, axis=axis)
Exemple #8
0
 def _check_positions(self, positions):
     if positions is None:
         return
     positions = _as_array(positions)
     if positions.shape[0] != self.n_channels:
         raise ValueError("'positions' "
                          "(shape {0:s})".format(str(positions.shape)) +
                          " and 'n_channels' "
                          "({0:d})".format(self.n_channels) +
                          " do not match.")
Exemple #9
0
def test_as_array():
    ae(_as_array(3), [3])
    ae(_as_array([3]), [3])
    ae(_as_array(3.), [3.])
    ae(_as_array([3.]), [3.])

    with raises(ValueError):
        _as_array(map)
Exemple #10
0
def _unique(x):
    """Faster version of np.unique().

    This version is restricted to 1D arrays of non-negative integers.

    It is only faster if len(x) >> len(unique(x)).

    """
    if x is None or len(x) == 0:
        return np.array([], dtype=np.int64)
    # WARNING: only keep positive values.
    # cluster=-1 means "unclustered".
    x = _as_array(x)
    x = x[x >= 0]
    bc = np.bincount(x)
    return np.nonzero(bc)[0]
Exemple #11
0
def firing_rate(spike_clusters,
                cluster_ids=None,
                bin_size=None,
                duration=None):
    """Compute the average number of spikes per cluster per bin."""

    # Take the cluster order into account.
    if cluster_ids is None:
        clusters = _unique(spike_clusters)
    else:
        clusters = _as_array(cluster_ids)

    # Like spike_clusters, but with 0..n_clusters-1 indices.
    spike_clusters_i = _index_of(spike_clusters, clusters)

    assert duration > 0
    assert bin_size > 0
    bc = np.bincount(spike_clusters_i)
    return bc * np.c_[bc] * (bin_size / duration)
Exemple #12
0
 def __init__(self,
              spike_clusters,
              new_cluster_id=None,
              spikes_per_cluster=None):
     super(Clustering, self).__init__()
     self._undo_stack = History(base_item=(None, None, None))
     # Spike -> cluster mapping.
     self._spike_clusters = _as_array(spike_clusters)
     self._spikes_per_cluster = {}
     self._n_spikes = len(self._spike_clusters)
     self._spike_ids = np.arange(self._n_spikes).astype(np.int64)
     # We can pass the precomputed spikes_per_cluster dictionary for
     # performance reasons.
     self._update_cluster_ids(to_add=spikes_per_cluster)
     self._new_cluster_id_0 = int(new_cluster_id
                                  or self._spike_clusters.max() + 1)
     self._new_cluster_id = self._new_cluster_id_0
     assert self._new_cluster_id >= 0
     assert np.all(self._spike_clusters < self._new_cluster_id)
     # Keep a copy of the original spike clusters assignment.
     self._spike_clusters_base = self._spike_clusters.copy()
Exemple #13
0
    def _do_assign(self, spike_ids, new_spike_clusters):
        """Make spike-cluster assignments after the spike selection has
        been extended to full clusters."""

        # Ensure spike_clusters has the right shape.
        spike_ids = _as_array(spike_ids)
        if len(new_spike_clusters) == 1 and len(spike_ids) > 1:
            new_spike_clusters = np.ones(
                len(spike_ids), dtype=np.int64) * new_spike_clusters[0]
        old_spike_clusters = self._spike_clusters[spike_ids]

        assert len(spike_ids) == len(old_spike_clusters)
        assert len(new_spike_clusters) == len(spike_ids)

        # Update the spikes per cluster structure.
        old_clusters = _unique(old_spike_clusters)

        # NOTE: shortcut to a merge if this assignment is effectively a merge
        # i.e. if all spikes are assigned to a single cluster.
        # The fact that spike selection has been previously extended to
        # whole clusters is critical here.
        new_clusters = _unique(new_spike_clusters)
        if len(new_clusters) == 1:
            return self._do_merge(spike_ids, old_clusters, new_clusters[0])

        # We return the UpdateInfo structure.
        up = _assign_update_info(spike_ids, old_spike_clusters,
                                 new_spike_clusters)

        # We update the new cluster id (strictly increasing during a session).
        self._new_cluster_id = max(self._new_cluster_id, max(up.added) + 1)

        # We make the assignments.
        self._spike_clusters[spike_ids] = new_spike_clusters
        # OPTIM: we update spikes_per_cluster manually.
        new_spc = _spikes_per_cluster(new_spike_clusters, spike_ids)
        self._update_cluster_ids(to_remove=old_clusters, to_add=new_spc)
        up.all_cluster_ids = list(self.cluster_ids)
        return up
Exemple #14
0
def firing_rate(spike_clusters,
                cluster_ids=None,
                bin_size=None,
                duration=None):
    """Compute the average number of spikes per cluster per bin."""

    # Take the cluster order into account.
    if cluster_ids is None:
        cluster_ids = _unique(spike_clusters)
    else:
        cluster_ids = _as_array(cluster_ids)

    # Like spike_clusters, but with 0..n_clusters-1 indices.
    spike_clusters_i = _index_of(spike_clusters, cluster_ids)

    assert bin_size > 0
    bc = np.bincount(spike_clusters_i)
    # Handle the case where the last cluster(s) are empty.
    if len(bc) < len(cluster_ids):
        n = len(cluster_ids) - len(bc)
        bc = np.concatenate((bc, np.zeros(n, dtype=bc.dtype)))
    assert bc.shape == (len(cluster_ids), )
    return bc * np.c_[bc] * (bin_size / (duration or 1.))
Exemple #15
0
 def _normalize(self, x):
     x = _as_array(x)
     tw = self._thresholds['weak']
     ts = self._thresholds['strong']
     return np.clip((x - tw) / (ts - tw), 0, 1)
Exemple #16
0
    def get(self, spike_ids, channels=None):
        """Load the waveforms of the specified spikes."""
        if isinstance(spike_ids, slice):
            spike_ids = _range_from_slice(spike_ids,
                                          start=0,
                                          stop=self.n_spikes)
        if not hasattr(spike_ids, '__len__'):
            spike_ids = [spike_ids]
        if channels is None:
            channels = slice(None, None, None)
            nc = self.n_channels
        else:
            channels = np.asarray(channels, dtype=np.int32)
            assert np.all(channels < self.n_channels)
            nc = len(channels)

        # Ensure a list of time samples are being requested.
        spike_ids = _as_array(spike_ids)
        n_spikes = len(spike_ids)

        # Initialize the array.
        # NOTE: last dimension is time to simplify things.
        shape = (n_spikes, nc, self._n_samples_extract)
        waveforms = np.zeros(shape, dtype=np.float32)

        # No traces: return null arrays.
        if self.n_samples_trace == 0:
            return np.transpose(waveforms, (0, 2, 1))

        # Load all spikes.
        for i, spike_id in enumerate(spike_ids):
            assert 0 <= spike_id < self.n_spikes
            time = self._spike_samples[spike_id]

            # Extract the waveforms on the unmasked channels.
            try:
                w = self._load_at(time, channels)
            except ValueError as e:  # pragma: no cover
                logger.warning("Error while loading waveform: %s", str(e))
                continue

            assert w.shape == (self._n_samples_extract, nc)

            waveforms[i, :, :] = w.T

        # Filter the waveforms.
        waveforms_f = waveforms.reshape((-1, self._n_samples_extract))
        # Only filter the non-zero waveforms.
        unmasked = waveforms_f.max(axis=1) != 0
        waveforms_f[unmasked] = self._filter(waveforms_f[unmasked], axis=1)
        waveforms_f = waveforms_f.reshape(
            (n_spikes, nc, self._n_samples_extract))

        # Remove the margin.
        margin_before, margin_after = self._filter_margin
        if margin_after > 0:
            assert margin_before >= 0
            waveforms_f = waveforms_f[:, :, margin_before:-margin_after]

        assert waveforms_f.shape == (n_spikes, nc, self.n_samples_waveforms)

        # NOTE: we transpose before returning the array.
        return np.transpose(waveforms_f, (0, 2, 1))
Exemple #17
0
def _diff_shifted(arr, steps=1):
    arr = _as_array(arr)
    return arr[steps:] - arr[:len(arr) - steps]
Exemple #18
0
def correlograms(
    spike_times,
    spike_clusters,
    cluster_ids=None,
    sample_rate=1.,
    bin_size=None,
    window_size=None,
    symmetrize=True,
):
    """Compute all pairwise cross-correlograms among the clusters appearing
    in `spike_clusters`.

    Parameters
    ----------

    spike_times : array-like
        Spike times in seconds.
    spike_clusters : array-like
        Spike-cluster mapping.
    cluster_ids : array-like
        The list of *all* unique clusters, in any order. That order will be used
        in the output array.
    bin_size : float
        Size of the bin, in seconds.
    window_size : float
        Size of the window, in seconds.

    Returns
    -------

    correlograms : array
        A `(n_clusters, n_clusters, winsize_samples)` array with all pairwise
        CCGs.

    """
    assert sample_rate > 0.
    assert np.all(np.diff(spike_times) >= 0), ("The spike times must be "
                                               "increasing.")

    # Get the spike samples.
    spike_times = np.asarray(spike_times, dtype=np.float64)
    spike_samples = (spike_times * sample_rate).astype(np.int64)

    spike_clusters = _as_array(spike_clusters)

    assert spike_samples.ndim == 1
    assert spike_samples.shape == spike_clusters.shape

    # Find `binsize`.
    bin_size = np.clip(bin_size, 1e-5, 1e5)  # in seconds
    binsize = int(sample_rate * bin_size)  # in samples
    assert binsize >= 1

    # Find `winsize_bins`.
    window_size = np.clip(window_size, 1e-5, 1e5)  # in seconds
    winsize_bins = 2 * int(.5 * window_size / bin_size) + 1

    assert winsize_bins >= 1
    assert winsize_bins % 2 == 1

    # Take the cluster order into account.
    if cluster_ids is None:
        clusters = _unique(spike_clusters)
    else:
        clusters = _as_array(cluster_ids)
    n_clusters = len(clusters)

    # Like spike_clusters, but with 0..n_clusters-1 indices.
    spike_clusters_i = _index_of(spike_clusters, clusters)

    # Shift between the two copies of the spike trains.
    shift = 1

    # At a given shift, the mask precises which spikes have matching spikes
    # within the correlogram time window.
    mask = np.ones_like(spike_samples, dtype=np.bool)

    correlograms = _create_correlograms_array(n_clusters, winsize_bins)

    # The loop continues as long as there is at least one spike with
    # a matching spike.
    while mask[:-shift].any():
        # Number of time samples between spike i and spike i+shift.
        spike_diff = _diff_shifted(spike_samples, shift)

        # Binarize the delays between spike i and spike i+shift.
        spike_diff_b = spike_diff // binsize

        # Spikes with no matching spikes are masked.
        mask[:-shift][spike_diff_b > (winsize_bins // 2)] = False

        # Cache the masked spike delays.
        m = mask[:-shift].copy()
        d = spike_diff_b[m]

        # # Update the masks given the clusters to update.
        # m0 = np.in1d(spike_clusters[:-shift], clusters)
        # m = m & m0
        # d = spike_diff_b[m]
        d = spike_diff_b[m]

        # Find the indices in the raveled correlograms array that need
        # to be incremented, taking into account the spike clusters.
        indices = np.ravel_multi_index(
            (spike_clusters_i[:-shift][m], spike_clusters_i[+shift:][m], d),
            correlograms.shape)

        # Increment the matching spikes in the correlograms array.
        _increment(correlograms.ravel(), indices)

        shift += 1

    # Remove ACG peaks.
    correlograms[np.arange(n_clusters), np.arange(n_clusters), 0] = 0

    if symmetrize:
        return _symmetrize_correlograms(correlograms)
    else:
        return correlograms
Exemple #19
0
 def _zoom_aspect(self, zoom=None):
     zoom = zoom if zoom is not None else self._zoom
     zoom = _as_array(zoom)
     aspect = (self._canvas_aspect * self._aspect if self._aspect is not None else 1.)
     return zoom * aspect
Exemple #20
0
    def assign(self, spike_ids, spike_clusters_rel=0):
        """Make new spike cluster assignments.

        Parameters
        ----------

        spike_ids : array-like
            List of spike ids.
        spike_clusters_rel : array-like
            Relative cluster ids of the spikes in `spike_ids`. This
            must have the same size as `spike_ids`.

        Returns
        -------

        up : UpdateInfo instance

        Note
        ----

        `spike_clusters_rel` contain *relative* cluster indices. Their values
        don't matter: what matters is whether two give spikes
        should end up in the same cluster or not. Adding a constant number
        to all elements in `spike_clusters_rel` results in exactly the same
        operation.

        The final cluster ids are automatically generated by the `Clustering`
        class. This is because we must ensure that all modified clusters
        get brand new ids. The whole library is based on the assumption that
        cluster ids are unique and "disposable". Changing a cluster always
        results in a new cluster id being assigned.

        If a spike is assigned to a new cluster, then all other spikes
        belonging to the same cluster are assigned to a brand new cluster,
        even if they were not changed explicitely by the `assign()` method.

        In other words, the list of spikes affected by an `assign()` is almost
        always a strict superset of the `spike_ids` parameter. The only case
        where this is not true is when whole clusters change: this is called
        a merge. It is implemented in a separate `merge()` method because it
        is logically much simpler, and faster to execute.

        """

        assert not isinstance(spike_ids, slice)

        # Ensure `spike_clusters_rel` is an array-like.
        if not hasattr(spike_clusters_rel, '__len__'):
            spike_clusters_rel = spike_clusters_rel * np.ones(len(spike_ids),
                                                              dtype=np.int64)

        spike_ids = _as_array(spike_ids)
        if len(spike_ids) == 0:
            return UpdateInfo()
        assert len(spike_ids) == len(spike_clusters_rel)
        assert spike_ids.min() >= 0
        assert spike_ids.max() < self._n_spikes, "Some spikes don't exist."

        # Normalize the spike-cluster assignment such that
        # there are only new or dead clusters, not modified clusters.
        # This implies that spikes not explicitly selected, but that
        # belong to clusters affected by the operation, will be assigned
        # to brand new clusters.
        spike_ids, cluster_ids = _extend_assignment(spike_ids,
                                                    self._spike_clusters,
                                                    spike_clusters_rel,
                                                    self.new_cluster_id())

        up = self._do_assign(spike_ids, cluster_ids)
        undo_state = emit('request_undo_state', self, up)

        # Add the assignment to the undo stack.
        self._undo_stack.add((spike_ids, cluster_ids, undo_state))

        emit('cluster', self, up)
        return up