示例#1
0
def compute_pointwise_error_fg(fg,
                               plot_errors=True,
                               return_errors=False,
                               **plt_kwargs):
    """Calculate the frequency by frequency error of model fits from a FOOOFGroup object.

    Parameters
    ----------
    fg : FOOOFGroup
        Object containing the data and models.
    plot_errors : bool, optional, default: True
        Whether to plot the errors across frequencies.
    return_errors : bool, optional, default: False
        Whether to return the calculated errors.
    **plt_kwargs
        Keyword arguments to be passed to the plot function.

    Returns
    -------
    errors : 2d array
        Calculated values of the difference between the data and the models.
        Only returned if `return_errors` is True.

    Raises
    ------
    NoDataError
        If there is no data available to calculate model errors from.
    NoModelError
        If there are no model results available to calculate model errors from.
    """

    if not np.any(fg.power_spectra):
        raise NoDataError(
            "Data must be available in the object to calculate errors.")
    if not fg.has_model:
        raise NoModelError("No model is available to use, can not proceed.")

    errors = np.zeros_like(fg.power_spectra)

    for ind, (res, data) in enumerate(zip(fg, fg.power_spectra)):

        model = gen_model(fg.freqs, res.aperiodic_params, res.gaussian_params)
        errors[ind, :] = np.abs(model - data)

    mean = np.mean(errors, 0)
    standard_dev = np.std(errors, 0)

    if plot_errors:
        plot_spectral_error(fg.freqs, mean, standard_dev, **plt_kwargs)

    if return_errors:
        return errors
示例#2
0
文件: fg.py 项目: ryanhammonds/fooof
def plot_fg(fg, save_fig=False, file_name=None, file_path=None):
    """Plot a figure with subplots visualizing the parameters from a FOOOFGroup object.

    Parameters
    ----------
    fg : FOOOFGroup
        Object containing results from fitting a group of power spectra.
    save_fig : bool, optional, default: False
        Whether to save out a copy of the plot.
    file_name : str, optional
        Name to give the saved out file.
    file_path : str, optional
        Path to directory to save to. If None, saves to current directory.

    Raises
    ------
    NoModelError
        If the FOOOF object does not have model fit data available to plot.
    """

    if not fg.has_model:
        raise NoModelError(
            "No model fit results are available, can not proceed.")

    fig = plt.figure(figsize=PLT_FIGSIZES['group'])
    gs = gridspec.GridSpec(2,
                           2,
                           wspace=0.4,
                           hspace=0.25,
                           height_ratios=[1, 1.2])

    # Aperiodic parameters plot
    ax0 = plt.subplot(gs[0, 0])
    plot_fg_ap(fg, ax0)

    # Goodness of fit plot
    ax1 = plt.subplot(gs[0, 1])
    plot_fg_gf(fg, ax1)

    # Center frequencies plot
    ax2 = plt.subplot(gs[1, :])
    plot_fg_peak_cens(fg, ax2)

    if save_fig:
        if not file_name:
            raise ValueError(
                "Input 'file_name' is required to save out the plot.")
        plt.savefig(fpath(file_path, fname(file_name, 'png')))
示例#3
0
def compute_pointwise_error_fm(fm,
                               plot_errors=True,
                               return_errors=False,
                               **plt_kwargs):
    """Calculate the frequency by frequency error of a model fit from a FOOOF object.

    Parameters
    ----------
    fm : FOOOF
        Object containing the data and model.
    plot_errors : bool, optional, default: True
        Whether to plot the errors across frequencies.
    return_errors : bool, optional, default: False
        Whether to return the calculated errors.
    **plt_kwargs
        Keyword arguments to be passed to the plot function.

    Returns
    -------
    errors : 1d array
        Calculated values of the difference between the data and the model.
        Only returned if `return_errors` is True.

    Raises
    ------
    NoDataError
        If there is no data available to calculate model error from.
    NoModelError
        If there are no model results available to calculate model error from.
    """

    if not fm.has_data:
        raise NoDataError(
            "Data must be available in the object to calculate errors.")
    if not fm.has_model:
        raise NoModelError("No model is available to use, can not proceed.")

    errors = compute_pointwise_error(fm.fooofed_spectrum_, fm.power_spectrum)

    if plot_errors:
        plot_spectral_error(fm.freqs, errors, **plt_kwargs)

    if return_errors:
        return errors
