Exemple #1
0
def plot_burst_detect_param(df,
                            sig,
                            fs,
                            burst_param,
                            thresh,
                            xlim=None,
                            ax=None,
                            interp=True,
                            **kwargs):
    """Plot a burst detection parameter and threshold.

    Parameters
    ----------
    df : pandas.DataFrame
        Dataframe output of :func:`~.compute_features`.
    sig : 1d array
        Time series to plot.
    fs : float
        Sampling rate, in Hz.
    burst_param : str
        Column name of the parameter of interest in ``df``.
    thresh : float
        The burst parameter threshold. Parameter values greater
        than ``thresh`` are considered bursts.
    xlim : tuple of (float, float), optional, default: None
        Start and stop times for plot.
    ax : matplotlib.Axes, optional
        Figure axes upon which to plot.
    interp : bool
        Interpolates points if true.
    **kwargs
        Keyword arguments to pass into `plot_time_series`.

    Notes
    -----
    Default keyword arguments include:

    - ``figsize``: tuple of (float, float), default: (15, 3)
    - ``xlabel``: str, default: 'Time (s)'
    - ``ylabel``: str, default: 'Voltage (uV)
    - ``color``: str, default: 'r'.

      - Note: ``color`` here is the fill color, rather than line color.

    """

    # Set default kwargs
    figsize = kwargs.pop('figsize', (15, 3))
    xlabel = kwargs.pop('xlabel', 'Time (s)')
    ylabel = kwargs.pop('ylabel', burst_param)
    color = kwargs.pop('color', 'r')

    # Determine time array and limits
    times = np.arange(0, len(sig) / fs, 1 / fs)
    xlim = (times[0], times[-1]) if xlim is None else xlim

    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)

    # Determine extrema strings
    center_e, side_e = get_extrema(df)

    # Limit dataframe, sig and times
    df = limit_df(df, fs, start=xlim[0], stop=xlim[1])
    sig, times = limit_signal(times, sig, start=xlim[0], stop=xlim[1])

    # Remove start/end cycles that tlims falls between
    df = df[df['sample_last_' + side_e] >= 0]
    df = df[df['sample_next_' + side_e] < xlim[1] * fs]

    # Plot burst param
    if interp:

        plot_time_series([times[df['sample_' + center_e]], xlim],
                         [df[burst_param], [thresh] * 2],
                         ax=ax,
                         colors=['k', 'k'],
                         ls=['-', '--'],
                         marker=["o", None],
                         lim=xlim,
                         xlabel=xlabel,
                         ylabel="{0:s}\nthreshold={1:.2f}".format(
                             ylabel, thresh),
                         **kwargs)

    else:

        # Create steps, from side to side of each cycle, and set the y-value to the burst parameter
        #   value for that cycle.
        side_times = np.array([])
        side_param = np.array([])

        for _, cyc in df.iterrows():

            # Get the times for the last and next side of a cycle
            side_times = np.append(side_times, [
                times[cyc['sample_last_' + side_e]],
                times[cyc['sample_next_' + side_e]]
            ])

            # Set the y-value, from side to side, to the burst param for a cycle.
            side_param = np.append(side_param, [cyc[burst_param]] * 2)

        plot_time_series([side_times, xlim], [side_param, [thresh] * 2],
                         ax=ax,
                         colors=['k', 'k'],
                         ls=['-', '--'],
                         marker=["o", None],
                         xlim=xlim,
                         xlabel=xlabel,
                         ylabel="{0:s}\nthreshold={1:.2f}".format(
                             ylabel, thresh),
                         **kwargs)

    # Highlight where param falls below threshold
    for _, cyc in df.iterrows():

        if cyc[burst_param] <= thresh:

            ax.axvspan(times[cyc['sample_last_' + side_e]],
                       times[cyc['sample_next_' + side_e]],
                       alpha=0.5,
                       color=color,
                       lw=0)
