Пример #1
0
    def get_traces(self, interval):
        """Load traces and spikes in an interval."""
        tr = select_traces(
            self.all_traces,
            interval,
            sample_rate=self.sample_rate,
        )
        tr = tr - np.mean(tr, axis=0)

        a, b = self.spike_times.searchsorted(interval)
        sc = self.spike_templates[a:b]

        # Remove templates.
        tr_sub = subtract_templates(
            tr,
            start=interval[0],
            spike_times=self.spike_times[a:b],
            spike_clusters=sc,
            amplitudes=self.all_amplitudes[a:b],
            spike_templates=self.templates_unw[sc],
            sample_rate=self.sample_rate,
        )

        return [
            Bunch(traces=tr),
            Bunch(traces=tr_sub, color=(.5, .5, .5, .75)),
        ]
Пример #2
0
def test_gui_state_view(tempdir):
    view = Bunch(name='MyView0')
    state = GUIState(config_dir=tempdir)
    state.update_view_state(view, dict(hello='world'))
    assert not state.get_view_state(Bunch(name='MyView'))
    assert not state.get_view_state(Bunch(name='MyView1'))
    assert state.get_view_state(view) == Bunch(hello='world')
Пример #3
0
    def get_features(self, cluster_id, load_all=False):
        # Overriden to take into account the sparse structure.
        # Only keep spikes belonging to the features spike ids.
        if self.features_spike_ids is not None:
            # All spikes
            spike_ids = self._select_spikes(cluster_id)
            spike_ids = np.intersect1d(spike_ids, self.features_spike_ids)
            # Relative indices of the spikes in the self.features_spike_ids
            # array, necessary to load features from all_features which only
            # contains the subset of the spikes.
            spike_ids_rel = _index_of(spike_ids, self.features_spike_ids)
        else:
            spike_ids = self._select_spikes(cluster_id,
                                            self.n_spikes_features
                                            if not load_all else None)
            spike_ids_rel = spike_ids
        st = self.spike_templates[spike_ids]
        nc = self.n_channels
        nfpc = self.n_features_per_channel
        ns = len(spike_ids)
        f = _densify(spike_ids_rel, self.all_features,
                     self.features_ind[st, :], self.n_channels)
        f = np.transpose(f, (0, 2, 1))
        assert f.shape == (ns, nc, nfpc)
        b = Bunch()

        # Normalize features.
        m = self.get_feature_lim()
        f = _normalize(f, -m, m)

        b.data = f
        b.spike_ids = spike_ids
        b.spike_clusters = self.spike_clusters[spike_ids]
        b.masks = self.all_masks[spike_ids]
        return b
Пример #4
0
    def _get_traces(self, interval):
        """Get traces and spike waveforms."""
        k = self.model.n_samples_templates
        m = self.model

        traces_interval = select_traces(m.traces, interval,
                                        sample_rate=m.sample_rate)
        # Reorder vertically.
        out = Bunch(data=traces_interval)
        out.waveforms = []

        def gbc(cluster_id):
            return self.get_best_channels(cluster_id)

        for b in _iter_spike_waveforms(interval=interval,
                                       traces_interval=traces_interval,
                                       model=self.model,
                                       supervisor=self.supervisor,
                                       color_selector=self.color_selector,
                                       n_samples_waveforms=k,
                                       get_best_channels=gbc,
                                       show_all_spikes=self._show_all_spikes,
                                       ):
            i = b.spike_id
            # Compute the residual: waveform - amplitude * template.
            residual = b.copy()
            template_id = m.spike_templates[i]
            template = m.get_template(template_id).template
            amplitude = m.amplitudes[i]
            residual.data = residual.data - amplitude * template
            out.waveforms.extend([b, residual])
        return out
