Esempio n. 1
0
def test_average_to_channel_groups():
    sample_folder = mne.datasets.sample.data_path()
    sample_fname = os.path.join(sample_folder, 'MEG', 'sample',
                                'sample_audvis_raw.fif')
    raw = mne.io.read_raw_fif(sample_fname, preload=True)
    info = raw.info

    data = raw._data

    ch_names = info['ch_names'][:20]

    meg_channel_groups = get_default_channel_groups(info, 'meg')
    eeg_channel_groups = get_default_channel_groups(info, 'eeg')

    # find out to which channel group each of the channels belongs to
    results = []
    for ch_name in ch_names:
        for ch_group_name, ch_group in meg_channel_groups.items():
            if ch_name in ch_group:
                results.append((ch_name, ch_group_name))

    labels, averaged_data = average_to_channel_groups(
        raw._data[:, 0:100], info, ch_names, {
            'meg': meg_channel_groups,
            'eeg': eeg_channel_groups
        })

    # check that the localization of data is same before and after averaging
    assert (set([result[1] for result in results
                 ]) == set([label[1] for label in labels]))
Esempio n. 2
0
def _create_averages(mne_evoked, channel_groups):
    mne_evoked = mne_evoked.copy().drop_channels(mne_evoked.info['bads'])

    data_labels, averaged_data = average_to_channel_groups(
        mne_evoked.data, mne_evoked.info, mne_evoked.info['ch_names'],
        channel_groups)

    return data_labels, averaged_data
Esempio n. 3
0
def save_tfr_channel_averages(experiment, tfr_name, blmode, blstart, blend,
                              tmin, tmax, fmin, fmax):
    """
    """
    column_names = []
    row_descs = []
    csv_data = []

    channel_groups = experiment.channel_groups

    # accumulate csv contents
    for subject in experiment.subjects.values():
        tfr = subject.tfr.get(tfr_name)
        if not tfr:
            continue

        ch_names = tfr.ch_names
        info = tfr.info

        for key, mne_tfr in tfr.content.items():

            # crop and correct to baseline
            mne_tfr = mne_tfr.copy().crop(tmin=tmin,
                                          tmax=tmax,
                                          fmin=fmin,
                                          fmax=fmax)
            times = mne_tfr.times
            column_names = format_floats(times)
            freqs = format_floats(mne_tfr.freqs)

            data = mne.baseline.rescale(mne_tfr.data,
                                        times,
                                        baseline=(blstart, blend),
                                        mode=blmode)

            data_labels, averaged_data = average_to_channel_groups(
                data, info, ch_names, channel_groups)

            if not data_labels:
                raise Exception('No channel groups matching the data found')

            for ix in range(averaged_data.shape[0]):
                for iy in range(averaged_data.shape[1]):
                    ch_type, area = data_labels[ix]

                    csv_data.append(averaged_data[ix, iy].tolist())

                    row_desc = (subject.name, key, ch_type, area, freqs[iy])
                    row_descs.append(row_desc)

        folder = filemanager.create_timestamped_folder(experiment)
        fname = tfr_name + '_all_subjects_channel_averages_tfr.csv'
        path = os.path.join(folder, fname)

        filemanager.save_csv(path, csv_data, column_names, row_descs)
        logging.getLogger('ui_logger').info('Saved the csv file to ' + path)
Esempio n. 4
0
def plot_tfr_averages(experiment, subject, tfr_name, tfr_condition, blmode,
                      blstart, blend, tmin, tmax, fmin, fmax):

    meggie_tfr = subject.tfr[tfr_name]

    if blmode:
        bline = (blstart, blend)
        mode = blmode
    else:
        bline = None
        mode = None

    tfr = meggie_tfr.content.get(tfr_condition)

    logging.getLogger('ui_logger').info("Plotting TFR channel averages...")

    data = tfr.data
    ch_names = meggie_tfr.ch_names
    channel_groups = experiment.channel_groups

    data_labels, averaged_data = average_to_channel_groups(
        data, meggie_tfr.info, ch_names, channel_groups)

    if not data_labels:
        raise Exception('No channel groups matching the data found')

    sfreq = meggie_tfr.info['sfreq']
    times = meggie_tfr.times
    freqs = meggie_tfr.freqs

    for idx in range(len(data_labels)):
        data = averaged_data[idx]
        labels = data_labels[idx]
        info = mne.create_info(ch_names=['grand_average'],
                               sfreq=sfreq,
                               ch_types='mag')
        tfr = mne.time_frequency.tfr.AverageTFR(info, data[np.newaxis, :],
                                                times, freqs, 1)

        title = 'TFR_{0}_{1}'.format(labels[1], labels[0])

        # prevent interaction as no topography is involved now
        def onselect(*args, **kwargs):
            pass

        tfr._onselect = onselect

        fig = tfr.plot(baseline=bline,
                       mode=mode,
                       title=title,
                       fmin=fmin,
                       fmax=fmax,
                       tmin=tmin,
                       tmax=tmax)
        fig.canvas.set_window_title(title)
