コード例 #1
0
def plot_cyclepoints_df(df_samples,
                        sig,
                        fs,
                        plot_sig=True,
                        plot_extrema=True,
                        plot_zerox=True,
                        xlim=None,
                        ax=None,
                        **kwargs):
    """Plot extrema and/or zero-crossings from a DataFrame.

    Parameters
    ----------
    df_samples : pandas.DataFrame
        Dataframe output of :func:`~.compute_cyclepoints`.
    sig : 1d array
        Time series to plot.
    fs : float
        Sampling rate, in Hz.
    plot_sig : bool, optional, default: True
        Whether to also plot the raw signal.
    plot_extrema :  bool, optional, default: True
        Whether to plots the peaks and troughs.
    plot_zerox :  bool, optional, default: True
        Whether to plots the zero-crossings.
    xlim : tuple of (float, float), optional
        Start and stop times.
    ax : matplotlib.Axes, optional
        Figure axes upon which to plot.
    **kwargs
        Keyword arguments to pass into `plot_time_series`.

    Notes
    -----
    Default keyword arguments include:

    - ``figsize``: tuple of (float, float), default: (15, 3)
    - ``xlabel``: str, default: 'Time (s)'
    - ``ylabel``: str, default: 'Voltage (uV)

    Examples
    --------
    Plot cyclepoints using a dataframe from :func:`~.compute_cyclepoints`:

    >>> from bycycle.features import compute_cyclepoints
    >>> from neurodsp.sim import sim_bursty_oscillation
    >>> fs = 500
    >>> sig = sim_bursty_oscillation(10, fs, freq=10)
    >>> df_samples = compute_cyclepoints(sig, fs, f_range=(8, 12))
    >>> plot_cyclepoints_df(df_samples, sig, fs)
    """

    # Ensure arguments are within valid range
    check_param_range(fs, 'fs', (0, np.inf))

    # Determine extrema/zero-crossings from dataframe
    center_e, side_e = get_extrema_df(df_samples)

    peaks, troughs, rises, decays = [None] * 4

    if plot_extrema:

        peaks = df_samples['sample_' + center_e].values
        troughs = np.append(df_samples['sample_last_' + side_e].values,
                            df_samples['sample_next_' + side_e].values)
        troughs = np.unique(troughs)

    if plot_zerox:

        rises = df_samples['sample_zerox_rise'].values
        decays = df_samples['sample_zerox_decay'].values

    plot_cyclepoints_array(sig,
                           fs,
                           peaks=peaks,
                           troughs=troughs,
                           rises=rises,
                           decays=decays,
                           plot_sig=plot_sig,
                           xlim=xlim,
                           ax=ax,
                           **kwargs)
