Example #1
0
def plot_spectrum_averages(subject,
                           channel_groups,
                           name,
                           log_transformed=True):
    """ Plots spectrum averages.
    """

    subject_name = subject.name

    spectrum = subject.spectrum.get(name)

    data = spectrum.content
    freqs = spectrum.freqs
    ch_names = spectrum.ch_names
    info = spectrum.info

    colors = color_cycle(len(data))
    conditions = spectrum.content.keys()

    averages = {}
    for key, psd in sorted(data.items()):

        data_labels, averaged_data = average_to_channel_groups(
            psd, info, ch_names, channel_groups)

        for label_idx, label in enumerate(data_labels):
            if not label in averages:
                averages[label] = []
            averages[label].append((key, averaged_data[label_idx]))

    ch_types = sorted(set([label[0] for label in averages.keys()]))
    for ch_type in ch_types:

        ch_groups = sorted(
            [label[1] for label in averages.keys() if label[0] == ch_type])

        def plot_fun(ax_idx, ax):
            ch_group = ch_groups[ax_idx]
            ax.set_title(ch_group)
            ax.set_xlabel('Frequency (Hz)')
            ax.set_ylabel('Power ({})'.format(
                get_power_unit(ch_type, log_transformed)))

            for color_idx, (key, curve) in enumerate(averages[(ch_type,
                                                               ch_group)]):
                if log_transformed:
                    curve = 10 * np.log10(curve)
                ax.plot(freqs, curve, color=colors[color_idx])

        title = ' '.join([name, ch_type])
        legend = list(zip(conditions, colors))
        create_channel_average_plot(len(ch_groups), plot_fun, title, legend)

    plt.show()
Example #2
0
def plot_evoked_topo(evoked, ch_type):
    """ Plots evoked time courses arranged as a topography """
    evokeds = []
    labels = []
    for key, evok in sorted(evoked.content.items()):

        info = evok.info
        if ch_type == 'eeg':
            dropped_names = [
                ch_name for ch_idx, ch_name in enumerate(info['ch_names'])
                if ch_idx not in mne.pick_types(info, eeg=True, meg=False)
            ]
        else:
            dropped_names = [
                ch_name for ch_idx, ch_name in enumerate(info['ch_names'])
                if ch_idx not in mne.pick_types(info, eeg=False, meg=True)
            ]

        evok = evok.copy().drop_channels(dropped_names)

        evokeds.append(evok)
        labels.append(key)

    colors = color_cycle(len(evoked.content.keys()))

    # setup legend for subplots
    lines = [
        Line2D([0], [0], color=colors[idx], label=labels[idx])
        for idx in range(len(labels))
    ]

    def onclick(event):
        try:
            # not nice:
            ax = plt.gca()

            channel = plt.getp(ax, 'title')
            ax.set_title('')

            ax.legend(handles=lines, loc='upper right')

            title = ' '.join([evoked.name, channel])
            ax.figure.suptitle(title)
            ax.figure.canvas.set_window_title(title.replace(' ', '_'))
            plt.show()
        except Exception as exc:
            pass

    fig = mne.viz.plot_evoked_topo(evokeds, color=colors)
    fig.canvas.mpl_connect('button_press_event', onclick)
    title = "{0}_{1}".format(evoked.name, ch_type)
    fig.canvas.set_window_title(title)
Example #3
0
def plot_tse_averages(subject, tfr_name, blmode, blstart, blend, tmin, tmax,
                      fmin, fmax, channel_groups):
    """ Plots tse averages.
    """
    meggie_tfr = subject.tfr.get(tfr_name)

    tses = _compute_tse(meggie_tfr, fmin, fmax)

    ch_names = meggie_tfr.ch_names
    info = meggie_tfr.info
    colors = color_cycle(len(tses))
    conditions = meggie_tfr.content.keys()

    averages = {}
    for key, tse in sorted(tses.items()):
        data_labels, averaged_data = average_to_channel_groups(
            tse, info, ch_names, channel_groups)

        times, averaged_data = _crop_and_correct_to_baseline(
            averaged_data, blmode, blstart, blend, tmin, tmax,
            meggie_tfr.times)

        for label_idx, label in enumerate(data_labels):
            if not label in averages:
                averages[label] = []
            averages[label].append((key, times, averaged_data[label_idx]))

    ch_types = sorted(set([label[0] for label in averages.keys()]))
    for ch_type in ch_types:

        ch_groups = sorted(
            [label[1] for label in averages.keys() if label[0] == ch_type])

        def plot_fun(ax_idx, ax):
            ch_group = ch_groups[ax_idx]
            ax.set_title(ch_group)
            ax.set_xlabel('Time (s)')
            ax.set_ylabel('Power ({})'.format(get_power_unit(ch_type, False)))

            for color_idx, (key, times,
                            curve) in enumerate(averages[(ch_type, ch_group)]):
                ax.plot(times, curve, color=colors[color_idx], label=key)

            ax.axhline(0, color='black')
            ax.axvline(0, color='black')

        title = ' '.join([tfr_name, ch_type])
        legend = list(zip(conditions, colors))
        create_channel_average_plot(len(ch_groups), plot_fun, title, legend)

    plt.show()
