コード例 #1
0
def nearest_neighbors_noise_overlap(waveform_extractor: si.WaveformExtractor,
                                    this_unit_id: int,
                                    max_spikes_for_nn: int = 1000,
                                    n_neighbors: int = 5,
                                    n_components: int = 10,
                                    radius_um: float = 100,
                                    seed: int = 0):
    """Calculates unit noise overlap based on NearestNeighbors search in PCA space.

    Based on noise overlap metric described in Chung et al. (2017) Neuron 95: 1381-1394.

    Rough logic:
    ------------
    1) Generate a noise cluster by randomly sampling voltage snippets from recording.
    2) Subtract projection onto the weighted average of noise snippets
       of both the target and noise clusters to correct for bias in sampling.
    3) Compute the isolation score between the noise cluster and the target cluster.
    
    Implementation details:
    -----------------------
    As with nn_isolation, the clusters that are compared (target and noise clusters)
    have the same number of spikes.
    
    See docstring for `_compute_isolation` for the definition of isolation score.
    
    Parameters:
    -----------
    we: si.WaveformExtractor
    this_unit_id: int
        ID of unit for which this metric will be calculated
    max_spikes_for_nn: int
        max number of spikes to use per cluster
    n_neighbors: int
        number of neighbors to check membership of
    n_components: int
        number of PC components to project the snippets
    radius_um: float
        only the channels within this radius of the peak channel
        are used to compute the metric
    seed: int
        seed for random subsampling of spikes

    Outputs:
    --------
    nn_noise_overlap : float
    """

    # set random seed
    rng = np.random.default_rng(seed=seed)

    # get random snippets from the recording to create a noise cluster
    recording = waveform_extractor.recording
    noise_cluster = get_random_data_chunks(
        recording,
        return_scaled=waveform_extractor.return_scaled,
        num_chunks_per_segment=max_spikes_for_nn,
        chunk_size=waveform_extractor.nsamples,
        seed=seed)

    noise_cluster = np.reshape(
        noise_cluster, (max_spikes_for_nn, waveform_extractor.nsamples, -1))

    # get waveforms for target cluster
    waveforms = waveform_extractor.get_waveforms(unit_id=this_unit_id)

    # adjust the size of the target and noise clusters to be equal
    if waveforms.shape[0] > max_spikes_for_nn:
        wf_ind = rng.choice(waveforms.shape[0],
                            max_spikes_for_nn,
                            replace=False)
        waveforms = waveforms[wf_ind]
        n_snippets = max_spikes_for_nn
    elif waveforms.shape[0] < max_spikes_for_nn:
        noise_ind = rng.choice(noise_cluster.shape[0],
                               waveforms.shape[0],
                               replace=False)
        noise_cluster = noise_cluster[noise_ind]
        n_snippets = waveforms.shape[0]
    else:
        n_snippets = max_spikes_for_nn

    # restrict to channels with significant signal
    closest_chans_idx = get_template_channel_sparsity(waveform_extractor,
                                                      method='radius',
                                                      outputs='index',
                                                      peak_sign='both',
                                                      radius_um=radius_um)
    waveforms = waveforms[:, :, closest_chans_idx[this_unit_id]]
    noise_cluster = noise_cluster[:, :, closest_chans_idx[this_unit_id]]

    # compute weighted noise snippet (Z)
    median_waveform = waveform_extractor.get_template(unit_id=this_unit_id,
                                                      mode='median')
    median_waveform = median_waveform[:, closest_chans_idx[this_unit_id]]
    tmax, chmax = np.unravel_index(np.argmax(np.abs(median_waveform)),
                                   median_waveform.shape)
    weights = [noise_clip[tmax, chmax] for noise_clip in noise_cluster]
    weights = np.asarray(weights)
    weights = weights / np.sum(weights)
    weighted_noise_snippet = np.sum(weights * noise_cluster.swapaxes(0, 2),
                                    axis=2).swapaxes(0, 1)

    # subtract projection onto weighted noise snippet
    for snippet in range(n_snippets):
        waveforms[snippet, :, :] = _subtract_clip_component(
            waveforms[snippet, :, :], weighted_noise_snippet)
        noise_cluster[snippet, :, :] = _subtract_clip_component(
            noise_cluster[snippet, :, :], weighted_noise_snippet)

    # compute principal components after concatenation
    all_snippets = np.concatenate([
        waveforms.reshape((n_snippets, -1)),
        noise_cluster.reshape((n_snippets, -1))
    ],
                                  axis=0)
    pca = IncrementalPCA(n_components=n_components)
    pca.partial_fit(all_snippets)
    projected_snippets = pca.transform(all_snippets)

    # compute overlap
    nn_noise_overlap = 1 - _compute_isolation(
        projected_snippets[:n_snippets, :], projected_snippets[n_snippets:, :],
        n_neighbors)
    return nn_noise_overlap