示例#4
0
def plot_annotated_model(fm,
                         plt_log=False,
                         annotate_peaks=True,
                         annotate_aperiodic=True,
                         ax=None,
                         plot_style=style_spectrum_plot):
    """Plot a an annotated power spectrum and model, from a FOOOF object.

    Parameters
    ----------
    fm : FOOOF
        FOOOF object, with model fit, data and settings available.
    plt_log : boolean, optional, default: False
        Whether to plot the frequency values in log10 spacing.
    ax : matplotlib.Axes, optional
        Figure axes upon which to plot.
    plot_style : callable, optional, default: style_spectrum_plot
        A function to call to apply styling & aesthetics to the plots.

    Raises
    ------
    NoModelError
        If there are no model results available to plot.
    """

    # Check that model is available
    if not fm.has_model:
        raise NoModelError("No model is available to plot, can not proceed.")

    # Settings
    fontsize = 15
    lw1 = 4.0
    lw2 = 3.0
    ms1 = 12

    # Create the baseline figure
    ax = check_ax(ax, PLT_FIGSIZES['spectral'])
    fm.plot(plot_peaks='dot-shade-width',
            plt_log=plt_log,
            ax=ax,
            plot_style=None,
            data_kwargs={
                'lw': lw1,
                'alpha': 0.6
            },
            aperiodic_kwargs={
                'lw': lw1,
                'zorder': 10
            },
            model_kwargs={
                'lw': lw1,
                'alpha': 0.5
            },
            peak_kwargs={
                'dot': {
                    'color': PLT_COLORS['periodic'],
                    'ms': ms1,
                    'lw': lw2
                },
                'shade': {
                    'color': PLT_COLORS['periodic']
                },
                'width': {
                    'color': PLT_COLORS['periodic'],
                    'alpha': 0.75,
                    'lw': lw2
                }
            })

    # Get freqs for plotting, and convert to log if needed
    freqs = fm.freqs if not plt_log else np.log10(fm.freqs)

    ## Buffers: for spacing things out on the plot (scaled by plot values)
    x_buff1 = max(freqs) * 0.1
    x_buff2 = max(freqs) * 0.25
    y_buff1 = 0.15 * np.ptp(ax.get_ylim())
    shrink = 0.1

    # There is a bug in annotations for some perpendicular lines, so add small offset
    #   See: https://github.com/matplotlib/matplotlib/issues/12820. Fixed in 3.2.1.
    bug_buff = 0.000001

    if annotate_peaks:

        # Extract largest peak, to annotate, grabbing gaussian params
        gauss = get_band_peak_fm(fm,
                                 fm.freq_range,
                                 attribute='gaussian_params')

        peak_ctr, peak_hgt, peak_wid = gauss
        bw_freqs = [
            peak_ctr - 0.5 * compute_fwhm(peak_wid),
            peak_ctr + 0.5 * compute_fwhm(peak_wid)
        ]

        if plt_log:
            peak_ctr = np.log10(peak_ctr)
            bw_freqs = np.log10(bw_freqs)

        peak_top = fm.power_spectrum[nearest_ind(freqs, peak_ctr)]

        # Annotate Peak CF
        ax.annotate('Center Frequency',
                    xy=(peak_ctr, peak_top),
                    xytext=(peak_ctr, peak_top + np.abs(0.6 * peak_hgt)),
                    verticalalignment='center',
                    horizontalalignment='center',
                    arrowprops=dict(facecolor=PLT_COLORS['periodic'],
                                    shrink=shrink),
                    color=PLT_COLORS['periodic'],
                    fontsize=fontsize)

        # Annotate Peak PW
        ax.annotate('Power',
                    xy=(peak_ctr, peak_top - 0.3 * peak_hgt),
                    xytext=(peak_ctr + x_buff1, peak_top - 0.3 * peak_hgt),
                    verticalalignment='center',
                    arrowprops=dict(facecolor=PLT_COLORS['periodic'],
                                    shrink=shrink),
                    color=PLT_COLORS['periodic'],
                    fontsize=fontsize)

        # Annotate Peak BW
        bw_buff = (peak_ctr - bw_freqs[0]) / 2
        ax.annotate('Bandwidth',
                    xy=(peak_ctr - bw_buff + bug_buff,
                        peak_top - (0.5 * peak_hgt)),
                    xytext=(peak_ctr - bw_buff, peak_top - (1.5 * peak_hgt)),
                    verticalalignment='center',
                    horizontalalignment='right',
                    arrowprops=dict(facecolor=PLT_COLORS['periodic'],
                                    shrink=shrink),
                    color=PLT_COLORS['periodic'],
                    fontsize=fontsize,
                    zorder=20)

    if annotate_aperiodic:

        # Annotate Aperiodic Offset
        #   Add a line to indicate offset, without adjusting plot limits below it
        ax.set_autoscaley_on(False)
        ax.plot([freqs[0], freqs[0]],
                [ax.get_ylim()[0], fm.fooofed_spectrum_[0]],
                color=PLT_COLORS['aperiodic'],
                linewidth=lw2,
                alpha=0.5)
        ax.annotate('Offset',
                    xy=(freqs[0] + bug_buff, fm.power_spectrum[0] - y_buff1),
                    xytext=(freqs[0] - x_buff1,
                            fm.power_spectrum[0] - y_buff1),
                    verticalalignment='center',
                    horizontalalignment='center',
                    arrowprops=dict(facecolor=PLT_COLORS['aperiodic'],
                                    shrink=shrink),
                    color=PLT_COLORS['aperiodic'],
                    fontsize=fontsize)

        # Annotate Aperiodic Knee
        if fm.aperiodic_mode == 'knee':

            # Find the knee frequency point to annotate
            knee_freq = compute_knee_frequency(
                fm.get_params('aperiodic', 'knee'),
                fm.get_params('aperiodic', 'exponent'))
            knee_freq = np.log10(knee_freq) if plt_log else knee_freq
            knee_pow = fm.power_spectrum[nearest_ind(freqs, knee_freq)]

            # Add a dot to the plot indicating the knee frequency
            ax.plot(knee_freq,
                    knee_pow,
                    'o',
                    color=PLT_COLORS['aperiodic'],
                    ms=ms1 * 1.5,
                    alpha=0.7)

            ax.annotate('Knee',
                        xy=(knee_freq, knee_pow),
                        xytext=(knee_freq - x_buff2, knee_pow - y_buff1),
                        verticalalignment='center',
                        arrowprops=dict(facecolor=PLT_COLORS['aperiodic'],
                                        shrink=shrink),
                        color=PLT_COLORS['aperiodic'],
                        fontsize=fontsize)

        # Annotate Aperiodic Exponent
        mid_ind = int(len(freqs) / 2)
        ax.annotate('Exponent',
                    xy=(freqs[mid_ind], fm.power_spectrum[mid_ind]),
                    xytext=(freqs[mid_ind] - x_buff2,
                            fm.power_spectrum[mid_ind] - y_buff1),
                    verticalalignment='center',
                    arrowprops=dict(facecolor=PLT_COLORS['aperiodic'],
                                    shrink=shrink),
                    color=PLT_COLORS['aperiodic'],
                    fontsize=fontsize)

    # Apply style to plot & tune grid styling
    check_n_style(plot_style, ax, plt_log, True)
    ax.grid(True, alpha=0.5)

    # Add labels to plot in the legend
    da_patch = mpatches.Patch(color=PLT_COLORS['data'], label='Original Data')
    ap_patch = mpatches.Patch(color=PLT_COLORS['aperiodic'],
                              label='Aperiodic Parameters')
    pe_patch = mpatches.Patch(color=PLT_COLORS['periodic'],
                              label='Peak Parameters')
    mo_patch = mpatches.Patch(color=PLT_COLORS['model'], label='Full Model')

    handles = [
        da_patch, ap_patch if annotate_aperiodic else None,
        pe_patch if annotate_peaks else None, mo_patch
    ]
    handles = [el for el in handles if el is not None]

    ax.legend(handles=handles, handlelength=1, fontsize='x-large')