Esempio n. 5
0
def plot_spectrum_averages(experiment, name, log_transformed=True):
    """
    """

    subject = experiment.active_subject
    subject_name = subject.name

    spectrum = subject.spectrum.get(name)

    data = spectrum.content
    freqs = spectrum.freqs
    ch_names = spectrum.ch_names

    channel_groups = experiment.channel_groups

    info = subject.get_raw().info

    colors = color_cycle(len(data))

    logging.getLogger('ui_logger').info('Plotting spectrum channel averages..')

    averages = {}
    for idx, (key, psd) in enumerate(data.items()):

        data_labels, averaged_data = average_to_channel_groups(
            psd, info, ch_names, channel_groups)

        if not data_labels:
            raise Exception('No channel groups matching the data found.')

        averages[key] = data_labels, averaged_data
        shape = averaged_data.shape

    for ii in range(shape[0]):
        fig, ax = plt.subplots()
        for color_idx, key in enumerate(averages.keys()):
            ax.set_xlabel('Frequency (Hz)')

            ax.set_ylabel('Power ({})'.format(
                get_power_unit(averages[key][0][ii][0], log_transformed)))

            if log_transformed:
                curve = 10 * np.log10(averages[key][1][ii])
            else:
                curve = averages[key][1][ii]

            ax.plot(freqs, curve, color=colors[color_idx], label=key)

        ax.legend()
        ch_type, ch_group = averages[key][0][ii]
        title = 'spectrum_{0}_{1}_{2}'.format(name, ch_type, ch_group)
        fig.canvas.set_window_title(title)
        fig.suptitle(title)

    plt.show()
Esempio n. 6
0
def plot_spectrum_averages(subject,
                           channel_groups,
                           name,
                           log_transformed=True):
    """ Plots spectrum averages.
    """

    subject_name = subject.name

    spectrum = subject.spectrum.get(name)

    data = spectrum.content
    freqs = spectrum.freqs
    ch_names = spectrum.ch_names
    info = spectrum.info

    colors = color_cycle(len(data))
    conditions = spectrum.content.keys()

    averages = {}
    for key, psd in sorted(data.items()):

        data_labels, averaged_data = average_to_channel_groups(
            psd, info, ch_names, channel_groups)

        for label_idx, label in enumerate(data_labels):
            if not label in averages:
                averages[label] = []
            averages[label].append((key, averaged_data[label_idx]))

    ch_types = sorted(set([label[0] for label in averages.keys()]))
    for ch_type in ch_types:

        ch_groups = sorted(
            [label[1] for label in averages.keys() if label[0] == ch_type])

        def plot_fun(ax_idx, ax):
            ch_group = ch_groups[ax_idx]
            ax.set_title(ch_group)
            ax.set_xlabel('Frequency (Hz)')
            ax.set_ylabel('Power ({})'.format(
                get_power_unit(ch_type, log_transformed)))

            for color_idx, (key, curve) in enumerate(averages[(ch_type,
                                                               ch_group)]):
                if log_transformed:
                    curve = 10 * np.log10(curve)
                ax.plot(freqs, curve, color=colors[color_idx])

        title = ' '.join([name, ch_type])
        legend = list(zip(conditions, colors))
        create_channel_average_plot(len(ch_groups), plot_fun, title, legend)

    plt.show()
Esempio n. 7
0
def create_averages(experiment, mne_evoked):
    """
    """
    channel_groups = experiment.channel_groups

    mne_evoked = mne_evoked.copy().drop_channels(mne_evoked.info['bads'])

    data_labels, averaged_data = average_to_channel_groups(
        mne_evoked.data, mne_evoked.info, mne_evoked.info['ch_names'],
        channel_groups)

    return data_labels, averaged_data