Exemple #2
0
def plot_cyclepoints_array(sig,
                           fs,
                           peaks=None,
                           troughs=None,
                           rises=None,
                           decays=None,
                           plot_sig=True,
                           xlim=None,
                           ax=None,
                           **kwargs):
    """Plot extrema and/or zero-crossings from arrays.

    Parameters
    ----------
    sig : 1d array
        Time series to plot.
    fs : float
        Sampling rate, in Hz.
    peaks : 1d array, optional
        Peak signal indices from :func:`.find_extrema`.
    troughs : 1d array, optional
        Trough signal indices from :func:`.find_extrema`.
    rises : 1d array, optional
        Zero-crossing rise indices from :func:`~.find_zerox`.
    decays : 1d array, optional
        Zero-crossing decay indices from :func:`~.find_zerox`.
    plot_sig : bool, optional, default: True
        Whether to also plot the raw signal.
    xlim : tuple of (float, float), optional
        Start and stop times.
    ax : matplotlib.Axes, optional, default: None
        Figure axes upon which to plot.
    **kwargs
        Keyword arguments to pass into `plot_time_series`.

    Notes
    -----
    Default keyword arguments include:

    - ``figsize``: tuple of (float, float), default: (15, 3)
    - ``xlabel``: str, default: 'Time (s)'
    - ``ylabel``: str, default: 'Voltage (uV)
    - ``colors``: list, default: ['k', 'b', 'r', 'g', 'm']

    Examples
    --------
    Plot cyclepoints using arrays from :func:`.find_extrema` and  :func:`~.find_zerox`:

    >>> from bycycle.cyclepoints import find_extrema, find_zerox
    >>> from neurodsp.sim import sim_bursty_oscillation
    >>> fs = 500
    >>> sig = sim_bursty_oscillation(10, fs, freq=10)
    >>> peaks, troughs = find_extrema(sig, fs, f_range=(8, 12), boundary=0)
    >>> rises, decays = find_zerox(sig, peaks, troughs)
    >>> plot_cyclepoints_array(sig, fs, peaks=peaks, troughs=troughs, rises=rises, decays=decays)
    """

    # Ensure arguments are within valid range
    check_param_range(fs, 'fs', (0, np.inf))

    # Set times and limits
    times = np.arange(0, len(sig) / fs, 1 / fs)

    # Restrict sig and times to xlim
    if xlim is not None:
        sig, times = limit_signal(times, sig, start=xlim[0], stop=xlim[1])

    # Set default kwargs
    figsize = kwargs.pop('figsize', (15, 3))
    xlabel = kwargs.pop('xlabel', 'Time (s)')
    ylabel = kwargs.pop('ylabel', 'Voltage (uV)')
    default_colors = ['b', 'r', 'g', 'm']

    # Extend plotting based on given arguments
    x_values = []
    y_values = []
    colors = ['k']

    for idx, points in enumerate([peaks, troughs, rises, decays]):

        if points is not None:

            # Limit times and shift indices of cyclepoints (cps)
            cps = points[(points >= times[0] * fs) & (points < times[-1] * fs)]
            cps = cps - int(times[0] * fs)

            y_values.append(sig[cps])
            x_values.append(times[cps])
            colors.append(default_colors[idx])

    # Allow custom colors to overwrite default
    colors = kwargs.pop('colors', colors)

    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)

    if plot_sig:
        plot_time_series(times, sig, colors=colors[0], ax=ax)
        colors = colors[1:]

    plot_time_series(x_values,
                     y_values,
                     ax=ax,
                     xlabel=xlabel,
                     ylabel=ylabel,
                     colors=colors,
                     marker='o',
                     ls='',
                     **kwargs)
