예제 #1
0
파일: views.py 프로젝트: stephenlenzi/phy
def extract_spikes(traces,
                   interval,
                   sample_rate=None,
                   spike_times=None,
                   spike_clusters=None,
                   cluster_groups=None,
                   all_masks=None,
                   n_samples_waveforms=None):
    cluster_groups = cluster_groups or {}
    sr = sample_rate
    ns = n_samples_waveforms
    if not isinstance(ns, tuple):
        ns = (ns // 2, ns // 2)
    offset_samples = ns[0]
    wave_len = ns[0] + ns[1]

    # Find spikes.
    a, b = spike_times.searchsorted(interval)
    st = spike_times[a:b]
    sc = spike_clusters[a:b]
    m = all_masks[a:b]
    n = len(st)
    assert len(sc) == n
    assert m.shape[0] == n

    # Extract waveforms.
    spikes = []
    for i in range(n):
        b = Bunch()
        # Find the start of the waveform in the extracted traces.
        sample_start = int(round((st[i] - interval[0]) * sr))
        sample_start -= offset_samples
        o = _extract_wave(traces, sample_start, m[i], wave_len)
        if o is None:  # pragma: no cover
            logger.debug("Unable to extract spike %d.", i)
            continue
        b.waveforms, b.channels = o
        # Masks on unmasked channels.
        b.masks = m[i, b.channels]
        b.spike_time = st[i]
        b.spike_cluster = sc[i]
        b.cluster_group = cluster_groups.get(b.spike_cluster, None)
        b.offset_samples = offset_samples

        spikes.append(b)
    return spikes