Esempio n. 1
0
###################################################################################################

# Burst settings
amp_dual_thresh = (1., 1.5)
f_range = (peak_cf - 2, peak_cf + 2)

###################################################################################################

# Detect bursts of high amplitude oscillations in the extracted signal
bursting = detect_bursts_dual_threshold(sig, fs, amp_dual_thresh, f_range)

###################################################################################################

# Plot original signal and burst activity
plot_bursts(times, sig, bursting, labels=['Raw Data', 'Detected Bursts'])

###################################################################################################
# Measure Rhythmicity with Lagged Coherence
# -----------------------------------------
#
# So far, in an example channel of data, we have explored some rhythmic properties of
# the EEG data. We did so by finding the main frequency of rhythmic power, and checking
# this frequency for bursting.
#
# Next, we can try applying similar analyses across all channels, to measure
# rhythmic properties across the whole dataset.
#
# To do so, let's use the lagged coherence measure. Using lagged coherence, we can measure,
# per channel, the frequency which displays the most rhythmic activity, as well as the score
# of how rhythmic this frequency is. We can do so per location, to map rhythmic activity
Esempio n. 2
0
def plot_spikes(df_features,
                sig,
                fs,
                spikes=None,
                index=None,
                xlim=None,
                ax=None):
    """Plot a group of spikes or the cyclepoints for an individual spike.

    Parameters
    ----------
    df_features : pandas.DataFrame
        Dataframe containing shape and burst features for each spike.
    sig : 1d or 2d array
        Voltage timeseries. May be 2d if spikes are split.
    fs : float
        Sampling rate, in Hz.
    spikes : 1d array, optional, default: None
        Spikes that have been split into a 2d array. Ignored if ``index`` is passed.
    index : int, optional, default: None
        The index in ``df_features`` to plot. If None, plot all spikes.
    xlim : tuple
        Upper and lower time limits. Ignored if spikes or index is passed.
    ax : matplotlib.Axes, optional, default: None
        Figure axes upon which to plot.
    """

    ax = check_ax(ax, (10, 4))

    center_e, _ = get_extrema_df(df_features)

    # Plot a single spike
    if index is not None:

        times = np.arange(0, len(sig) / fs, 1 / fs)

        # Get where spike starts/ends
        start = df_features.iloc[index]['sample_start'].astype(int)
        end = df_features.iloc[index]['sample_end'].astype(int)

        sig_lim = sig[start:end + 1]
        times_lim = times[start:end + 1]

        # Plot the spike waveform
        plot_time_series(times_lim, sig_lim, ax=ax)

        # Plot cyclespoints
        labels, keys = _infer_labels(center_e)
        colors = ['C0', 'C1', 'C2', 'C3']

        for idx, key in enumerate(keys):

            sample = df_features.iloc[index][key].astype('int')

            plot_time_series(np.array([times[sample]]),
                             np.array([sig[sample]]),
                             colors=colors[idx],
                             labels=labels[idx],
                             ls='',
                             marker='o',
                             ax=ax)

    # Plot as stack of spikes
    elif index is None and spikes is not None:

        times = np.arange(0, len(spikes[0]) / fs, 1 / fs)

        plot_time_series(times, spikes, ax=ax)

    # Plot as continuous timeseries
    elif index is None and spikes is None:

        ax = check_ax(ax, (15, 3))

        times = np.arange(0, len(sig) / fs, 1 / fs)

        plot_time_series(times, sig, ax=ax, xlim=xlim)

        if xlim is None:
            sig_lim = sig
            df_lim = df_features
            times_lim = times
            starts = df_lim['sample_start']
        else:
            cyc_idxs = (df_features['sample_start'].values >= xlim[0] * fs) & \
                    (df_features['sample_end'].values <= xlim[1] * fs)

            df_lim = df_features.iloc[cyc_idxs].copy()

            sig_lim, times_lim = limit_signal(times,
                                              sig,
                                              start=xlim[0],
                                              stop=xlim[1])

            starts = df_lim['sample_start'] - int(fs * xlim[0])

        ends = starts + df_lim['period'].values

        is_spike = np.zeros(len(sig_lim), dtype='bool')

        for start, end in zip(starts, ends):
            is_spike[start:end] = True

        plot_bursts(times_lim, sig_lim, is_spike, ax=ax)