示例#5
0
def gen_results_fg_str(fg, concise=False):
    """Generate a string representation of group fit results.

    Parameters
    ----------
    fg : FOOOFGroup
        Object to access results from.
    concise : bool, optional, default: False
        Whether to print the report in concise mode.

    Returns
    -------
    output : str
        Formatted string of results.

    Raises
    ------
    NoModelError
        If no model fit data is available to report.
    """

    if not fg.has_model:
        raise NoModelError(
            "No model fit results are available, can not proceed.")

    # Extract all the relevant data for printing
    n_peaks = len(fg.get_params('peak_params'))
    r2s = fg.get_params('r_squared')
    errors = fg.get_params('error')
    exps = fg.get_params('aperiodic_params', 'exponent')
    kns = fg.get_params('aperiodic_params', 'knee') \
        if fg.aperiodic_mode == 'knee' else np.array([0])

    # Check if there are any power spectra that failed to fit
    n_failed = sum(np.isnan(exps))

    str_lst = [

        # Header
        '=',
        '',
        ' FOOOF - GROUP RESULTS',
        '',

        # Group information
        'Number of power spectra in the Group: {}'.format(len(fg.group_results)),
        *[el for el in ['{} power spectra failed to fit'.format(n_failed)] if n_failed],
        '',

        # Frequency range and resolution
        'The model was run on the frequency range {} - {} Hz'.format(
            int(np.floor(fg.freq_range[0])), int(np.ceil(fg.freq_range[1]))),
        'Frequency Resolution is {:1.2f} Hz'.format(fg.freq_res),
        '',

        # Aperiodic parameters - knee fit status, and quick exponent description
        'Power spectra were fit {} a knee.'.format(\
            'with' if fg.aperiodic_mode == 'knee' else 'without'),
        '',
        'Aperiodic Fit Values:',
        *[el for el in ['    Knees - Min: {:6.2f}, Max: {:6.2f}, Mean: {:5.2f}'
                        .format(np.nanmin(kns), np.nanmax(kns), np.nanmean(kns)),
                       ] if fg.aperiodic_mode == 'knee'],
        'Exponents - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}'
        .format(np.nanmin(exps), np.nanmax(exps), np.nanmean(exps)),
        '',

        # Peak Parameters
        'In total {} peaks were extracted from the group'
        .format(n_peaks),
        '',

        # Goodness if fit
        'Goodness of fit metrics:',
        '   R2s -  Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}'
        .format(np.nanmin(r2s), np.nanmax(r2s), np.nanmean(r2s)),
        'Errors -  Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}'
        .format(np.nanmin(errors), np.nanmax(errors), np.nanmean(errors)),
        '',

        # Footer
        '='
    ]

    output = _format(str_lst, concise)

    return output
