Esempio n. 1
0
File: color.py Progetto: zsong30/phy
def _selected_cluster_idx(selected_clusters, cluster_ids):
    selected_clusters = np.asarray(selected_clusters, dtype=np.int32)
    cluster_ids = np.asarray(cluster_ids, dtype=np.int32)
    kept = np.isin(selected_clusters, cluster_ids)
    clu_idx = _index_of(selected_clusters[kept], cluster_ids)
    cmap_idx = np.arange(len(selected_clusters))[kept]
    return clu_idx, cmap_idx
Esempio n. 2
0
 def _get_box_index(self):
     """Return, for every spike, its row in the raster plot. This depends on the ordering
     in self.cluster_ids."""
     cl = self.spike_clusters[self.spike_ids]
     # Sanity check.
     # assert np.all(np.in1d(cl, self.cluster_ids))
     return _index_of(cl, self.all_cluster_ids)
Esempio n. 3
0
def _add_selected_clusters_colors(selected_clusters, cluster_ids,
                                  cluster_colors):
    """Take an array with colors of clusters as input, and add colors of selected clusters."""
    # Find the index of the selected clusters within the self.cluster_ids.
    clu_idx = _index_of(selected_clusters, cluster_ids)
    # Get the colors of the selected clusters.
    colormap = _categorical_colormap(colormaps.default, clu_idx)
    # Inject those colors in cluster_colors.
    cluster_colors[clu_idx] = add_alpha(colormap, 1)
    return cluster_colors
Esempio n. 4
0
 def _get_box_index(self, bunch):
     """Get the box_index array for a cluster."""
     # Generate the box index (channel_idx, cluster_idx) per vertex.
     n_samples, nc = bunch.template.shape
     box_index = _index_of(bunch.channel_ids, self.channel_ids)
     box_index = np.repeat(box_index, n_samples)
     box_index = np.c_[box_index.reshape((-1, 1)),
                       bunch.cluster_idx * np.ones(
                           (n_samples * len(bunch.channel_ids), 1))]
     assert box_index.shape == (len(bunch.channel_ids) * n_samples, 2)
     assert box_index.size == bunch.template.size * 2
     return box_index
Esempio n. 5
0
def _categorical_colormap(colormap, values, vmin=None, vmax=None):
    assert np.issubdtype(values.dtype, np.integer)
    assert colormap.shape[1] == 3
    n = colormap.shape[0]
    if vmin is None and vmax is None:
        # Find unique values and keep the order.
        _, idx = np.unique(values, return_index=True)
        lookup = values[np.sort(idx)]
        x = _index_of(values, lookup)
    else:
        x = values
    return colormap[x % n, :]
Esempio n. 6
0
    def _plot_cluster(self, bunch):
        wave = bunch.data
        if wave is None or not wave.size:
            return
        channel_ids_loc = bunch.channel_ids

        n_channels = len(channel_ids_loc)
        masks = bunch.get('masks', np.ones((wave.shape[0], n_channels)))
        # By default, this is 0, 1, 2 for the first 3 clusters.
        # But it can be customized when displaying several sets
        # of waveforms per cluster.

        n_spikes_clu, n_samples = wave.shape[:2]
        assert wave.shape[2] == n_channels
        assert masks.shape == (n_spikes_clu, n_channels)

        # Find the x coordinates.
        t = get_linear_x(n_spikes_clu * n_channels, n_samples)
        t = _overlap_transform(t,
                               offset=bunch.offset,
                               n=bunch.n_clu,
                               overlap=self.overlap)
        # HACK: on the GPU, we get the actual masks with fract(masks)
        # since we add the relative cluster index. We need to ensure
        # that the masks is never 1.0, otherwise it is interpreted as
        # 0.
        masks *= .99999
        # NOTE: we add the cluster index which is used for the
        # computation of the depth on the GPU.
        masks += bunch.index

        # Generate the box index (one number per channel).
        box_index = _index_of(channel_ids_loc, self.channel_ids)
        box_index = np.repeat(box_index, n_samples)
        box_index = np.tile(box_index, n_spikes_clu)
        assert box_index.shape == (n_spikes_clu * n_channels * n_samples, )

        # Generate the waveform array.
        wave = np.transpose(wave, (0, 2, 1))
        wave = wave.reshape((n_spikes_clu * n_channels, n_samples))

        self.waveform_visual.add_batch_data(x=t,
                                            y=wave,
                                            color=bunch.color,
                                            masks=masks,
                                            box_index=box_index,
                                            data_bounds=self.data_bounds)
