Ejemplo n.º 1
0
Archivo: fm.py Proyecto: parhalje/fooof
def _add_peaks_width(fm, plt_log, ax, **plot_kwargs):
    """Add a line across the width of peaks.

    Parameters
    ----------
    fm : FOOOF
        FOOOF object containing results from fitting.
    plt_log : boolean
        Whether to plot the frequency values in log10 spacing.
    ax : matplotlib.Axes
        Figure axes upon which to plot.
    **plot_kwargs
        Keyword arguments to pass into the plot call.

    Notes
    -----
    This line represents the bandwidth (width or gaussian standard deviation) of
    the peak, though what is literally plotted is the full-width half-max.
    """

    defaults = {
        'color': PLT_COLORS['periodic'],
        'alpha': 0.6,
        'lw': 2.5,
        'ms': 6
    }
    plot_kwargs = check_plot_kwargs(plot_kwargs, defaults)

    for peak in fm.gaussian_params_:

        peak_top = fm.power_spectrum[nearest_ind(fm.freqs, peak[0])]
        bw_freqs = [
            peak[0] - 0.5 * compute_fwhm(peak[2]),
            peak[0] + 0.5 * compute_fwhm(peak[2])
        ]

        if plt_log:
            bw_freqs = np.log10(bw_freqs)

        ax.plot(bw_freqs,
                [peak_top - (0.5 * peak[1]), peak_top - (0.5 * peak[1])],
                **plot_kwargs)