コード例 #2
0
def plot_burst_detect_summary(df_features,
                              sig,
                              fs,
                              threshold_kwargs,
                              xlim=None,
                              figsize=(15, 3),
                              plot_only_result=False,
                              interp=True):
    """Plot the cycle-by-cycle burst detection parameters and burst detection summary.

    Parameters
    ----------
    df_features : pandas.DataFrame
        Dataframe output of :func:`~.compute_features`. The df must contain sample indices (i.e.
        when ``return_samples = True``).
    sig : 1d array
        Time series to plot.
    fs : float
        Sampling rate, in Hz.
    threshold_kwargs : dict
        Burst parameter keys and threshold value pairs, as defined in the 'threshold_kwargs'
        argument of :func:`.compute_features`.
    xlim : tuple of (float, float), optional, default: None
        Start and stop times for plot.
    figsize : tuple of (float, float), optional, default: (15, 3)
        Size of each plot.
    plot_only_result : bool, optional, default: False
        Plot only the signal and bursts, excluding burst parameter plots.
    interp : bool, optional, default: True
        If True, interpolates between given values. Otherwise, plots in a step-wise fashion.

    Notes
    -----

    - If plot_only_result = True: return a plot of the burst detection in which periods with bursts
      are denoted in red.

    - If plot_only_result = False: return a list of the fig handle followed by the 5 axes.

    - In the top plot, the raw signal is plotted in black, and the red line indicates periods
      defined as oscillatory bursts. The highlighted regions indicate when each burst requirement
      was violated, color-coded consistently with the plots below.

      - blue: amp_fraction_threshold
      - red: amp_consistency_threshold
      - yellow: period_consistency_threshold
      - green: monotonicity_threshold

    Examples
    --------
    Plot the burst detection summary of a bursting signal:

    >>> from bycycle.features import compute_features
    >>> from neurodsp.sim import sim_bursty_oscillation
    >>> fs = 500
    >>> sig = sim_bursty_oscillation(10, fs, freq=10)
    >>> threshold_kwargs = {'amp_fraction_threshold': 0., 'amp_consistency_threshold': .5,
    ...                     'period_consistency_threshold': .5, 'monotonicity_threshold': .8}
    >>> df_features = compute_features(sig, fs, f_range=(8, 12), threshold_kwargs=threshold_kwargs)
    >>> plot_burst_detect_summary(df_features, sig, fs, threshold_kwargs)
    """

    # Ensure arguments are within valid range
    check_param_range(fs, 'fs', (0, np.inf))

    # Normalize signal
    sig = zscore(sig)

    # Determine time array and limits
    times = np.arange(0, len(sig) / fs, 1 / fs)
    xlim = (times[0], times[-1]) if xlim is None else xlim

    # Determine if peak of troughs are the sides of an oscillation
    _, side_e = get_extrema_df(df_features)

    # Remove this kwarg since it isn't stored cycle by cycle in the df (nothing to plot)
    thresholds = threshold_kwargs.copy()
    if 'min_n_cycles' in thresholds.keys():
        del thresholds['min_n_cycles']

    n_kwargs = len(thresholds.keys())

    # Create figure and subplots
    if plot_only_result:
        fig, axes = plt.subplots(figsize=figsize, nrows=1)
        axes = [axes]
    else:
        fig, axes = plt.subplots(figsize=(figsize[0],
                                          figsize[1] * (n_kwargs + 1)),
                                 nrows=n_kwargs + 1,
                                 sharex=True)

    # Determine which samples are defined as bursting
    is_osc = np.zeros(len(sig), dtype=bool)
    df_osc = df_features.loc[df_features['is_burst']]

    for _, cyc in df_osc.iterrows():

        samp_start_burst = int(cyc['sample_last_' + side_e])
        samp_end_burst = int(cyc['sample_next_' + side_e] + 1)
        is_osc[samp_start_burst:samp_end_burst] = True

    # Plot bursts with extrema points
    xlabel = 'Time (s)' if len(axes) == 1 else ''

    plot_bursts(times,
                sig,
                is_osc,
                ax=axes[0],
                xlim=xlim,
                lw=2,
                labels=['Signal', 'Bursts'],
                xlabel='',
                ylabel='')

    plot_cyclepoints_df(df_features,
                        sig,
                        fs,
                        ax=axes[0],
                        xlim=xlim,
                        plot_zerox=False,
                        plot_sig=False,
                        xlabel=xlabel,
                        ylabel='Voltage\n(normalized)',
                        colors=['m', 'c'])

    # Plot each burst param
    colors = cycle(
        ['blue', 'red', 'yellow', 'green', 'cyan', 'magenta', 'orange'])

    for idx, osc_key in enumerate(thresholds.keys()):

        column = osc_key.replace('_threshold', '')

        color = next(colors)

        # Highlight where a burst param falls below threshold
        for _, cyc in df_features.iterrows():

            if cyc[column] < thresholds[osc_key]:
                axes[0].axvspan(times[int(cyc['sample_last_' + side_e])],
                                times[int(cyc['sample_next_' + side_e])],
                                alpha=0.5,
                                color=color,
                                lw=0)

        # Plot each burst param on separate axes
        if not plot_only_result:

            ylabel = column.replace('_', ' ').capitalize()
            xlabel = 'Time (s)' if idx == n_kwargs - 1 else ''

            plot_burst_detect_param(df_features,
                                    sig,
                                    fs,
                                    column,
                                    thresholds[osc_key],
                                    figsize=figsize,
                                    ax=axes[idx + 1],
                                    xlim=xlim,
                                    xlabel=xlabel,
                                    ylabel=ylabel,
                                    color=color,
                                    interp=interp)
