Esempio n. 1
0
def test_apply_filter():
    """Test bandpass filtering on a combination of two sinusoids."""
    rate = 10000.
    low, high = 100., 200.
    # Create a signal with small and high frequencies.
    t = np.linspace(0., 1., rate)
    x = np.sin(2*np.pi*low/2*t) + np.cos(2*np.pi*high*2*t)
    # Filter the signal.
    filter = bandpass_filter(filter_low=low,
        filter_high=high, filter_butter_order=4, sample_rate=rate)
    x_filtered = apply_filter(x, filter=filter)
    # Check that the bandpass-filtered signal is weak.
    assert np.abs(x[int(2./low*rate):-int(2./low*rate)]).max() >= .9
    assert np.abs(x_filtered[int(2./low*rate):-int(2./low*rate)]).max() <= .1
Esempio n. 2
0
def test_apply_filter():
    """Test bandpass filtering on a combination of two sinusoids."""
    rate = 10000.
    low, high = 100., 200.
    # Create a signal with small and high frequencies.
    t = np.linspace(0., 1., rate)
    x = np.sin(2 * np.pi * low / 2 * t) + np.cos(2 * np.pi * high * 2 * t)
    # Filter the signal.
    filter = bandpass_filter(filter_low=low,
                             filter_high=high,
                             filter_butter_order=4,
                             sample_rate=rate)
    x_filtered = apply_filter(x, filter=filter)
    # Check that the bandpass-filtered signal is weak.
    assert np.abs(x[int(2. / low * rate):-int(2. / low * rate)]).max() >= .9
    assert np.abs(
        x_filtered[int(2. / low * rate):-int(2. / low * rate)]).max() <= .1
def get_threshold(raw_data, filter=None, channels=slice(None), **prm):
    """Compute the threshold from the standard deviation of the filtered signal
    across many uniformly scattered excerpts of data.
    
    threshold_std_factor can be a tuple, in which case multiple thresholds
    are returned.
    
    """
    nexcerpts = prm.get('nexcerpts', None)
    excerpt_size = prm.get('excerpt_size', None)
    use_single_threshold = prm.get('use_single_threshold', True)
    threshold_strong_std_factor = prm.get('threshold_strong_std_factor', None)
    threshold_weak_std_factor = prm.get('threshold_weak_std_factor', None)
    threshold_std_factor = prm.get(
        'threshold_std_factor',
        (threshold_strong_std_factor, threshold_weak_std_factor))

    if isinstance(threshold_std_factor, tuple):
        # Fix bug with use_single_threshold=False: ensure that
        # threshold_std_factor has 2 dimensions (threshold_weak_strong, channel)
        threshold_std_factor = np.array(threshold_std_factor)[:, None]

    # We compute the standard deviation of the signal across the excerpts.
    # WARNING: this may use a lot of RAM.
    excerpts = np.vstack(
        # Filter each excerpt.
        apply_filter(excerpt.data[:, :], filter=filter)
        for excerpt in raw_data.excerpts(nexcerpts=nexcerpts,
                                         excerpt_size=excerpt_size))

    # Get the median of all samples in all excerpts,
    # on all channels...
    if use_single_threshold:
        median = np.median(np.abs(excerpts))
    # ...or independently for each channel.
    else:
        median = np.median(np.abs(excerpts), axis=0)

    # Compute the threshold from the median.
    std = median / .6745
    threshold = threshold_std_factor * std

    if isinstance(threshold, np.ndarray):
        return DoubleThreshold(strong=threshold[0], weak=threshold[1])
    else:
        return threshold
