Exemple #1
0
def test_detect_bursts_cycles():
    """Test amplitude and period consistency burst detection."""

    # Load signal
    signal = np.load(DATA_PATH + 'sim_bursting.npy')
    Fs = 1000
    f_range = (6, 14)

    signal = filt.lowpass_filter(signal, Fs, 30, N_seconds=.3,
                                 remove_edge_artifacts=False)

    # Compute cycle-by-cycle df without burst detection column
    df = features.compute_features(signal, Fs, f_range,
                                   burst_detection_method='amp',
                                   burst_detection_kwargs={'amp_threshes': (1, 2),
                                                           'filter_kwargs': {'N_seconds': .5}})
    df.drop('is_burst', axis=1, inplace=True)

    # Apply consistency burst detection
    df_burst_cycles = burst.detect_bursts_cycles(df, signal)

    # Make sure that burst detection is only boolean
    assert df_burst_cycles.dtypes['is_burst'] == 'bool'
    assert df_burst_cycles['is_burst'].mean() > 0
    assert df_burst_cycles['is_burst'].mean() < 1
    assert np.min([sum(1 for _ in group) for key, group in \
        itertools.groupby(df_burst_cycles['is_burst']) if key]) >= 3
Exemple #2
0
def test_plot_feature_hist(sim_args):

    df_features = sim_args['df_features']
    threshold_kwargs = sim_args['threshold_kwargs']

    df_features = detect_bursts_cycles(df_features, **threshold_kwargs)

    plot_feature_hist(df_features,
                      'amp_consistency',
                      xlim=(0, 1),
                      save_fig=True,
                      file_name='test_plot_feature_hist',
                      file_path=TEST_PLOTS_PATH)
Exemple #3
0
def test_plot_feature_categorical(sim_args):

    df_features = sim_args['df_features']
    threshold_kwargs = sim_args['threshold_kwargs']

    df_features = detect_bursts_cycles(df_features, **threshold_kwargs)

    # Compare first & second halves of the signal - distributions should be the same
    group = np.array(['First Half' for row in range(len(df_features))])
    group[:round(len(group) / 2)] = 'Second Half'
    df_features['group'] = group

    plot_feature_categorical(df_features,
                             'amp_consistency',
                             group_by='group',
                             save_fig=True,
                             file_name='test_plot_feature_hist',
                             file_path=TEST_PLOTS_PATH)
Exemple #4
0
def test_plot_burst_detect_summary(sim_args, plot_only_result):

    burst_detection_kwargs = {
        'amp_fraction_threshold': 1,
        'amp_consistency_threshold': .5,
        'period_consistency_threshold': .5,
        'monotonicity_threshold': .8,
        'min_n_cycles': 3
    }

    df_features = sim_args['df_features']
    sig = sim_args['sig']
    fs = sim_args['fs']

    df_features = detect_bursts_cycles(df_features, **burst_detection_kwargs)

    plot_burst_detect_summary(df_features,
                              sig,
                              fs,
                              burst_detection_kwargs,
                              plot_only_result=plot_only_result,
                              save_fig=True,
                              file_path=TEST_PLOTS_PATH,
                              file_name='test_plot_burst_detect_summary')
Exemple #5
0
def recompute_edges(df_features, threshold_kwargs):
    """Recompute the is_burst column for cycles on the edges of bursts.

    Parameters
    ----------
    df_features : pandas.DataFrame
        A dataframe containing shape and burst features for each cycle.
    threshold_kwargs : dict, optional, default: None
        Feature thresholds for cycles to be considered bursts, matching keyword arguments for:

        - :func:`~.detect_bursts_cycles` for consistency burst detection
          (i.e. when burst_method == 'cycles')

    Returns
    -------
    df_features_edges : pandas.DataFrame
        An cycle feature dataframe with an updated ``is_burst`` column for edge cycles.

    Notes
    -----

    - `df_features` must be computed using consistency burst detection.

    Examples
    --------
    Lower the amplitude consistency threshold to zero for cycles on the edges of bursts:

    >>> from neurodsp.sim import sim_combined
    >>> from bycycle.features import compute_features
    >>> sig = sim_combined(n_seconds=4, fs=1000, components={'sim_bursty_oscillation': {'freq': 10},
    ...                                                      'sim_powerlaw': {'exp': 2}})
    >>> threshold_kwargs = {'amp_fraction_threshold': 0., 'amp_consistency_threshold': .5,
    ...                     'period_consistency_threshold': .5, 'monotonicity_threshold': .4,
    ...                     'min_n_cycles': 3}
    >>> df_features = compute_features(sig, fs=1000, f_range=(8, 12),
    ...                                threshold_kwargs=threshold_kwargs)
    >>> threshold_kwargs['amp_consistency_threshold'] = 0
    >>> df_features_edges = recompute_edges(df_features, threshold_kwargs)
    """

    # Prevent circular import between burst.utils and burst.cycle
    from bycycle.burst import detect_bursts_cycles

    # Prevent overwriting the original dataframe
    df_features_edges = df_features.copy()

    # Identify all cycles where is_burst changes on the following cycle
    #   Use copy to keep dataframe columns unlinked
    is_burst = deepcopy(df_features_edges['is_burst'].values)
    burst_edges = np.where(is_burst[1:] == ~is_burst[:-1])[0]

    # Adjust odd edges such that all edges fall on is_burst == False
    burst_edges = np.array([edge if idx % 2 == 0 else edge+1 for idx, edge in
                            enumerate(burst_edges)])

    # Recompute is_burst
    df_features_edges = detect_bursts_cycles(
        df_features_edges,
        amp_fraction_threshold=threshold_kwargs['amp_fraction_threshold'],
        amp_consistency_threshold=threshold_kwargs['amp_consistency_threshold'],
        period_consistency_threshold=threshold_kwargs['period_consistency_threshold'],
        monotonicity_threshold=threshold_kwargs['monotonicity_threshold'],
        min_n_cycles=threshold_kwargs['min_n_cycles']
    )

    # Confine recomputed is_burst to edges
    is_burst[burst_edges] = df_features_edges.iloc[burst_edges]['is_burst'].values

    df_features_edges['is_burst'] = is_burst

    return df_features_edges
