Exemplo n.º 1
0
def plot_oscillations(alphas, save_fig=False, save_name=None):
    """Plot a group of (flattened) oscillation definitions."""

    n_subjs = alphas.shape[0]

    # Initialize figure
    fig, ax = plt.subplots(figsize=[6, 6])

    # Get frequency axis (x-axis)
    fs = np.arange(6, 16, 0.1)

    # Create the oscillation model from parameters
    osc_psds = np.empty(shape=[n_subjs, len(fs)])
    for ind, alpha in enumerate(alphas):
        osc_psds[ind, :] = gaussian_function(fs, *alphas[ind, :])

    # Plot each individual subject
    for ind in range(n_subjs):
        ax.plot(fs, osc_psds[ind, :], alpha=0.3, linewidth=1.5)

    # Plot the average across all subjects
    avg = np.nanmean(osc_psds, 0)
    ax.plot(fs, avg, 'k', linewidth=3)

    ax.set_ylim([0, 1.75])

    ax.set_xlabel('Frequency', {'fontsize': 14})
    ax.set_ylabel('Power', {'fontsize': 14})

    # Set tick font-sizes
    plt.setp(ax.get_xticklabels(), fontsize=12)
    plt.setp(ax.get_yticklabels(), fontsize=12)

    _set_lr_spines(ax, 2)
    _save_fig(save_fig, save_name)
Exemplo n.º 2
0
def plot_peak_iter(fm):
    """Plots a series of plots illustrating the peak search from a flattened spectrum.

    Parameters
    ----------
    fm : FOOOF() object
        FOOOF object, with model fit, data and settings available.
    """

    flatspec = fm._spectrum_flat
    n_gauss = fm._gaussian_params.shape[0]
    ylims = [
        min(flatspec) - 0.1 * np.abs(min(flatspec)),
        max(flatspec) + 0.1 * max(flatspec)
    ]

    for ind in range(n_gauss + 1):

        # Note: this forces to create a new plotting axes per iteration
        ax = check_ax(None)

        plot_spectrum(fm.freqs,
                      flatspec,
                      linewidth=2.0,
                      label='Flattened Spectrum',
                      ax=ax)
        plot_spectrum(fm.freqs,
                      [fm.peak_threshold * np.std(flatspec)] * len(fm.freqs),
                      color='orange',
                      linestyle='dashed',
                      label='Relative Threshold',
                      ax=ax)
        plot_spectrum(fm.freqs, [fm.min_peak_height] * len(fm.freqs),
                      color='red',
                      linestyle='dashed',
                      label='Absolute Threshold',
                      ax=ax)

        maxi = np.argmax(flatspec)
        ax.plot(fm.freqs[maxi], flatspec[maxi], '.', markersize=24)

        ax.set_ylim(ylims)
        ax.set_title('Iteration #' + str(ind + 1), fontsize=16)

        if ind < n_gauss:

            gauss = gaussian_function(fm.freqs, *fm._gaussian_params[ind, :])
            plot_spectrum(fm.freqs,
                          gauss,
                          label='Gaussian Fit',
                          linestyle=':',
                          linewidth=2.0,
                          ax=ax)

            flatspec = flatspec - gauss
Exemplo n.º 3
0
def gen_peaks(xs, gauss_params):
    """Generate peaks values, from parameter definition.

    Parameters
    ----------
    xs : 1d array
        Frequency vector to create peak values from.
    gauss_params : list of list of float
        Parameters to create peaks. Length of n_peaks * 3.

    Returns
    -------
    1d array
        Generated background values.
    """

    return gaussian_function(xs, *gauss_params)
Exemplo n.º 4
0
Arquivo: gen.py Projeto: vlitvak/fooof
def gen_peaks(freqs, gaussian_params):
    """Generate peaks values, from parameter definition.

    Parameters
    ----------
    freqs : 1d array
        Frequency vector to create peak values from.
    gaussian_params : list of float
        Parameters to create peaks. Length of n_peaks * 3.

    Returns
    -------
    peak_vals : 1d array
        Generated aperiodic values.
    """

    peak_vals = gaussian_function(freqs, *gaussian_params)

    return peak_vals
