Exemplo n.º 1
0
    def get_features(self, cluster_id, load_all=False):
        # Overriden to take into account the sparse structure.
        # Only keep spikes belonging to the features spike ids.
        if self.features_spike_ids is not None:
            # All spikes
            spike_ids = self._select_spikes(cluster_id)
            spike_ids = np.intersect1d(spike_ids, self.features_spike_ids)
            # Relative indices of the spikes in the self.features_spike_ids
            # array, necessary to load features from all_features which only
            # contains the subset of the spikes.
            spike_ids_rel = _index_of(spike_ids, self.features_spike_ids)
        else:
            spike_ids = self._select_spikes(
                cluster_id, self.n_spikes_features if not load_all else None)
            spike_ids_rel = spike_ids
        st = self.spike_templates[spike_ids]
        nc = self.n_channels
        nfpc = self.n_features_per_channel
        ns = len(spike_ids)
        f = _densify(spike_ids_rel, self.all_features,
                     self.features_ind[st, :], self.n_channels)
        f = np.transpose(f, (0, 2, 1))
        assert f.shape == (ns, nc, nfpc)
        b = Bunch()

        # Normalize features.
        m = self.get_feature_lim()
        f = _normalize(f, -m, m)

        b.data = f
        b.spike_ids = spike_ids
        b.spike_clusters = self.spike_clusters[spike_ids]
        b.masks = self.all_masks[spike_ids]
        return b
Exemplo n.º 2
0
    def get_features(self, cluster_id, load_all=False):
        # Overriden to take into account the sparse structure.
        # Only keep spikes belonging to the features spike ids.
        if self.features_spike_ids is not None:
            # All spikes
            spike_ids = self._select_spikes(cluster_id)
            spike_ids = np.intersect1d(spike_ids, self.features_spike_ids)
            # Relative indices of the spikes in the self.features_spike_ids
            # array, necessary to load features from all_features which only
            # contains the subset of the spikes.
            spike_ids_rel = _index_of(spike_ids, self.features_spike_ids)
        else:
            spike_ids = self._select_spikes(cluster_id,
                                            self.n_spikes_features
                                            if not load_all else None)
            spike_ids_rel = spike_ids
        st = self.spike_templates[spike_ids]
        nc = self.n_channels
        nfpc = self.n_features_per_channel
        ns = len(spike_ids)
        f = _densify(spike_ids_rel, self.all_features,
                     self.features_ind[st, :], self.n_channels)
        f = np.transpose(f, (0, 2, 1))
        assert f.shape == (ns, nc, nfpc)
        b = Bunch()

        # Normalize features.
        m = self.get_feature_lim()
        f = _normalize(f, -m, m)

        b.data = f
        b.spike_ids = spike_ids
        b.spike_clusters = self.spike_clusters[spike_ids]
        b.masks = self.all_masks[spike_ids]
        return b
Exemplo n.º 3
0
 def _select_data(self, cluster_id, arr, n_max=None):
     spike_ids = self._select_spikes(cluster_id, n_max)
     b = Bunch()
     b.data = arr[spike_ids]
     b.spike_ids = spike_ids
     b.spike_clusters = self.spike_clusters[spike_ids]
     b.masks = self.all_masks[spike_ids]
     return b
Exemplo n.º 4
0
 def _select_data(self, cluster_id, arr, n_max=None):
     spike_ids = self._select_spikes(cluster_id, n_max)
     b = Bunch()
     b.data = arr[spike_ids]
     b.spike_ids = spike_ids
     b.spike_clusters = self.spike_clusters[spike_ids]
     b.masks = self.all_masks[spike_ids]
     return b