Exemple #6
0
def compute_features(sig,
                     fs,
                     f_range,
                     center_extrema='P',
                     burst_detection_method='cycles',
                     burst_detection_kwargs=None,
                     find_extrema_kwargs=None,
                     hilbert_increase_n=False):
    """Segment a recording into individual cycles and compute features for each cycle.

    Parameters
    ----------
    sig : 1d array
        Voltage time series.
    fs : float
        Sampling rate, in Hz.
    f_range : tuple of (float, float)
        Frequency range for narrowband signal of interest (Hz).
    center_extrema : {'P', 'T'}
        The center extrema in the cycle.

        - 'P' : cycles are defined trough-to-trough
        - 'T' : cycles are defined peak-to-peak

    burst_detection_method : {'cycles', 'amp'}
        Method for detecting bursts.

        - 'cycles': detect bursts based on the consistency of consecutive periods & amplitudes
        - 'amp': detect bursts using an amplitude threshold

    burst_detection_kwargs : dict, optional
        Keyword arguments for function to find label cycles as in or not in an oscillation.
    find_extrema_kwargs : dict, optional
        Keyword arguments for function to find peaks an troughs (:func:`~.find_extrema`)
        to change filter Parameters or boundary.By default, it sets the filter length to three
        cycles of the low cutoff frequency (``f_range[0]``).
    hilbert_increase_n : bool, optional, default: False
        Corresponding kwarg for :func:`~neurodsp.timefrequency.hilbert.amp_by_time`.
        If true, this zero-pads the signal when computing the Fourier transform, which can be
        necessary for computing it in a reasonable amount of time.

    Returns
    -------
    df : pandas.DataFrame
        Dataframe containing features and identifiers for each cycle. Each row is one cycle.
        Columns (listed for peak-centered cycles):

        - ``sample_peak`` : sample of 'sig' at which the peak occurs
        - ``sample_zerox_decay`` : sample of the decaying zero-crossing
        - ``sample_zerox_rise`` : sample of the rising zero-crossing
        - ``sample_last_trough`` : sample of the last trough
        - ``sample_next_trough`` : sample of the next trough
        - ``period`` : period of the cycle
        - ``time_decay`` : time between peak and next trough
        - ``time_rise`` : time between peak and previous trough
        - ``time_peak`` : time between rise and decay zero-crosses
        - ``time_trough`` : duration of previous trough estimated by zero-crossings
        - ``volt_decay`` : voltage change between peak and next trough
        - ``volt_rise`` : voltage change between peak and previous trough
        - ``volt_amp`` : average of rise and decay voltage
        - ``volt_peak`` : voltage at the peak
        - ``volt_trough`` : voltage at the last trough
        - ``time_rdsym`` : fraction of cycle in the rise period
        - ``time_ptsym`` : fraction of cycle in the peak period
        - ``band_amp`` : average analytic amplitude of the oscillation
          computed using narrowband filtering and the Hilbert
          transform. Filter length is 3 cycles of the low
          cutoff frequency. Average taken across all time points
          in the cycle.
        - ``is_burst`` : True if the cycle is part of a detected oscillatory burst
        - ``amp_fraction`` : normalized amplitude
        - ``amp_consistency`` : difference in the rise and decay voltage within a cycle
        - ``period_consistency`` : difference between a cycle’s period and the period of the
          adjacent cycles
        - ``monotonicity`` : fraction of instantaneous voltage changes between consecutive
          samples that are positive during the rise phase and negative during the decay phase

    Notes
    -----
    Peak vs trough centering
        - By default, the first extrema analyzed will be a peak, and the final one a trough.
        - In order to switch the preference, the signal is simply inverted and columns are renamed.
        - Columns are slightly different depending on if ``center_extrema`` is set to 'P' or 'T'.
    """

    # Set defaults if user input is None
    if burst_detection_kwargs is None:
        burst_detection_kwargs = {}
        warnings.warn('''
            No burst detection parameters are provided. This is not recommended.
            Check your data and choose appropriate parameters for "burst_detection_kwargs".
            Default burst detection parameters are likely not well suited for the data.
            ''')
    if find_extrema_kwargs is None:
        find_extrema_kwargs = {'filter_kwargs': {'n_cycles': 3}}
    else:
        # Raise warning if switch from peak start to trough start
        if 'first_extrema' in find_extrema_kwargs.keys():
            raise ValueError('''
                This function assumes that the first extrema identified will be a peak.
                This cannot be overwritten at this time.''')

    # Negate signal if to analyze trough-centered cycles
    if center_extrema == 'P':
        pass
    elif center_extrema == 'T':
        sig = -sig
    else:
        raise ValueError(
            'Parameter "center_extrema" must be either "P" or "T"')

    # Find peak and trough locations in the signal
    ps, ts = find_extrema(sig, fs, f_range, **find_extrema_kwargs)

    # Find zero-crossings
    zerox_rise, zerox_decay = find_zerox(sig, ps, ts)

    # For each cycle, identify the sample of each extrema and zero-crossing
    shape_features = {}
    shape_features['sample_peak'] = ps[1:]
    shape_features['sample_zerox_decay'] = zerox_decay[1:]
    shape_features['sample_zerox_rise'] = zerox_rise
    shape_features['sample_last_trough'] = ts[:-1]
    shape_features['sample_next_trough'] = ts[1:]

    # Compute duration of period
    shape_features['period'] = shape_features['sample_next_trough'] - \
        shape_features['sample_last_trough']

    # Compute duration of peak
    shape_features['time_peak'] = shape_features['sample_zerox_decay'] - \
        shape_features['sample_zerox_rise']

    # Compute duration of last trough
    shape_features['time_trough'] = zerox_rise - zerox_decay[:-1]

    # Determine extrema voltage
    shape_features['volt_peak'] = sig[ps[1:]]
    shape_features['volt_trough'] = sig[ts[:-1]]

    # Determine rise and decay characteristics
    shape_features['time_decay'] = (ts[1:] - ps[1:])
    shape_features['time_rise'] = (ps[1:] - ts[:-1])

    shape_features['volt_decay'] = sig[ps[1:]] - sig[ts[1:]]
    shape_features['volt_rise'] = sig[ps[1:]] - sig[ts[:-1]]
    shape_features['volt_amp'] = (shape_features['volt_decay'] +
                                  shape_features['volt_rise']) / 2

    # Compute rise-decay symmetry features
    shape_features[
        'time_rdsym'] = shape_features['time_rise'] / shape_features['period']

    # Compute peak-trough symmetry features
    shape_features['time_ptsym'] = shape_features['time_peak'] / \
        (shape_features['time_peak'] + shape_features['time_trough'])

    # Compute average oscillatory amplitude estimate during cycle
    amp = amp_by_time(sig,
                      fs,
                      f_range,
                      hilbert_increase_n=hilbert_increase_n,
                      n_cycles=3)
    shape_features['band_amp'] = [
        np.mean(amp[ts[sig_idx]:ts[sig_idx + 1]])
        for sig_idx in range(len(shape_features['sample_peak']))
    ]

    # Convert feature dictionary into a DataFrame
    df = pd.DataFrame.from_dict(shape_features)

    # Define whether or not each cycle is part of a burst
    if burst_detection_method == 'cycles':
        df = detect_bursts_cycles(df, sig, **burst_detection_kwargs)
    elif burst_detection_method == 'amp':
        df = detect_bursts_df_amp(df, sig, fs, f_range,
                                  **burst_detection_kwargs)
    else:
        raise ValueError('Invalid entry for "burst_detection_method"')

    # Rename columns if they are actually trough-centered
    if center_extrema == 'T':
        rename_dict = {
            'sample_peak': 'sample_trough',
            'sample_zerox_decay': 'sample_zerox_rise',
            'sample_zerox_rise': 'sample_zerox_decay',
            'sample_last_trough': 'sample_last_peak',
            'sample_next_trough': 'sample_next_peak',
            'time_peak': 'time_trough',
            'time_trough': 'time_peak',
            'volt_peak': 'volt_trough',
            'volt_trough': 'volt_peak',
            'time_rise': 'time_decay',
            'time_decay': 'time_rise',
            'volt_rise': 'volt_decay',
            'volt_decay': 'volt_rise'
        }
        df.rename(columns=rename_dict, inplace=True)

        # Need to reverse symmetry measures
        df['volt_peak'] = -df['volt_peak']
        df['volt_trough'] = -df['volt_trough']
        df['time_rdsym'] = 1 - df['time_rdsym']
        df['time_ptsym'] = 1 - df['time_ptsym']

    return df