コード例 #2
0
def nearest_neighbors_isolation(waveform_extractor: si.WaveformExtractor,
                                this_unit_id: int,
                                max_spikes_for_nn: int = 1000,
                                n_neighbors: int = 5,
                                n_components: int = 10,
                                radius_um: float = 100,
                                seed: int = 0):
    """Calculates unit isolation based on NearestNeighbors search in PCA space

    Based on isolation metric described in Chung et al. (2017) Neuron 95: 1381-1394.

    Rough logic:
    ------------
    1) Choose a cluster
    2) Compute the isolation score with every other cluster
    3) Isolation score is defined as the min of (2) (i.e. 'worst-case measure')
    
    Implementation details:
    -----------------------
    Let A and B be two clusters from sorting. 
    
    We set |A| = |B|:
        If max_spikes_for_nn < |A| and max_spikes_for_nn < |B|, then randomly subsample max_spikes_for_nn samples from A and B.
        If max_spikes_for_nn > min(|A|, |B|) (e.g. |A| > max_spikes_for_nn > |B|), then randomly subsample min(|A|, |B|) samples from A and B.
        This is because the metric is affected by the size of the clusters being compared independently of how well-isolated they are.
    
    We also restrict the waveforms to channels with significant signal
    
    See docstring for `_compute_isolation` for the definition of isolation score.

    Parameters:
    -----------
    all_pcs: array_like, (num_spikes, PCs)
        2D array of PCs for all spikes
    all_labels: array_like, (num_spikes, )
        1D array of cluster labels for all spikes
    this_unit_id: int
        ID of unit for which this metric will be calculated
    max_spikes_for_nn: int
        max number of spikes to use per cluster
    n_neighbors: int
        number of neighbors to check membership of
    seed: int
        seed for random subsampling of spikes

    Outputs:
    --------
    nn_isolation : float
    """

    rng = np.random.default_rng(seed=seed)

    all_units_ids = waveform_extractor.sorting.get_unit_ids()
    other_units_ids = np.setdiff1d(all_units_ids, this_unit_id)

    # get waveforms for target cluster
    waveforms_target_unit = waveform_extractor.get_waveforms(
        unit_id=this_unit_id)
    n_spikes_target_unit = waveforms_target_unit.shape[0]

    # find units whose closest channels overlap with closest channels of target cluster
    closest_chans_all = get_template_channel_sparsity(waveform_extractor,
                                                      method='radius',
                                                      outputs='index',
                                                      peak_sign='both',
                                                      radius_um=radius_um)
    closest_chans_target_unit = closest_chans_all[this_unit_id]
    other_units_ids = [
        unit_id for unit_id in other_units_ids if np.any(
            np.in1d(closest_chans_all[unit_id], closest_chans_target_unit))
    ]

    # if no unit is within neighborhood of target unit, then just say isolation is 1 (best possible)
    if not other_units_ids:
        nn_isolation = 1
    # if there are units to compare, then compute isolation with each
    else:
        isolation = np.zeros(len(other_units_ids), )
        for other_unit_id in other_units_ids:
            waveforms_other_unit = waveform_extractor.get_waveforms(
                unit_id=other_unit_id)
            n_spikes_other_unit = waveforms_other_unit.shape[0]

            n_snippets = np.min(
                [n_spikes_target_unit, n_spikes_other_unit, max_spikes_for_nn])

            # make the two clusters equal in terms of:
            # - number of spikes
            # - channels with signal
            waveforms_target_unit_idx = rng.choice(n_spikes_target_unit,
                                                   size=n_snippets,
                                                   replace=False)
            waveforms_target_unit_sampled = waveforms_target_unit[
                waveforms_target_unit_idx]
            waveforms_target_unit_sampled = waveforms_target_unit_sampled[:, :,
                                                                          closest_chans_target_unit]

            waveforms_other_unit_idx = rng.choice(n_spikes_other_unit,
                                                  size=n_snippets,
                                                  replace=False)
            waveforms_other_unit_sampled = waveforms_other_unit[
                waveforms_other_unit_idx]
            waveforms_other_unit_sampled = waveforms_other_unit_sampled[:, :,
                                                                        closest_chans_target_unit]

            # compute principal components after concatenation
            all_snippets = np.concatenate([
                waveforms_target_unit_sampled.reshape((n_snippets, -1)),
                waveforms_other_unit_sampled.reshape((n_snippets, -1))
            ],
                                          axis=0)
            pca = IncrementalPCA(n_components=n_components)
            pca.partial_fit(all_snippets)
            projected_snippets = pca.transform(all_snippets)
            # compute isolation
            isolation[other_unit_id == other_units_ids] = _compute_isolation(
                projected_snippets[:n_snippets, :],
                projected_snippets[n_snippets:, :], n_neighbors)
        nn_isolation = np.min(isolation)
    return nn_isolation