Exemplo n.º 5
0
 def get_background_features(self):
     k = max(1, int(self.n_spikes // self.n_spikes_background_features))
     spike_ids = slice(None, None, k)
     b = Bunch()
     b.data = self.all_features[spike_ids]
     b.spike_ids = spike_ids
     b.spike_clusters = self.spike_clusters[spike_ids]
     b.masks = self.all_masks[spike_ids]
     return b
Exemplo n.º 6
0
 def get_background_features(self):
     k = max(1, int(self.n_spikes // self.n_spikes_background_features))
     spike_ids = slice(None, None, k)
     b = Bunch()
     b.data = self.all_features[spike_ids]
     b.spike_ids = spike_ids
     b.spike_clusters = self.spike_clusters[spike_ids]
     b.masks = self.all_masks[spike_ids]
     return b
Exemplo n.º 7
0
 def get_background_features(self):
     k = max(1, int(self.n_spikes // self.n_spikes_background_features))
     spike_ids = slice(None, None, k)
     b = Bunch()
     b.data = self.all_features[spike_ids]
     m = self.get_feature_lim()
     b.data = _normalize(b.data.copy(), -m, +m)
     b.spike_ids = spike_ids
     b.spike_clusters = self.spike_clusters[spike_ids]
     b.masks = self.all_masks[spike_ids]
     return b
Exemplo n.º 8
0
def extract_spikes(traces,
                   interval,
                   sample_rate=None,
                   spike_times=None,
                   spike_clusters=None,
                   cluster_groups=None,
                   all_masks=None,
                   n_samples_waveforms=None):
    cluster_groups = cluster_groups or {}
    sr = sample_rate
    ns = n_samples_waveforms
    if not isinstance(ns, tuple):
        ns = (ns // 2, ns // 2)
    offset_samples = ns[0]
    wave_len = ns[0] + ns[1]

    # Find spikes.
    a, b = spike_times.searchsorted(interval)
    st = spike_times[a:b]
    sc = spike_clusters[a:b]
    m = all_masks[a:b]
    n = len(st)
    assert len(sc) == n
    assert m.shape[0] == n

    # Extract waveforms.
    spikes = []
    for i in range(n):
        b = Bunch()
        # Find the start of the waveform in the extracted traces.
        sample_start = int(round((st[i] - interval[0]) * sr))
        sample_start -= offset_samples
        o = _extract_wave(traces, sample_start, m[i], wave_len)
        if o is None:  # pragma: no cover
            logger.debug("Unable to extract spike %d.", i)
            continue
        b.waveforms, b.channels = o
        # Masks on unmasked channels.
        b.masks = m[i, b.channels]
        b.spike_time = st[i]
        b.spike_cluster = sc[i]
        b.cluster_group = cluster_groups.get(b.spike_cluster, None)
        b.offset_samples = offset_samples

        spikes.append(b)
    return spikes
Exemplo n.º 9
0
def extract_spikes(traces, interval, sample_rate=None,
                   spike_times=None, spike_clusters=None,
                   all_masks=None,
                   n_samples_waveforms=None):
    sr = sample_rate
    ns = n_samples_waveforms
    if not isinstance(ns, tuple):
        ns = (ns // 2, ns // 2)
    offset_samples = ns[0]
    wave_len = ns[0] + ns[1]

    # Find spikes.
    a, b = spike_times.searchsorted(interval)
    st = spike_times[a:b]
    sc = spike_clusters[a:b]
    m = all_masks[a:b]
    n = len(st)
    assert len(sc) == n
    assert m.shape[0] == n

    # Extract waveforms.
    spikes = []
    for i in range(n):
        b = Bunch()
        # Find the start of the waveform in the extracted traces.
        sample_start = int(round((st[i] - interval[0]) * sr))
        sample_start -= offset_samples
        o = _extract_wave(traces, sample_start, m[i], wave_len)
        if o is None:  # pragma: no cover
            logger.debug("Unable to extract spike %d.", i)
            continue
        b.waveforms, b.channels = o
        # Masks on unmasked channels.
        b.masks = m[i, b.channels]
        b.spike_time = st[i]
        b.spike_cluster = sc[i]
        b.offset_samples = offset_samples

        spikes.append(b)
    return spikes