Example #1
0
    def _init_data(self):
        self.cache_dir = self.config_dir
        self.n_samples_waveforms = 31
        self.n_samples_t = 20000
        self.n_channels = 11
        self.n_clusters = 4
        self.n_spikes_per_cluster = 50
        n_spikes_total = self.n_clusters * self.n_spikes_per_cluster
        n_features_per_channel = 4

        self.n_channels = self.n_channels
        self.n_spikes = n_spikes_total
        self.sample_rate = 20000.
        self.duration = self.n_samples_t / float(self.sample_rate)
        self.spike_times = np.arange(0, self.duration, 100. / self.sample_rate)
        self.spike_clusters = np.repeat(np.arange(self.n_clusters),
                                        self.n_spikes_per_cluster)
        assert len(self.spike_times) == len(self.spike_clusters)
        self.cluster_ids = np.unique(self.spike_clusters)
        self.channel_positions = staggered_positions(self.n_channels)

        sc = self.spike_clusters
        self.spikes_per_cluster = lambda c: _spikes_in_clusters(sc, [c])
        self.spike_count = lambda c: len(self.spikes_per_cluster(c))
        self.n_features_per_channel = n_features_per_channel
        self.cluster_groups = {c: None for c in range(self.n_clusters)}

        self.all_traces = artificial_traces(self.n_samples_t, self.n_channels)
        self.all_masks = artificial_masks(n_spikes_total, self.n_channels)
        self.all_waveforms = artificial_waveforms(n_spikes_total,
                                                  self.n_samples_waveforms,
                                                  self.n_channels)
        self.all_features = artificial_features(n_spikes_total,
                                                self.n_channels,
                                                self.n_features_per_channel)
Example #2
0
def _extend_spikes(spike_ids, spike_clusters):
    """Return all spikes belonging to the clusters containing the specified
    spikes."""
    # We find the spikes belonging to modified clusters.
    # What are the old clusters that are modified by the assignment?
    old_spike_clusters = spike_clusters[spike_ids]
    unique_clusters = _unique(old_spike_clusters)
    # Now we take all spikes from these clusters.
    changed_spike_ids = _spikes_in_clusters(spike_clusters, unique_clusters)
    # These are the new spikes that need to be reassigned.
    extended_spike_ids = np.setdiff1d(changed_spike_ids, spike_ids,
                                      assume_unique=True)
    return extended_spike_ids
Example #3
0
def _extend_spikes(spike_ids, spike_clusters):
    """Return all spikes belonging to the clusters containing the specified
    spikes."""
    # We find the spikes belonging to modified clusters.
    # What are the old clusters that are modified by the assignment?
    old_spike_clusters = spike_clusters[spike_ids]
    unique_clusters = _unique(old_spike_clusters)
    # Now we take all spikes from these clusters.
    changed_spike_ids = _spikes_in_clusters(spike_clusters, unique_clusters)
    # These are the new spikes that need to be reassigned.
    extended_spike_ids = np.setdiff1d(changed_spike_ids,
                                      spike_ids,
                                      assume_unique=True)
    return extended_spike_ids
Example #4
0
def test_manual_clustering_split_2(gui, quality, similarity):
    spike_clusters = np.array([0, 0, 1])

    mc = ManualClustering(spike_clusters,
                          lambda c: _spikes_in_clusters(spike_clusters, [c]),
                          similarity=similarity,
                          )
    mc.attach(gui)

    mc.add_column(quality, name='quality', default=True)
    mc.set_default_sort('quality', 'desc')

    mc.split([0])
    assert mc.selected == [3, 2]
Example #5
0
    def merge(self, cluster_ids, to=None):
        """Merge several clusters to a new cluster.

        Parameters
        ----------

        cluster_ids : array-like
            List of clusters to merge.
        to : integer or None
            The id of the new cluster. By default, this is `new_cluster_id()`.

        Returns
        -------

        up : UpdateInfo instance

        """

        if not _is_array_like(cluster_ids):
            raise ValueError("The first argument should be a list or "
                             "an array.")

        cluster_ids = sorted(cluster_ids)
        if not set(cluster_ids) <= set(self.cluster_ids):
            raise ValueError("Some clusters do not exist.")

        # Find the new cluster number.
        if to is None:
            to = self.new_cluster_id()
        if to < self.new_cluster_id():
            raise ValueError("The new cluster numbers should be higher than "
                             "{0}.".format(self.new_cluster_id()))

        # NOTE: we could have called self.assign() here, but we don't.
        # We circumvent self.assign() for performance reasons.
        # assign() is a relatively costly operation, whereas merging is a much
        # cheaper operation.

        # Find all spikes in the specified clusters.
        spike_ids = _spikes_in_clusters(self.spike_clusters, cluster_ids)

        up = self._do_merge(spike_ids, cluster_ids, to)
        undo_state = self.emit('request_undo_state', up)

        # Add to stack.
        self._undo_stack.add((spike_ids, [to], undo_state))

        self.emit('cluster', up)
        return up