def plot_cyclepoints_array(sig,
                           fs,
                           peaks=None,
                           troughs=None,
                           rises=None,
                           decays=None,
                           plot_sig=True,
                           xlim=None,
                           ax=None,
                           **kwargs):
    """Plot extrema and/or zero-crossings using arrays to define points.

    Parameters
    ----------
    sig : 1d array
        Time series to plot.
    fs : float
        Sampling rate, in Hz.
    xlim : tuple of (float, float), optional, default: None
        Start and stop times.
    peaks : 1d array, optional, default: None
        Peak signal indices from :func:`.find_extrema`.
    troughs : 1d array, optional, default: None
        Trough signal indices from :func:`.find_extrema`.
    rises : 1d array, optional, default: None
        Zero-crossing rise indices from :func:`~.find_zerox`.
    decays : 1d array, optional, default: None
        Zero-crossing decay indices from :func:`~.find_zerox`.
    ax : matplotlib.Axes, optional, default: None
        Figure axes upon which to plot.
    **kwargs
        Keyword arguments to pass into `plot_time_series`.

    Notes
    -----
    Default keyword arguments include:

    - ``figsize``: tuple of (float, float), default: (15, 3)
    - ``xlabel``: str, default: 'Time (s)'
    - ``ylabel``: str, default: 'Voltage (uV)
    - ``colors``: list, default: ['k', 'b', 'r', 'g', 'm']

    """

    # Set times and limits
    times = np.arange(0, len(sig) / fs, 1 / fs)
    xlim = (times[0], times[-1]) if xlim is None else xlim

    # Restrict sig and times to xlim
    sig, times = limit_signal(times, sig, start=xlim[0], stop=xlim[1])

    # Set default kwargs
    figsize = kwargs.pop('figsize', (15, 3))
    xlabel = kwargs.pop('xlabel', 'Time (s)')
    ylabel = kwargs.pop('ylabel', 'Voltage (uV)')
    default_colors = ['b', 'r', 'g', 'm']

    # Extend plotting based on given arguments
    x_values = []
    y_values = []
    colors = ['k']

    for idx, points in enumerate([peaks, troughs, rises, decays]):

        if points is not None:

            # Limit times and shift indices of cyclepoints (cps)
            cps = points[(points >= xlim[0] * fs) & (points < xlim[1] * fs)]
            cps = cps - int(xlim[0] * fs)

            y_values.append(sig[cps])
            x_values.append(times[cps])
            colors.append(default_colors[idx])

    # Allow custom colors to overwrite default
    colors = kwargs.pop('colors', colors)

    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)

    if plot_sig:
        plot_time_series(times, sig, colors=colors[0], ax=ax)
        colors = colors[1:]

    plot_time_series(x_values,
                     y_values,
                     ax=ax,
                     xlabel=xlabel,
                     ylabel=ylabel,
                     colors=colors,
                     marker='o',
                     ls='',
                     **kwargs)
Exemple #4
0
def plot_burst_detect_param(df_features,
                            sig,
                            fs,
                            burst_param,
                            thresh,
                            xlim=None,
                            interp=True,
                            ax=None,
                            **kwargs):
    """Plot a burst detection parameter and threshold.

    Parameters
    ----------
    df_features : pandas.DataFrame
        Dataframe output of :func:`~.compute_features`.
    sig : 1d array
        Time series to plot.
    fs : float
        Sampling rate, in Hz.
    burst_param : str
        Column name of the parameter of interest in ``df``.
    thresh : float
        The burst parameter threshold. Parameter values greater
        than ``thresh`` are considered bursts.
    xlim : tuple of (float, float), optional, default: None
        Start and stop times for plot.
    interp : bool, optional, default: True
        Interpolates points if true.
    ax : matplotlib.Axes, optional
        Figure axes upon which to plot.
    **kwargs
        Keyword arguments to pass into `plot_time_series`.

    Notes
    -----
    Default keyword arguments include:

    - ``figsize``: tuple of (float, float), default: (15, 3)
    - ``xlabel``: str, default: 'Time (s)'
    - ``ylabel``: str, default: 'Voltage (uV)
    - ``color``: str, default: 'r'.

      - Note: ``color`` here is the fill color, rather than line color.

    Examples
    --------
    Plot the monotonicity of a bursting signal:

    >>> from bycycle.features import compute_features
    >>> from neurodsp.sim import sim_bursty_oscillation
    >>> fs = 500
    >>> sig = sim_bursty_oscillation(10, fs, freq=10)
    >>> threshold_kwargs = {'amp_fraction_threshold': 0., 'amp_consistency_threshold': .5,
    ...                     'period_consistency_threshold': .5, 'monotonicity_threshold': .8}
    >>> df_features = compute_features(sig, fs, f_range=(8, 12),
    ...                                            threshold_kwargs=threshold_kwargs)
    >>> plot_burst_detect_param(df_features, sig, fs, 'monotonicity', .8)
    """

    # Ensure arguments are within valid range
    check_param_range(fs, 'fs', (0, np.inf))

    # Set default kwargs
    figsize = kwargs.pop('figsize', (15, 3))
    xlabel = kwargs.pop('xlabel', 'Time (s)')
    ylabel = kwargs.pop('ylabel', burst_param)
    color = kwargs.pop('color', 'r')

    # Determine time array and limits
    times = np.arange(0, len(sig) / fs, 1 / fs)
    xlim = (times[0], times[-1]) if xlim is None else xlim

    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)

    # Determine extrema strings
    center_e, side_e = get_extrema_df(df_features)

    # Limit dataframe, sig and times
    df = limit_df(df_features, fs, start=xlim[0], stop=xlim[1])

    sig, times = limit_signal(times, sig, start=xlim[0], stop=xlim[1])

    # Remove start / end cycles that tlims falls between
    df = df[(df['sample_last_' + side_e] >= 0) & \
            (df['sample_next_' + side_e] < xlim[1]*fs)]

    # Plot burst param
    if interp:

        plot_time_series([times[df['sample_' + center_e]], xlim],
                         [df[burst_param], [thresh] * 2],
                         ax=ax,
                         colors=['k', 'k'],
                         ls=['-', '--'],
                         marker=["o", None],
                         xlabel=xlabel,
                         ylabel="{0:s}\nthreshold={1:.2f}".format(
                             ylabel, thresh),
                         **kwargs)

    else:

        # Create steps, from side to side of each cycle, and set the y-value
        #   to the burst parameter value for that cycle
        side_times = np.array([])
        side_param = np.array([])

        for _, cyc in df.iterrows():

            # Get the times for the last and next side of a cycle
            side_times = np.append(side_times, [
                times[int(cyc['sample_last_' + side_e])], times[int(
                    cyc['sample_next_' + side_e])]
            ])

            # Set the y-value, from side to side, to the burst param for each cycle
            side_param = np.append(side_param, [cyc[burst_param]] * 2)

        plot_time_series([side_times, xlim], [side_param, [thresh] * 2],
                         ax=ax,
                         colors=['k', 'k'],
                         ls=['-', '--'],
                         marker=["o", None],
                         xlim=xlim,
                         xlabel=xlabel,
                         ylabel="{0:s}\nthreshold={1:.2f}".format(
                             ylabel, thresh),
                         **kwargs)

    # Highlight where param falls below threshold
    for _, cyc in df.iterrows():

        if cyc[burst_param] <= thresh:

            ax.axvspan(times[int(cyc['sample_last_' + side_e])],
                       times[int(cyc['sample_next_' + side_e])],
                       alpha=0.5,
                       color=color,
                       lw=0)
