示例#1
0
def plot_tse_topo(experiment, subject, tfr_name, blmode, blstart, blend, tmin,
                  tmax, fmin, fmax):
    """
    """
    meggie_tfr = subject.tfr.get(tfr_name)

    times, tses = _compute_tse(meggie_tfr, fmin, fmax, tmin, tmax, blmode,
                               blstart, blend)

    ch_names = meggie_tfr.ch_names
    info = meggie_tfr.info
    colors = color_cycle(len(tses))

    logging.getLogger('ui_logger').info('Plotting TSE from all channels..')

    def individual_plot(ax, info_idx, names_idx):
        """
        """
        ch_name = ch_names[names_idx]

        title = 'TSE_{0}_{1}'.format(tfr_name, ch_name)
        ax.figure.canvas.set_window_title(title)
        ax.figure.suptitle(title)
        ax.set_title('')

        for color_idx, (key, tse) in enumerate(tses.items()):
            ax.plot(times, tse[names_idx], color=colors[color_idx], label=key)
            ax.axhline(0)
            ax.axvline(0)

        ax.legend()

        ax.set_xlabel('Time (s)')
        ax.set_ylabel('Power ({})'.format(
            get_power_unit(mne.io.pick.channel_type(info, info_idx), False)))

        plt.show()

    fig = plt.figure()
    for ax, info_idx, names_idx in iterate_topography(fig, info, ch_names,
                                                      individual_plot):

        handles = []
        for color_idx, (key, tse) in enumerate(tses.items()):
            handles.append(
                ax.plot(tse[names_idx],
                        linewidth=0.2,
                        color=colors[color_idx],
                        label=key)[0])

    fig.legend(handles=handles)
    title = 'TSE_{0}'.format(tfr_name)
    fig.canvas.set_window_title(title)
    fig.suptitle(title)

    plt.show()
示例#2
0
def plot_spectrum_averages(experiment, name, log_transformed=True):
    """
    """

    subject = experiment.active_subject
    subject_name = subject.name

    spectrum = subject.spectrum.get(name)

    data = spectrum.content
    freqs = spectrum.freqs
    ch_names = spectrum.ch_names

    channel_groups = experiment.channel_groups

    info = subject.get_raw().info

    colors = color_cycle(len(data))

    logging.getLogger('ui_logger').info('Plotting spectrum channel averages..')

    averages = {}
    for idx, (key, psd) in enumerate(data.items()):

        data_labels, averaged_data = average_to_channel_groups(
            psd, info, ch_names, channel_groups)

        if not data_labels:
            raise Exception('No channel groups matching the data found.')

        averages[key] = data_labels, averaged_data
        shape = averaged_data.shape

    for ii in range(shape[0]):
        fig, ax = plt.subplots()
        for color_idx, key in enumerate(averages.keys()):
            ax.set_xlabel('Frequency (Hz)')

            ax.set_ylabel('Power ({})'.format(
                get_power_unit(averages[key][0][ii][0], log_transformed)))

            if log_transformed:
                curve = 10 * np.log10(averages[key][1][ii])
            else:
                curve = averages[key][1][ii]

            ax.plot(freqs, curve, color=colors[color_idx], label=key)

        ax.legend()
        ch_type, ch_group = averages[key][0][ii]
        title = 'spectrum_{0}_{1}_{2}'.format(name, ch_type, ch_group)
        fig.canvas.set_window_title(title)
        fig.suptitle(title)

    plt.show()