示例#6
0
def average_fg(fg, bands, avg_method='mean', regenerate=True):
    """Average across model fits in a FOOOFGroup object.

    Parameters
    ----------
    fg : FOOOFGroup
        Object with model fit results to average across.
    bands : Bands
        Bands object that defines the frequency bands to collapse peaks across.
    avg : {'mean', 'median'}
        Averaging function to use.
    regenerate : bool, optional, default: True
        Whether to regenerate the model for the averaged parameters.

    Returns
    -------
    fm : FOOOF
        Object containing the average model results.

    Raises
    ------
    ValueError
        If the requested averaging method is not understood.
    NoModelError
        If there are no model fit results available to average across.
    """

    if avg_method not in ['mean', 'median']:
        raise ValueError("Requested average method not understood.")
    if not fg.has_model:
        raise NoModelError(
            "No model fit results are available, can not proceed.")

    if avg_method == 'mean':
        avg_func = np.nanmean
    elif avg_method == 'median':
        avg_func = np.nanmedian

    # Aperiodic parameters: extract & average
    ap_params = avg_func(fg.get_params('aperiodic_params'), 0)

    # Periodic parameters: extract & average
    peak_params = []
    gauss_params = []

    for band_def in bands.definitions:

        peaks = get_band_peak_fg(fg, band_def, attribute='peak_params')
        gauss = get_band_peak_fg(fg, band_def, attribute='gaussian_params')

        # Check if there are any extracted peaks - if not, don't add
        #   Note that we only check peaks, but gauss should be the same
        if not np.all(np.isnan(peaks)):
            peak_params.append(avg_func(peaks, 0))
            gauss_params.append(avg_func(gauss, 0))

    peak_params = np.array(peak_params)
    gauss_params = np.array(gauss_params)

    # Goodness of fit measures: extract & average
    r2 = avg_func(fg.get_params('r_squared'))
    error = avg_func(fg.get_params('error'))

    # Collect all results together, to be added to FOOOF object
    results = FOOOFResults(ap_params, peak_params, r2, error, gauss_params)

    # Create the new FOOOF object, with settings, data info & results
    fm = FOOOF()
    fm.add_settings(fg.get_settings())
    fm.add_meta_data(fg.get_meta_data())
    fm.add_results(results)

    # Generate the average model from the parameters
    if regenerate:
        fm._regenerate_model()

    return fm