Exemple #7
0
def compute_features(sig, fs, f_range, center_extrema='peak', burst_method='cycles',
                     burst_kwargs=None, threshold_kwargs=None, find_extrema_kwargs=None,
                     return_samples=True):
    """Compute shape and burst features for each cycle.

    Parameters
    ----------
    sig : 1d array
        Time series.
    fs : float
        Sampling rate, in Hz.
    f_range : tuple of (float, float)
        Frequency range for narrowband signal of interest (Hz).
    center_extrema : {'peak', 'trough'}
        The center extrema in the cycle.

        - 'peak' : cycles are defined trough-to-trough
        - 'trough' : cycles are defined peak-to-peak

    burst_method : {'cycles', 'amp'}
        Method for detecting bursts.

        - 'cycles': detect bursts based on the consistency of consecutive periods & amplitudes
        - 'amp': detect bursts using an amplitude threshold

    burst_kwargs : dict, optional, default: None
        Additional keyword arguments defined in :func:`~.compute_burst_fraction` for dual
        amplitude threshold burst detection (i.e. when burst_method='amp').
    threshold_kwargs : dict, optional, default: None
        Feature thresholds for cycles to be considered bursts, matching keyword arguments for:

        - :func:`~.detect_bursts_cycles` for consistency burst detection
          (i.e. when burst_method='cycles')
        - :func:`~.detect_bursts_amp` for  amplitude threshold burst detection
          (i.e. when burst_method='amp').

    find_extrema_kwargs : dict, optional, default: None
        Keyword arguments for function to find peaks an troughs (:func:`~.find_extrema`)
        to change filter parameters or boundary. By default, the filter length is set to three
        cycles of the low cutoff frequency (``f_range[0]``).
    return_samples : bool, optional, default: True
        Returns samples indices of cyclepoints used for determining features if True.

    Returns
    -------
    df_features : pandas.DataFrame
        A dataframe containing shape and burst features for each cycle. Columns:

        - ``period`` : period of the cycle
        - ``time_decay`` : time between peak and next trough
        - ``time_rise`` : time between peak and previous trough
        - ``time_peak`` : time between rise and decay zero-crosses
        - ``time_trough`` : duration of previous trough estimated by zero-crossings
        - ``volt_decay`` : voltage change between peak and next trough
        - ``volt_rise`` : voltage change between peak and previous trough
        - ``volt_amp`` : average of rise and decay voltage
        - ``volt_peak`` : voltage at the peak
        - ``volt_trough`` : voltage at the last trough
        - ``time_rdsym`` : fraction of cycle in the rise period
        - ``time_ptsym`` : fraction of cycle in the peak period
        - ``band_amp`` : average analytic amplitude of the oscillation

        When consistency burst detection is used (i.e. burst_method='cycles'):

        - ``amp_fraction`` : normalized amplitude
        - ``amp_consistency`` : difference in the rise and decay voltage within a cycle
        - ``period_consistency`` : difference between a cycle’s period and the period of the
          adjacent cycles
        - ``monotonicity`` : fraction of monotonic voltage changes in rise and decay phases
          (positive going in rise and negative going in decay)

        When dual threshold burst detection is used (i.e. burst_method='amp'):

        - ``burst_fraction`` : fraction of a cycle that is bursting

        When cyclepoints are returned (i.e. default, return_samples=True)

        - ``sample_peak`` : sample at which the peak occurs
        - ``sample_zerox_decay`` : sample of the decaying zero-crossing
        - ``sample_zerox_rise`` : sample of the rising zero-crossing
        - ``sample_last_trough`` : sample of the last trough
        - ``sample_next_trough`` : sample of the next trough

    Examples
    --------
    Compute shape and burst features:

    >>> from neurodsp.sim import sim_bursty_oscillation
    >>> fs = 500
    >>> sig = sim_bursty_oscillation(10, fs, freq=10)
    >>> df_features = compute_features(sig, fs, f_range=(8, 12))
    """

    # Ensure arguments are within valid range
    check_param(fs, 'fs', (0, np.inf))

    # Compute shape features for each cycle
    df_shape_features = compute_shape_features(sig, fs, f_range, center_extrema=center_extrema,
                                               find_extrema_kwargs=find_extrema_kwargs)

    # Ensure kwargs are a dictionaries
    if burst_method == 'amp' and not isinstance(burst_kwargs, dict):
        burst_kwargs = {}

    if not isinstance(threshold_kwargs, dict):
        threshold_kwargs = {}
        warnings.warn("""
            No burst detection thresholds are provided. This is not recommended. Please
            inspect your data and choose appropriate parameters for 'threshold_kwargs'.
            Default burst detection parameters are likely not well suited for your
            desired application.
            """)

    # Ensure required kwargs are set for amplitude burst detection
    if burst_method == 'amp':
        burst_kwargs['fs'] = fs
        burst_kwargs['f_range'] = f_range

    # Compute burst features for each cycle
    df_burst_features = compute_burst_features(df_shape_features, sig, burst_method=burst_method,
                                               burst_kwargs=burst_kwargs)

    # Concatenate shape and burst features
    df_features = pd.concat((df_burst_features, df_shape_features), axis=1)

    # Define whether or not each cycle is part of a burst
    if burst_method == 'cycles':
        df_features = detect_bursts_cycles(df_features, **threshold_kwargs)
    elif burst_method == 'amp':
        df_features = detect_bursts_amp(df_features, **threshold_kwargs)
    else:
        raise ValueError('Invalid argument for "burst_method".'
                         'Either "cycles" or "amp" must be specified."')

    df_features = drop_samples_df(df_features) if return_samples is False else df_features

    return df_features