Esempio n. 3
0
def plot_burst_detect_summary(df,
                              sig,
                              fs,
                              burst_detection_kwargs,
                              xlim=None,
                              figsize=(15, 3),
                              plot_only_result=False,
                              interp=True):
    """
    Create a plot to study how the cycle-by-cycle burst detection
    algorithm determines bursting periods of a signal.

    Parameters
    ----------
    df : pandas.DataFrame
        Dataframe output of :func:`~.compute_features`.
    sig : 1d array
        Time series to plot.
    fs : float
        Sampling rate, in Hz.
    burst_detection_kwargs : dict
        Burst parameter keys and threshold value pairs, as defined in the 'burst_detection_kwargs'
        argument of :func:`.compute_features`.
    xlim : tuple of (float, float), optional, default: None
        Start and stop times for plot.
    figsize : tuple of (float, float), optional, default: (15, 3)
        Size of each plot.
    plot_only_result : bool, optional, default: False
        Plot only the signal and bursts, excluding burst parameter plots.
    interp : bool, optional, default: True
        If True, interpolates between given values. Otherwise, plots in a step-wise fashion.

    Notes
    -----

    - If plot_only_result = True: return a plot of the burst detection in which periods with bursts
      are denoted in red.

    - If plot_only_result = False: return a list of the fig handle followed by the 5 axes.

    - In the top plot, the raw signal is plotted in black, and the red line indicates periods
      defined as oscillatory bursts. The highlighted regions indicate when each burst requirement
      was violated, color-coded consistently with the plots below.

      - blue: amp_fraction_threshold
      - red: amp_consistency_threshold
      - yellow: period_consistency_threshold
      - green: monotonicity_threshold

    """

    # Normalize signal
    sig = zscore(sig)

    # Determine time array and limits
    times = np.arange(0, len(sig) / fs, 1 / fs)
    xlim = (times[0], times[-1]) if xlim is None else xlim

    # Determine if peak of troughs are the sides of an oscillation
    _, side_e = get_extrema(df)

    # Remove this kwarg since it isn't stored cycle by cycle in the df (nothing to plot)
    if 'n_cycles_min' in burst_detection_kwargs.keys():
        del burst_detection_kwargs['n_cycles_min']

    n_kwargs = len(burst_detection_kwargs.keys())

    # Create figure and subplots
    if plot_only_result:
        fig, axes = plt.subplots(figsize=figsize, nrows=1)
        axes = [axes]
    else:
        fig, axes = plt.subplots(figsize=(figsize[0],
                                          figsize[1] * (n_kwargs + 1)),
                                 nrows=n_kwargs + 1,
                                 sharex=True)

    # Determine which samples are defined as bursting
    is_osc = np.zeros(len(sig), dtype=bool)
    df_osc = df[df['is_burst']]

    for _, cyc in df_osc.iterrows():

        samp_start_burst = cyc['sample_last_' + side_e]
        samp_end_burst = cyc['sample_next_' + side_e] + 1
        is_osc[samp_start_burst:samp_end_burst] = True

    # Plot bursts with extrema points
    xlabel = 'Time (s)' if len(axes) == 1 else ''

    plot_bursts(times,
                sig,
                is_osc,
                ax=axes[0],
                xlim=xlim,
                lw=2,
                labels=['Signal', 'Bursts'],
                xlabel='',
                ylabel='')

    plot_cyclepoints_df(df,
                        sig,
                        fs,
                        ax=axes[0],
                        xlim=xlim,
                        plot_zerox=False,
                        plot_sig=False,
                        xlabel=xlabel,
                        ylabel='Voltage\n(normalized)',
                        colors=['m', 'c'])

    # Plot each burst param
    colors = cycle(
        ['blue', 'red', 'yellow', 'green', 'cyan', 'magenta', 'orange'])

    for idx, osc_key in enumerate(burst_detection_kwargs.keys()):

        column = osc_key.replace('_threshold', '')

        color = next(colors)

        # Highlight where a burst param falls below threshold
        for _, cyc in df.iterrows():

            if cyc[column] < burst_detection_kwargs[osc_key]:
                axes[0].axvspan(times[cyc['sample_last_' + side_e]],
                                times[cyc['sample_next_' + side_e]],
                                alpha=0.5,
                                color=color,
                                lw=0)

        # Plot each burst param on separate axes
        if not plot_only_result:

            ylabel = column.replace('_', ' ').capitalize()
            xlabel = 'Time (s)' if idx == n_kwargs - 1 else ''

            plot_burst_detect_param(df,
                                    sig,
                                    fs,
                                    column,
                                    burst_detection_kwargs[osc_key],
                                    figsize=figsize,
                                    ax=axes[idx + 1],
                                    xlim=xlim,
                                    xlabel=xlabel,
                                    ylabel=ylabel,
                                    color=color,
                                    interp=interp)