Exemple #5
0
def plot_burst_detect_summary(df_features,
                              sig,
                              fs,
                              threshold_kwargs,
                              xlim=None,
                              figsize=(15, 3),
                              plot_only_result=False,
                              interp=True):
    """Plot the cycle-by-cycle burst detection parameters and burst detection summary.

    Parameters
    ----------
    df_features : pandas.DataFrame
        Dataframe output of :func:`~.compute_features`. The df must contain sample indices (i.e.
        when ``return_samples = True``).
    sig : 1d array
        Time series to plot.
    fs : float
        Sampling rate, in Hz.
    threshold_kwargs : dict
        Burst parameter keys and threshold value pairs, as defined in the 'threshold_kwargs'
        argument of :func:`.compute_features`.
    xlim : tuple of (float, float), optional, default: None
        Start and stop times for plot.
    figsize : tuple of (float, float), optional, default: (15, 3)
        Size of each plot.
    plot_only_result : bool, optional, default: False
        Plot only the signal and bursts, excluding burst parameter plots.
    interp : bool, optional, default: True
        If True, interpolates between given values. Otherwise, plots in a step-wise fashion.

    Notes
    -----

    - If plot_only_result = True: return a plot of the burst detection in which periods with bursts
      are denoted in red.

    - If plot_only_result = False: return a list of the fig handle followed by the 5 axes.

    - In the top plot, the raw signal is plotted in black, and the red line indicates periods
      defined as oscillatory bursts. The highlighted regions indicate when each burst requirement
      was violated, color-coded consistently with the plots below.

      - blue: amp_fraction_threshold
      - red: amp_consistency_threshold
      - yellow: period_consistency_threshold
      - green: monotonicity_threshold

    Examples
    --------
    Plot the burst detection summary of a bursting signal:

    >>> from bycycle.features import compute_features
    >>> from neurodsp.sim import sim_bursty_oscillation
    >>> fs = 500
    >>> sig = sim_bursty_oscillation(10, fs, freq=10)
    >>> threshold_kwargs = {'amp_fraction_threshold': 0., 'amp_consistency_threshold': .5,
    ...                     'period_consistency_threshold': .5, 'monotonicity_threshold': .8}
    >>> df_features = compute_features(sig, fs, f_range=(8, 12), threshold_kwargs=threshold_kwargs)
    >>> plot_burst_detect_summary(df_features, sig, fs, threshold_kwargs)
    """

    # Ensure arguments are within valid range
    check_param_range(fs, 'fs', (0, np.inf))

    # Normalize signal
    sig_full = zscore(sig)
    times_full = np.arange(0, len(sig_full) / fs, 1 / fs)

    # Limit arrays and dataframe
    if xlim is not None:
        sig, times = limit_signal(times_full,
                                  sig_full,
                                  start=xlim[0],
                                  stop=xlim[1])
        df_features = limit_df(df_features,
                               fs,
                               start=xlim[0],
                               stop=xlim[1],
                               reset_indices=False)
    else:
        sig, times, = sig_full, times_full

    # Determine if peak of troughs are the sides of an oscillation
    _, side_e = get_extrema_df(df_features)

    # Remove this kwarg since it isn't stored cycle by cycle in the df (nothing to plot)
    thresholds = threshold_kwargs.copy()
    if 'min_n_cycles' in thresholds.keys():
        del thresholds['min_n_cycles']

    n_kwargs = len(thresholds.keys())

    # Create figure and subplots
    if plot_only_result:
        fig, axes = plt.subplots(figsize=figsize, nrows=1)
        axes = [axes]
    else:
        fig, axes = plt.subplots(figsize=(figsize[0],
                                          figsize[1] * (n_kwargs + 1)),
                                 nrows=n_kwargs + 1,
                                 sharex=True)

    # Determine which samples are defined as bursting
    is_osc = np.zeros(len(sig), dtype=bool)
    df_osc = df_features.loc[df_features['is_burst']]
    start = 0 if xlim is None else xlim[0]

    for _, cyc in df_osc.iterrows():

        samp_start_burst = int(cyc['sample_last_' + side_e]) - int(fs * start)
        samp_end_burst = int(cyc['sample_next_' + side_e] + 1) - int(
            fs * start)

        is_osc[samp_start_burst:samp_end_burst] = True

    # Plot bursts with extrema points
    xlabel = 'Time (s)' if len(axes) == 1 else ''

    plot_bursts(times,
                sig,
                is_osc,
                ax=axes[0],
                lw=2,
                labels=['Signal', 'Bursts'],
                xlabel='',
                ylabel='')

    plot_cyclepoints_df(df_features,
                        sig_full,
                        fs,
                        xlim=xlim,
                        ax=axes[0],
                        plot_zerox=False,
                        plot_sig=False,
                        xlabel=xlabel,
                        ylabel='Voltage\n(normalized)',
                        colors=['m', 'c'])

    # Plot each burst param
    colors = cycle(
        ['blue', 'red', 'yellow', 'green', 'cyan', 'magenta', 'orange'])

    for idx, osc_key in enumerate(thresholds.keys()):

        column = osc_key.replace('_threshold', '')

        color = next(colors)

        # Highlight where a burst param falls below threshold
        for _, cyc in df_features.iterrows():

            last_cyc = int(cyc['sample_last_' + side_e]) - int(fs * start)
            next_cyc = int(cyc['sample_next_' + side_e]) - int(fs * start)
            if cyc[column] < threshold_kwargs[osc_key] and last_cyc > 0:
                axes[0].axvspan(times[last_cyc],
                                times[next_cyc],
                                alpha=0.5,
                                color=color,
                                lw=0)

        # Plot each burst param on separate axes
        if not plot_only_result:

            ylabel = column.replace('_', ' ').capitalize()
            xlabel = 'Time (s)' if idx == n_kwargs - 1 else ''

            plot_burst_detect_param(df_features,
                                    sig_full,
                                    fs,
                                    column,
                                    thresholds[osc_key],
                                    xlim=xlim,
                                    figsize=figsize,
                                    ax=axes[idx + 1],
                                    xlabel=xlabel,
                                    ylabel=ylabel,
                                    color=color,
                                    interp=interp)