Exemple #8
0
def compute_features_2d(sigs,
                        fs,
                        f_range,
                        compute_features_kwargs=None,
                        axis=0,
                        return_samples=True,
                        n_jobs=-1,
                        progress=None):
    """Compute shape and burst features for a 2 dimensional array of signals.

    Parameters
    ----------
    sigs : 2d array
        Voltage time series, i.e. (n_channels, n_samples) or (n_epochs, n_samples).
    fs : float
        Sampling rate, in Hz.
    f_range : tuple of (float, float)
        Frequency range for narrowband signal of interest, in Hz.
    compute_features_kwargs : dict or list of dict
        Keyword arguments used in :func:`~.compute_features`.
    axis : {0, None}
        Which axes to calculate features across:

        - ``axis=0`` : Iterates over each row/signal in an array independently (i.e. for each
          channel in (n_channels, n_timepoints)).
        - ``axis=None`` : Flattens rows/signals prior to computing features (i.e. across flatten
          epochs in (n_epochs, n_timepoints)).

    return_samples : bool, optional, default: True
        Whether to return a dataframe of cyclepoint sample indices.
    n_jobs : int, optional, default: -1
        The number of jobs to compute features in parallel.
    progress : {None, 'tqdm', 'tqdm.notebook'}
        Specify whether to display a progress bar. Uses 'tqdm', if installed.

    Returns
    -------
    dfs_features : list of pandas.DataFrame
        Dataframes containing shape and burst features for each cycle.
        Each dataframe is computed using the :func:`~.compute_features` function.

    Notes
    -----

    - The order of ``dfs_features`` corresponds to the order of ``sigs``. This list of dataframes
      may be reorganized into a single dataframe using :func:`~.flatten_dfs`.
    - When ``axis=None`` parallel computation may not be performed due to the requirement of
      flattening the array into one dimension.
    - If ``compute_features_kwargs`` is a dictionary, the same kwargs are applied applied across
      the first axis of ``sigs``. Otherwise, a list of dictionaries equal in length to the
      first axis of ``sigs`` is required to apply unique kwargs to each signal.
    - ``return_samples`` is controlled from the kwargs passed in this function. If
      ``return_samples`` is a key in ``compute_features_kwargs``, it's value will be ignored.

    Examples
    --------
    Compute the features of a 2d array (n_epochs=10, n_samples=5000) containing epoched data:

    >>> import numpy as np
    >>> from neurodsp.sim import sim_bursty_oscillation
    >>> fs = 500
    >>> sigs = np.array([sim_bursty_oscillation(10, fs, 10) for i in range(10)])
    >>> compute_kwargs = {'burst_method': 'amp', 'threshold_kwargs':{'burst_fraction_threshold': 1}}
    >>> dfs_features = compute_features_2d(sigs, fs, f_range=(8, 12), axis=None,
    ...                                   compute_features_kwargs=compute_kwargs)

    Compute the features of a 2d array in parallel using the same compute_features kwargs. Note each
    signal features are computed separately in this case, recommended for (n_channels, n_samples):

    >>> compute_kwargs = {'burst_method': 'amp', 'threshold_kwargs':{'burst_fraction_threshold': 1}}
    >>> dfs_features = compute_features_2d(sigs, fs, f_range=(8, 12), n_jobs=2, axis=0,
    ...                                   compute_features_kwargs=compute_kwargs)

    Compute the features of a 2d array in parallel using using individualized settings per signal to
    examine the effect of various amplitude consistency thresholds:

    >>> sigs =  np.array([sim_bursty_oscillation(10, fs, freq=10)] * 10)
    >>> compute_kwargs = [{'threshold_kwargs': {'amp_consistency_threshold': thresh*.1}}
    ...                   for thresh in range(1, 11)]
    >>> dfs_features = compute_features_2d(sigs, fs, f_range=(8, 12), return_samples=False,
    ...                                   n_jobs=2, compute_features_kwargs=compute_kwargs, axis=0)
    """

    # Check compute_features_kwargs
    kwargs = deepcopy(compute_features_kwargs)
    kwargs = np.array(kwargs) if isinstance(kwargs, list) else kwargs

    check_kwargs_shape(sigs, kwargs, axis)

    kwargs = {} if kwargs is None else kwargs
    kwargs = [kwargs] if isinstance(kwargs, dict) else list(kwargs)

    # Drop return_samples argument, as it is set directly in the function call
    for kwarg in kwargs:
        kwarg.pop('return_samples', None)

    n_jobs = cpu_count() if n_jobs == -1 else n_jobs

    if axis == 0:
        # Compute each signal independently and in paralllel
        with Pool(processes=n_jobs) as pool:

            if len(kwargs) > 1:
                # Map iterable sigs and kwargs together
                mapping = pool.imap(
                    partial(_proxy_2d,
                            fs=fs,
                            f_range=f_range,
                            return_samples=return_samples), zip(sigs, kwargs))

            else:
                # Only map sigs, kwargs are the same for each mapping
                mapping = pool.imap(
                    partial(compute_features,
                            fs=fs,
                            f_range=f_range,
                            return_samples=return_samples,
                            **kwargs[0]), sigs)

            dfs_features = list(progress_bar(mapping, progress, len(sigs)))

    elif axis is None:
        # Compute features after flattening the 2d array (i.e. calculated across a 1d signal)
        sig_flat = sigs.flatten()

        center_extrema = kwargs[0].pop('center_extrema', 'peak')

        df_flat = compute_features(sig_flat,
                                   fs=fs,
                                   f_range=f_range,
                                   return_samples=True,
                                   center_extrema=center_extrema,
                                   **kwargs[0])

        dfs_features = epoch_df(df_flat, len(sig_flat), len(sigs[0]))

        # Apply different thresholds if specified
        if len(kwargs) > 0:

            for idx, compute_kwargs in enumerate(kwargs):

                burst_method = compute_kwargs.pop('burst_method', 'cycles')
                thresholds = compute_kwargs.pop('threshold_kwargs', {})
                center_extrema_next = compute_kwargs.pop(
                    'center_extrema', None)

                if idx > 0 and center_extrema_next is not None \
                    and center_extrema_next != center_extrema:

                    warnings.warn('''
                        The same center extrema must be used when axis is None and
                        compute_features_kwargs is a list. Proceeding using the first
                        center_extrema: {extrema}.'''.format(
                        extrema=center_extrema))

                if burst_method == 'cycles':
                    dfs_features[idx] = detect_bursts_cycles(
                        dfs_features[idx], **thresholds)

                elif burst_method == 'amp':
                    dfs_features[idx] = detect_bursts_amp(
                        dfs_features[idx], **thresholds)

    else:
        raise ValueError("The axis kwarg must be either 0 or None.")

    return dfs_features