示例#3
0
def plot_tse_averages(experiment, subject, tfr_name, blmode, blstart, blend,
                      tmin, tmax, fmin, fmax):
    """
    """
    meggie_tfr = subject.tfr.get(tfr_name)

    times, tses = _compute_tse(meggie_tfr, fmin, fmax, tmin, tmax, blmode,
                               blstart, blend)

    ch_names = meggie_tfr.ch_names
    info = meggie_tfr.info
    colors = color_cycle(len(tses))

    channel_groups = experiment.channel_groups

    logging.getLogger('ui_logger').info('Plotting TSE channel averages..')

    averages = {}
    for idx, (key, tse) in enumerate(tses.items()):

        data_labels, averaged_data = average_to_channel_groups(
            tse, info, ch_names, channel_groups)

        if not data_labels:
            raise Exception('No channel groups matching the data found.')

        averages[key] = data_labels, averaged_data
        shape = averaged_data.shape

    for ii in range(shape[0]):
        fig, ax = plt.subplots()
        for color_idx, key in enumerate(averages.keys()):
            ax.set_xlabel('Time (s)')
            ax.set_ylabel('Power ({})'.format(
                get_power_unit(averages[key][0][ii][0], False)))

            ax.plot(times,
                    averages[key][1][ii],
                    color=colors[color_idx],
                    label=key)
        ax.axhline(0)
        ax.legend()
        ch_type, ch_group = averages[key][0][ii]
        title = 'TSE_{0}_{1}_{2}'.format(tfr_name, ch_type, ch_group)
        fig.canvas.set_window_title(title)
        fig.suptitle(title)

    plt.show()
示例#4
0
def plot_channel_averages(experiment, evoked):
    """
    Draws a topography representation of the evoked potentials.

    """

    # average and restructure for ease of plotting
    averages = {}
    for key, mne_evoked in evoked.content.items():
        data_labels, averaged_data = create_averages(experiment, mne_evoked)
        for idx in range(len(data_labels)):
            if not data_labels[idx] in averages:
                averages[data_labels[idx]] = []
            averages[data_labels[idx]].append(averaged_data[idx])

    if not averages:
        raise Exception('No channel groups matching the data found')

    colors = color_cycle(len(list(averages.values())[0]))

    for type_key, item in averages.items():
        fig, ax = plt.subplots()
        for evoked_idx, evoked_data in enumerate(item):
            mne_evoked = list(evoked.content.values())[evoked_idx]
            evoked_name = mne_evoked.comment
            times = mne_evoked.times
            ax.set_xlabel('Time (s)')
            ax.set_ylabel('Amplitude ({})'.format(get_unit(type_key[0])))
            ax.plot(times,
                    evoked_data,
                    color=colors[evoked_idx],
                    label=evoked_name)

            ax.axhline(0)
            ax.axvline(0)
            ax.legend()

        title = 'evoked_{0}_{1}'.format(type_key[0], type_key[1])
        fig.canvas.set_window_title(title)
        fig.suptitle(title)

    plt.show()
示例#5
0
def _plot_evoked_topo(experiment, evoked):
    evokeds = []
    labels = []
    for key, evok in evoked.content.items():
        evokeds.append(evok)
        labels.append(key)

    colors = color_cycle(len(evoked.content.keys()))

    # setup legend for subplots
    lines = [
        Line2D([0], [0], color=colors[idx], label=labels[idx])
        for idx in range(len(labels))
    ]

    def onclick(event):
        try:
            # not nice:
            ax = plt.gca()

            channel = plt.getp(ax, 'title')
            ax.set_title('')

            title = "evoked_{0}_{1}".format(evoked.name, channel)

            ax.legend(handles=lines, loc='upper right')

            ax.figure.canvas.set_window_title(title)
            ax.figure.suptitle(title)
            plt.show()
        except Exception as exc:
            pass

    fig = mne.viz.plot_evoked_topo(evokeds, color=colors)
    fig.canvas.mpl_connect('button_press_event', onclick)
    title = "evoked_{0}".format(evoked.name)
    fig.canvas.set_window_title(title)
    fig.suptitle(title)