Пример #5
0
    def _get_traces(self, interval):
        """Get traces and spike waveforms."""
        ns = self.model.n_samples_waveforms
        m = self.model
        c = self.channel_vertical_order

        traces_interval = select_traces(m.traces,
                                        interval,
                                        sample_rate=m.sample_rate)
        # Reorder vertically.
        traces_interval = traces_interval[:, c]

        def gbc(cluster_id):
            ch = self.get_best_channels(cluster_id)
            return ch

        out = Bunch(data=traces_interval)
        out.waveforms = []
        for b in _iter_spike_waveforms(
                interval=interval,
                traces_interval=traces_interval,
                model=self.model,
                supervisor=self.supervisor,
                color_selector=self.color_selector,
                n_samples_waveforms=ns,
                get_best_channels=gbc,
                show_all_spikes=self._show_all_spikes,
        ):
            b.channel_labels = m.channel_order[b.channel_ids]
            out.waveforms.append(b)
        return out
Пример #6
0
def state(tempdir):
    # Save a test GUI state JSON file in the tempdir.
    state = Bunch()
    state.WaveformView0 = Bunch(overlap=False)
    state.TraceView0 = Bunch(scaling=1.)
    state.FeatureView0 = Bunch(feature_scaling=.5)
    state.CorrelogramView0 = Bunch(uniform_normalization=True)
    return state
Пример #7
0
def state(tempdir):
    # Save a test GUI state JSON file in the tempdir.
    state = Bunch()
    state.WaveformView0 = Bunch(overlap=False)
    state.TraceView0 = Bunch(scaling=1.)
    state.FeatureView0 = Bunch(feature_scaling=.5)
    state.CorrelogramView0 = Bunch(uniform_normalization=True)
    return state
Пример #8
0
 def _select_data(self, cluster_id, arr, n_max=None, batch_size=None):
     spike_ids = self._select_spikes(cluster_id, n_max,
                                     batch_size=batch_size)
     b = Bunch()
     b.data = arr[spike_ids]
     b.spike_ids = spike_ids
     b.masks = self.all_masks[spike_ids]
     return b
Пример #9
0
 def _select_data(self, cluster_id, arr, n_max=None):
     spike_ids = self._select_spikes(cluster_id, n_max)
     b = Bunch()
     b.data = arr[spike_ids]
     b.spike_ids = spike_ids
     b.spike_clusters = self.spike_clusters[spike_ids]
     b.masks = self.all_masks[spike_ids]
     return b
Пример #10
0
 def get_amplitudes(self, cluster_id):
     spike_ids = self._select_spikes(cluster_id, self.n_spikes_features)
     d = Bunch()
     d.spike_ids = spike_ids
     d.spike_clusters = cluster_id * np.ones(len(spike_ids), dtype=np.int32)
     d.x = self.spike_times[spike_ids]
     d.y = self.all_amplitudes[spike_ids]
     return d