Exemple #9
0
def compute_features(x, Fs, f_range,
                     center_extrema='P',
                     burst_detection_method='cycles',
                     burst_detection_kwargs=None,
                     find_extrema_kwargs=None,
                     hilbert_increase_N=False):
    """
    Segment a recording into individual cycles and compute
    features for each cycle

    Parameters
    ----------
    x : 1d array
        voltage time series
    Fs : float
        sampling rate (Hz)
    f_range : tuple of (float, float)
        frequency range for narrowband signal of interest (Hz)
    center_extrema : {'P', 'T'}
        The center extrema in the cycle
        'P' : cycles are defined trough-to-trough
        'T' : cycles are defined peak-to-peak
    burst_detection_method: {'consistency', 'amp'}
        Method for detecting bursts
        'cycles': detect bursts based on the consistency of consecutive periods and amplitudes
        'amp': detect bursts using an amplitude threshold
    burst_detection_kwargs : dict | None
        Keyword arguments for function to find label cycles
        as in or not in an oscillation
    find_extrema_kwargs : dict | None
        Keyword arguments for function to find peaks and
        troughs (cyclepoints.find_extrema) to change filter
        parameters or boundary.
        By default, it sets the filter length to three cycles
        of the low cutoff frequency (`f_range[0]`)
    hilbert_increase_N : bool
        corresponding kwarg for filt.amp_by_time
        If true, this zeropads the signal when computing the
        Fourier transform, which can be necessary for
        computing it in a reasonable amount of time.

    Returns
    -------
    df : pandas.DataFrame
        dataframe containing several features and identifiers
        for each cycle. Each row is one cycle.
        Columns (listed for peak-centered cycles):
            - sample_peak : sample of 'x' at which the peak occurs
            - sample_zerox_decay : sample of the decaying zerocrossing
            - sample_zerox_rise : sample of the rising zerocrossing
            - sample_last_trough : sample of the last trough
            - sample_next_trough : sample of the next trough
            - period : period of the cycle
            - time_decay : time between peak and next trough
            - time_rise : time between peak and previous trough
            - time_peak : time between rise and decay zerocrosses
            - time_trough : duration of previous trough estimated by zerocrossings
            - volt_decay : voltage change between peak and next trough
            - volt_rise : voltage change between peak and previous trough
            - volt_amp : average of rise and decay voltage
            - volt_peak : voltage at the peak
            - volt_trough : voltage at the last trough
            - time_rdsym : fraction of cycle in the rise period
            - time_ptsym : fraction of cycle in the peak period
            - band_amp : average analytic amplitude of the oscillation
              computed using narrowband filtering and the Hilbert
              transform. Filter length is 3 cycles of the low
              cutoff frequency. Average taken across all time points
              in the cycle.
            - is_burst : True if the cycle is part of a detected oscillatory burst

    Notes
    -----
    Peak vs trough centering
        - By default, the first extrema analyzed will be a peak,
          and the final one a trough. In order to switch the preference,
          the signal is simply inverted and columns are renamed.
        - Columns are slightly different depending on if 'center_extrema'
          is set to 'P' or 'T'.
    """

    # Set defaults if user input is None
    if burst_detection_kwargs is None:
        burst_detection_kwargs = {}
        warnings.warn('''
            No burst detection parameters are provided.
            This is very much not recommended.
            Please inspect your data and choose appropriate
            parameters for "burst_detection_kwargs".
            Default burst detection parameters are likely
            not well suited for your desired application.
            ''')
    if find_extrema_kwargs is None:
        find_extrema_kwargs = {'filter_kwargs': {'N_cycles': 3}}
    else:
        # Raise warning if switch from peak start to trough start
        if 'first_extrema' in find_extrema_kwargs.keys():
            raise ValueError('''This function has been designed
                to assume that the first extrema identified will be a peak.
                This cannot be overwritten at this time.''')

    # Negate signal if to analyze trough-centered cycles
    if center_extrema == 'P':
        pass
    elif center_extrema == 'T':
        x = -x
    else:
        raise ValueError('Parameter "center_extrema" must be either "P" or "T"')

    # Find peak and trough locations in the signal
    Ps, Ts = find_extrema(x, Fs, f_range, **find_extrema_kwargs)

    # Find zero-crossings
    zeroxR, zeroxD = find_zerox(x, Ps, Ts)

    # For each cycle, identify the sample of each extrema and zerocrossing
    shape_features = {}
    shape_features['sample_peak'] = Ps[1:]
    shape_features['sample_zerox_decay'] = zeroxD[1:]
    shape_features['sample_zerox_rise'] = zeroxR
    shape_features['sample_last_trough'] = Ts[:-1]
    shape_features['sample_next_trough'] = Ts[1:]

    # Compute duration of period
    shape_features['period'] = shape_features['sample_next_trough'] - \
        shape_features['sample_last_trough']

    # Compute duration of peak
    shape_features['time_peak'] = shape_features['sample_zerox_decay'] - \
        shape_features['sample_zerox_rise']

    # Compute duration of last trough
    shape_features['time_trough'] = zeroxR - zeroxD[:-1]

    # Determine extrema voltage
    shape_features['volt_peak'] = x[Ps[1:]]
    shape_features['volt_trough'] = x[Ts[:-1]]

    # Determine rise and decay characteristics
    shape_features['time_decay'] = (Ts[1:] - Ps[1:])
    shape_features['time_rise'] = (Ps[1:] - Ts[:-1])

    shape_features['volt_decay'] = x[Ps[1:]] - x[Ts[1:]]
    shape_features['volt_rise'] = x[Ps[1:]] - x[Ts[:-1]]
    shape_features['volt_amp'] = (shape_features['volt_decay'] + shape_features['volt_rise']) / 2

    # Comptue rise-decay symmetry features
    shape_features['time_rdsym'] = shape_features['time_rise'] / shape_features['period']

    # Compute peak-trough symmetry features
    shape_features['time_ptsym'] = shape_features['time_peak'] / (shape_features['time_peak'] + shape_features['time_trough'])

    # Compute average oscillatory amplitude estimate during cycle
    amp = amp_by_time(x, Fs, f_range, hilbert_increase_N=hilbert_increase_N, filter_kwargs={'N_cycles': 3})
    shape_features['band_amp'] = [np.mean(amp[Ts[i]:Ts[i + 1]]) for i in range(len(shape_features['sample_peak']))]

    # Convert feature dictionary into a DataFrame
    df = pd.DataFrame.from_dict(shape_features)

    # Define whether or not each cycle is part of a burst
    if burst_detection_method == 'cycles':
        df = detect_bursts_cycles(df, x, **burst_detection_kwargs)
    elif burst_detection_method == 'amp':
        df = detect_bursts_df_amp(df, x, Fs, f_range, **burst_detection_kwargs)
    else:
        raise ValueError('Invalid entry for "burst_detection_method"')

    # Rename columns if they are actually trough-centered
    if center_extrema == 'T':
        rename_dict = {'sample_peak': 'sample_trough',
                       'sample_zerox_decay': 'sample_zerox_rise',
                       'sample_zerox_rise': 'sample_zerox_decay',
                       'sample_last_trough': 'sample_last_peak',
                       'sample_next_trough': 'sample_next_peak',
                       'time_peak': 'time_trough',
                       'time_trough': 'time_peak',
                       'volt_peak': 'volt_trough',
                       'volt_trough': 'volt_peak',
                       'time_rise': 'time_decay',
                       'time_decay': 'time_rise',
                       'volt_rise': 'volt_decay',
                       'volt_decay': 'volt_rise'}
        df.rename(columns=rename_dict, inplace=True)

        # Need to reverse symmetry measures
        df['volt_peak'] = -df['volt_peak']
        df['volt_trough'] = -df['volt_trough']
        df['time_rdsym'] = 1 - df['time_rdsym']
        df['time_ptsym'] = 1 - df['time_ptsym']

    return df