Exemplo n.º 5
0
def plot_oscillations(alphas, save_fig=False, save_name=None):
    """Plot a group of (flattened) oscillation definitions.

    alphas: 1d array
    save_fig: bool
    save_name: str
    Note: plot taken & adapated from EEGFOOOF.
    """

    n_subjs = alphas.shape[0]

    # Initialize figure
    fig, ax = plt.subplots(figsize=[6, 6])

    # Get frequency axis (x-axis)
    fs = np.arange(4, 18, 0.1)

    # Create the oscillation model from parameters
    osc_psds = np.empty(shape=[n_subjs, len(fs)])
    for ind, alpha in enumerate(alphas):
        osc_psds[ind, :] = gaussian_function(fs, *alphas[ind, :])

    # Plot each individual subject
    for ind in range(n_subjs):
        ax.plot(fs, osc_psds[ind, :], alpha=0.3, linewidth=1.5)

    # Plot the average across all subjects
    avg = np.nanmean(osc_psds, 0)
    ax.plot(fs, avg, 'k', linewidth=3)

    ax.set_ylim([0, 2.2])

    ax.set_xlabel('Frequency')
    ax.set_ylabel('Power')

    # Set the top and right side frame & ticks off
    _set_lr_spines(ax, 2)
    _set_tick_sizes(ax)
    _set_label_sizes(ax)

    save_figure(save_fig, save_name)
Exemplo n.º 6
0
    def _fit_peaks(self, flat_iter):
        """Iteratively fit peaks to flattened spectrum.

        Parameters
        ----------
        flat_iter : 1d array
            Flattened power spectrum values.

        Returns
        -------
        gaussian_params : 2d array
            Parameters that define the gaussian fit(s).
            Each row is a gaussian, as [mean, height, standard deviation].
        """

        # Initialize matrix of guess parameters for gaussian fitting.
        guess = np.empty([0, 3])

        # Find peak: Loop through, finding a candidate peak, and fitting with a guass gaussian.
        #  Stopping procedure based on either # of peaks, or the relative or absolute height thresholds.
        while len(guess) < self.max_n_peaks:

            # Find candidate peak - the maximum point of the flattened spectrum.
            max_ind = np.argmax(flat_iter)
            max_height = flat_iter[max_ind]

            # Stop searching for peaks peaks once drops below height threshold.
            if max_height <= self.peak_threshold * np.std(flat_iter):
                break

            # Set the guess parameters for gaussian fitting - mean and height.
            guess_freq = self.freqs[max_ind]
            guess_height = max_height

            # Halt fitting process if candidate peak drops below minimum height.
            if not guess_height > self.min_peak_height:
                break

            # Data-driven first guess at standard deviation
            #  Find half height index on each side of the center frequency.
            half_height = 0.5 * max_height
            le_ind = next((x for x in range(max_ind - 1, 0, -1)
                           if flat_iter[x] <= half_height), None)
            ri_ind = next((x for x in range(max_ind + 1, len(flat_iter), 1)
                           if flat_iter[x] <= half_height), None)

            # Keep bandwidth estimation from the shortest side.
            #  We grab shortest to avoid estimating very large std from overalapping peaks.
            # Grab the shortest side, ignoring a side if the half max was not found.
            #  Note: will fail if both le & ri ind's end up as None (probably shouldn't happen).
            shortest_side = min([
                abs(ind - max_ind) for ind in [le_ind, ri_ind]
                if ind is not None
            ])

            # Estimate std from FWHM. Calculate FWHM, converting to Hz, get guess std from FWHM
            fwhm = shortest_side * 2 * self.freq_res
            guess_std = fwhm / (2 * np.sqrt(2 * np.log(2)))

            # Check that guess std isn't outside preset std limits; restrict if so.
            #  Note: without this, curve_fitting fails if given guess > or < bounds.
            if guess_std < self._gauss_std_limits[0]:
                guess_std = self._gauss_std_limits[0]
            if guess_std > self._gauss_std_limits[1]:
                guess_std = self._gauss_std_limits[1]

            # Collect guess parameters.
            guess = np.vstack((guess, (guess_freq, guess_height, guess_std)))

            # Subtract best-guess gaussian.
            peak_gauss = gaussian_function(self.freqs, guess_freq,
                                           guess_height, guess_std)
            flat_iter = flat_iter - peak_gauss

        # Check peaks based on edges, and on overlap
        #  Drop any that violate requirements.
        guess = self._drop_peak_cf(guess)
        guess = self._drop_peak_overlap(guess)

        # If there are peak guesses, fit the peaks, and sort results.
        if len(guess) > 0:
            gaussian_params = self._fit_peak_guess(guess)
            gaussian_params = gaussian_params[gaussian_params[:, 0].argsort()]
        else:
            gaussian_params = np.empty([0, 3])

        return gaussian_params
