예제 #1
0
    def _do_assign(self, spike_ids, new_spike_clusters):
        """Make spike-cluster assignments after the spike selection has
        been extended to full clusters."""

        # Ensure spike_clusters has the right shape.
        spike_ids = _as_array(spike_ids)
        if len(new_spike_clusters) == 1 and len(spike_ids) > 1:
            new_spike_clusters = (np.ones(len(spike_ids), dtype=np.int64) *
                                  new_spike_clusters[0])
        old_spike_clusters = self._spike_clusters[spike_ids]

        assert len(spike_ids) == len(old_spike_clusters)
        assert len(new_spike_clusters) == len(spike_ids)

        # Update the spikes per cluster structure.
        old_clusters = _unique(old_spike_clusters)

        # NOTE: shortcut to a merge if this assignment is effectively a merge
        # i.e. if all spikes are assigned to a single cluster.
        # The fact that spike selection has been previously extended to
        # whole clusters is critical here.
        new_clusters = _unique(new_spike_clusters)
        if len(new_clusters) == 1:
            return self._do_merge(spike_ids, old_clusters, new_clusters[0])

        # We return the UpdateInfo structure.
        up = _assign_update_info(spike_ids, old_spike_clusters,
                                 new_spike_clusters)

        # We update the new cluster id (strictly increasing during a session).
        self._new_cluster_id = max(self._new_cluster_id, max(up.added) + 1)

        # We make the assignments.
        self._spike_clusters[spike_ids] = new_spike_clusters
        return up
예제 #2
0
def _assign_update_info(spike_ids, old_spike_clusters, new_spike_clusters):
    old_clusters = _unique(old_spike_clusters)
    new_clusters = _unique(new_spike_clusters)
    descendants = list(set(zip(old_spike_clusters, new_spike_clusters)))
    update_info = UpdateInfo(
        description='assign',
        spike_ids=spike_ids,
        added=list(new_clusters),
        deleted=list(old_clusters),
        descendants=descendants,
    )
    return update_info
예제 #3
0
파일: clustering.py 프로젝트: ablot/phy
def _assign_update_info(spike_ids, old_spike_clusters, new_spike_clusters):
    old_clusters = _unique(old_spike_clusters)
    new_clusters = _unique(new_spike_clusters)
    descendants = list(set(zip(old_spike_clusters,
                               new_spike_clusters)))
    update_info = UpdateInfo(description='assign',
                             spike_ids=spike_ids,
                             added=list(new_clusters),
                             deleted=list(old_clusters),
                             descendants=descendants,
                             )
    return update_info
예제 #4
0
    def _do_assign(self, spike_ids, new_spike_clusters):
        """Make spike-cluster assignments after the spike selection has
        been extended to full clusters."""

        # Ensure spike_clusters has the right shape.
        spike_ids = _as_array(spike_ids)
        if len(new_spike_clusters) == 1 and len(spike_ids) > 1:
            new_spike_clusters = (np.ones(len(spike_ids), dtype=np.int64) *
                                  new_spike_clusters[0])
        old_spike_clusters = self._spike_clusters[spike_ids]

        assert len(spike_ids) == len(old_spike_clusters)
        assert len(new_spike_clusters) == len(spike_ids)

        # Update the spikes per cluster structure.
        old_clusters = _unique(old_spike_clusters)

        # NOTE: shortcut to a merge if this assignment is effectively a merge
        # i.e. if all spikes are assigned to a single cluster.
        # The fact that spike selection has been previously extended to
        # whole clusters is critical here.
        new_clusters = _unique(new_spike_clusters)
        if len(new_clusters) == 1:
            return self._do_merge(spike_ids, old_clusters, new_clusters[0])

        # We return the UpdateInfo structure.
        up = _assign_update_info(spike_ids,
                                 old_spike_clusters,
                                 new_spike_clusters)

        # We update the new cluster id (strictly increasing during a session).
        self._new_cluster_id = max(self._new_cluster_id, max(up.added) + 1)

        # We make the assignments.
        self._spike_clusters[spike_ids] = new_spike_clusters
        # OPTIM: we update spikes_per_cluster manually.
        new_spc = _spikes_per_cluster(new_spike_clusters, spike_ids)
        self._update_cluster_ids(to_remove=old_clusters, to_add=new_spc)
        return up