Exemple #6
0
def plot_bg(dfs_features, sigs, fs, titles=None, btn=True, xlim=None):
    """Plot 2D bycycle results.

    Parameters
    ----------
    dfs_features : list of pandas.DataFrame
        Dataframes containing shape and burst features for each cycle.
    sigs : 2d array
        Time series.
    fs : float
        Sampling rate, in Hz.
    titles : list, optional, default: None
        The titles for each subplot.
    btn : bool, optional, default: True
        Adds a recompute bursts button when True. Omits when False.
    xlim : tuple of (float, float), optional, default: None
        Start and stop times for plot.

    Returns
    -------
    graph : str
        The bycycle plot as a string containing html.
    """

    # Initialize figures in groups of 10
    #   Plotly doesn't render single figures well with 100+ plots
    n_per_fig = 10
    n_figs = int(np.ceil(len(sigs) / n_per_fig))
    figs = np.zeros(n_figs).tolist()

    titles = [] if titles is None else titles
    n_rows = []
    for idx in range(n_figs):

        # Subplot titles
        start = idx * n_per_fig
        end = start + len(sigs[start:start + n_per_fig])

        if len(titles) == idx:
            titles.append("Indices: {start}-{end}".format(start=start,
                                                          end=end))

        # Create subplots
        n_rows.append(end - start)
        fig = make_subplots(rows=n_rows[idx],
                            cols=1,
                            vertical_spacing=0.005,
                            shared_xaxes=True)
        figs[idx] = fig

    for idx, df_features in enumerate(dfs_features):

        # Normalize signal
        sig = zscore(sigs[idx])

        # Determine time array and limits
        times = np.arange(0, len(sig) / fs, 1 / fs)
        xlim = (times[0], times[-1]) if xlim is None else xlim

        # Limit signal and dataframe
        sig, times = limit_signal(times, sig, start=xlim[0], stop=xlim[1])
        df_features = limit_df(df_features, fs, start=xlim[0], stop=xlim[1])

        # Get extrema
        center_e, side_e = get_extrema_df(df_features)

        # Plot bursts
        fig_idx = int(np.ceil((idx + 1) / n_per_fig)) - 1
        row_idx = int(idx - (n_per_fig * fig_idx))

        _plot_bursts(df_features,
                     sig,
                     times,
                     center_e,
                     side_e,
                     figs[fig_idx],
                     plot_cps=False,
                     row=row_idx + 1,
                     col=1)

    graphs = []
    for idx, fig in enumerate(figs):

        # The size of plotly subplots don't scale properly, this is a workaround
        height = (60 * n_rows[idx]) + ((n_per_fig - n_rows[idx]) * 18)

        # Update the figures
        figs[idx].update_layout(autosize=False,
                                width=1000,
                                height=height,
                                showlegend=False,
                                margin_autoexpand=False,
                                title_text=titles[idx])
        figs[idx].update_yaxes(showticklabels=False, showgrid=False)
        figs[idx].update_xaxes(showgrid=False)
        figs[idx].update_xaxes(rangeslider={'visible': True},
                               row=n_rows[idx],
                               col=1)

        # Convert to html
        if idx == 0:

            graphs.append(fig.to_html(include_plotlyjs=False, full_html=False))

        else:

            div = re.search(
                "<div>.*</div>",
                fig.to_html(include_plotlyjs=False, full_html=False))[0]
            graphs.append(div + "\n")

    # Custom js callback
    js_callback = ["<script type=\"text/javascript\">\n"]

    # Get plot div ids
    div_ids = [re.search("<div id=.+?\"", graph)[0][9:-1] for graph in graphs]

    # Recolor (non)bursts on click
    for div_id in div_ids:

        js_callback.append("""
        recolorBursts('{plotID}');
        """.format(plotID=div_id))

    js_callback.append("</script>\n")

    # Flatten lists into single string
    graphs.append("".join(js_callback))

    if btn:

        rewrite_call = "rewriteBursts({div_ids})".format(div_ids=div_ids)

        # Add a button
        btn = "\n\t\t<p><center><button onclick=\"" + rewrite_call + "\" class=\"btn\" "
        btn = btn + "title=\"update is_burst column\">Update Bursts</button></center></p>"
        graphs.append(btn)

    graphs[0] = re.sub("</body>\n</html>", "", graphs[0])
    graphs = "".join(graphs)

    return graphs