Esempio n. 7
0
File: color.py Progetto: zsong30/phy
def _categorical_colormap(colormap,
                          values,
                          vmin=None,
                          vmax=None,
                          categorize=None):
    """Convert values into colors given a specified categorical colormap."""
    assert np.issubdtype(values.dtype, np.integer)
    assert colormap.shape[1] == 3
    n = colormap.shape[0]
    if categorize is True or (categorize is None and vmin is None
                              and vmax is None):
        # Find unique values and keep the order.
        _, idx = np.unique(values, return_index=True)
        lookup = values[np.sort(idx)]
        x = _index_of(values, lookup)
    else:
        x = values
    return colormap[x % n, :]
Esempio n. 8
0
File: ccg.py Progetto: CINPLA/phylib
def firing_rate(spike_clusters,
                cluster_ids=None,
                bin_size=None,
                duration=None):
    """Compute the average number of spikes per cluster per bin."""

    # Take the cluster order into account.
    if cluster_ids is None:
        clusters = _unique(spike_clusters)
    else:
        clusters = _as_array(cluster_ids)

    # Like spike_clusters, but with 0..n_clusters-1 indices.
    spike_clusters_i = _index_of(spike_clusters, clusters)

    assert duration > 0
    assert bin_size > 0
    bc = np.bincount(spike_clusters_i)
    return bc * np.c_[bc] * (bin_size / duration)
Esempio n. 9
0
File: color.py Progetto: zsong30/phy
def spike_colors(spike_clusters, cluster_ids):
    """Return the colors of spikes according to the index of their cluster within `cluster_ids`.

    Parameters
    ----------

    spike_clusters : array-like
        The spike-cluster assignments.
    cluster_ids : array-like
        The set of unique selected cluster ids appearing in spike_clusters, in a given order

    Returns
    -------

    spike_colors : array-like
        For each spike, the RGBA color (in [0,1]) depending on the index of the cluster within
        `cluster_ids`.

    """
    spike_clusters_idx = _index_of(spike_clusters, cluster_ids)
    return add_alpha(colormaps.default[np.mod(spike_clusters_idx,
                                              colormaps.default.shape[0])])
Esempio n. 10
0
def firing_rate(spike_clusters,
                cluster_ids=None,
                bin_size=None,
                duration=None):
    """Compute the average number of spikes per cluster per bin."""

    # Take the cluster order into account.
    if cluster_ids is None:
        cluster_ids = _unique(spike_clusters)
    else:
        cluster_ids = _as_array(cluster_ids)

    # Like spike_clusters, but with 0..n_clusters-1 indices.
    spike_clusters_i = _index_of(spike_clusters, cluster_ids)

    assert bin_size > 0
    bc = np.bincount(spike_clusters_i)
    # Handle the case where the last cluster(s) are empty.
    if len(bc) < len(cluster_ids):
        n = len(cluster_ids) - len(bc)
        bc = np.concatenate((bc, np.zeros(n, dtype=bc.dtype)))
    assert bc.shape == (len(cluster_ids), )
    return bc * np.c_[bc] * (bin_size / (duration or 1.))