예제 #5
0
파일: clustering.py 프로젝트: ablot/phy
def _extend_spikes(spike_ids, spike_clusters):
    """Return all spikes belonging to the clusters containing the specified
    spikes."""
    # We find the spikes belonging to modified clusters.
    # What are the old clusters that are modified by the assignment?
    old_spike_clusters = spike_clusters[spike_ids]
    unique_clusters = _unique(old_spike_clusters)
    # Now we take all spikes from these clusters.
    changed_spike_ids = _spikes_in_clusters(spike_clusters, unique_clusters)
    # These are the new spikes that need to be reassigned.
    extended_spike_ids = np.setdiff1d(changed_spike_ids, spike_ids,
                                      assume_unique=True)
    return extended_spike_ids
예제 #6
0
def _extend_spikes(spike_ids, spike_clusters):
    """Return all spikes belonging to the clusters containing the specified
    spikes."""
    # We find the spikes belonging to modified clusters.
    # What are the old clusters that are modified by the assignment?
    old_spike_clusters = spike_clusters[spike_ids]
    unique_clusters = _unique(old_spike_clusters)
    # Now we take all spikes from these clusters.
    changed_spike_ids = _spikes_in_clusters(spike_clusters, unique_clusters)
    # These are the new spikes that need to be reassigned.
    extended_spike_ids = np.setdiff1d(changed_spike_ids,
                                      spike_ids,
                                      assume_unique=True)
    return extended_spike_ids
예제 #7
0
 def _update_cluster_ids(self, to_remove=None, to_add=None):
     # Update the list of non-empty cluster ids.
     self._cluster_ids = _unique(self._spike_clusters)
     # Clusters to remove.
     if to_remove is not None:
         for clu in to_remove:
             self._spikes_per_cluster.pop(clu, None)
     # Clusters to add.
     if to_add:
         for clu, spk in to_add.items():
             self._spikes_per_cluster[clu] = spk
     # If spikes_per_cluster is invalid, recompute the entire
     # spikes_per_cluster array.
     coherent = np.all(np.in1d(self._cluster_ids,
                               sorted(self._spikes_per_cluster),
                               ))
     if not coherent:
         logger.debug("Recompute spikes_per_cluster manually: "
                      "this is long.")
         sc = self._spike_clusters
         self._spikes_per_cluster = _spikes_per_cluster(sc)
예제 #8
0
 def _update_cluster_ids(self, to_remove=None, to_add=None):
     # Update the list of non-empty cluster ids.
     self._cluster_ids = _unique(self._spike_clusters)
     # Clusters to remove.
     if to_remove is not None:
         for clu in to_remove:
             self._spikes_per_cluster.pop(clu, None)
     # Clusters to add.
     if to_add:
         for clu, spk in to_add.items():
             self._spikes_per_cluster[clu] = spk
     # If spikes_per_cluster is invalid, recompute the entire
     # spikes_per_cluster array.
     coherent = np.all(
         np.in1d(
             self._cluster_ids,
             sorted(self._spikes_per_cluster),
         ))
     if not coherent:
         logger.debug("Recompute spikes_per_cluster manually: "
                      "this is long.")
         sc = self._spike_clusters
         self._spikes_per_cluster = _spikes_per_cluster(sc)
예제 #9
0
 def cluster_ids(self):
     """Ordered list of ids of all non-empty clusters."""
     return _unique(self._spike_clusters)