Exemple #7
0
def plot_bm(df_features,
            sig,
            fs,
            threshold_kwargs,
            df_idx,
            xlim=None,
            plot_only_result=True):
    """Plot a single bycycle fit using plotly.

    Parameters
    ----------
    df_features : pandas.DataFrame
        A dataframe containing shape and burst features for each cycle.
    sig : 1d array
        Time series.
    fs : float
        Sampling rate, in Hz.
    threshold_kwargs : dict, optional, default: None
        Feature thresholds for cycles to be considered bursts.
    df_idx : int
        The index of the dataframe in the javascript array. This is only used for fetching data
        for the relabel js callback.
    xlim : tuple of (float, float), optional, default: None
        Start and stop times for plot.
    plot_only_result : bool, optional, default: True
        Plot only the signal and bursts, excluding burst parameter plots.

    Notes
    -----
    The output_dir and df_idx arguments are used to fetch a javascript array containing the
    dataframe results during the js relabel callback. This is done since javascript can't access
    the local filesystem to load csv files.

    Returns
    -------
    graph : str
        The bycycle plot as a string containing html.
    """

    # Normalize signal
    sig = zscore(sig)

    # Determine time array and limits
    times = np.arange(0, len(sig) / fs, 1 / fs)
    xlim = (times[0], times[-1]) if xlim is None else xlim

    # Limit signal and dataframe
    sig, times = limit_signal(times, sig, start=xlim[0], stop=xlim[1])
    df_features = limit_df(df_features, fs, start=xlim[0], stop=xlim[1])

    # Determine if peak of troughs are the sides of an oscillation
    center_e, side_e = get_extrema_df(df_features)

    # Remove this kwarg since it isn't stored cycle by cycle in the df (nothing to plot)
    if 'min_n_cycles' in threshold_kwargs.keys():
        del threshold_kwargs['min_n_cycles']

    # Create figure and subplots
    if plot_only_result:
        fig = go.Figure(make_subplots(rows=1, cols=1))
    elif not plot_only_result:
        fig = go.Figure(make_subplots(rows=5, cols=1, vertical_spacing=0.01))

    # Plot bursts
    fig = _plot_bursts(df_features,
                       sig,
                       times,
                       center_e,
                       side_e,
                       fig,
                       row=1,
                       col=1)

    # Plot params
    if not plot_only_result:

        burst_params = [
            'amp_fraction', 'amp_consistency', 'period_consistency',
            'monotonicity', 'burst_fraction'
        ]

        burst_params = [
            param for param in burst_params if param in df_features.columns
        ]

        fig = _plot_params(df_features,
                           sig,
                           fs,
                           times,
                           center_e,
                           side_e,
                           burst_params,
                           threshold_kwargs,
                           fig,
                           row=2,
                           col=1)

    # Update axes and layout
    if not plot_only_result:

        # Add time label to last subplot
        fig.update_xaxes(title_text='Time',
                         showticklabels=True,
                         row=len(burst_params) + 2,
                         col=1)

        # Update signal axes
        fig.update_xaxes(showticklabels=False, row=1, col=1)
        fig.update_yaxes(title_text="Voltage<br>(normalized)", row=1, col=1)

        # Zoom link across all subplots
        fig.update_xaxes(matches="x")

        # Update layout across all subplots
        fig.update_layout(width=1000, height=1000)

    else:

        fig.update_layout(width=1000,
                          height=325,
                          xaxis_title="Time",
                          yaxis_title="Voltage<br>(normalized)")

    fig.update_layout(
        autosize=True,
        showlegend=False,
        title_text="Burst Detection Plots",
    )

    fig.update_xaxes(rangeslider={'visible': True}, row=5, col=1)

    graph = fig.to_html(include_plotlyjs=False, full_html=False)

    # Get burst traces
    trace_id = len(fig.data) - 2
    burst_traces = list(range(len(df_features)))

    # Get the div id containg the plot
    div_id = re.search("<div id=.+?\"", graph)[0]
    div_id = div_id[9:-1]

    # Create js callback
    js_callback = ["<script type=\"text/javascript\">\n"]
    js_callback.append("""
    var burstPlot = document.getElementById('{plot_id}');
    var burstTraces = {burst_traces};
    var traceId = {trace_id};
    burstPlot.on('plotly_click', function(data){{
        relabel1DBursts(data, burstPlot, {idx}, burstTraces, traceId);
    }});
    """.format(trace_id=trace_id,
               burst_traces=str(burst_traces),
               plot_id=div_id,
               idx=df_idx))
    js_callback.append("</script>\n")
    js_callback = "".join(js_callback)

    # Flatten lists into single string
    graph = re.sub("</body>\n</html>", "\n", graph)

    # Add recompute burst btn below the plots
    rewrite_call = "rewriteBursts({div_id})".format(div_id=str([div_id]))
    btn = "\n\t\t<p><center><button onclick=\"" + rewrite_call + "\" class=\"btn\" "
    btn = btn + "title=\"update is_burst column\">Update Bursts</button></center></p>"

    graph = graph + js_callback + btn + "\n</body>\n</html>"

    return graph