コード例 #3
0
ファイル: bycycle.py プロジェクト: voytekresearch/ndspflow
def plot_bg(dfs_features, sigs, fs, titles=None, btn=True, xlim=None):
    """Plot 2D bycycle results.

    Parameters
    ----------
    dfs_features : list of pandas.DataFrame
        Dataframes containing shape and burst features for each cycle.
    sigs : 2d array
        Time series.
    fs : float
        Sampling rate, in Hz.
    titles : list, optional, default: None
        The titles for each subplot.
    btn : bool, optional, default: True
        Adds a recompute bursts button when True. Omits when False.
    xlim : tuple of (float, float), optional, default: None
        Start and stop times for plot.

    Returns
    -------
    graph : str
        The bycycle plot as a string containing html.
    """

    # Initialize figures in groups of 10
    #   Plotly doesn't render single figures well with 100+ plots
    n_per_fig = 10
    n_figs = int(np.ceil(len(sigs) / n_per_fig))
    figs = np.zeros(n_figs).tolist()

    titles = [] if titles is None else titles
    n_rows = []
    for idx in range(n_figs):

        # Subplot titles
        start = idx * n_per_fig
        end = start + len(sigs[start:start + n_per_fig])

        if len(titles) == idx:
            titles.append("Indices: {start}-{end}".format(start=start,
                                                          end=end))

        # Create subplots
        n_rows.append(end - start)
        fig = make_subplots(rows=n_rows[idx],
                            cols=1,
                            vertical_spacing=0.005,
                            shared_xaxes=True)
        figs[idx] = fig

    for idx, df_features in enumerate(dfs_features):

        # Normalize signal
        sig = zscore(sigs[idx])

        # Determine time array and limits
        times = np.arange(0, len(sig) / fs, 1 / fs)
        xlim = (times[0], times[-1]) if xlim is None else xlim

        # Limit signal and dataframe
        sig, times = limit_signal(times, sig, start=xlim[0], stop=xlim[1])
        df_features = limit_df(df_features, fs, start=xlim[0], stop=xlim[1])

        # Get extrema
        center_e, side_e = get_extrema_df(df_features)

        # Plot bursts
        fig_idx = int(np.ceil((idx + 1) / n_per_fig)) - 1
        row_idx = int(idx - (n_per_fig * fig_idx))

        _plot_bursts(df_features,
                     sig,
                     times,
                     center_e,
                     side_e,
                     figs[fig_idx],
                     plot_cps=False,
                     row=row_idx + 1,
                     col=1)

    graphs = []
    for idx, fig in enumerate(figs):

        # The size of plotly subplots don't scale properly, this is a workaround
        height = (60 * n_rows[idx]) + ((n_per_fig - n_rows[idx]) * 18)

        # Update the figures
        figs[idx].update_layout(autosize=False,
                                width=1000,
                                height=height,
                                showlegend=False,
                                margin_autoexpand=False,
                                title_text=titles[idx])
        figs[idx].update_yaxes(showticklabels=False, showgrid=False)
        figs[idx].update_xaxes(showgrid=False)
        figs[idx].update_xaxes(rangeslider={'visible': True},
                               row=n_rows[idx],
                               col=1)

        # Convert to html
        if idx == 0:

            graphs.append(fig.to_html(include_plotlyjs=False, full_html=False))

        else:

            div = re.search(
                "<div>.*</div>",
                fig.to_html(include_plotlyjs=False, full_html=False))[0]
            graphs.append(div + "\n")

    # Custom js callback
    js_callback = ["<script type=\"text/javascript\">\n"]

    # Get plot div ids
    div_ids = [re.search("<div id=.+?\"", graph)[0][9:-1] for graph in graphs]

    # Recolor (non)bursts on click
    for div_id in div_ids:

        js_callback.append("""
        recolorBursts('{plotID}');
        """.format(plotID=div_id))

    js_callback.append("</script>\n")

    # Flatten lists into single string
    graphs.append("".join(js_callback))

    if btn:

        rewrite_call = "rewriteBursts({div_ids})".format(div_ids=div_ids)

        # Add a button
        btn = "\n\t\t<p><center><button onclick=\"" + rewrite_call + "\" class=\"btn\" "
        btn = btn + "title=\"update is_burst column\">Update Bursts</button></center></p>"
        graphs.append(btn)

    graphs[0] = re.sub("</body>\n</html>", "", graphs[0])
    graphs = "".join(graphs)

    return graphs