Esempio n. 4
0
def plot_burst_detect_summary(df_features,
                              sig,
                              fs,
                              threshold_kwargs,
                              xlim=None,
                              figsize=(15, 3),
                              plot_only_result=False,
                              interp=True):
    """Plot the cycle-by-cycle burst detection parameters and burst detection summary.

    Parameters
    ----------
    df_features : pandas.DataFrame
        Dataframe output of :func:`~.compute_features`. The df must contain sample indices (i.e.
        when ``return_samples = True``).
    sig : 1d array
        Time series to plot.
    fs : float
        Sampling rate, in Hz.
    threshold_kwargs : dict
        Burst parameter keys and threshold value pairs, as defined in the 'threshold_kwargs'
        argument of :func:`.compute_features`.
    xlim : tuple of (float, float), optional, default: None
        Start and stop times for plot.
    figsize : tuple of (float, float), optional, default: (15, 3)
        Size of each plot.
    plot_only_result : bool, optional, default: False
        Plot only the signal and bursts, excluding burst parameter plots.
    interp : bool, optional, default: True
        If True, interpolates between given values. Otherwise, plots in a step-wise fashion.

    Notes
    -----

    - If plot_only_result = True: return a plot of the burst detection in which periods with bursts
      are denoted in red.

    - If plot_only_result = False: return a list of the fig handle followed by the 5 axes.

    - In the top plot, the raw signal is plotted in black, and the red line indicates periods
      defined as oscillatory bursts. The highlighted regions indicate when each burst requirement
      was violated, color-coded consistently with the plots below.

      - blue: amp_fraction_threshold
      - red: amp_consistency_threshold
      - yellow: period_consistency_threshold
      - green: monotonicity_threshold

    Examples
    --------
    Plot the burst detection summary of a bursting signal:

    >>> from bycycle.features import compute_features
    >>> from neurodsp.sim import sim_bursty_oscillation
    >>> fs = 500
    >>> sig = sim_bursty_oscillation(10, fs, freq=10)
    >>> threshold_kwargs = {'amp_fraction_threshold': 0., 'amp_consistency_threshold': .5,
    ...                     'period_consistency_threshold': .5, 'monotonicity_threshold': .8}
    >>> df_features = compute_features(sig, fs, f_range=(8, 12), threshold_kwargs=threshold_kwargs)
    >>> plot_burst_detect_summary(df_features, sig, fs, threshold_kwargs)
    """

    # Ensure arguments are within valid range
    check_param_range(fs, 'fs', (0, np.inf))

    # Normalize signal
    sig_full = zscore(sig)
    times_full = np.arange(0, len(sig_full) / fs, 1 / fs)

    # Limit arrays and dataframe
    if xlim is not None:
        sig, times = limit_signal(times_full,
                                  sig_full,
                                  start=xlim[0],
                                  stop=xlim[1])
        df_features = limit_df(df_features,
                               fs,
                               start=xlim[0],
                               stop=xlim[1],
                               reset_indices=False)
    else:
        sig, times, = sig_full, times_full

    # Determine if peak of troughs are the sides of an oscillation
    _, side_e = get_extrema_df(df_features)

    # Remove this kwarg since it isn't stored cycle by cycle in the df (nothing to plot)
    thresholds = threshold_kwargs.copy()
    if 'min_n_cycles' in thresholds.keys():
        del thresholds['min_n_cycles']

    n_kwargs = len(thresholds.keys())

    # Create figure and subplots
    if plot_only_result:
        fig, axes = plt.subplots(figsize=figsize, nrows=1)
        axes = [axes]
    else:
        fig, axes = plt.subplots(figsize=(figsize[0],
                                          figsize[1] * (n_kwargs + 1)),
                                 nrows=n_kwargs + 1,
                                 sharex=True)

    # Determine which samples are defined as bursting
    is_osc = np.zeros(len(sig), dtype=bool)
    df_osc = df_features.loc[df_features['is_burst']]
    start = 0 if xlim is None else xlim[0]

    for _, cyc in df_osc.iterrows():

        samp_start_burst = int(cyc['sample_last_' + side_e]) - int(fs * start)
        samp_end_burst = int(cyc['sample_next_' + side_e] + 1) - int(
            fs * start)

        is_osc[samp_start_burst:samp_end_burst] = True

    # Plot bursts with extrema points
    xlabel = 'Time (s)' if len(axes) == 1 else ''

    plot_bursts(times,
                sig,
                is_osc,
                ax=axes[0],
                lw=2,
                labels=['Signal', 'Bursts'],
                xlabel='',
                ylabel='')

    plot_cyclepoints_df(df_features,
                        sig_full,
                        fs,
                        xlim=xlim,
                        ax=axes[0],
                        plot_zerox=False,
                        plot_sig=False,
                        xlabel=xlabel,
                        ylabel='Voltage\n(normalized)',
                        colors=['m', 'c'])

    # Plot each burst param
    colors = cycle(
        ['blue', 'red', 'yellow', 'green', 'cyan', 'magenta', 'orange'])

    for idx, osc_key in enumerate(thresholds.keys()):

        column = osc_key.replace('_threshold', '')

        color = next(colors)

        # Highlight where a burst param falls below threshold
        for _, cyc in df_features.iterrows():

            last_cyc = int(cyc['sample_last_' + side_e]) - int(fs * start)
            next_cyc = int(cyc['sample_next_' + side_e]) - int(fs * start)
            if cyc[column] < threshold_kwargs[osc_key] and last_cyc > 0:
                axes[0].axvspan(times[last_cyc],
                                times[next_cyc],
                                alpha=0.5,
                                color=color,
                                lw=0)

        # Plot each burst param on separate axes
        if not plot_only_result:

            ylabel = column.replace('_', ' ').capitalize()
            xlabel = 'Time (s)' if idx == n_kwargs - 1 else ''

            plot_burst_detect_param(df_features,
                                    sig_full,
                                    fs,
                                    column,
                                    thresholds[osc_key],
                                    xlim=xlim,
                                    figsize=figsize,
                                    ax=axes[idx + 1],
                                    xlabel=xlabel,
                                    ylabel=ylabel,
                                    color=color,
                                    interp=interp)