Example #6
0
    def merge(self, cluster_ids, to=None):
        """Merge several clusters to a new cluster.

        Parameters
        ----------

        cluster_ids : array-like
            List of clusters to merge.
        to : integer or None
            The id of the new cluster. By default, this is `new_cluster_id()`.

        Returns
        -------

        up : UpdateInfo instance

        """

        if not _is_array_like(cluster_ids):
            raise ValueError("The first argument should be a list or "
                             "an array.")

        cluster_ids = sorted(cluster_ids)
        if not set(cluster_ids) <= set(self.cluster_ids):
            raise ValueError("Some clusters do not exist.")

        # Find the new cluster number.
        if to is None:
            to = self.new_cluster_id()
        if to < self.new_cluster_id():
            raise ValueError("The new cluster numbers should be higher than "
                             "{0}.".format(self.new_cluster_id()))

        # NOTE: we could have called self.assign() here, but we don't.
        # We circumvent self.assign() for performance reasons.
        # assign() is a relatively costly operation, whereas merging is a much
        # cheaper operation.

        # Find all spikes in the specified clusters.
        spike_ids = _spikes_in_clusters(self.spike_clusters, cluster_ids)

        up = self._do_merge(spike_ids, cluster_ids, to)
        undo_state = self.emit('request_undo_state', up)

        # Add to stack.
        self._undo_stack.add((spike_ids, [to], undo_state))

        self.emit('cluster', up)
        return up
Example #7
0
def test_manual_clustering_split_2(gui, quality, similarity):
    spike_clusters = np.array([0, 0, 1])

    mc = ManualClustering(
        spike_clusters,
        lambda c: _spikes_in_clusters(spike_clusters, [c]),
        similarity=similarity,
    )
    mc.attach(gui)

    mc.add_column(quality, name='quality', default=True)
    mc.set_default_sort('quality', 'desc')

    mc.split([0])
    assert mc.selected == [3, 2]
Example #8
0
    def _init_data(self):
        self.cache_dir = self.config_dir
        self.n_samples_waveforms = 31
        self.n_samples_t = 20000
        self.n_channels = 11
        self.n_clusters = 4
        self.n_spikes_per_cluster = 200
        n_spikes_total = self.n_clusters * self.n_spikes_per_cluster
        n_features_per_channel = 4

        self.n_channels = self.n_channels
        self.n_spikes = n_spikes_total
        self.sample_rate = 20000.
        self.duration = self.n_samples_t / float(self.sample_rate)
        self.spike_times = np.arange(
            0, self.duration,
            5000. / (self.sample_rate * self.n_spikes_per_cluster))
        self.spike_clusters = np.repeat(np.arange(self.n_clusters),
                                        self.n_spikes_per_cluster)
        assert len(self.spike_times) == len(self.spike_clusters)
        self.cluster_ids = np.unique(self.spike_clusters)
        self.channel_positions = staggered_positions(self.n_channels)
        self.channel_order = np.arange(self.n_channels)

        sc = self.spike_clusters
        self.spikes_per_cluster = lambda c: _spikes_in_clusters(sc, [c])
        self.spike_count = lambda c: len(self.spikes_per_cluster(c))
        self.n_features_per_channel = n_features_per_channel
        self.cluster_groups = {c: None for c in range(self.n_clusters)}

        self.all_traces = artificial_traces(self.n_samples_t, self.n_channels)
        self.all_masks = artificial_masks(n_spikes_total, self.n_channels)
        self.all_waveforms = artificial_waveforms(n_spikes_total,
                                                  self.n_samples_waveforms,
                                                  self.n_channels)
        self.all_features = artificial_features(n_spikes_total,
                                                self.n_channels,
                                                self.n_features_per_channel)
Example #9
0
 def spikes_in_template(self, template_id):
     return _spikes_in_clusters(self.spike_templates, [template_id])
 def _assert_spikes(clusters):
     ae(info.spike_ids, _spikes_in_clusters(spike_clusters, clusters))
Example #11
0
 def spikes_in_clusters(self, clusters):
     """Return the array of spike ids belonging to a list of clusters."""
     return _spikes_in_clusters(self.spike_clusters, clusters)
Example #12
0
 def spikes_in_clusters(self, clusters):
     """Return the array of spike ids belonging to a list of clusters."""
     return _spikes_in_clusters(self.spike_clusters, clusters)