def test_append_sortings():
    sampling_frequency = 30000.
    times = np.arange(0, 1000, 10)
    labels = np.zeros(times.size, dtype='int64')
    labels[0::3] = 0
    labels[1::3] = 1
    labels[2::3] = 2
    sorting0 = NumpySorting.from_times_labels([times] * 3, [labels] * 3, sampling_frequency)
    sorting1 = NumpySorting.from_times_labels([times] * 2, [labels] * 2, sampling_frequency)

    sorting = append_sortings([sorting0, sorting1])
    # print(sorting)
    assert sorting.get_num_segments() == 5
Esempio n. 2
0
def test_NumpySorting():
    sampling_frequency = 30000

    # empty
    unit_ids = []
    sorting = NumpySorting(sampling_frequency, unit_ids)
    # print(sorting)

    # 2 columns
    times = np.arange(0, 1000, 10)
    labels = np.zeros(times.size, dtype='int64')
    labels[0::3] = 0
    labels[1::3] = 1
    labels[2::3] = 2
    sorting = NumpySorting.from_times_labels(times, labels, sampling_frequency)
    print(sorting)
    assert sorting.get_num_segments() == 1

    sorting = NumpySorting.from_times_labels([times] * 3, [labels] * 3,
                                             sampling_frequency)
    # print(sorting)
    assert sorting.get_num_segments() == 3

    # from other extracrtor
    num_seg = 2
    file_path = 'test_NpzSortingExtractor.npz'
    create_sorting_npz(num_seg, file_path)
    other_sorting = NpzSortingExtractor(file_path)

    sorting = NumpySorting.from_extractor(other_sorting)
def detect_peaks(recording,
                 method='by_channel',
                 peak_sign='neg',
                 detect_threshold=5,
                 n_shifts=2,
                 local_radius_um=50,
                 noise_levels=None,
                 random_chunk_kwargs={},
                 outputs='numpy_compact',
                 localization_dict=None,
                 **job_kwargs):
    """Peak detection based on threshold crossing in term of k x MAD.

    Parameters
    ----------
    recording: RecordingExtractor
        The recording extractor object.
    method: 'by_channel', 'locally_exclusive'
        Method to use. Options:
            * 'by_channel' : peak are detected in each channel independently
            * 'locally_exclusive' : a single best peak is taken from a set of neighboring channels
    peak_sign: 'neg', 'pos', 'both'
        Sign of the peak.
    detect_threshold: float
        Threshold, in median absolute deviations (MAD), to use to detect peaks.
    n_shifts: int
        Number of shifts to find peak.
        For example, if `n_shift` is 2, a peak is detected if a sample crosses the threshold,
        and the two samples before and after are above the sample.
    local_radius_um: float
        The radius to use for detection across local channels.
    noise_levels: array, optional
        Estimated noise levels to use, if already computed.
        If not provide then it is estimated from a random snippet of the data.
    random_chunk_kwargs: dict, optional
        A dict that contain option to randomize chunk for get_noise_levels().
        Only used if noise_levels is None.
    outputs: 'numpy_compact', 'numpy_split', 'sorting'
        The type of the output. By default, "numpy_compact" returns an array with complex dtype.
        In case of 'sorting', each unit corresponds to a recording channel.
    localization_dict : dict, optional
        Can optionally do peak localization at the same time as detection.
        This avoids running `localize_peaks` separately and re-reading the entire dataset.
    {}

    Returns
    -------
    peaks: array
        Detected peaks.

    Notes
    -----
    This peak detection ported from tridesclous into spikeinterface.
    """

    assert method in ('by_channel', 'locally_exclusive')
    assert peak_sign in ('both', 'neg', 'pos')
    assert outputs in ('numpy_compact', 'numpy_split', 'sorting')

    if method == 'locally_exclusive' and not HAVE_NUMBA:
        raise ModuleNotFoundError(
            '"locally_exclusive" need numba which is not installed')

    if noise_levels is None:
        noise_levels = get_noise_levels(recording,
                                        return_scaled=False,
                                        **random_chunk_kwargs)

    abs_threholds = noise_levels * detect_threshold

    if method == 'locally_exclusive':
        assert local_radius_um is not None
        channel_distance = get_channel_distances(recording)
        neighbours_mask = channel_distance < local_radius_um
    else:
        neighbours_mask = None

    # deal with margin
    if localization_dict is None:
        extra_margin = 0
    else:
        assert isinstance(localization_dict, dict)
        assert localization_dict['method'] in dtype_localize_by_method.keys()
        localization_dict = init_kwargs_dict(localization_dict['method'],
                                             localization_dict)

        nbefore = int(localization_dict['ms_before'] *
                      recording.get_sampling_frequency() / 1000.)
        nafter = int(localization_dict['ms_after'] *
                     recording.get_sampling_frequency() / 1000.)
        extra_margin = max(nbefore, nafter)

    # and run
    func = _detect_peaks_chunk
    init_func = _init_worker_detect_peaks
    init_args = (recording.to_dict(), method, peak_sign, abs_threholds,
                 n_shifts, neighbours_mask, extra_margin, localization_dict)
    processor = ChunkRecordingExecutor(recording,
                                       func,
                                       init_func,
                                       init_args,
                                       handle_returns=True,
                                       job_name='detect peaks',
                                       **job_kwargs)
    peaks = processor.run()
    peaks = np.concatenate(peaks)

    if outputs == 'numpy_compact':
        return peaks
    elif outputs == 'sorting':
        return NumpySorting.from_peaks(
            peaks, sampling_frequency=recording.get_sampling_frequency())