Esempio n. 4
0
def get_threshold(raw_data, filter=None, channels=slice(None), **prm):
    """Compute the threshold from the standard deviation of the filtered signal
    across many uniformly scattered excerpts of data.
    
    threshold_std_factor can be a tuple, in which case multiple thresholds
    are returned.
    
    """
    nexcerpts = prm.get("nexcerpts", None)
    excerpt_size = prm.get("excerpt_size", None)
    use_single_threshold = prm.get("use_single_threshold", True)
    threshold_strong_std_factor = prm.get("threshold_strong_std_factor", None)
    threshold_weak_std_factor = prm.get("threshold_weak_std_factor", None)
    threshold_std_factor = prm.get("threshold_std_factor", (threshold_strong_std_factor, threshold_weak_std_factor))

    if isinstance(threshold_std_factor, tuple):
        # Fix bug with use_single_threshold=False: ensure that
        # threshold_std_factor has 2 dimensions (threshold_weak_strong, channel)
        threshold_std_factor = np.array(threshold_std_factor)[:, None]

    # We compute the standard deviation of the signal across the excerpts.
    # WARNING: this may use a lot of RAM.
    excerpts = np.vstack(
        # Filter each excerpt.
        apply_filter(excerpt.data[:, :], filter=filter)
        for excerpt in raw_data.excerpts(nexcerpts=nexcerpts, excerpt_size=excerpt_size)
    )

    # Get the median of all samples in all excerpts,
    # on all channels...
    if use_single_threshold:
        median = np.median(np.abs(excerpts))
    # ...or independently for each channel.
    else:
        median = np.median(np.abs(excerpts), axis=0)

    # Compute the threshold from the median.
    std = median / 0.6745
    threshold = threshold_std_factor * std

    if isinstance(threshold, np.ndarray):
        return DoubleThreshold(strong=threshold[0], weak=threshold[1])
    else:
        return threshold