Exemplo n.º 7
0
def plot_annotated_peak_search(fm, plot_style=style_spectrum_plot):
    """Plot a series of plots illustrating the peak search from a flattened spectrum.

    Parameters
    ----------
    fm : FOOOF
        FOOOF object, with model fit, data and settings available.
    plot_style : callable, optional, default: style_spectrum_plot
        A function to call to apply styling & aesthetics to the plots.
    """

    # Recalculate the initial aperiodic fit and flattened spectrum that
    #   is the same as the one that is used in the peak fitting procedure
    flatspec = fm.power_spectrum - \
        gen_aperiodic(fm.freqs, fm._robust_ap_fit(fm.freqs, fm.power_spectrum))

    # Calculate ylims of the plot that are scaled to the range of the data
    ylims = [
        min(flatspec) - 0.1 * np.abs(min(flatspec)),
        max(flatspec) + 0.1 * max(flatspec)
    ]

    # Loop through the iterative search for each peak
    for ind in range(fm.n_peaks_ + 1):

        # This forces the creation of a new plotting axes per iteration
        ax = check_ax(None, PLT_FIGSIZES['spectral'])

        plot_spectrum(fm.freqs,
                      flatspec,
                      ax=ax,
                      plot_style=None,
                      label='Flattened Spectrum',
                      color=PLT_COLORS['data'],
                      linewidth=2.5)
        plot_spectrum(fm.freqs,
                      [fm.peak_threshold * np.std(flatspec)] * len(fm.freqs),
                      ax=ax,
                      plot_style=None,
                      label='Relative Threshold',
                      color='orange',
                      linewidth=2.5,
                      linestyle='dashed')
        plot_spectrum(fm.freqs, [fm.min_peak_height] * len(fm.freqs),
                      ax=ax,
                      plot_style=None,
                      label='Absolute Threshold',
                      color='red',
                      linewidth=2.5,
                      linestyle='dashed')

        maxi = np.argmax(flatspec)
        ax.plot(fm.freqs[maxi],
                flatspec[maxi],
                '.',
                color=PLT_COLORS['periodic'],
                alpha=0.75,
                markersize=30)

        ax.set_ylim(ylims)
        ax.set_title('Iteration #' + str(ind + 1), fontsize=16)

        if ind < fm.n_peaks_:

            gauss = gaussian_function(fm.freqs, *fm.gaussian_params_[ind, :])
            plot_spectrum(fm.freqs,
                          gauss,
                          ax=ax,
                          plot_style=None,
                          label='Gaussian Fit',
                          color=PLT_COLORS['periodic'],
                          linestyle=':',
                          linewidth=3.0)

            flatspec = flatspec - gauss

        check_n_style(plot_style, ax, False, True)
Exemplo n.º 8
0
def plot_peak_fits(peaks,
                   freq_range=None,
                   colors=None,
                   labels=None,
                   ax=None,
                   **plot_kwargs):
    """Plot reconstructions of model peak fits.

    Parameters
    ----------
    peaks : 2d array
        Peak data. Each row is a peak, as [CF, PW, BW].
    freq_range : list of [float, float] , optional
        The frequency range to plot the peak fits across, as [f_min, f_max].
        If not provided, defaults to +/- 4 around given peak center frequencies.
    colors : str or list of str, optional
        Color(s) to plot data.
    labels : list of str, optional
        Label(s) for plotted data, to be added in a legend.
    ax : matplotlib.Axes, optional
        Figure axes upon which to plot.
    **plot_kwargs
        Keyword arguments to pass into the plot call.
    """

    ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['params']))

    if isinstance(peaks, list):

        if not colors:
            colors = cycle(plt.rcParams['axes.prop_cycle'].by_key()['color'])

        recursive_plot(
            peaks,
            plot_function=plot_peak_fits,
            ax=ax,
            freq_range=tuple(freq_range) if freq_range else freq_range,
            colors=colors,
            labels=labels,
            **plot_kwargs)

    else:

        if not freq_range:

            # Extract all the CF values, excluding any NaNs
            cfs = peaks[~np.isnan(peaks[:, 0]), 0]

            # Define the frequency range as +/- buffer around the data range
            #   This also doesn't let the plot range drop below 0
            f_buffer = 4
            freq_range = [
                cfs.min() - f_buffer if cfs.min() - f_buffer > 0 else 0,
                cfs.max() + f_buffer
            ]

        # Create the frequency axis, which will be the plot x-axis
        freqs = gen_freqs(freq_range, 0.1)

        colors = colors[0] if isinstance(colors, list) else colors

        avg_vals = np.zeros(shape=[len(freqs)])

        for peak_params in peaks:

            # Create & plot the peak model from parameters
            peak_vals = gaussian_function(freqs, *peak_params)
            ax.plot(freqs, peak_vals, color=colors, alpha=0.35, linewidth=1.25)

            # Collect a running average average peaks
            avg_vals = np.nansum(np.vstack([avg_vals, peak_vals]), axis=0)

        # Plot the average across all components
        avg = avg_vals / peaks.shape[0]
        avg_color = 'black' if not colors else colors
        ax.plot(freqs, avg, color=avg_color, linewidth=3.75, label=labels)

    # Add axis labels
    ax.set_xlabel('Frequency')
    ax.set_ylabel('log(Power)')

    # Set plot limits
    ax.set_xlim(freq_range)
    ax.set_ylim([0, ax.get_ylim()[1]])

    # Apply plot style
    style_param_plot(ax)