コード例 #4
0
def plot_burst_detect_param(df_features,
                            sig,
                            fs,
                            burst_param,
                            thresh,
                            xlim=None,
                            interp=True,
                            ax=None,
                            **kwargs):
    """Plot a burst detection parameter and threshold.

    Parameters
    ----------
    df_features : pandas.DataFrame
        Dataframe output of :func:`~.compute_features`.
    sig : 1d array
        Time series to plot.
    fs : float
        Sampling rate, in Hz.
    burst_param : str
        Column name of the parameter of interest in ``df``.
    thresh : float
        The burst parameter threshold. Parameter values greater
        than ``thresh`` are considered bursts.
    xlim : tuple of (float, float), optional, default: None
        Start and stop times for plot.
    interp : bool, optional, default: True
        Interpolates points if true.
    ax : matplotlib.Axes, optional
        Figure axes upon which to plot.
    **kwargs
        Keyword arguments to pass into `plot_time_series`.

    Notes
    -----
    Default keyword arguments include:

    - ``figsize``: tuple of (float, float), default: (15, 3)
    - ``xlabel``: str, default: 'Time (s)'
    - ``ylabel``: str, default: 'Voltage (uV)
    - ``color``: str, default: 'r'.

      - Note: ``color`` here is the fill color, rather than line color.

    Examples
    --------
    Plot the monotonicity of a bursting signal:

    >>> from bycycle.features import compute_features
    >>> from neurodsp.sim import sim_bursty_oscillation
    >>> fs = 500
    >>> sig = sim_bursty_oscillation(10, fs, freq=10)
    >>> threshold_kwargs = {'amp_fraction_threshold': 0., 'amp_consistency_threshold': .5,
    ...                     'period_consistency_threshold': .5, 'monotonicity_threshold': .8}
    >>> df_features = compute_features(sig, fs, f_range=(8, 12),
    ...                                            threshold_kwargs=threshold_kwargs)
    >>> plot_burst_detect_param(df_features, sig, fs, 'monotonicity', .8)
    """

    # Ensure arguments are within valid range
    check_param_range(fs, 'fs', (0, np.inf))

    # Set default kwargs
    figsize = kwargs.pop('figsize', (15, 3))
    xlabel = kwargs.pop('xlabel', 'Time (s)')
    ylabel = kwargs.pop('ylabel', burst_param)
    color = kwargs.pop('color', 'r')

    # Determine time array and limits
    times = np.arange(0, len(sig) / fs, 1 / fs)
    xlim = (times[0], times[-1]) if xlim is None else xlim

    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)

    # Determine extrema strings
    center_e, side_e = get_extrema_df(df_features)

    # Limit dataframe, sig and times
    df = limit_df(df_features, fs, start=xlim[0], stop=xlim[1])

    sig, times = limit_signal(times, sig, start=xlim[0], stop=xlim[1])

    # Remove start / end cycles that tlims falls between
    df = df[(df['sample_last_' + side_e] >= 0) & \
            (df['sample_next_' + side_e] < xlim[1]*fs)]

    # Plot burst param
    if interp:

        plot_time_series([times[df['sample_' + center_e]], xlim],
                         [df[burst_param], [thresh] * 2],
                         ax=ax,
                         colors=['k', 'k'],
                         ls=['-', '--'],
                         marker=["o", None],
                         xlabel=xlabel,
                         ylabel="{0:s}\nthreshold={1:.2f}".format(
                             ylabel, thresh),
                         **kwargs)

    else:

        # Create steps, from side to side of each cycle, and set the y-value
        #   to the burst parameter value for that cycle
        side_times = np.array([])
        side_param = np.array([])

        for _, cyc in df.iterrows():

            # Get the times for the last and next side of a cycle
            side_times = np.append(side_times, [
                times[int(cyc['sample_last_' + side_e])], times[int(
                    cyc['sample_next_' + side_e])]
            ])

            # Set the y-value, from side to side, to the burst param for each cycle
            side_param = np.append(side_param, [cyc[burst_param]] * 2)

        plot_time_series([side_times, xlim], [side_param, [thresh] * 2],
                         ax=ax,
                         colors=['k', 'k'],
                         ls=['-', '--'],
                         marker=["o", None],
                         xlim=xlim,
                         xlabel=xlabel,
                         ylabel="{0:s}\nthreshold={1:.2f}".format(
                             ylabel, thresh),
                         **kwargs)

    # Highlight where param falls below threshold
    for _, cyc in df.iterrows():

        if cyc[burst_param] <= thresh:

            ax.axvspan(times[int(cyc['sample_last_' + side_e])],
                       times[int(cyc['sample_next_' + side_e])],
                       alpha=0.5,
                       color=color,
                       lw=0)