示例#6
0
def plot_single_channel(experiment, data, window):
    """ Plots a single channel from selected evoked
    """
    try:
        selected_name = data['outputs']['evoked'][0]
    except IndexError as exc:
        return
    subject = experiment.active_subject
    evokeds = subject.evoked.get(selected_name)
    content = evokeds.content

    info = list(content.values())[0].info

    conditions = [key for key in content.keys()]

    title = evokeds.name

    ch_names = info['ch_names']
    chs_by_type = get_channels_by_type(info)

    ylims = {}
    scalings = {}
    units = {}
    ch_types = {}
    for ch_name in ch_names:
        ch_type = None

        for key, values in chs_by_type.items():
            if ch_name in values:
                ch_type = key

        if not ch_type:
            continue

        ch_types[ch_name] = ch_type

        ymin, ymax = 0, 0
        idx = info['ch_names'].index(ch_name)
        for mne_evoked in content.values():
            if np.max(mne_evoked.data[idx]) > ymax:
                ymax = np.max(mne_evoked.data[idx])
            if np.min(mne_evoked.data[idx]) < ymin:
                ymin = np.min(mne_evoked.data[idx])

        ylims[ch_name] = (ymin, ymax)
        scalings[ch_name] = get_scaling(ch_type)
        units[ch_name] = get_unit(ch_type)

    colors = color_cycle(len(content.keys()))

    def handler(ch_name, title, legend, ylim, window, window_len):
        try:
            ch_idx = info['ch_names'].index(ch_name)

            # create new evoked based on old
            new_evokeds = {}
            for key, evoked in content.items():
                new_evoked = evoked.copy()

                # smoothen
                if window:
                    try:
                        new_evoked.data[ch_idx] = smooth_signal(
                            new_evoked.data[ch_idx],
                            window_len=window_len,
                            window=window)
                    except ValueError as exc:
                        exc_messagebox(window, exc)

                new_evoked.comment = legend[key]
                new_evokeds[legend[key]] = new_evoked

            ylim = {ch_types[ch_name]: ylim}

            mne.viz.plot_compare_evokeds(new_evokeds,
                                         title=title,
                                         picks=[ch_idx],
                                         colors=colors,
                                         ylim=ylim,
                                         show_sensors=False)
        except Exception as exc:
            exc_messagebox(window, exc)

    dialog = SingleChannelDialog(window, handler, title, ch_names, scalings,
                                 units, ylims, conditions)
    dialog.show()
示例#7
0
def plot_spectrum_topo(experiment, name, log_transformed=True):
    """
    """

    subject = experiment.active_subject
    subject_name = subject.name

    spectrum = subject.spectrum.get(name)

    data = spectrum.content
    freqs = spectrum.freqs
    ch_names = spectrum.ch_names

    info = subject.get_raw().info
    info_names = info['ch_names']

    colors = color_cycle(len(data))

    logging.getLogger('ui_logger').info(
        'Plotting spectrum from all channels..')

    def individual_plot(ax, info_idx, names_idx):
        """
        """
        ch_name = ch_names[names_idx]
        for color_idx, (key, psd) in enumerate(data.items()):

            if log_transformed:
                curve = 10 * np.log10(psd[names_idx])
            else:
                curve = psd[names_idx]

            ax.plot(freqs, curve, color=colors[color_idx], label=key)

        title = 'spectrum_{0}_{1}'.format(name, ch_name)
        ax.figure.canvas.set_window_title(title)
        ax.figure.suptitle(title)
        ax.set_title('')

        ax.legend()

        ax.set_xlabel('Frequency (Hz)')
        ax.set_ylabel('Power ({})'.format(
            get_power_unit(mne.io.pick.channel_type(info, info_idx),
                           log_transformed)))

        plt.show()

    fig = plt.figure()

    for ax, info_idx, names_idx in iterate_topography(fig, info, ch_names,
                                                      individual_plot):

        handles = []
        for color_idx, (key, psd) in enumerate(data.items()):

            if log_transformed:
                curve = 10 * np.log10(psd[names_idx])
            else:
                curve = psd[names_idx]

            handles.append(
                ax.plot(curve,
                        color=colors[color_idx],
                        linewidth=0.5,
                        label=key)[0])

    if not handles:
        return

    fig.legend(handles=handles)
    title = 'spectrum_{0}'.format(name)
    fig.canvas.set_window_title(title)
    fig.suptitle(title)
    plt.show()