Esempio n. 5
0
def run(raw_data=None, experiment=None, prm=None, probe=None,
        _debug=False, convert_only=False):
    """This main function takes raw data (either as a RawReader, or a path
    to a filename, or an array) and executes the main algorithm (filtering,
    spike detection, extraction...)."""
    assert experiment is not None, ("An Experiment instance needs to be "
        "provided in order to write the output.")

    # Create file logger for the experiment.
    LOGGER_FILE = create_file_logger(experiment.gen_filename('log'))

    # Get parameters from the PRM dictionary.
    chunk_size = prm.get('chunk_size', None)
    chunk_overlap = prm.get('chunk_overlap', 0)
    nchannels = prm.get('nchannels', None)

    # Ensure a RawDataReader is instantiated.
    if raw_data is not None:
        if not isinstance(raw_data, BaseRawDataReader):
            raw_data = read_raw(raw_data, nchannels=nchannels)
    else:
        raw_data = read_raw(experiment)

    # Log.
    if convert_only:
        info("Starting file conversion only. Klusta version {1:s}, on {0:s}".format((str(raw_data)), spikedetekt2.__version__))
        info("Running spike detection on a single chunk of spikes only, so as to have some information")
        first_chunk_detected = False # horrible hack - detects spikes on one chunk only so KV doesn't complain
    else:
        info("Starting SpikeDetekt version {1:s} on {0:s}".format((str(raw_data)), spikedetekt2.__version__))
    debug("Parameters: \n" + (display_params(prm)))

    # Get the bandpass filter.
    filter = bandpass_filter(**prm)

    if not (convert_only and first_chunk_detected):
        # Compute the strong threshold across excerpts uniformly scattered across the
        # whole recording.
        threshold = get_threshold(raw_data, filter=filter,
                                  channels=probe.channels, **prm)
        assert not np.isnan(threshold.weak).any()
        assert not np.isnan(threshold.strong).any()
        debug("Threshold: " + str(threshold))

        # Debug module.
        diagnostics_path = prm.get('diagnostics_path', None)
        if diagnostics_path:
            diagnostics_mod = _import_module(diagnostics_path)
            if not hasattr(diagnostics_mod, 'diagnostics'):
                raise ValueError("The diagnostics module must implement a "
                                 "'diagnostics()' function.")
            diagnostics_fun = diagnostics_mod.diagnostics
        else:
            diagnostics_fun = None


    # Progress bar.
    progress_bar = ProgressReporter(period=30.)
    nspikes = 0

    # Loop through all chunks with overlap.
    for chunk in raw_data.chunks(chunk_size=chunk_size,
                                 chunk_overlap=chunk_overlap,):
        # Log.
        debug("Processing chunk {0:s}...".format(chunk))

        nsamples = chunk.nsamples
        rec = chunk.recording
        nrecs = chunk.nrecordings
        s_end = chunk.s_end

        # Filter the (full) chunk.
        chunk_raw = chunk.data_chunk_full  # shape: (nsamples, nchannels)
        chunk_fil = apply_filter(chunk_raw, filter=filter)

        i = chunk.keep_start - chunk.s_start
        j = chunk.keep_end - chunk.s_start

        # Add the data to the KWD files.
        if prm.get('save_raw', False):
            # Do not append the raw data to the .kwd file if we're already reading
            # from the .kwd file.
            if not isinstance(raw_data, (KwdRawDataReader, ExperimentRawDataReader)):
                # Save raw data.
                experiment.recordings[chunk.recording].raw.append(convert_dtype(chunk.data_chunk_keep, np.int16))

        if prm.get('save_high', False):
            # Save high-pass filtered data: need to remove the overlapping
            # sections.
            chunk_fil_keep = chunk_fil[i:j,:]
            experiment.recordings[chunk.recording].high.append(convert_dtype(chunk_fil_keep, np.int16))

        if prm.get('save_low', True):
            # Save LFP.
            chunk_low = decimate(chunk_raw)
            chunk_low_keep = chunk_low[i//16:j//16,:]
            experiment.recordings[chunk.recording].low.append(convert_dtype(chunk_low_keep, np.int16))

        if not (convert_only and first_chunk_detected):
            # Apply thresholds.
            chunk_detect, chunk_threshold = apply_threshold(chunk_fil,
                threshold=threshold, **prm)

            # Remove dead channels.
            dead = np.setdiff1d(np.arange(nchannels), probe.channels)
            chunk_detect[:,dead] = 0
            chunk_threshold.strong[:,dead] = 0
            chunk_threshold.weak[:,dead] = 0

            # Find connected component (strong threshold). Return list of
            # Component instances.
            components = connected_components(
                chunk_strong=chunk_threshold.strong,
                chunk_weak=chunk_threshold.weak,
                probe_adjacency_list=probe.adjacency_list,
                chunk=chunk, **prm)

            # Now we extract the spike in each component.
            waveforms = extract_waveforms(chunk_detect=chunk_detect,
                threshold=threshold, chunk_fil=chunk_fil, chunk_raw=chunk_raw,
                probe=probe, components=components, **prm)

            # DEBUG module.
            # Execute the debug script.
            if diagnostics_fun:
                try:
                    diagnostics_fun(**locals())
                except Exception as e:
                    warn("The diagnostics module failed: " + e.message)

            # Log number of spikes in the chunk.
            nspikes += len(waveforms)

            # We sort waveforms by increasing order of fractional time.
            [add_waveform(experiment, waveform) for waveform in sorted(waveforms)]

            first_chunk_detected = True

        # Update the progress bar.
        progress_bar.update(rec/float(nrecs) + (float(s_end) / (nsamples*nrecs)),
            '%d spikes found.' % (nspikes))

        # DEBUG: keep only the first shank.
        if _debug:
            break

    # Feature extraction.
    save_features(experiment, **prm)

    close_file_logger(LOGGER_FILE)
    progress_bar.finish()
Esempio n. 6
0
def run(raw_data=None, experiment=None, prm=None, probe=None,
        _debug=False, convert_only=False):
    """This main function takes raw data (either as a RawReader, or a path
    to a filename, or an array) and executes the main algorithm (filtering,
    spike detection, extraction...)."""
    assert experiment is not None, ("An Experiment instance needs to be "
        "provided in order to write the output.")

    # Create file logger for the experiment.
    LOGGER_FILE = create_file_logger(experiment.gen_filename('log'))

    # Get parameters from the PRM dictionary.
    chunk_size = prm.get('chunk_size', None)
    chunk_overlap = prm.get('chunk_overlap', 0)
    nchannels = prm.get('nchannels', None)

    # Ensure a RawDataReader is instantiated.
    if raw_data is not None:
        if not isinstance(raw_data, BaseRawDataReader):
            raw_data = read_raw(raw_data, nchannels=nchannels)
    else:
        raw_data = read_raw(experiment)

    # Log.
    if convert_only:
        info("Starting file conversion only. Klusta version {1:s}, on {0:s}".format((str(raw_data)), spikedetekt2.__version__))
        info("Running spike detection on a single chunk of spikes only, so as to have some information")
        first_chunk_detected = False # horrible hack - detects spikes on one chunk only so KV doesn't complain
    else:
        info("Starting SpikeDetekt version {1:s} on {0:s}".format((str(raw_data)), spikedetekt2.__version__))
    debug("Parameters: \n" + (display_params(prm)))

    # Get the bandpass filter.
    filter = bandpass_filter(**prm)

    if not (convert_only and first_chunk_detected):
        # Compute the strong threshold across excerpts uniformly scattered across the
        # whole recording.
        threshold = get_threshold(raw_data, filter=filter,
                                  channels=probe.channels, **prm)
        assert not np.isnan(threshold.weak).any()
        assert not np.isnan(threshold.strong).any()
        debug("Threshold: " + str(threshold))

        # Debug module.
        diagnostics_path = prm.get('diagnostics_path', None)
        if diagnostics_path:
            diagnostics_mod = _import_module(diagnostics_path)
            if not hasattr(diagnostics_mod, 'diagnostics'):
                raise ValueError("The diagnostics module must implement a "
                                 "'diagnostics()' function.")
            diagnostics_fun = diagnostics_mod.diagnostics
        else:
            diagnostics_fun = None


    # Progress bar.
    progress_bar = ProgressReporter(period=30.)
    nspikes = 0

    # Loop through all chunks with overlap.
    for chunk in raw_data.chunks(chunk_size=chunk_size,
                                 chunk_overlap=chunk_overlap,):
        # Log.
        debug("Processing chunk {0:s}...".format(chunk))

        nsamples = chunk.nsamples
        rec = chunk.recording
        nrecs = chunk.nrecordings
        s_end = chunk.s_end

        # Filter the (full) chunk.
        chunk_raw = chunk.data_chunk_full  # shape: (nsamples, nchannels)
        chunk_fil = apply_filter(chunk_raw, filter=filter)

        i = chunk.keep_start - chunk.s_start
        j = chunk.keep_end - chunk.s_start

        # Add the data to the KWD files.
        if prm.get('save_raw', False):
            # Do not append the raw data to the .kwd file if we're already reading
            # from the .kwd file.
            if not isinstance(raw_data, (KwdRawDataReader, ExperimentRawDataReader)):
                # Save raw data.
                experiment.recordings[chunk.recording].raw.append(convert_dtype(chunk.data_chunk_keep, np.int16))

        if prm.get('save_high', False):
            # Save high-pass filtered data: need to remove the overlapping
            # sections.
            chunk_fil_keep = chunk_fil[i:j,:]
            experiment.recordings[chunk.recording].high.append(convert_dtype(chunk_fil_keep, np.int16))

        if prm.get('save_low', True):
            # Save LFP.
            chunk_low = decimate(chunk_raw)
            chunk_low_keep = chunk_low[i//16:j//16,:]
            experiment.recordings[chunk.recording].low.append(convert_dtype(chunk_low_keep, np.int16))

        if not (convert_only and first_chunk_detected):
            # Apply thresholds.
            chunk_detect, chunk_threshold = apply_threshold(chunk_fil,
                threshold=threshold, **prm)

            # Remove dead channels.
            dead = np.setdiff1d(np.arange(nchannels), probe.channels)
            chunk_detect[:,dead] = 0
            chunk_threshold.strong[:,dead] = 0
            chunk_threshold.weak[:,dead] = 0

            # Find connected component (strong threshold). Return list of
            # Component instances.
            components = connected_components(
                chunk_strong=chunk_threshold.strong,
                chunk_weak=chunk_threshold.weak,
                probe_adjacency_list=probe.adjacency_list,
                chunk=chunk, **prm)

            # Now we extract the spike in each component.
            waveforms = extract_waveforms(chunk_detect=chunk_detect,
                threshold=threshold, chunk_fil=chunk_fil, chunk_raw=chunk_raw,
                probe=probe, components=components, **prm)

            # DEBUG module.
            # Execute the debug script.
            if diagnostics_fun:
                try:
                    diagnostics_fun(**locals())
                except Exception as e:
                    warn("The diagnostics module failed: " + e.message)

            # Log number of spikes in the chunk.
            nspikes += len(waveforms)

            # We sort waveforms by increasing order of fractional time.
            [add_waveform(experiment, waveform) for waveform in sorted(waveforms)]

            first_chunk_detected = True

        # Update the progress bar.
        progress_bar.update(rec/float(nrecs) + (float(s_end) / (nsamples*nrecs)),
            '%d spikes found.' % (nspikes))

        # DEBUG: keep only the first shank.
        if _debug:
            break

    # Feature extraction.
    save_features(experiment, **prm)

    close_file_logger(LOGGER_FILE)
    progress_bar.finish()