예제 #10
0
파일: ccg.py 프로젝트: ablot/phy
def correlograms(spike_times,
                 spike_clusters,
                 cluster_ids=None,
                 sample_rate=1.,
                 bin_size=None,
                 window_size=None,
                 symmetrize=True,
                 ):
    """Compute all pairwise cross-correlograms among the clusters appearing
    in `spike_clusters`.

    Parameters
    ----------

    spike_times : array-like
        Spike times in seconds.
    spike_clusters : array-like
        Spike-cluster mapping.
    cluster_ids : array-like
        The list of unique clusters, in any order. That order will be used
        in the output array.
    bin_size : float
        Size of the bin, in seconds.
    window_size : float
        Size of the window, in seconds.

    Returns
    -------

    correlograms : array
        A `(n_clusters, n_clusters, winsize_samples)` array with all pairwise
        CCGs.

    """
    assert sample_rate > 0.
    assert np.all(np.diff(spike_times) >= 0), ("The spike times must be "
                                               "increasing.")

    # Get the spike samples.
    spike_times = np.asarray(spike_times, dtype=np.float64)
    spike_samples = (spike_times * sample_rate).astype(np.int64)

    spike_clusters = _as_array(spike_clusters)

    assert spike_samples.ndim == 1
    assert spike_samples.shape == spike_clusters.shape

    # Find `binsize`.
    bin_size = np.clip(bin_size, 1e-5, 1e5)  # in seconds
    binsize = int(sample_rate * bin_size)  # in samples
    assert binsize >= 1

    # Find `winsize_bins`.
    window_size = np.clip(window_size, 1e-5, 1e5)  # in seconds
    winsize_bins = 2 * int(.5 * window_size / bin_size) + 1

    assert winsize_bins >= 1
    assert winsize_bins % 2 == 1

    # Take the cluster oder into account.
    if cluster_ids is None:
        clusters = _unique(spike_clusters)
    else:
        clusters = _as_array(cluster_ids)
    n_clusters = len(clusters)

    # Like spike_clusters, but with 0..n_clusters-1 indices.
    spike_clusters_i = _index_of(spike_clusters, clusters)

    # Shift between the two copies of the spike trains.
    shift = 1

    # At a given shift, the mask precises which spikes have matching spikes
    # within the correlogram time window.
    mask = np.ones_like(spike_samples, dtype=np.bool)

    correlograms = _create_correlograms_array(n_clusters, winsize_bins)

    # The loop continues as long as there is at least one spike with
    # a matching spike.
    while mask[:-shift].any():
        # Number of time samples between spike i and spike i+shift.
        spike_diff = _diff_shifted(spike_samples, shift)

        # Binarize the delays between spike i and spike i+shift.
        spike_diff_b = spike_diff // binsize

        # Spikes with no matching spikes are masked.
        mask[:-shift][spike_diff_b > (winsize_bins // 2)] = False

        # Cache the masked spike delays.
        m = mask[:-shift].copy()
        d = spike_diff_b[m]

        # # Update the masks given the clusters to update.
        # m0 = np.in1d(spike_clusters[:-shift], clusters)
        # m = m & m0
        # d = spike_diff_b[m]
        d = spike_diff_b[m]

        # Find the indices in the raveled correlograms array that need
        # to be incremented, taking into account the spike clusters.
        indices = np.ravel_multi_index((spike_clusters_i[:-shift][m],
                                        spike_clusters_i[+shift:][m],
                                        d),
                                       correlograms.shape)

        # Increment the matching spikes in the correlograms array.
        _increment(correlograms.ravel(), indices)

        shift += 1

    # Remove ACG peaks.
    correlograms[np.arange(n_clusters),
                 np.arange(n_clusters),
                 0] = 0

    if symmetrize:
        return _symmetrize_correlograms(correlograms)
    else:
        return correlograms
예제 #11
0
파일: clustering.py 프로젝트: ablot/phy
 def cluster_ids(self):
     """Ordered list of ids of all non-empty clusters."""
     return _unique(self._spike_clusters)
예제 #12
0
def correlograms(
    spike_times,
    spike_clusters,
    cluster_ids=None,
    sample_rate=1.,
    bin_size=None,
    window_size=None,
    symmetrize=True,
):
    """Compute all pairwise cross-correlograms among the clusters appearing
    in `spike_clusters`.

    Parameters
    ----------

    spike_times : array-like
        Spike times in seconds.
    spike_clusters : array-like
        Spike-cluster mapping.
    cluster_ids : array-like
        The list of unique clusters, in any order. That order will be used
        in the output array.
    bin_size : float
        Size of the bin, in seconds.
    window_size : float
        Size of the window, in seconds.

    Returns
    -------

    correlograms : array
        A `(n_clusters, n_clusters, winsize_samples)` array with all pairwise
        CCGs.

    """
    assert sample_rate > 0.
    assert np.all(np.diff(spike_times) >= 0), ("The spike times must be "
                                               "increasing.")

    # Get the spike samples.
    spike_times = np.asarray(spike_times, dtype=np.float64)
    spike_samples = (spike_times * sample_rate).astype(np.int64)

    spike_clusters = _as_array(spike_clusters)

    assert spike_samples.ndim == 1
    assert spike_samples.shape == spike_clusters.shape

    # Find `binsize`.
    bin_size = np.clip(bin_size, 1e-5, 1e5)  # in seconds
    binsize = int(sample_rate * bin_size)  # in samples
    assert binsize >= 1

    # Find `winsize_bins`.
    window_size = np.clip(window_size, 1e-5, 1e5)  # in seconds
    winsize_bins = 2 * int(.5 * window_size / bin_size) + 1

    assert winsize_bins >= 1
    assert winsize_bins % 2 == 1

    # Take the cluster oder into account.
    if cluster_ids is None:
        clusters = _unique(spike_clusters)
    else:
        clusters = _as_array(cluster_ids)
    n_clusters = len(clusters)

    # Like spike_clusters, but with 0..n_clusters-1 indices.
    spike_clusters_i = _index_of(spike_clusters, clusters)

    # Shift between the two copies of the spike trains.
    shift = 1

    # At a given shift, the mask precises which spikes have matching spikes
    # within the correlogram time window.
    mask = np.ones_like(spike_samples, dtype=np.bool)

    correlograms = _create_correlograms_array(n_clusters, winsize_bins)

    # The loop continues as long as there is at least one spike with
    # a matching spike.
    while mask[:-shift].any():
        # Number of time samples between spike i and spike i+shift.
        spike_diff = _diff_shifted(spike_samples, shift)

        # Binarize the delays between spike i and spike i+shift.
        spike_diff_b = spike_diff // binsize

        # Spikes with no matching spikes are masked.
        mask[:-shift][spike_diff_b > (winsize_bins // 2)] = False

        # Cache the masked spike delays.
        m = mask[:-shift].copy()
        d = spike_diff_b[m]

        # # Update the masks given the clusters to update.
        # m0 = np.in1d(spike_clusters[:-shift], clusters)
        # m = m & m0
        # d = spike_diff_b[m]
        d = spike_diff_b[m]

        # Find the indices in the raveled correlograms array that need
        # to be incremented, taking into account the spike clusters.
        indices = np.ravel_multi_index(
            (spike_clusters_i[:-shift][m], spike_clusters_i[+shift:][m], d),
            correlograms.shape)

        # Increment the matching spikes in the correlograms array.
        _increment(correlograms.ravel(), indices)

        shift += 1

    # Remove ACG peaks.
    correlograms[np.arange(n_clusters), np.arange(n_clusters), 0] = 0

    if symmetrize:
        return _symmetrize_correlograms(correlograms)
    else:
        return correlograms
예제 #13
0
    def _init_data(self):
        if op.exists(self.dat_path):
            logger.debug("Loading traces at `%s`.", self.dat_path)
            traces = _dat_to_traces(
                self.dat_path,
                n_channels=self.n_channels_dat,
                dtype=self.dtype or np.int16,
                offset=self.offset,
            )
            n_samples_t, _ = traces.shape
            assert _ == self.n_channels_dat
        else:
            if self.dat_path is not None:
                logger.warning("Error while loading data: File %s not found.",
                               self.dat_path)
            traces = None
            n_samples_t = 0

        logger.debug("Loading amplitudes.")
        amplitudes = read_array('amplitudes').squeeze()
        n_spikes, = amplitudes.shape
        self.n_spikes = n_spikes

        # Create spike_clusters if the file doesn't exist.
        if not op.exists(filenames['spike_clusters']):
            shutil.copy(filenames['spike_templates'],
                        filenames['spike_clusters'])
        logger.debug("Loading %d spike clusters.", self.n_spikes)
        spike_clusters = read_array('spike_clusters').squeeze()
        spike_clusters = spike_clusters.astype(np.int32)
        assert spike_clusters.shape == (n_spikes, )
        self.spike_clusters = spike_clusters

        logger.debug("Loading spike templates.")
        spike_templates = read_array('spike_templates').squeeze()
        spike_templates = spike_templates.astype(np.int32)
        assert spike_templates.shape == (n_spikes, )
        self.spike_templates = spike_templates

        logger.debug("Loading spike samples.")
        spike_samples = read_array('spike_samples').squeeze()
        assert spike_samples.shape == (n_spikes, )

        logger.debug("Loading templates.")
        templates = read_array('templates')
        templates[np.isnan(templates)] = 0
        # templates = np.transpose(templates, (2, 1, 0))

        # Unwhiten the templates.
        logger.debug("Loading the whitening matrix.")
        self.whitening_matrix = read_array('whitening_matrix')

        if op.exists(filenames['templates_unw']):
            logger.debug("Loading unwhitened templates.")
            templates_unw = read_array('templates_unw')
            templates_unw[np.isnan(templates_unw)] = 0
        else:
            logger.debug("Couldn't find unwhitened templates, computing them.")
            logger.debug("Inversing the whitening matrix %s.",
                         self.whitening_matrix.shape)
            wmi = np.linalg.inv(self.whitening_matrix)
            logger.debug("Unwhitening the templates %s.", templates.shape)
            templates_unw = np.dot(np.ascontiguousarray(templates),
                                   np.ascontiguousarray(wmi))
            # Save the unwhitened templates.
            write_array('templates_unw.npy', templates_unw)

        n_templates, n_samples_templates, n_channels = templates.shape
        self.n_templates = n_templates

        logger.debug("Loading similar templates.")
        self.similar_templates = read_array('similar_templates')
        assert self.similar_templates.shape == (self.n_templates,
                                                self.n_templates)

        logger.debug("Loading channel mapping.")
        channel_mapping = read_array('channel_mapping').squeeze()
        channel_mapping = channel_mapping.astype(np.int32)
        assert channel_mapping.shape == (n_channels, )
        # Ensure that the mappings maps to valid columns in the dat file.
        assert np.all(channel_mapping <= self.n_channels_dat - 1)

        logger.debug("Loading channel positions.")
        channel_positions = read_array('channel_positions')
        assert channel_positions.shape == (n_channels, 2)

        if op.exists(filenames['features']):
            logger.debug("Loading features.")
            all_features = np.load(filenames['features'], mmap_mode='r')
            features_ind = read_array('features_ind').astype(np.int32)
            # Feature subset.
            if op.exists(filenames['features_spike_ids']):
                features_spike_ids = read_array('features_spike_ids') \
                    .astype(np.int32)
                assert len(features_spike_ids) == len(all_features)
                self.features_spike_ids = features_spike_ids
                ns = len(features_spike_ids)
            else:
                ns = self.n_spikes
                self.features_spike_ids = None

            assert all_features.ndim == 3
            n_loc_chan = all_features.shape[2]
            self.n_features_per_channel = all_features.shape[1]
            assert all_features.shape == (
                ns,
                self.n_features_per_channel,
                n_loc_chan,
            )
            # Check sparse features arrays shapes.
            assert features_ind.shape == (self.n_templates, n_loc_chan)
        else:
            all_features = None
            features_ind = None

        self.all_features = all_features
        self.features_ind = features_ind

        if op.exists(filenames['template_features']):
            logger.debug("Loading template features.")
            template_features = np.load(filenames['template_features'],
                                        mmap_mode='r')
            template_features_ind = read_array('template_features_ind'). \
                astype(np.int32)
            template_features_ind = template_features_ind.copy()
            n_sim_tem = template_features.shape[1]
            assert template_features.shape == (n_spikes, n_sim_tem)
            assert template_features_ind.shape == (n_templates, n_sim_tem)
        else:
            template_features = None
            template_features_ind = None

        self.template_features_ind = template_features_ind
        self.template_features = template_features

        self.n_channels = n_channels
        # Take dead channels into account.
        if traces is not None:
            # Find the scaling factor for the traces.
            scaling = 1. / self._data_lim(traces[:10000])
            traces = _concatenate_virtual_arrays(
                [traces],
                channel_mapping,
                scaling=scaling,
            )
        else:
            scaling = 1.

        # Amplitudes
        self.all_amplitudes = amplitudes
        self.amplitudes_lim = np.max(self.all_amplitudes)

        # Templates
        self.templates = templates
        self.templates_unw = templates_unw
        assert self.templates.shape == self.templates_unw.shape
        self.n_samples_templates = n_samples_templates
        self.n_samples_waveforms = n_samples_templates
        self.template_lim = np.max(np.abs(self.templates))

        self.duration = n_samples_t / float(self.sample_rate)

        self.spike_times = spike_samples / float(self.sample_rate)
        assert np.all(np.diff(self.spike_times) >= 0)

        self.cluster_ids = _unique(self.spike_clusters)
        # n_clusters = len(self.cluster_ids)

        self.channel_positions = channel_positions
        self.all_traces = traces

        # Only filter the data for the waveforms if the traces
        # are not already filtered.
        if not getattr(self, 'hp_filtered', False):
            logger.debug("HP filtering the data for waveforms")
            filter_order = 3
        else:
            filter_order = None

        n_closest_channels = getattr(self, 'max_n_unmasked_channels', 16)
        mask_threshold = getattr(self, 'waveform_mask_threshold', None)
        self.closest_channels = get_closest_channels(
            self.channel_positions,
            n_closest_channels,
        )
        self.template_masks = get_masks(self.templates, self.closest_channels)
        self.all_masks = MaskLoader(self.template_masks, self.spike_templates)

        # Fetch waveforms from traces.
        nsw = self.n_samples_waveforms
        if traces is not None:
            waveforms = WaveformLoader(
                traces=traces,
                masks=self.all_masks,
                spike_samples=spike_samples,
                n_samples_waveforms=nsw,
                filter_order=filter_order,
                sample_rate=self.sample_rate,
                mask_threshold=mask_threshold,
            )
        else:
            waveforms = None
        self.all_waveforms = waveforms

        # Read the cluster groups.
        logger.debug("Loading the cluster groups.")
        self.cluster_groups = {}
        if op.exists(filenames['cluster_groups']):
            with open(filenames['cluster_groups'], 'r') as f:
                reader = csv.reader(f, delimiter='\t')
                # Skip the header.
                for row in reader:
                    break
                for row in reader:
                    cluster, group = row
                    cluster = int(cluster)
                    self.cluster_groups[cluster] = group
        for cluster_id in self.cluster_ids:
            if cluster_id not in self.cluster_groups:
                self.cluster_groups[cluster_id] = None
예제 #14
0
파일: gui.py 프로젝트: kwikteam/phy-contrib
    def _init_data(self):
        if op.exists(self.dat_path):
            logger.debug("Loading traces at `%s`.", self.dat_path)
            traces = _dat_to_traces(self.dat_path,
                                    n_channels=self.n_channels_dat,
                                    dtype=self.dtype or np.int16,
                                    offset=self.offset,
                                    )
            n_samples_t, _ = traces.shape
            assert _ == self.n_channels_dat
        else:
            if self.dat_path is not None:
                logger.warning("Error while loading data: File %s not found.",
                               self.dat_path)
            traces = None
            n_samples_t = 0

        logger.debug("Loading amplitudes.")
        amplitudes = read_array('amplitudes').squeeze()
        n_spikes, = amplitudes.shape
        self.n_spikes = n_spikes

        # Create spike_clusters if the file doesn't exist.
        if not op.exists(filenames['spike_clusters']):
            shutil.copy(filenames['spike_templates'],
                        filenames['spike_clusters'])
        logger.debug("Loading %d spike clusters.", self.n_spikes)
        spike_clusters = read_array('spike_clusters').squeeze()
        spike_clusters = spike_clusters.astype(np.int32)
        assert spike_clusters.shape == (n_spikes,)
        self.spike_clusters = spike_clusters

        logger.debug("Loading spike templates.")
        spike_templates = read_array('spike_templates').squeeze()
        spike_templates = spike_templates.astype(np.int32)
        assert spike_templates.shape == (n_spikes,)
        self.spike_templates = spike_templates

        logger.debug("Loading spike samples.")
        spike_samples = read_array('spike_samples').squeeze()
        assert spike_samples.shape == (n_spikes,)

        logger.debug("Loading templates.")
        templates = read_array('templates')
        templates[np.isnan(templates)] = 0
        # templates = np.transpose(templates, (2, 1, 0))

        # Unwhiten the templates.
        logger.debug("Loading the whitening matrix.")
        self.whitening_matrix = read_array('whitening_matrix')

        if op.exists(filenames['templates_unw']):
            logger.debug("Loading unwhitened templates.")
            templates_unw = read_array('templates_unw')
            templates_unw[np.isnan(templates_unw)] = 0
        else:
            logger.debug("Couldn't find unwhitened templates, computing them.")
            logger.debug("Inversing the whitening matrix %s.",
                         self.whitening_matrix.shape)
            wmi = np.linalg.inv(self.whitening_matrix)
            logger.debug("Unwhitening the templates %s.",
                         templates.shape)
            templates_unw = np.dot(np.ascontiguousarray(templates),
                                   np.ascontiguousarray(wmi))
            # Save the unwhitened templates.
            write_array('templates_unw.npy', templates_unw)

        n_templates, n_samples_templates, n_channels = templates.shape
        self.n_templates = n_templates

        logger.debug("Loading similar templates.")
        self.similar_templates = read_array('similar_templates')
        assert self.similar_templates.shape == (self.n_templates,
                                                self.n_templates)

        logger.debug("Loading channel mapping.")
        channel_mapping = read_array('channel_mapping').squeeze()
        channel_mapping = channel_mapping.astype(np.int32)
        assert channel_mapping.shape == (n_channels,)
        # Ensure that the mappings maps to valid columns in the dat file.
        assert np.all(channel_mapping <= self.n_channels_dat - 1)
        self.channel_order = channel_mapping

        logger.debug("Loading channel positions.")
        channel_positions = read_array('channel_positions')
        assert channel_positions.shape == (n_channels, 2)

        if op.exists(filenames['features']):
            logger.debug("Loading features.")
            all_features = np.load(filenames['features'], mmap_mode='r')
            features_ind = read_array('features_ind').astype(np.int32)
            # Feature subset.
            if op.exists(filenames['features_spike_ids']):
                features_spike_ids = read_array('features_spike_ids') \
                    .astype(np.int32)
                assert len(features_spike_ids) == len(all_features)
                self.features_spike_ids = features_spike_ids
                ns = len(features_spike_ids)
            else:
                ns = self.n_spikes
                self.features_spike_ids = None

            assert all_features.ndim == 3
            n_loc_chan = all_features.shape[2]
            self.n_features_per_channel = all_features.shape[1]
            assert all_features.shape == (ns,
                                          self.n_features_per_channel,
                                          n_loc_chan,
                                          )
            # Check sparse features arrays shapes.
            assert features_ind.shape == (self.n_templates, n_loc_chan)
        else:
            all_features = None
            features_ind = None

        self.all_features = all_features
        self.features_ind = features_ind

        if op.exists(filenames['template_features']):
            logger.debug("Loading template features.")
            template_features = np.load(filenames['template_features'],
                                        mmap_mode='r')
            template_features_ind = read_array('template_features_ind'). \
                astype(np.int32)
            template_features_ind = template_features_ind.copy()
            n_sim_tem = template_features.shape[1]
            assert template_features.shape == (n_spikes, n_sim_tem)
            assert template_features_ind.shape == (n_templates, n_sim_tem)
        else:
            template_features = None
            template_features_ind = None

        self.template_features_ind = template_features_ind
        self.template_features = template_features

        self.n_channels = n_channels
        # Take dead channels into account.
        if traces is not None:
            # Find the scaling factor for the traces.
            scaling = 1. / self._data_lim(traces[:10000])
            traces = _concatenate_virtual_arrays([traces],
                                                 channel_mapping,
                                                 scaling=scaling,
                                                 )
        else:
            scaling = 1.

        # Amplitudes
        self.all_amplitudes = amplitudes
        self.amplitudes_lim = np.max(self.all_amplitudes)

        # Templates
        self.templates = templates
        self.templates_unw = templates_unw
        assert self.templates.shape == self.templates_unw.shape
        self.n_samples_templates = n_samples_templates
        self.n_samples_waveforms = n_samples_templates
        self.template_lim = np.max(np.abs(self.templates))

        self.duration = n_samples_t / float(self.sample_rate)

        self.spike_times = spike_samples / float(self.sample_rate)
        assert np.all(np.diff(self.spike_times) >= 0)

        self.cluster_ids = _unique(self.spike_clusters)
        # n_clusters = len(self.cluster_ids)

        self.channel_positions = channel_positions
        self.all_traces = traces

        # Only filter the data for the waveforms if the traces
        # are not already filtered.
        if not getattr(self, 'hp_filtered', False):
            logger.debug("HP filtering the data for waveforms")
            filter_order = 3
        else:
            filter_order = None

        n_closest_channels = getattr(self, 'max_n_unmasked_channels', 16)
        self.closest_channels = get_closest_channels(self.channel_positions,
                                                     n_closest_channels,
                                                     )
        self.template_masks = get_masks(self.templates, self.closest_channels)
        self.all_masks = MaskLoader(self.template_masks, self.spike_templates)

        # Fetch waveforms from traces.
        nsw = self.n_samples_waveforms
        if traces is not None:
            waveforms = WaveformLoader(traces=traces,
                                       spike_samples=spike_samples,
                                       n_samples_waveforms=nsw,
                                       filter_order=filter_order,
                                       sample_rate=self.sample_rate,
                                       )
        else:
            waveforms = None
        self.all_waveforms = waveforms

        # Read the cluster groups.
        logger.debug("Loading the cluster groups.")
        self.cluster_groups = load_metadata(filenames['cluster_groups'],
                                            self.cluster_ids)
예제 #15
0
 def _update_cluster_ids(self):
     # Update the list of non-empty cluster ids.
     self._cluster_ids = _unique(self._spike_clusters)