Example #4
0
def plot_evoked_averages(evoked, channel_groups):
    """ Plots channel averages.
    """
    conditions = evoked.content.keys()
    colors = color_cycle(len(conditions))
    times = evoked.times

    averages = {}
    for key, mne_evoked in sorted(evoked.content.items()):
        data_labels, averaged_data = _create_averages(mne_evoked,
                                                      channel_groups)

        for label_idx, label in enumerate(data_labels):
            if not label in averages:
                averages[label] = []
            averages[label].append((key, averaged_data[label_idx]))

    ch_types = sorted(set([label[0] for label in averages.keys()]))
    for ch_type in ch_types:

        ch_groups = sorted(
            [label[1] for label in averages.keys() if label[0] == ch_type])

        def plot_fun(ax_idx, ax):
            ch_group = ch_groups[ax_idx]
            ax.set_title(ch_group)

            ax.set_xlabel('Time (s)')
            ax.set_ylabel('Amplitude ({})'.format(get_unit(ch_type)))

            for color_idx, (key, curve) in enumerate(averages[(ch_type,
                                                               ch_group)]):
                ax.plot(times, curve, color=colors[color_idx])

            ax.axhline(0, color='black')
            ax.axvline(0, color='black')

        title = ' '.join([evoked.name, ch_type])
        legend = list(zip(conditions, colors))
        create_channel_average_plot(len(ch_groups), plot_fun, title, legend)

    plt.show()
Example #5
0
def plot_tse_topo(subject, tfr_name, blmode, blstart, blend, tmin, tmax, fmin,
                  fmax, ch_type):
    """ Plots a tse topography.
    """
    meggie_tfr = subject.tfr.get(tfr_name)

    tses = _compute_tse(meggie_tfr, fmin, fmax)

    info = meggie_tfr.info
    if ch_type == 'meg':
        picked_channels = [
            ch_name for ch_idx, ch_name in enumerate(info['ch_names'])
            if ch_idx in mne.pick_types(info, meg=True, eeg=False)
        ]
    else:
        picked_channels = [
            ch_name for ch_idx, ch_name in enumerate(info['ch_names'])
            if ch_idx in mne.pick_types(info, meg=False, eeg=True)
        ]
    info = info.copy().pick_channels(picked_channels)

    ch_names = meggie_tfr.ch_names
    colors = color_cycle(len(tses))

    def individual_plot(ax, info_idx, names_idx):
        """
        """
        ch_name = ch_names[names_idx]

        title = ' '.join([tfr_name, ch_name])
        ax.figure.canvas.set_window_title(title.replace(' ', '_'))
        ax.figure.suptitle(title)
        ax.set_title('')

        for color_idx, (key, tse) in enumerate(sorted(tses.items())):
            times, tse = _crop_and_correct_to_baseline(tse, blmode, blstart,
                                                       blend, tmin, tmax,
                                                       meggie_tfr.times)
            ax.plot(times, tse[names_idx], color=colors[color_idx], label=key)
            ax.axhline(0, color='black')
            ax.axvline(0, color='black')

        ax.legend()

        ax.set_xlabel('Time (s)')
        ax.set_ylabel('Power ({})'.format(
            get_power_unit(mne.io.pick.channel_type(info, info_idx), False)))

        plt.show()

    fig = plt.figure()
    for ax, info_idx, names_idx in iterate_topography(fig, info, ch_names,
                                                      individual_plot):

        handles = []
        for color_idx, (key, tse) in enumerate(sorted(tses.items())):
            times, tse = _crop_and_correct_to_baseline(tse, blmode, blstart,
                                                       blend, tmin, tmax,
                                                       meggie_tfr.times)
            handles.append(
                ax.plot(times,
                        tse[names_idx],
                        linewidth=0.2,
                        color=colors[color_idx],
                        label=key)[0])

    fig.legend(handles=handles)
    title = '{0}_{1}'.format(tfr_name, ch_type)
    fig.canvas.set_window_title(title)
    plt.show()