Esempio n. 11
0
    def make_depths(self):
        """Make spikes.depths.npy and clusters.depths.npy."""
        channel_positions = self.model.channel_positions
        assert channel_positions.ndim == 2

        spike_clusters = self.model.spike_clusters
        assert spike_clusters.ndim == 1
        n_spikes = spike_clusters.shape[0]
        self.cluster_ids = _unique(self.model.spike_clusters)

        cluster_channels = np.load(self.out_path / 'clusters.peakChannel.npy')
        assert cluster_channels.ndim == 1
        n_clusters = cluster_channels.shape[0]

        clusters_depths = channel_positions[cluster_channels, 1]
        assert clusters_depths.shape == (n_clusters, )

        spike_clusters_rel = _index_of(spike_clusters, self.cluster_ids)
        assert spike_clusters_rel.max() < clusters_depths.shape[0]
        spikes_depths = clusters_depths[spike_clusters_rel]
        assert spikes_depths.shape == (n_spikes, )

        self._save_npy('spikes.depths.npy', spikes_depths)
        self._save_npy('clusters.depths.npy', clusters_depths)
Esempio n. 12
0
    def _plot_cluster(self, bunch):
        wave = bunch.data
        if wave is None or not wave.size:
            return
        channel_ids_loc = bunch.channel_ids

        n_channels = len(channel_ids_loc)
        masks = bunch.get('masks', np.ones((wave.shape[0], n_channels)))
        # By default, this is 0, 1, 2 for the first 3 clusters.
        # But it can be customized when displaying several sets
        # of waveforms per cluster.

        n_spikes_clu, n_samples = wave.shape[:2]
        assert wave.shape[2] == n_channels
        assert masks.shape == (n_spikes_clu, n_channels)

        # Find the x coordinates.
        t = get_linear_x(n_spikes_clu * n_channels, n_samples)
        t = _overlap_transform(t,
                               offset=bunch.offset,
                               n=bunch.n_clu,
                               overlap=self.overlap)
        # HACK: on the GPU, we get the actual masks with fract(masks)
        # since we add the relative cluster index. We need to ensure
        # that the masks is never 1.0, otherwise it is interpreted as
        # 0.
        eps = .001
        masks = eps + (1 - 2 * eps) * masks
        # NOTE: we add the cluster index which is used for the
        # computation of the depth on the GPU.
        masks += bunch.index

        # Generate the box index (one number per channel).
        box_index = _index_of(channel_ids_loc, self.channel_ids)
        box_index = np.tile(box_index, n_spikes_clu)

        # Find the correct number of vertices depending on the current waveform visual.
        if self._current_visual == self.waveform_visual:
            # PlotVisual
            box_index = np.repeat(box_index, n_samples)
            assert box_index.size == n_spikes_clu * n_channels * n_samples
        else:
            # PlotAggVisual
            box_index = np.repeat(box_index, 2 * (n_samples + 2))
            assert box_index.size == n_spikes_clu * n_channels * 2 * (
                n_samples + 2)

        # Generate the waveform array.
        wave = np.transpose(wave, (0, 2, 1))
        nw = n_spikes_clu * n_channels
        wave = wave.reshape((nw, n_samples))

        assert self.data_bounds is not None
        self._current_visual.add_batch_data(x=t,
                                            y=wave,
                                            color=bunch.color,
                                            masks=masks,
                                            box_index=box_index,
                                            data_bounds=self.data_bounds)

        # Waveform axes.
        # --------------

        # Horizontal y=0 lines.
        ax_db = self.data_bounds
        a, b = _overlap_transform(np.array([-1, 1]),
                                  offset=bunch.offset,
                                  n=bunch.n_clu,
                                  overlap=self.overlap)
        box_index = _index_of(channel_ids_loc, self.channel_ids)
        box_index = np.repeat(box_index, 2)
        box_index = np.tile(box_index, n_spikes_clu)
        hpos = np.tile([[a, 0, b, 0]], (nw, 1))
        assert box_index.size == hpos.shape[0] * 2
        self.line_visual.add_batch_data(
            pos=hpos,
            color=self.ax_color,
            data_bounds=ax_db,
            box_index=box_index,
        )

        # Vertical ticks every millisecond.
        steps = np.arange(np.round(self.wave_duration * 1000))
        # A vline every millisecond.
        x = .001 * steps
        # Scale to [-1, 1], same coordinates as the waveform points.
        x = -1 + 2 * x / self.wave_duration
        # Take overlap into account.
        x = _overlap_transform(x,
                               offset=bunch.offset,
                               n=bunch.n_clu,
                               overlap=self.overlap)
        x = np.tile(x, len(channel_ids_loc))
        # Generate the box index.
        box_index = _index_of(channel_ids_loc, self.channel_ids)
        box_index = np.repeat(box_index, x.size // len(box_index))
        assert x.size == box_index.size
        self.tick_visual.add_batch_data(
            x=x,
            y=np.zeros_like(x),
            data_bounds=ax_db,
            box_index=box_index,
        )
Esempio n. 13
0
def merge_probes(ses_path):
    """
    Merge spike sorting output from 2 probes and output in the session ALF folder the combined
    output in IBL format
    :param ses_path: session containing probes to be merged
    :return: None
    """
    def _sr(ap_file):
        md = spikeglx.read_meta_data(ap_file.with_suffix('.meta'))
        return spikeglx._get_fs_from_meta(md)

    ses_path = Path(ses_path)
    out_dir = ses_path.joinpath('alf').joinpath('tmp_merge')
    ephys_files = glob_ephys_files(ses_path)
    subdirs, labels, efiles_sorted, srates = zip(
        *sorted([(ep.ap.parent, ep.label, ep, _sr(ep.ap)) for ep in ephys_files if ep.get('ap')]))

    # if there is only one file, just convert the output to IBL format et basta
    if len(subdirs) == 1:
        ks2_to_alf(subdirs[0], ses_path / 'alf')
        return
    else:
        _logger.info('converting individual spike-sorting outputs to ALF')
        for subdir, label, ef, sr in zip(subdirs, labels, efiles_sorted, srates):
            ks2alf_path = subdir / 'ks2_alf'
            if ks2alf_path.exists():
                shutil.rmtree(ks2alf_path, ignore_errors=True)
            ks2_to_alf(subdir, ks2alf_path, label=label, sr=sr, force=True)

    probe_info = [{'label': lab} for lab in labels]
    mt = merge.Merger(subdirs=subdirs, out_dir=out_dir, probe_info=probe_info).merge()
    # Create the cluster channels file, this should go in the model template as 2 methods
    tmp = mt.sparse_templates.data
    n_templates, n_samples, n_channels = tmp.shape
    template_peak_channels = np.argmax(tmp.max(axis=1) - tmp.min(axis=1), axis=1)
    cluster_probes = mt.channel_probes[template_peak_channels]
    spike_clusters_rel = _index_of(mt.spike_clusters, _unique(mt.spike_clusters))
    spike_probes = cluster_probes[spike_clusters_rel]

    # sync spikes according to the probes
    # how do you make sure they match the files:
    for ind, probe in enumerate(efiles_sorted):
        assert(labels[ind] == probe.label)  # paranoid, make sure they are sorted
        if not probe.get('ap'):
            continue
        sync_file = probe.ap.parent.joinpath(probe.ap.name.replace('.ap.', '.sync.')
                                             ).with_suffix('.npy')
        if not sync_file.exists():
            error_msg = f'No synchronisation file for {sync_file}'
            _logger.error(error_msg)
            raise FileNotFoundError(error_msg)
        sync_points = np.load(sync_file)
        fcn = interp1d(sync_points[:, 0] * srates[ind],
                       sync_points[:, 1], fill_value='extrapolate')
        mt.spike_times[spike_probes == ind] = fcn(mt.spike_samples[spike_probes == ind])

    # And convert to ALF
    ac = alf.EphysAlfCreator(mt)
    ac.convert(ses_path / 'alf', force=True)
    # remove the temporary directory
    shutil.rmtree(out_dir)
Esempio n. 14
0
File: ccg.py Progetto: CINPLA/phylib
def correlograms(
    spike_times,
    spike_clusters,
    cluster_ids=None,
    sample_rate=1.,
    bin_size=None,
    window_size=None,
    symmetrize=True,
):
    """Compute all pairwise cross-correlograms among the clusters appearing
    in `spike_clusters`.

    Parameters
    ----------

    spike_times : array-like
        Spike times in seconds.
    spike_clusters : array-like
        Spike-cluster mapping.
    cluster_ids : array-like
        The list of *all* unique clusters, in any order. That order will be used
        in the output array.
    bin_size : float
        Size of the bin, in seconds.
    window_size : float
        Size of the window, in seconds.

    Returns
    -------

    correlograms : array
        A `(n_clusters, n_clusters, winsize_samples)` array with all pairwise
        CCGs.

    """
    assert sample_rate > 0.
    assert np.all(np.diff(spike_times) >= 0), ("The spike times must be "
                                               "increasing.")

    # Get the spike samples.
    spike_times = np.asarray(spike_times, dtype=np.float64)
    spike_samples = (spike_times * sample_rate).astype(np.int64)

    spike_clusters = _as_array(spike_clusters)

    assert spike_samples.ndim == 1
    assert spike_samples.shape == spike_clusters.shape

    # Find `binsize`.
    bin_size = np.clip(bin_size, 1e-5, 1e5)  # in seconds
    binsize = int(sample_rate * bin_size)  # in samples
    assert binsize >= 1

    # Find `winsize_bins`.
    window_size = np.clip(window_size, 1e-5, 1e5)  # in seconds
    winsize_bins = 2 * int(.5 * window_size / bin_size) + 1

    assert winsize_bins >= 1
    assert winsize_bins % 2 == 1

    # Take the cluster order into account.
    if cluster_ids is None:
        clusters = _unique(spike_clusters)
    else:
        clusters = _as_array(cluster_ids)
    n_clusters = len(clusters)

    # Like spike_clusters, but with 0..n_clusters-1 indices.
    spike_clusters_i = _index_of(spike_clusters, clusters)

    # Shift between the two copies of the spike trains.
    shift = 1

    # At a given shift, the mask precises which spikes have matching spikes
    # within the correlogram time window.
    mask = np.ones_like(spike_samples, dtype=np.bool)

    correlograms = _create_correlograms_array(n_clusters, winsize_bins)

    # The loop continues as long as there is at least one spike with
    # a matching spike.
    while mask[:-shift].any():
        # Number of time samples between spike i and spike i+shift.
        spike_diff = _diff_shifted(spike_samples, shift)

        # Binarize the delays between spike i and spike i+shift.
        spike_diff_b = spike_diff // binsize

        # Spikes with no matching spikes are masked.
        mask[:-shift][spike_diff_b > (winsize_bins // 2)] = False

        # Cache the masked spike delays.
        m = mask[:-shift].copy()
        d = spike_diff_b[m]

        # # Update the masks given the clusters to update.
        # m0 = np.in1d(spike_clusters[:-shift], clusters)
        # m = m & m0
        # d = spike_diff_b[m]
        d = spike_diff_b[m]

        # Find the indices in the raveled correlograms array that need
        # to be incremented, taking into account the spike clusters.
        indices = np.ravel_multi_index(
            (spike_clusters_i[:-shift][m], spike_clusters_i[+shift:][m], d),
            correlograms.shape)

        # Increment the matching spikes in the correlograms array.
        _increment(correlograms.ravel(), indices)

        shift += 1

    # Remove ACG peaks.
    correlograms[np.arange(n_clusters), np.arange(n_clusters), 0] = 0

    if symmetrize:
        return _symmetrize_correlograms(correlograms)
    else:
        return correlograms