Пример #11
0
 def get_background_features(self):
     k = max(1, int(self.n_spikes // self.n_spikes_background_features))
     spike_ids = slice(None, None, k)
     b = Bunch()
     b.data = self.all_features[spike_ids]
     b.spike_ids = spike_ids
     b.spike_clusters = self.spike_clusters[spike_ids]
     b.masks = self.all_masks[spike_ids]
     return b
Пример #12
0
 def get_amplitudes(self, cluster_id):
     spike_ids = self._select_spikes(cluster_id, self.n_spikes_features)
     d = Bunch()
     d.spike_ids = spike_ids
     d.x = self.spike_times[spike_ids]
     d.y = self.all_amplitudes[spike_ids]
     M = d.y.max()
     d.data_bounds = [0, 0, self.duration, M]
     return d
Пример #13
0
 def wrapped(cluster_ids, **kwargs):
     # Single cluster.
     if not hasattr(cluster_ids, '__len__'):
         return f(cluster_ids, **kwargs)
     # Concatenate the result of multiple clusters.
     l = [f(c, **kwargs) for c in cluster_ids]
     # Handle the case where every function returns a list of Bunch.
     if l and isinstance(l[0], list):
         # We assume that all items have the same length.
         n = len(l[0])
         return [
             Bunch(_accumulate([item[i] for item in l])) for i in range(n)
         ]
     else:
         return Bunch(_accumulate(l))
Пример #14
0
    def validate(pos=None, text=None, anchor=None,
                 data_bounds=None,
                 ):

        if text is None:
            text = []
        if isinstance(text, string_types):
            text = [text]
        if pos is None:
            pos = np.zeros((len(text), 2))

        assert pos is not None
        pos = np.atleast_2d(pos)
        assert pos.ndim == 2
        assert pos.shape[1] == 2
        n_text = pos.shape[0]
        assert len(text) == n_text

        anchor = anchor if anchor is not None else (0., 0.)
        anchor = np.atleast_2d(anchor)
        if anchor.shape[0] == 1:
            anchor = np.repeat(anchor, n_text, axis=0)
        assert anchor.ndim == 2
        assert anchor.shape == (n_text, 2)

        if data_bounds is not None:
            data_bounds = _get_data_bounds(data_bounds, pos)
            assert data_bounds.shape[0] == n_text
            data_bounds = data_bounds.astype(np.float64)
            assert data_bounds.shape == (n_text, 4)

        return Bunch(pos=pos, text=text, anchor=anchor,
                     data_bounds=data_bounds)
Пример #15
0
 def _get_amplitudes(self, cluster_id):
     n = self.n_spikes_amplitudes
     m = self.model
     spike_ids = self.selector.select_spikes([cluster_id], n)
     x = m.spike_times[spike_ids]
     y = m.amplitudes[spike_ids]
     return Bunch(x=x, y=y, data_bounds=(0., 0., m.duration, y.max()))
Пример #16
0
 def state(self):
     return Bunch(box_scaling=tuple(self.box_scaling),
                  probe_scaling=tuple(self.probe_scaling),
                  overlap=self.overlap,
                  do_zoom_on_channels=self.do_zoom_on_channels,
                  do_show_labels=self.do_show_labels,
                  )
Пример #17
0
 def get_traces(self, interval):
     tr = select_traces(
         self.all_traces,
         interval,
         sample_rate=self.sample_rate,
     )
     return [Bunch(traces=tr)]
Пример #18
0
    def validate(
        x=None,
        y=None,
        pos=None,
        color=None,
        size=None,
        depth=None,
        data_bounds='auto',
    ):
        if pos is None:
            x, y = _get_pos(x, y)
            pos = np.c_[x, y]
        pos = np.asarray(pos)
        assert pos.ndim == 2
        assert pos.shape[1] == 2
        n = pos.shape[0]

        # Validate the data.
        color = _get_array(color, (n, 4),
                           ScatterVisual._default_color,
                           dtype=np.float32)
        size = _get_array(size, (n, 1), ScatterVisual._default_marker_size)
        depth = _get_array(depth, (n, 1), 0)
        if data_bounds is not None:
            data_bounds = _get_data_bounds(data_bounds, pos)
            assert data_bounds.shape[0] == n

        return Bunch(pos=pos,
                     color=color,
                     size=size,
                     depth=depth,
                     data_bounds=data_bounds)
Пример #19
0
 def state(self):
     return Bunch(speed_threshold=self.speed_threshold,
                  speed_threshold_mode=self.speed_threshold_mode,
                  time_range=self.time_ranges,
                  n_rate_map_contours=self.n_rate_map_contours,
                  rate_map_contour_mode=self.rate_map_contour_mode,
                  spike_pos_shift=self.spike_pos_shift)
Пример #20
0
    def validate(
        x=None,
        y=None,
        pos=None,
        masks=None,
        data_bounds='auto',
    ):
        if pos is None:
            x, y = _get_pos(x, y)
            pos = np.c_[x, y]
        pos = np.asarray(pos)
        assert pos.ndim == 2
        assert pos.shape[1] == 2
        n = pos.shape[0]

        masks = _get_array(masks, (n, 1), 1., np.float32)
        assert masks.shape == (n, 1)

        # Validate the data.
        if data_bounds is not None:
            data_bounds = _get_data_bounds(data_bounds, pos)
            assert data_bounds.shape[0] == n

        return Bunch(
            pos=pos,
            masks=masks,
            data_bounds=data_bounds,
        )
Пример #21
0
    def validate(hist=None, color=None, ylim=None):
        assert hist is not None
        hist = np.asarray(hist, np.float64)
        if hist.ndim == 1:
            hist = hist[None, :]
        assert hist.ndim == 2
        n_hists, n_bins = hist.shape

        # Validate the data.
        color = _get_array(
            color,
            (n_hists, 4),
            HistogramVisual._default_color,
            dtype=np.float32,
        )

        # Validate ylim.
        if ylim is None:
            ylim = hist.max() if hist.size > 0 else 1.
        ylim = np.atleast_1d(ylim)
        if len(ylim) == 1:
            ylim = np.tile(ylim, n_hists)
        if ylim.ndim == 1:
            ylim = ylim[:, np.newaxis]
        assert ylim.shape == (n_hists, 1)

        return Bunch(
            hist=hist,
            ylim=ylim,
            color=color,
        )
Пример #22
0
 def _get_template_waveforms(self, cluster_id):
     """Return the waveforms of the templates corresponding to a cluster."""
     pos = self.model.channel_positions
     count = self.get_template_counts(cluster_id)
     template_ids = np.nonzero(count)[0]
     count = count[template_ids]
     # Get local channels.
     channel_ids = self.get_best_channels(cluster_id)
     # Get masks.
     masks = count / float(count.max())
     masks = np.tile(masks.reshape((-1, 1)), (1, len(channel_ids)))
     # Get the mean amplitude for the cluster.
     mean_amp = self._get_amplitudes(cluster_id).y.mean()
     # Get all templates from which this cluster stems from.
     templates = [self.model.get_template(template_id)
                  for template_id in template_ids]
     data = np.stack([b.template * mean_amp for b in templates], axis=0)
     cols = np.stack([b.channel_ids for b in templates], axis=0)
     # NOTE: transposition because the channels should be in the second
     # dimension for from_sparse.
     data = data.transpose((0, 2, 1))
     assert data.ndim == 3
     assert data.shape[1] == cols.shape[1]
     waveforms = from_sparse(data, cols, channel_ids)
     # Transpose back.
     waveforms = waveforms.transpose((0, 2, 1))
     return Bunch(data=waveforms,
                  channel_ids=channel_ids,
                  channel_positions=pos[channel_ids],
                  masks=masks,
                  alpha=1.,
                  )
Пример #23
0
    def _get_template_features(self, cluster_ids):
        assert len(cluster_ids) == 2
        clu0, clu1 = cluster_ids

        s0 = self._get_spike_ids(clu0)
        s1 = self._get_spike_ids(clu1)

        n0 = self.get_template_counts(clu0)
        n1 = self.get_template_counts(clu1)

        t0 = self.model.get_template_features(s0)
        t1 = self.model.get_template_features(s1)

        x0 = np.average(t0, weights=n0, axis=1)
        y0 = np.average(t0, weights=n1, axis=1)

        x1 = np.average(t1, weights=n0, axis=1)
        y1 = np.average(t1, weights=n1, axis=1)

        return Bunch(x0=x0, y0=y0, x1=x1, y1=y1,
                     data_bounds=(min(x0.min(), x1.min()),
                                  min(y0.min(), y1.min()),
                                  max(y0.max(), y1.max()),
                                  max(y0.max(), y1.max()),
                                  ),
                     )
Пример #24
0
    def _get_axis_data(self, bunch, dim, cluster_id=None, load_all=None):
        """Extract the points from the data on a given dimension.

        bunch is returned by the features() function.
        dim is the string specifying the dimensions to extract for the data.

        """
        if dim in self.attributes:
            return self.attributes[dim](cluster_id, load_all=load_all)
        masks = bunch.get('masks', None)
        assert dim not in self.attributes  # This is called only on PC data.
        s = 'ABCDEFGHIJ'
        # Channel relative index.
        c_rel = int(dim[:-1])
        # Get the channel_id from the currently-selected channels.
        channel_id = self.channel_ids[c_rel % len(self.channel_ids)]
        # Skup the plot if the channel id is not displayed.
        if channel_id not in bunch.channel_ids:  # pragma: no cover
            return None
        # Get the column index of the current channel in data.
        c = list(bunch.channel_ids).index(channel_id)
        # Principal component: A=0, B=1, etc.
        d = s.index(dim[-1])
        if masks is not None:
            masks = masks[:, c]
        return Bunch(
            data=bunch.data[:, c, d],
            masks=masks,
        )
Пример #25
0
 def state(self):
     return Bunch(
         scaling=self.scaling,
         origin=self.origin,
         interval=self._interval,
         do_show_labels=self.do_show_labels,
     )
Пример #26
0
def waveform_loader(request):
    scale_factor, dc_offset = request.param

    n_samples_trace, n_channels = 1000, 5
    h = 10
    n_samples_waveforms = 2 * h
    n_spikes = n_samples_trace // (2 * n_samples_waveforms)

    traces = artificial_traces(n_samples_trace, n_channels)
    spike_samples = artificial_spike_samples(n_spikes,
                                             max_isi=2 * n_samples_waveforms)

    with raises(ValueError):
        WaveformLoader(traces)

    loader = WaveformLoader(
        traces=traces,
        n_samples_waveforms=n_samples_waveforms,
        scale_factor=scale_factor,
        dc_offset=dc_offset,
    )
    b = Bunch(
        loader=loader,
        n_samples_waveforms=n_samples_waveforms,
        n_spikes=n_spikes,
        spike_samples=spike_samples,
    )
    yield b
Пример #27
0
 def state(self):
     return Bunch(
         bin_size=self.bin_size,
         window_size=self.window_size,
         excerpt_size=self.excerpt_size,
         n_excerpts=self.n_excerpts,
         uniform_normalization=self.uniform_normalization,
     )
Пример #28
0
 def _get_amplitudes(self, cluster_id):
     n = self.n_spikes_amplitudes
     m = self.model
     spike_ids = self.selector.select_spikes([cluster_id], n)
     channel_id = self.get_best_channel(cluster_id)
     x = m.spike_times[spike_ids]
     y = m.amplitudes[spike_ids, channel_id]
     return Bunch(x=x, y=y, data_bounds=(0., y.min(), m.duration, y.max()))
Пример #29
0
    def state(self):
        """View state.

        This Bunch will be automatically persisted in the GUI state when the
        GUI is closed.

        To be overriden.

        """
        return Bunch()
Пример #30
0
 def get_traces(interval):
     out = Bunch(data=select_traces(traces, interval, sample_rate=sr),
                 color=(.75,) * 4,
                 )
     a, b = st.searchsorted(interval)
     out.waveforms = []
     k = 20
     for i in range(a, b):
         t = st[i]
         c = sc[i]
         s = int(round(t * sr))
         d = Bunch(data=traces[s - k:s + k, :],
                   start_time=t - k / sr,
                   color=cs.get(c),
                   channel_ids=np.arange(5),
                   cluster_id=c,
                   )
         out.waveforms.append(d)
     return out
Пример #31
0
    def get_cluster_pair_features(self, ci, cj):
        si = self._select_spikes(ci, self.n_spikes_features)
        sj = self._select_spikes(cj, self.n_spikes_features)

        ni = self.get_cluster_templates(ci)
        nj = self.get_cluster_templates(cj)

        ti = self._get_template_features(si)
        x0 = np.average(ti, weights=ni, axis=1)
        y0 = np.average(ti, weights=nj, axis=1)

        tj = self._get_template_features(sj)
        x1 = np.average(tj, weights=ni, axis=1)
        y1 = np.average(tj, weights=nj, axis=1)

        return [
            Bunch(x=x0, y=y0, spike_ids=si),
            Bunch(x=x1, y=y1, spike_ids=sj)
        ]
Пример #32
0
 def get_traces(interval):
     out = Bunch(data=select_traces(traces, interval, sample_rate=sr),
                 color=(.75,) * 4,
                 )
     a, b = st.searchsorted(interval)
     out.waveforms = []
     k = 20
     for i in range(a, b):
         t = st[i]
         c = sc[i]
         s = int(round(t * sr))
         d = Bunch(data=traces[s - k:s + k, :],
                   start_time=t - k / sr,
                   color=cs.get(c),
                   channel_ids=np.arange(5),
                   spike_id=i,
                   spike_cluster=c,
                   )
         out.waveforms.append(d)
     return out
Пример #33
0
 def get_waveforms(self, cluster_id):
     m, M = self.get_waveform_lims()
     if self.all_waveforms is not None:
         # Waveforms.
         waveforms_b = self._select_data(
             cluster_id,
             self.all_waveforms,
             self.n_spikes_waveforms,
         )
         w = waveforms_b.data
         # Sparsify.
         channels = np.nonzero(w.mean(axis=1).mean(axis=0))[0]
         w = w[:, :, channels]
         waveforms_b.channels = channels
         # Normalize.
         mean = w.mean(axis=1).mean(axis=1)
         w = w.astype(np.float64)
         w -= mean[:, np.newaxis, np.newaxis]
         w = _normalize(w, m, M)
         waveforms_b.data = w
         waveforms_b.cluster_id = cluster_id
         waveforms_b.tag = 'waveforms'
     else:
         waveforms_b = None
     # Find the templates corresponding to the cluster.
     template_ids = np.nonzero(self.get_cluster_templates(cluster_id))[0]
     # Templates.
     templates = self.templates_unw[template_ids]
     assert templates.ndim == 3
     # Masks.
     masks = self.template_masks[template_ids]
     assert masks.ndim == 2
     assert templates.shape[0] == masks.shape[0]
     # Find mean amplitude.
     spike_ids = self._select_spikes(cluster_id,
                                     self.n_spikes_waveforms_lim)
     mean_amp = self.all_amplitudes[spike_ids].mean()
     # Normalize.
     # mean = templates.mean(axis=1).mean(axis=1)
     templates = templates.astype(np.float64).copy()
     # templates -= mean[:, np.newaxis, np.newaxis]
     templates *= mean_amp
     templates *= 2. / (M - m)
     template_b = Bunch(
         data=templates,
         masks=masks,
         alpha=1.,
         cluster_id=cluster_id,
         tag='templates',
     )
     if waveforms_b is not None:
         return [waveforms_b, template_b]
     else:
         return [template_b]
Пример #34
0
    def get_cluster_pair_features(self, ci, cj):
        si = self._select_spikes(ci, self.n_spikes_features)
        sj = self._select_spikes(cj, self.n_spikes_features)

        ni = self.get_cluster_templates(ci)
        nj = self.get_cluster_templates(cj)

        ti = self._get_template_features(si)
        x0 = np.sum(ti * ni[np.newaxis, :], axis=1) / ni.sum()
        y0 = np.sum(ti * nj[np.newaxis, :], axis=1) / nj.sum()

        tj = self._get_template_features(sj)
        x1 = np.sum(tj * ni[np.newaxis, :], axis=1) / ni.sum()
        y1 = np.sum(tj * nj[np.newaxis, :], axis=1) / nj.sum()

        d = Bunch()
        d.x = np.hstack((x0, x1))
        d.y = np.hstack((y0, y1))
        d.spike_ids = np.hstack((si, sj))
        d.spike_clusters = self.spike_clusters[d.spike_ids]
        return d
Пример #35
0
    def get_features(self, cluster_id, load_all=False):
        # Overriden to take into account the sparse structure.
        # Only keep spikes belonging to the features spike ids.
        if self.features_spike_ids is not None:
            # All spikes
            spike_ids = self._select_spikes(cluster_id)
            spike_ids = np.intersect1d(spike_ids, self.features_spike_ids)
            # Relative indices of the spikes in the self.features_spike_ids
            # array, necessary to load features from all_features which only
            # contains the subset of the spikes.
            spike_ids_rel = _index_of(spike_ids, self.features_spike_ids)
        else:
            spike_ids = self._select_spikes(
                cluster_id, self.n_spikes_features if not load_all else None)
            spike_ids_rel = spike_ids
        st = self.spike_templates[spike_ids]
        nc = self.n_channels
        nfpc = self.n_features_per_channel
        ns = len(spike_ids)
        f = _densify(spike_ids_rel, self.all_features,
                     self.features_ind[st, :], self.n_channels)
        f = np.transpose(f, (0, 2, 1))
        assert f.shape == (ns, nc, nfpc)
        b = Bunch()

        # Normalize features.
        m = self.get_feature_lim()
        f = _normalize(f, -m, m)

        b.data = f
        b.spike_ids = spike_ids
        b.spike_clusters = self.spike_clusters[spike_ids]
        b.masks = self.all_masks[spike_ids]
        return b
Пример #36
0
 def get_background_features(self):
     k = max(1, int(self.n_spikes // self.n_spikes_background_features))
     spike_ids = slice(None, None, k)
     b = Bunch()
     b.data = self.all_features[spike_ids]
     m = self.get_feature_lim()
     b.data = _normalize(b.data.copy(), -m, +m)
     b.spike_ids = spike_ids
     b.spike_clusters = self.spike_clusters[spike_ids]
     b.masks = self.all_masks[spike_ids]
     return b
Пример #37
0
 def get_amplitudes(self, cluster_id):
     spike_ids = self._select_spikes(cluster_id, self.n_spikes_features)
     d = Bunch()
     d.spike_ids = spike_ids
     d.spike_clusters = cluster_id * np.ones(len(spike_ids), dtype=np.int32)
     d.x = self.spike_times[spike_ids]
     d.y = self.all_amplitudes[spike_ids]
     return d
Пример #38
0
 def _get_waveforms(self, cluster_id):
     """Return a selection of waveforms for a cluster."""
     pos = self.model.channel_positions
     spike_ids = self.selector.select_spikes([cluster_id],
                                             self.n_spikes_waveforms,
                                             self.batch_size_waveforms,
                                             )
     channel_ids = self.get_best_channels(cluster_id)
     data = self.model.get_waveforms(spike_ids, channel_ids)
     data = data - data.mean()
     return Bunch(data=data,
                  channel_ids=channel_ids,
                  channel_positions=pos[channel_ids],
                  )
Пример #39
0
def extract_spikes(traces, interval, sample_rate=None,
                   spike_times=None, spike_clusters=None,
                   all_masks=None,
                   n_samples_waveforms=None):
    sr = sample_rate
    ns = n_samples_waveforms
    if not isinstance(ns, tuple):
        ns = (ns // 2, ns // 2)
    offset_samples = ns[0]
    wave_len = ns[0] + ns[1]

    # Find spikes.
    a, b = spike_times.searchsorted(interval)
    st = spike_times[a:b]
    sc = spike_clusters[a:b]
    m = all_masks[a:b]
    n = len(st)
    assert len(sc) == n
    assert m.shape[0] == n

    # Extract waveforms.
    spikes = []
    for i in range(n):
        b = Bunch()
        # Find the start of the waveform in the extracted traces.
        sample_start = int(round((st[i] - interval[0]) * sr))
        sample_start -= offset_samples
        o = _extract_wave(traces, sample_start, m[i], wave_len)
        if o is None:  # pragma: no cover
            logger.debug("Unable to extract spike %d.", i)
            continue
        b.waveforms, b.channels = o
        # Masks on unmasked channels.
        b.masks = m[i, b.channels]
        b.spike_time = st[i]
        b.spike_cluster = sc[i]
        b.offset_samples = offset_samples

        spikes.append(b)
    return spikes