Esempio n. 8
0
def plot_tse_averages(subject, tfr_name, blmode, blstart, blend, tmin, tmax,
                      fmin, fmax, channel_groups):
    """ Plots tse averages.
    """
    meggie_tfr = subject.tfr.get(tfr_name)

    tses = _compute_tse(meggie_tfr, fmin, fmax)

    ch_names = meggie_tfr.ch_names
    info = meggie_tfr.info
    colors = color_cycle(len(tses))
    conditions = meggie_tfr.content.keys()

    averages = {}
    for key, tse in sorted(tses.items()):
        data_labels, averaged_data = average_to_channel_groups(
            tse, info, ch_names, channel_groups)

        times, averaged_data = _crop_and_correct_to_baseline(
            averaged_data, blmode, blstart, blend, tmin, tmax,
            meggie_tfr.times)

        for label_idx, label in enumerate(data_labels):
            if not label in averages:
                averages[label] = []
            averages[label].append((key, times, averaged_data[label_idx]))

    ch_types = sorted(set([label[0] for label in averages.keys()]))
    for ch_type in ch_types:

        ch_groups = sorted(
            [label[1] for label in averages.keys() if label[0] == ch_type])

        def plot_fun(ax_idx, ax):
            ch_group = ch_groups[ax_idx]
            ax.set_title(ch_group)
            ax.set_xlabel('Time (s)')
            ax.set_ylabel('Power ({})'.format(get_power_unit(ch_type, False)))

            for color_idx, (key, times,
                            curve) in enumerate(averages[(ch_type, ch_group)]):
                ax.plot(times, curve, color=colors[color_idx], label=key)

            ax.axhline(0, color='black')
            ax.axvline(0, color='black')

        title = ' '.join([tfr_name, ch_type])
        legend = list(zip(conditions, colors))
        create_channel_average_plot(len(ch_groups), plot_fun, title, legend)

    plt.show()
Esempio n. 9
0
def plot_tse_averages(experiment, subject, tfr_name, blmode, blstart, blend,
                      tmin, tmax, fmin, fmax):
    """
    """
    meggie_tfr = subject.tfr.get(tfr_name)

    times, tses = _compute_tse(meggie_tfr, fmin, fmax, tmin, tmax, blmode,
                               blstart, blend)

    ch_names = meggie_tfr.ch_names
    info = meggie_tfr.info
    colors = color_cycle(len(tses))

    channel_groups = experiment.channel_groups

    logging.getLogger('ui_logger').info('Plotting TSE channel averages..')

    averages = {}
    for idx, (key, tse) in enumerate(tses.items()):

        data_labels, averaged_data = average_to_channel_groups(
            tse, info, ch_names, channel_groups)

        if not data_labels:
            raise Exception('No channel groups matching the data found.')

        averages[key] = data_labels, averaged_data
        shape = averaged_data.shape

    for ii in range(shape[0]):
        fig, ax = plt.subplots()
        for color_idx, key in enumerate(averages.keys()):
            ax.set_xlabel('Time (s)')
            ax.set_ylabel('Power ({})'.format(
                get_power_unit(averages[key][0][ii][0], False)))

            ax.plot(times,
                    averages[key][1][ii],
                    color=colors[color_idx],
                    label=key)
        ax.axhline(0)
        ax.legend()
        ch_type, ch_group = averages[key][0][ii]
        title = 'TSE_{0}_{1}_{2}'.format(tfr_name, ch_type, ch_group)
        fig.canvas.set_window_title(title)
        fig.suptitle(title)

    plt.show()