示例#7
0
文件: group.py 项目: anc211/fooof
    def get_params(self, name, col=None):
        """Return model fit parameters for specified feature(s).

        Parameters
        ----------
        name : {'aperiodic_params', 'peak_params', 'gaussian_params', 'error', 'r_squared'}
            Name of the data field to extract across the group.
        col : {'CF', 'PW', 'BW', 'offset', 'knee', 'exponent'} or int, optional
            Column name / index to extract from selected data, if requested.
            Only used for name of {'aperiodic_params', 'peak_params', 'gaussian_params'}.

        Returns
        -------
        out : ndarray
            Requested data.

        Raises
        ------
        NoModelError
            If there are no model fit results available.
        ValueError
            If the input for the `col` input is not understood.

        Notes
        -----
        For further description of the data you can extract, check the FOOOFResults documentation.
        """

        if not self.has_model:
            raise NoModelError(
                "No model fit results are available, can not proceed.")

        # Allow for shortcut alias, without adding `_params`
        if name in ['aperiodic', 'peak', 'gaussian']:
            name = name + '_params'

        # If col specified as string, get mapping back to integer
        if isinstance(col, str):
            col = get_indices(self.aperiodic_mode)[col]
        elif isinstance(col, int):
            if col not in [0, 1, 2]:
                raise ValueError("Input value for `col` not valid.")

        # Pull out the requested data field from the group data
        # As a special case, peak_params are pulled out in a way that appends
        #  an extra column, indicating which FOOOF run each peak comes from
        if name in ('peak_params', 'gaussian_params'):
            out = np.array([
                np.insert(getattr(data, name), 3, index, axis=1)
                for index, data in enumerate(self.group_results)
            ])
            # This updates index to grab selected column, and the last column
            #  This last column is the 'index' column (FOOOF object source)
            if col is not None:
                col = [col, -1]
        else:
            out = np.array(
                [getattr(data, name) for data in self.group_results])

        # Some data can end up as a list of separate arrays
        #   If so, concatenate it all into one 2d array
        if isinstance(out[0], np.ndarray):
            out = np.concatenate([arr.reshape(1, len(arr)) \
                if arr.ndim == 1 else arr for arr in out], 0)

        # Select out a specific column, if requested
        if col is not None:
            out = out[:, col]

        return out