Exemple #10
0
def recompute_edges(df_features,
                    threshold_kwargs,
                    burst_method='cycles',
                    burst_kwargs=None):
    """Recompute the is_burst column for cycles on the edges of bursts.

    Parameters
    ----------
    df_features : pandas.DataFrame
        A dataframe containing shape and burst features for each cycle.
    threshold_kwargs : dict
        Feature thresholds for cycles to be considered bursts, matching keyword arguments for:

        - :func:`~.detect_bursts_cycles` for consistency burst detection
          (i.e. when burst_method == 'cycles')

    Returns
    -------
    df_features_edges : pandas.DataFrame
        An cycle feature dataframe with an updated ``is_burst`` column for edge cycles.

    Notes
    -----

    - `df_features` must be computed using consistency burst detection.

    Examples
    --------
    Lower the amplitude consistency threshold to zero for cycles on the edges of bursts:

    >>> from neurodsp.sim import sim_combined
    >>> from bycycle.features import compute_features
    >>> sig = sim_combined(n_seconds=4, fs=1000, components={'sim_bursty_oscillation': {'freq': 10},
    ...                                                      'sim_powerlaw': {'exp': 2}})
    >>> threshold_kwargs = {'amp_fraction_threshold': 0., 'amp_consistency_threshold': .5,
    ...                     'period_consistency_threshold': .5, 'monotonicity_threshold': .4,
    ...                     'min_n_cycles': 3}
    >>> df_features = compute_features(sig, fs=1000, f_range=(8, 12),
    ...                                threshold_kwargs=threshold_kwargs)
    >>> threshold_kwargs['amp_consistency_threshold'] = 0
    >>> df_features_edges = recompute_edges(df_features, threshold_kwargs)
    """

    # Prevent circular imports between burst.utils and burst.cycle
    from bycycle.burst import detect_bursts_cycles

    # Prevent overwriting the original dataframe
    df_features_edges = df_features.copy()

    # Identify all cycles where is_burst changes on the following cycle
    #   Use copy to keep dataframe columns unlinked
    is_burst = deepcopy(df_features_edges['is_burst'].values)
    burst_edges = np.where(is_burst[1:] == ~is_burst[:-1])[0]

    # Get cycles outside of bursts
    burst_starts = np.array(
        [edge for idx, edge in enumerate(burst_edges) if idx % 2 == 0])
    burst_ends = np.array(
        [edge + 1 for idx, edge in enumerate(burst_edges) if idx % 2 == 1])

    # Recompute is_burst for cycles at the edge
    for start_idx, end_idx in zip(burst_starts, burst_ends):

        df_features_edges = recompute_edge(df_features_edges, start_idx,
                                           'next')
        df_features_edges = recompute_edge(df_features_edges, end_idx, 'last')

    # Apply thresholding
    df_features_edges = detect_bursts_cycles(df_features_edges,
                                             **threshold_kwargs)

    return df_features_edges