Example #6
0
def plot_spectrum_topo(subject, name, log_transformed=True, ch_type='meg'):
    """ Plots spectrum topography.
    """

    subject_name = subject.name
    spectrum = subject.spectrum.get(name)
    data = spectrum.content
    freqs = spectrum.freqs
    ch_names = spectrum.ch_names
    info = spectrum.info
    if ch_type == 'meg':
        picked_channels = [
            ch_name for ch_idx, ch_name in enumerate(info['ch_names'])
            if ch_idx in mne.pick_types(info, meg=True, eeg=False)
        ]
    else:
        picked_channels = [
            ch_name for ch_idx, ch_name in enumerate(info['ch_names'])
            if ch_idx in mne.pick_types(info, eeg=True, meg=False)
        ]
    info = info.copy().pick_channels(picked_channels)

    colors = color_cycle(len(data))

    def individual_plot(ax, info_idx, names_idx):
        """
        """
        ch_name = ch_names[names_idx]
        for color_idx, (key, psd) in enumerate(sorted(data.items())):

            if log_transformed:
                curve = 10 * np.log10(psd[names_idx])
            else:
                curve = psd[names_idx]

            ax.plot(freqs, curve, color=colors[color_idx], label=key)

        title = ' '.join([name, ch_name])
        ax.figure.canvas.set_window_title(title.replace(' ', '_'))
        ax.figure.suptitle(title)
        ax.set_title('')

        ax.legend()
        ax.set_xlabel('Frequency (Hz)')
        ax.set_ylabel('Power ({})'.format(
            get_power_unit(mne.io.pick.channel_type(info, info_idx),
                           log_transformed)))

        plt.show()

    fig = plt.figure()

    for ax, info_idx, names_idx in iterate_topography(fig, info, ch_names,
                                                      individual_plot):

        handles = []
        for color_idx, (key, psd) in enumerate(sorted(data.items())):

            if log_transformed:
                curve = 10 * np.log10(psd[names_idx])
            else:
                curve = psd[names_idx]

            handles.append(
                ax.plot(curve,
                        color=colors[color_idx],
                        linewidth=0.5,
                        label=key)[0])

    if not handles:
        return

    fig.legend(handles=handles)
    title = '{0}_{1}'.format(name, ch_type)
    fig.canvas.set_window_title(title)
    plt.show()
Example #7
0
def test_color_cycle():
    colors = color_cycle(30)
    assert(type(colors) == list)
    assert(len(colors) == 30)
    assert(len(set(colors)) == 8)
Example #8
0
def plot_topo_fit(subject, report_item):
    """ Plot topography where by clicking subplots you can check the fit parameters
    of specific channels """

    reports = report_item.content
    ch_names = report_item.params['ch_names']
    freqs = list(reports.values())[0].freqs

    colors = color_cycle(len(reports))

    raw = subject.get_raw()
    info = raw.info

    def on_pick(ax, info_idx, names_idx):
        """ When a subplot representing a specific channel is clicked on the 
        main topography plot, show a new figure containing FOOOF fit plot
        for every condition """

        fig = ax.figure
        fig.delaxes(ax)

        for idx, (report_key, report) in enumerate(reports.items()):
            report_ax = fig.add_subplot(1, len(reports), idx + 1)
            fooof = report.get_fooof(names_idx)

            # Use plot function from fooof
            fooof.plot(
                ax=report_ax,
                plot_peaks='dot',
                add_legend=False,
            )
            # Add information about the fit to the axis title
            text = ("Condition: " + str(report_key) + "\n" + "R squred: " +
                    format_float(fooof.r_squared_) + "\n" + "Peaks: \n")
            for peak_params in fooof.peak_params_:
                text = text + '{0} ({1}, {2})\n'.format(
                    *format_floats(peak_params))

            report_ax.set_title(text)

        fig.tight_layout()

    # Create a topography where one can inspect fits by clicking subplots
    fig = plt.figure()
    for ax, info_idx, names_idx in iterate_topography(fig, info, ch_names,
                                                      on_pick):

        handles = []
        for color_idx, (key, report) in enumerate(reports.items()):
            curve = report.power_spectra[names_idx]
            handles.append(
                ax.plot(curve,
                        color=colors[color_idx],
                        linewidth=0.5,
                        label=key)[0])

    fig.legend(handles=handles)
    fig.canvas.set_window_title(report_item.name)
    fig.suptitle(report_item.name)

    plt.show()