def test_average_to_channel_groups(): sample_folder = mne.datasets.sample.data_path() sample_fname = os.path.join(sample_folder, 'MEG', 'sample', 'sample_audvis_raw.fif') raw = mne.io.read_raw_fif(sample_fname, preload=True) info = raw.info data = raw._data ch_names = info['ch_names'][:20] meg_channel_groups = get_default_channel_groups(info, 'meg') eeg_channel_groups = get_default_channel_groups(info, 'eeg') # find out to which channel group each of the channels belongs to results = [] for ch_name in ch_names: for ch_group_name, ch_group in meg_channel_groups.items(): if ch_name in ch_group: results.append((ch_name, ch_group_name)) labels, averaged_data = average_to_channel_groups( raw._data[:, 0:100], info, ch_names, { 'meg': meg_channel_groups, 'eeg': eeg_channel_groups }) # check that the localization of data is same before and after averaging assert (set([result[1] for result in results ]) == set([label[1] for label in labels]))
def _create_averages(mne_evoked, channel_groups): mne_evoked = mne_evoked.copy().drop_channels(mne_evoked.info['bads']) data_labels, averaged_data = average_to_channel_groups( mne_evoked.data, mne_evoked.info, mne_evoked.info['ch_names'], channel_groups) return data_labels, averaged_data
def save_tfr_channel_averages(experiment, tfr_name, blmode, blstart, blend, tmin, tmax, fmin, fmax): """ """ column_names = [] row_descs = [] csv_data = [] channel_groups = experiment.channel_groups # accumulate csv contents for subject in experiment.subjects.values(): tfr = subject.tfr.get(tfr_name) if not tfr: continue ch_names = tfr.ch_names info = tfr.info for key, mne_tfr in tfr.content.items(): # crop and correct to baseline mne_tfr = mne_tfr.copy().crop(tmin=tmin, tmax=tmax, fmin=fmin, fmax=fmax) times = mne_tfr.times column_names = format_floats(times) freqs = format_floats(mne_tfr.freqs) data = mne.baseline.rescale(mne_tfr.data, times, baseline=(blstart, blend), mode=blmode) data_labels, averaged_data = average_to_channel_groups( data, info, ch_names, channel_groups) if not data_labels: raise Exception('No channel groups matching the data found') for ix in range(averaged_data.shape[0]): for iy in range(averaged_data.shape[1]): ch_type, area = data_labels[ix] csv_data.append(averaged_data[ix, iy].tolist()) row_desc = (subject.name, key, ch_type, area, freqs[iy]) row_descs.append(row_desc) folder = filemanager.create_timestamped_folder(experiment) fname = tfr_name + '_all_subjects_channel_averages_tfr.csv' path = os.path.join(folder, fname) filemanager.save_csv(path, csv_data, column_names, row_descs) logging.getLogger('ui_logger').info('Saved the csv file to ' + path)
def plot_tfr_averages(experiment, subject, tfr_name, tfr_condition, blmode, blstart, blend, tmin, tmax, fmin, fmax): meggie_tfr = subject.tfr[tfr_name] if blmode: bline = (blstart, blend) mode = blmode else: bline = None mode = None tfr = meggie_tfr.content.get(tfr_condition) logging.getLogger('ui_logger').info("Plotting TFR channel averages...") data = tfr.data ch_names = meggie_tfr.ch_names channel_groups = experiment.channel_groups data_labels, averaged_data = average_to_channel_groups( data, meggie_tfr.info, ch_names, channel_groups) if not data_labels: raise Exception('No channel groups matching the data found') sfreq = meggie_tfr.info['sfreq'] times = meggie_tfr.times freqs = meggie_tfr.freqs for idx in range(len(data_labels)): data = averaged_data[idx] labels = data_labels[idx] info = mne.create_info(ch_names=['grand_average'], sfreq=sfreq, ch_types='mag') tfr = mne.time_frequency.tfr.AverageTFR(info, data[np.newaxis, :], times, freqs, 1) title = 'TFR_{0}_{1}'.format(labels[1], labels[0]) # prevent interaction as no topography is involved now def onselect(*args, **kwargs): pass tfr._onselect = onselect fig = tfr.plot(baseline=bline, mode=mode, title=title, fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax) fig.canvas.set_window_title(title)
def plot_spectrum_averages(experiment, name, log_transformed=True): """ """ subject = experiment.active_subject subject_name = subject.name spectrum = subject.spectrum.get(name) data = spectrum.content freqs = spectrum.freqs ch_names = spectrum.ch_names channel_groups = experiment.channel_groups info = subject.get_raw().info colors = color_cycle(len(data)) logging.getLogger('ui_logger').info('Plotting spectrum channel averages..') averages = {} for idx, (key, psd) in enumerate(data.items()): data_labels, averaged_data = average_to_channel_groups( psd, info, ch_names, channel_groups) if not data_labels: raise Exception('No channel groups matching the data found.') averages[key] = data_labels, averaged_data shape = averaged_data.shape for ii in range(shape[0]): fig, ax = plt.subplots() for color_idx, key in enumerate(averages.keys()): ax.set_xlabel('Frequency (Hz)') ax.set_ylabel('Power ({})'.format( get_power_unit(averages[key][0][ii][0], log_transformed))) if log_transformed: curve = 10 * np.log10(averages[key][1][ii]) else: curve = averages[key][1][ii] ax.plot(freqs, curve, color=colors[color_idx], label=key) ax.legend() ch_type, ch_group = averages[key][0][ii] title = 'spectrum_{0}_{1}_{2}'.format(name, ch_type, ch_group) fig.canvas.set_window_title(title) fig.suptitle(title) plt.show()
def plot_spectrum_averages(subject, channel_groups, name, log_transformed=True): """ Plots spectrum averages. """ subject_name = subject.name spectrum = subject.spectrum.get(name) data = spectrum.content freqs = spectrum.freqs ch_names = spectrum.ch_names info = spectrum.info colors = color_cycle(len(data)) conditions = spectrum.content.keys() averages = {} for key, psd in sorted(data.items()): data_labels, averaged_data = average_to_channel_groups( psd, info, ch_names, channel_groups) for label_idx, label in enumerate(data_labels): if not label in averages: averages[label] = [] averages[label].append((key, averaged_data[label_idx])) ch_types = sorted(set([label[0] for label in averages.keys()])) for ch_type in ch_types: ch_groups = sorted( [label[1] for label in averages.keys() if label[0] == ch_type]) def plot_fun(ax_idx, ax): ch_group = ch_groups[ax_idx] ax.set_title(ch_group) ax.set_xlabel('Frequency (Hz)') ax.set_ylabel('Power ({})'.format( get_power_unit(ch_type, log_transformed))) for color_idx, (key, curve) in enumerate(averages[(ch_type, ch_group)]): if log_transformed: curve = 10 * np.log10(curve) ax.plot(freqs, curve, color=colors[color_idx]) title = ' '.join([name, ch_type]) legend = list(zip(conditions, colors)) create_channel_average_plot(len(ch_groups), plot_fun, title, legend) plt.show()
def create_averages(experiment, mne_evoked): """ """ channel_groups = experiment.channel_groups mne_evoked = mne_evoked.copy().drop_channels(mne_evoked.info['bads']) data_labels, averaged_data = average_to_channel_groups( mne_evoked.data, mne_evoked.info, mne_evoked.info['ch_names'], channel_groups) return data_labels, averaged_data
def plot_tse_averages(subject, tfr_name, blmode, blstart, blend, tmin, tmax, fmin, fmax, channel_groups): """ Plots tse averages. """ meggie_tfr = subject.tfr.get(tfr_name) tses = _compute_tse(meggie_tfr, fmin, fmax) ch_names = meggie_tfr.ch_names info = meggie_tfr.info colors = color_cycle(len(tses)) conditions = meggie_tfr.content.keys() averages = {} for key, tse in sorted(tses.items()): data_labels, averaged_data = average_to_channel_groups( tse, info, ch_names, channel_groups) times, averaged_data = _crop_and_correct_to_baseline( averaged_data, blmode, blstart, blend, tmin, tmax, meggie_tfr.times) for label_idx, label in enumerate(data_labels): if not label in averages: averages[label] = [] averages[label].append((key, times, averaged_data[label_idx])) ch_types = sorted(set([label[0] for label in averages.keys()])) for ch_type in ch_types: ch_groups = sorted( [label[1] for label in averages.keys() if label[0] == ch_type]) def plot_fun(ax_idx, ax): ch_group = ch_groups[ax_idx] ax.set_title(ch_group) ax.set_xlabel('Time (s)') ax.set_ylabel('Power ({})'.format(get_power_unit(ch_type, False))) for color_idx, (key, times, curve) in enumerate(averages[(ch_type, ch_group)]): ax.plot(times, curve, color=colors[color_idx], label=key) ax.axhline(0, color='black') ax.axvline(0, color='black') title = ' '.join([tfr_name, ch_type]) legend = list(zip(conditions, colors)) create_channel_average_plot(len(ch_groups), plot_fun, title, legend) plt.show()
def plot_tse_averages(experiment, subject, tfr_name, blmode, blstart, blend, tmin, tmax, fmin, fmax): """ """ meggie_tfr = subject.tfr.get(tfr_name) times, tses = _compute_tse(meggie_tfr, fmin, fmax, tmin, tmax, blmode, blstart, blend) ch_names = meggie_tfr.ch_names info = meggie_tfr.info colors = color_cycle(len(tses)) channel_groups = experiment.channel_groups logging.getLogger('ui_logger').info('Plotting TSE channel averages..') averages = {} for idx, (key, tse) in enumerate(tses.items()): data_labels, averaged_data = average_to_channel_groups( tse, info, ch_names, channel_groups) if not data_labels: raise Exception('No channel groups matching the data found.') averages[key] = data_labels, averaged_data shape = averaged_data.shape for ii in range(shape[0]): fig, ax = plt.subplots() for color_idx, key in enumerate(averages.keys()): ax.set_xlabel('Time (s)') ax.set_ylabel('Power ({})'.format( get_power_unit(averages[key][0][ii][0], False))) ax.plot(times, averages[key][1][ii], color=colors[color_idx], label=key) ax.axhline(0) ax.legend() ch_type, ch_group = averages[key][0][ii] title = 'TSE_{0}_{1}_{2}'.format(tfr_name, ch_type, ch_group) fig.canvas.set_window_title(title) fig.suptitle(title) plt.show()
def save_channel_averages(experiment, selected_name, log_transformed=False): column_names = [] row_descs = [] csv_data = [] channel_groups = experiment.channel_groups # accumulate csv contents for subject in experiment.subjects.values(): spectrum = subject.spectrum.get(selected_name) if not spectrum: continue ch_names = spectrum.ch_names freqs = spectrum.freqs info = subject.get_raw().info for key, psd in spectrum.content.items(): data_labels, averaged_data = average_to_channel_groups( psd, info, ch_names, channel_groups) if not data_labels: raise Exception('No channel groups matching the data found.') if log_transformed: csv_data.extend(10 * np.log10(averaged_data.tolist())) else: csv_data.extend(averaged_data.tolist()) column_names = format_floats(freqs) for ch_type, area in data_labels: row_desc = (subject.name, key, ch_type, area) row_descs.append(row_desc) folder = filemanager.create_timestamped_folder(experiment) fname = selected_name + '_all_subjects_channel_averages_spectrum.csv' path = os.path.join(folder, fname) filemanager.save_csv(path, csv_data, column_names, row_descs) logging.getLogger('ui_logger').info('Saved the csv file to ' + path)
def save_tse_channel_averages(experiment, tfr_name, blmode, blstart, blend, tmin, tmax, fmin, fmax): """ """ column_names = [] row_descs = [] csv_data = [] channel_groups = experiment.channel_groups # accumulate csv contents for subject in experiment.subjects.values(): tfr = subject.tfr.get(tfr_name) if not tfr: continue times, tses = _compute_tse(tfr, fmin, fmax, tmin, tmax, blmode, blstart, blend) column_names = format_floats(times) for key, tse in tses.items(): data_labels, averaged_data = average_to_channel_groups( tse, tfr.info, tfr.ch_names, channel_groups) if not data_labels: raise Exception('No channel groups matching the data found.') csv_data.extend(averaged_data.tolist()) for ch_type, area in data_labels: row_desc = (subject.name, key, ch_type, area) row_descs.append(row_desc) folder = filemanager.create_timestamped_folder(experiment) fname = tfr_name + '_all_subjects_channel_averages_tfr.csv' path = os.path.join(folder, fname) filemanager.save_csv(path, csv_data, column_names, row_descs) logging.getLogger('ui_logger').info('Saved the csv file to ' + path)
def plot_tfr_averages(subject, tfr_name, tfr_condition, blmode, blstart, blend, tmin, tmax, fmin, fmax, channel_groups): """ Plots tfr averages. """ meggie_tfr = subject.tfr[tfr_name] if blmode: bline = (blstart, blend) mode = blmode else: bline = None mode = None tfr = meggie_tfr.content.get(tfr_condition) data = tfr.data ch_names = meggie_tfr.ch_names sfreq = meggie_tfr.info['sfreq'] times = meggie_tfr.times freqs = meggie_tfr.freqs # compared to spectrums, evoked and tse, tfr is plotted with only one condition. # it makes the plotting a bit simpler. we will also misuse # AverageTFR object to do the heavy work. data_labels, averaged_data = average_to_channel_groups( data, meggie_tfr.info, ch_names, channel_groups) averages = {} for label_idx, label in enumerate(data_labels): averages[label] = averaged_data[label_idx] ch_types = sorted(set([label[0] for label in data_labels])) for ch_type in ch_types: ch_groups = sorted( [label[1] for label in data_labels if label[0] == ch_type]) def plot_fun(ax_idx, ax): ch_group = ch_groups[ax_idx] ax.set_title(ch_group) info = mne.create_info(ch_names=['grand_average'], sfreq=sfreq, ch_types='mag') tfr = mne.time_frequency.tfr.AverageTFR( info, averages[(ch_type, ch_group)][np.newaxis, :], times, freqs, 1) # prevent interaction as no topography is involved now def onselect(*args, **kwargs): pass tfr._onselect = onselect tfr.plot(baseline=bline, mode=mode, title='', fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax, axes=ax) title = ' '.join([tfr_name, tfr_condition, ch_type]) create_channel_average_plot(len(ch_groups), plot_fun, title)