コード例 #5
0
ファイル: bycycle.py プロジェクト: voytekresearch/ndspflow
def plot_bm(df_features,
            sig,
            fs,
            threshold_kwargs,
            df_idx,
            xlim=None,
            plot_only_result=True):
    """Plot a single bycycle fit using plotly.

    Parameters
    ----------
    df_features : pandas.DataFrame
        A dataframe containing shape and burst features for each cycle.
    sig : 1d array
        Time series.
    fs : float
        Sampling rate, in Hz.
    threshold_kwargs : dict, optional, default: None
        Feature thresholds for cycles to be considered bursts.
    df_idx : int
        The index of the dataframe in the javascript array. This is only used for fetching data
        for the relabel js callback.
    xlim : tuple of (float, float), optional, default: None
        Start and stop times for plot.
    plot_only_result : bool, optional, default: True
        Plot only the signal and bursts, excluding burst parameter plots.

    Notes
    -----
    The output_dir and df_idx arguments are used to fetch a javascript array containing the
    dataframe results during the js relabel callback. This is done since javascript can't access
    the local filesystem to load csv files.

    Returns
    -------
    graph : str
        The bycycle plot as a string containing html.
    """

    # Normalize signal
    sig = zscore(sig)

    # Determine time array and limits
    times = np.arange(0, len(sig) / fs, 1 / fs)
    xlim = (times[0], times[-1]) if xlim is None else xlim

    # Limit signal and dataframe
    sig, times = limit_signal(times, sig, start=xlim[0], stop=xlim[1])
    df_features = limit_df(df_features, fs, start=xlim[0], stop=xlim[1])

    # Determine if peak of troughs are the sides of an oscillation
    center_e, side_e = get_extrema_df(df_features)

    # Remove this kwarg since it isn't stored cycle by cycle in the df (nothing to plot)
    if 'min_n_cycles' in threshold_kwargs.keys():
        del threshold_kwargs['min_n_cycles']

    # Create figure and subplots
    if plot_only_result:
        fig = go.Figure(make_subplots(rows=1, cols=1))
    elif not plot_only_result:
        fig = go.Figure(make_subplots(rows=5, cols=1, vertical_spacing=0.01))

    # Plot bursts
    fig = _plot_bursts(df_features,
                       sig,
                       times,
                       center_e,
                       side_e,
                       fig,
                       row=1,
                       col=1)

    # Plot params
    if not plot_only_result:

        burst_params = [
            'amp_fraction', 'amp_consistency', 'period_consistency',
            'monotonicity', 'burst_fraction'
        ]

        burst_params = [
            param for param in burst_params if param in df_features.columns
        ]

        fig = _plot_params(df_features,
                           sig,
                           fs,
                           times,
                           center_e,
                           side_e,
                           burst_params,
                           threshold_kwargs,
                           fig,
                           row=2,
                           col=1)

    # Update axes and layout
    if not plot_only_result:

        # Add time label to last subplot
        fig.update_xaxes(title_text='Time',
                         showticklabels=True,
                         row=len(burst_params) + 2,
                         col=1)

        # Update signal axes
        fig.update_xaxes(showticklabels=False, row=1, col=1)
        fig.update_yaxes(title_text="Voltage<br>(normalized)", row=1, col=1)

        # Zoom link across all subplots
        fig.update_xaxes(matches="x")

        # Update layout across all subplots
        fig.update_layout(width=1000, height=1000)

    else:

        fig.update_layout(width=1000,
                          height=325,
                          xaxis_title="Time",
                          yaxis_title="Voltage<br>(normalized)")

    fig.update_layout(
        autosize=True,
        showlegend=False,
        title_text="Burst Detection Plots",
    )

    fig.update_xaxes(rangeslider={'visible': True}, row=5, col=1)

    graph = fig.to_html(include_plotlyjs=False, full_html=False)

    # Get burst traces
    trace_id = len(fig.data) - 2
    burst_traces = list(range(len(df_features)))

    # Get the div id containg the plot
    div_id = re.search("<div id=.+?\"", graph)[0]
    div_id = div_id[9:-1]

    # Create js callback
    js_callback = ["<script type=\"text/javascript\">\n"]
    js_callback.append("""
    var burstPlot = document.getElementById('{plot_id}');
    var burstTraces = {burst_traces};
    var traceId = {trace_id};
    burstPlot.on('plotly_click', function(data){{
        relabel1DBursts(data, burstPlot, {idx}, burstTraces, traceId);
    }});
    """.format(trace_id=trace_id,
               burst_traces=str(burst_traces),
               plot_id=div_id,
               idx=df_idx))
    js_callback.append("</script>\n")
    js_callback = "".join(js_callback)

    # Flatten lists into single string
    graph = re.sub("</body>\n</html>", "\n", graph)

    # Add recompute burst btn below the plots
    rewrite_call = "rewriteBursts({div_id})".format(div_id=str([div_id]))
    btn = "\n\t\t<p><center><button onclick=\"" + rewrite_call + "\" class=\"btn\" "
    btn = btn + "title=\"update is_burst column\">Update Bursts</button></center></p>"

    graph = graph + js_callback + btn + "\n</body>\n</html>"

    return graph