Ejemplo n.º 2
0
def plot_annotated_model(fm,
                         plt_log=False,
                         annotate_peaks=True,
                         annotate_aperiodic=True,
                         ax=None,
                         plot_style=style_spectrum_plot):
    """Plot a an annotated power spectrum and model, from a FOOOF object.

    Parameters
    ----------
    fm : FOOOF
        FOOOF object, with model fit, data and settings available.
    plt_log : boolean, optional, default: False
        Whether to plot the frequency values in log10 spacing.
    ax : matplotlib.Axes, optional
        Figure axes upon which to plot.
    plot_style : callable, optional, default: style_spectrum_plot
        A function to call to apply styling & aesthetics to the plots.

    Raises
    ------
    NoModelError
        If there are no model results available to plot.
    """

    # Check that model is available
    if not fm.has_model:
        raise NoModelError("No model is available to plot, can not proceed.")

    # Settings
    fontsize = 15
    lw1 = 4.0
    lw2 = 3.0
    ms1 = 12

    # Create the baseline figure
    ax = check_ax(ax, PLT_FIGSIZES['spectral'])
    fm.plot(plot_peaks='dot-shade-width',
            plt_log=plt_log,
            ax=ax,
            plot_style=None,
            data_kwargs={
                'lw': lw1,
                'alpha': 0.6
            },
            aperiodic_kwargs={
                'lw': lw1,
                'zorder': 10
            },
            model_kwargs={
                'lw': lw1,
                'alpha': 0.5
            },
            peak_kwargs={
                'dot': {
                    'color': PLT_COLORS['periodic'],
                    'ms': ms1,
                    'lw': lw2
                },
                'shade': {
                    'color': PLT_COLORS['periodic']
                },
                'width': {
                    'color': PLT_COLORS['periodic'],
                    'alpha': 0.75,
                    'lw': lw2
                }
            })

    # Get freqs for plotting, and convert to log if needed
    freqs = fm.freqs if not plt_log else np.log10(fm.freqs)

    ## Buffers: for spacing things out on the plot (scaled by plot values)
    x_buff1 = max(freqs) * 0.1
    x_buff2 = max(freqs) * 0.25
    y_buff1 = 0.15 * np.ptp(ax.get_ylim())
    shrink = 0.1

    # There is a bug in annotations for some perpendicular lines, so add small offset
    #   See: https://github.com/matplotlib/matplotlib/issues/12820. Fixed in 3.2.1.
    bug_buff = 0.000001

    if annotate_peaks:

        # Extract largest peak, to annotate, grabbing gaussian params
        gauss = get_band_peak_fm(fm,
                                 fm.freq_range,
                                 attribute='gaussian_params')

        peak_ctr, peak_hgt, peak_wid = gauss
        bw_freqs = [
            peak_ctr - 0.5 * compute_fwhm(peak_wid),
            peak_ctr + 0.5 * compute_fwhm(peak_wid)
        ]

        if plt_log:
            peak_ctr = np.log10(peak_ctr)
            bw_freqs = np.log10(bw_freqs)

        peak_top = fm.power_spectrum[nearest_ind(freqs, peak_ctr)]

        # Annotate Peak CF
        ax.annotate('Center Frequency',
                    xy=(peak_ctr, peak_top),
                    xytext=(peak_ctr, peak_top + np.abs(0.6 * peak_hgt)),
                    verticalalignment='center',
                    horizontalalignment='center',
                    arrowprops=dict(facecolor=PLT_COLORS['periodic'],
                                    shrink=shrink),
                    color=PLT_COLORS['periodic'],
                    fontsize=fontsize)

        # Annotate Peak PW
        ax.annotate('Power',
                    xy=(peak_ctr, peak_top - 0.3 * peak_hgt),
                    xytext=(peak_ctr + x_buff1, peak_top - 0.3 * peak_hgt),
                    verticalalignment='center',
                    arrowprops=dict(facecolor=PLT_COLORS['periodic'],
                                    shrink=shrink),
                    color=PLT_COLORS['periodic'],
                    fontsize=fontsize)

        # Annotate Peak BW
        bw_buff = (peak_ctr - bw_freqs[0]) / 2
        ax.annotate('Bandwidth',
                    xy=(peak_ctr - bw_buff + bug_buff,
                        peak_top - (0.5 * peak_hgt)),
                    xytext=(peak_ctr - bw_buff, peak_top - (1.5 * peak_hgt)),
                    verticalalignment='center',
                    horizontalalignment='right',
                    arrowprops=dict(facecolor=PLT_COLORS['periodic'],
                                    shrink=shrink),
                    color=PLT_COLORS['periodic'],
                    fontsize=fontsize,
                    zorder=20)

    if annotate_aperiodic:

        # Annotate Aperiodic Offset
        #   Add a line to indicate offset, without adjusting plot limits below it
        ax.set_autoscaley_on(False)
        ax.plot([freqs[0], freqs[0]],
                [ax.get_ylim()[0], fm.fooofed_spectrum_[0]],
                color=PLT_COLORS['aperiodic'],
                linewidth=lw2,
                alpha=0.5)
        ax.annotate('Offset',
                    xy=(freqs[0] + bug_buff, fm.power_spectrum[0] - y_buff1),
                    xytext=(freqs[0] - x_buff1,
                            fm.power_spectrum[0] - y_buff1),
                    verticalalignment='center',
                    horizontalalignment='center',
                    arrowprops=dict(facecolor=PLT_COLORS['aperiodic'],
                                    shrink=shrink),
                    color=PLT_COLORS['aperiodic'],
                    fontsize=fontsize)

        # Annotate Aperiodic Knee
        if fm.aperiodic_mode == 'knee':

            # Find the knee frequency point to annotate
            knee_freq = compute_knee_frequency(
                fm.get_params('aperiodic', 'knee'),
                fm.get_params('aperiodic', 'exponent'))
            knee_freq = np.log10(knee_freq) if plt_log else knee_freq
            knee_pow = fm.power_spectrum[nearest_ind(freqs, knee_freq)]

            # Add a dot to the plot indicating the knee frequency
            ax.plot(knee_freq,
                    knee_pow,
                    'o',
                    color=PLT_COLORS['aperiodic'],
                    ms=ms1 * 1.5,
                    alpha=0.7)

            ax.annotate('Knee',
                        xy=(knee_freq, knee_pow),
                        xytext=(knee_freq - x_buff2, knee_pow - y_buff1),
                        verticalalignment='center',
                        arrowprops=dict(facecolor=PLT_COLORS['aperiodic'],
                                        shrink=shrink),
                        color=PLT_COLORS['aperiodic'],
                        fontsize=fontsize)

        # Annotate Aperiodic Exponent
        mid_ind = int(len(freqs) / 2)
        ax.annotate('Exponent',
                    xy=(freqs[mid_ind], fm.power_spectrum[mid_ind]),
                    xytext=(freqs[mid_ind] - x_buff2,
                            fm.power_spectrum[mid_ind] - y_buff1),
                    verticalalignment='center',
                    arrowprops=dict(facecolor=PLT_COLORS['aperiodic'],
                                    shrink=shrink),
                    color=PLT_COLORS['aperiodic'],
                    fontsize=fontsize)

    # Apply style to plot & tune grid styling
    check_n_style(plot_style, ax, plt_log, True)
    ax.grid(True, alpha=0.5)

    # Add labels to plot in the legend
    da_patch = mpatches.Patch(color=PLT_COLORS['data'], label='Original Data')
    ap_patch = mpatches.Patch(color=PLT_COLORS['aperiodic'],
                              label='Aperiodic Parameters')
    pe_patch = mpatches.Patch(color=PLT_COLORS['periodic'],
                              label='Peak Parameters')
    mo_patch = mpatches.Patch(color=PLT_COLORS['model'], label='Full Model')

    handles = [
        da_patch, ap_patch if annotate_aperiodic else None,
        pe_patch if annotate_peaks else None, mo_patch
    ]
    handles = [el for el in handles if el is not None]

    ax.legend(handles=handles, handlelength=1, fontsize='x-large')