Esempio n. 10
0
def save_channel_averages(experiment, selected_name, log_transformed=False):
    column_names = []
    row_descs = []
    csv_data = []

    channel_groups = experiment.channel_groups

    # accumulate csv contents
    for subject in experiment.subjects.values():
        spectrum = subject.spectrum.get(selected_name)
        if not spectrum:
            continue

        ch_names = spectrum.ch_names
        freqs = spectrum.freqs

        info = subject.get_raw().info

        for key, psd in spectrum.content.items():

            data_labels, averaged_data = average_to_channel_groups(
                psd, info, ch_names, channel_groups)

            if not data_labels:
                raise Exception('No channel groups matching the data found.')

            if log_transformed:
                csv_data.extend(10 * np.log10(averaged_data.tolist()))
            else:
                csv_data.extend(averaged_data.tolist())

            column_names = format_floats(freqs)

            for ch_type, area in data_labels:
                row_desc = (subject.name, key, ch_type, area)
                row_descs.append(row_desc)

    folder = filemanager.create_timestamped_folder(experiment)
    fname = selected_name + '_all_subjects_channel_averages_spectrum.csv'
    path = os.path.join(folder, fname)

    filemanager.save_csv(path, csv_data, column_names, row_descs)
    logging.getLogger('ui_logger').info('Saved the csv file to ' + path)
Esempio n. 11
0
def save_tse_channel_averages(experiment, tfr_name, blmode, blstart, blend,
                              tmin, tmax, fmin, fmax):
    """
    """
    column_names = []
    row_descs = []
    csv_data = []

    channel_groups = experiment.channel_groups

    # accumulate csv contents
    for subject in experiment.subjects.values():
        tfr = subject.tfr.get(tfr_name)
        if not tfr:
            continue

        times, tses = _compute_tse(tfr, fmin, fmax, tmin, tmax, blmode,
                                   blstart, blend)
        column_names = format_floats(times)

        for key, tse in tses.items():

            data_labels, averaged_data = average_to_channel_groups(
                tse, tfr.info, tfr.ch_names, channel_groups)

            if not data_labels:
                raise Exception('No channel groups matching the data found.')

            csv_data.extend(averaged_data.tolist())

            for ch_type, area in data_labels:
                row_desc = (subject.name, key, ch_type, area)
                row_descs.append(row_desc)

    folder = filemanager.create_timestamped_folder(experiment)
    fname = tfr_name + '_all_subjects_channel_averages_tfr.csv'
    path = os.path.join(folder, fname)

    filemanager.save_csv(path, csv_data, column_names, row_descs)
    logging.getLogger('ui_logger').info('Saved the csv file to ' + path)
Esempio n. 12
0
def plot_tfr_averages(subject, tfr_name, tfr_condition, blmode, blstart, blend,
                      tmin, tmax, fmin, fmax, channel_groups):
    """ Plots tfr averages.
    """

    meggie_tfr = subject.tfr[tfr_name]

    if blmode:
        bline = (blstart, blend)
        mode = blmode
    else:
        bline = None
        mode = None

    tfr = meggie_tfr.content.get(tfr_condition)

    data = tfr.data
    ch_names = meggie_tfr.ch_names

    sfreq = meggie_tfr.info['sfreq']
    times = meggie_tfr.times
    freqs = meggie_tfr.freqs

    # compared to spectrums, evoked and tse, tfr is plotted with only one condition.
    # it makes the plotting a bit simpler. we will also misuse
    # AverageTFR object to do the heavy work.

    data_labels, averaged_data = average_to_channel_groups(
        data, meggie_tfr.info, ch_names, channel_groups)

    averages = {}
    for label_idx, label in enumerate(data_labels):
        averages[label] = averaged_data[label_idx]

    ch_types = sorted(set([label[0] for label in data_labels]))
    for ch_type in ch_types:

        ch_groups = sorted(
            [label[1] for label in data_labels if label[0] == ch_type])

        def plot_fun(ax_idx, ax):
            ch_group = ch_groups[ax_idx]
            ax.set_title(ch_group)

            info = mne.create_info(ch_names=['grand_average'],
                                   sfreq=sfreq,
                                   ch_types='mag')
            tfr = mne.time_frequency.tfr.AverageTFR(
                info, averages[(ch_type, ch_group)][np.newaxis, :], times,
                freqs, 1)

            # prevent interaction as no topography is involved now
            def onselect(*args, **kwargs):
                pass

            tfr._onselect = onselect

            tfr.plot(baseline=bline,
                     mode=mode,
                     title='',
                     fmin=fmin,
                     fmax=fmax,
                     tmin=tmin,
                     tmax=tmax,
                     axes=ax)

        title = ' '.join([tfr_name, tfr_condition, ch_type])
        create_channel_average_plot(len(ch_groups), plot_fun, title)