def plot_spectrum_averages(subject, channel_groups, name, log_transformed=True): """ Plots spectrum averages. """ subject_name = subject.name spectrum = subject.spectrum.get(name) data = spectrum.content freqs = spectrum.freqs ch_names = spectrum.ch_names info = spectrum.info colors = color_cycle(len(data)) conditions = spectrum.content.keys() averages = {} for key, psd in sorted(data.items()): data_labels, averaged_data = average_to_channel_groups( psd, info, ch_names, channel_groups) for label_idx, label in enumerate(data_labels): if not label in averages: averages[label] = [] averages[label].append((key, averaged_data[label_idx])) ch_types = sorted(set([label[0] for label in averages.keys()])) for ch_type in ch_types: ch_groups = sorted( [label[1] for label in averages.keys() if label[0] == ch_type]) def plot_fun(ax_idx, ax): ch_group = ch_groups[ax_idx] ax.set_title(ch_group) ax.set_xlabel('Frequency (Hz)') ax.set_ylabel('Power ({})'.format( get_power_unit(ch_type, log_transformed))) for color_idx, (key, curve) in enumerate(averages[(ch_type, ch_group)]): if log_transformed: curve = 10 * np.log10(curve) ax.plot(freqs, curve, color=colors[color_idx]) title = ' '.join([name, ch_type]) legend = list(zip(conditions, colors)) create_channel_average_plot(len(ch_groups), plot_fun, title, legend) plt.show()
def plot_evoked_topo(evoked, ch_type): """ Plots evoked time courses arranged as a topography """ evokeds = [] labels = [] for key, evok in sorted(evoked.content.items()): info = evok.info if ch_type == 'eeg': dropped_names = [ ch_name for ch_idx, ch_name in enumerate(info['ch_names']) if ch_idx not in mne.pick_types(info, eeg=True, meg=False) ] else: dropped_names = [ ch_name for ch_idx, ch_name in enumerate(info['ch_names']) if ch_idx not in mne.pick_types(info, eeg=False, meg=True) ] evok = evok.copy().drop_channels(dropped_names) evokeds.append(evok) labels.append(key) colors = color_cycle(len(evoked.content.keys())) # setup legend for subplots lines = [ Line2D([0], [0], color=colors[idx], label=labels[idx]) for idx in range(len(labels)) ] def onclick(event): try: # not nice: ax = plt.gca() channel = plt.getp(ax, 'title') ax.set_title('') ax.legend(handles=lines, loc='upper right') title = ' '.join([evoked.name, channel]) ax.figure.suptitle(title) ax.figure.canvas.set_window_title(title.replace(' ', '_')) plt.show() except Exception as exc: pass fig = mne.viz.plot_evoked_topo(evokeds, color=colors) fig.canvas.mpl_connect('button_press_event', onclick) title = "{0}_{1}".format(evoked.name, ch_type) fig.canvas.set_window_title(title)
def plot_tse_averages(subject, tfr_name, blmode, blstart, blend, tmin, tmax, fmin, fmax, channel_groups): """ Plots tse averages. """ meggie_tfr = subject.tfr.get(tfr_name) tses = _compute_tse(meggie_tfr, fmin, fmax) ch_names = meggie_tfr.ch_names info = meggie_tfr.info colors = color_cycle(len(tses)) conditions = meggie_tfr.content.keys() averages = {} for key, tse in sorted(tses.items()): data_labels, averaged_data = average_to_channel_groups( tse, info, ch_names, channel_groups) times, averaged_data = _crop_and_correct_to_baseline( averaged_data, blmode, blstart, blend, tmin, tmax, meggie_tfr.times) for label_idx, label in enumerate(data_labels): if not label in averages: averages[label] = [] averages[label].append((key, times, averaged_data[label_idx])) ch_types = sorted(set([label[0] for label in averages.keys()])) for ch_type in ch_types: ch_groups = sorted( [label[1] for label in averages.keys() if label[0] == ch_type]) def plot_fun(ax_idx, ax): ch_group = ch_groups[ax_idx] ax.set_title(ch_group) ax.set_xlabel('Time (s)') ax.set_ylabel('Power ({})'.format(get_power_unit(ch_type, False))) for color_idx, (key, times, curve) in enumerate(averages[(ch_type, ch_group)]): ax.plot(times, curve, color=colors[color_idx], label=key) ax.axhline(0, color='black') ax.axvline(0, color='black') title = ' '.join([tfr_name, ch_type]) legend = list(zip(conditions, colors)) create_channel_average_plot(len(ch_groups), plot_fun, title, legend) plt.show()
def plot_evoked_averages(evoked, channel_groups): """ Plots channel averages. """ conditions = evoked.content.keys() colors = color_cycle(len(conditions)) times = evoked.times averages = {} for key, mne_evoked in sorted(evoked.content.items()): data_labels, averaged_data = _create_averages(mne_evoked, channel_groups) for label_idx, label in enumerate(data_labels): if not label in averages: averages[label] = [] averages[label].append((key, averaged_data[label_idx])) ch_types = sorted(set([label[0] for label in averages.keys()])) for ch_type in ch_types: ch_groups = sorted( [label[1] for label in averages.keys() if label[0] == ch_type]) def plot_fun(ax_idx, ax): ch_group = ch_groups[ax_idx] ax.set_title(ch_group) ax.set_xlabel('Time (s)') ax.set_ylabel('Amplitude ({})'.format(get_unit(ch_type))) for color_idx, (key, curve) in enumerate(averages[(ch_type, ch_group)]): ax.plot(times, curve, color=colors[color_idx]) ax.axhline(0, color='black') ax.axvline(0, color='black') title = ' '.join([evoked.name, ch_type]) legend = list(zip(conditions, colors)) create_channel_average_plot(len(ch_groups), plot_fun, title, legend) plt.show()
def plot_tse_topo(subject, tfr_name, blmode, blstart, blend, tmin, tmax, fmin, fmax, ch_type): """ Plots a tse topography. """ meggie_tfr = subject.tfr.get(tfr_name) tses = _compute_tse(meggie_tfr, fmin, fmax) info = meggie_tfr.info if ch_type == 'meg': picked_channels = [ ch_name for ch_idx, ch_name in enumerate(info['ch_names']) if ch_idx in mne.pick_types(info, meg=True, eeg=False) ] else: picked_channels = [ ch_name for ch_idx, ch_name in enumerate(info['ch_names']) if ch_idx in mne.pick_types(info, meg=False, eeg=True) ] info = info.copy().pick_channels(picked_channels) ch_names = meggie_tfr.ch_names colors = color_cycle(len(tses)) def individual_plot(ax, info_idx, names_idx): """ """ ch_name = ch_names[names_idx] title = ' '.join([tfr_name, ch_name]) ax.figure.canvas.set_window_title(title.replace(' ', '_')) ax.figure.suptitle(title) ax.set_title('') for color_idx, (key, tse) in enumerate(sorted(tses.items())): times, tse = _crop_and_correct_to_baseline(tse, blmode, blstart, blend, tmin, tmax, meggie_tfr.times) ax.plot(times, tse[names_idx], color=colors[color_idx], label=key) ax.axhline(0, color='black') ax.axvline(0, color='black') ax.legend() ax.set_xlabel('Time (s)') ax.set_ylabel('Power ({})'.format( get_power_unit(mne.io.pick.channel_type(info, info_idx), False))) plt.show() fig = plt.figure() for ax, info_idx, names_idx in iterate_topography(fig, info, ch_names, individual_plot): handles = [] for color_idx, (key, tse) in enumerate(sorted(tses.items())): times, tse = _crop_and_correct_to_baseline(tse, blmode, blstart, blend, tmin, tmax, meggie_tfr.times) handles.append( ax.plot(times, tse[names_idx], linewidth=0.2, color=colors[color_idx], label=key)[0]) fig.legend(handles=handles) title = '{0}_{1}'.format(tfr_name, ch_type) fig.canvas.set_window_title(title) plt.show()
def plot_spectrum_topo(subject, name, log_transformed=True, ch_type='meg'): """ Plots spectrum topography. """ subject_name = subject.name spectrum = subject.spectrum.get(name) data = spectrum.content freqs = spectrum.freqs ch_names = spectrum.ch_names info = spectrum.info if ch_type == 'meg': picked_channels = [ ch_name for ch_idx, ch_name in enumerate(info['ch_names']) if ch_idx in mne.pick_types(info, meg=True, eeg=False) ] else: picked_channels = [ ch_name for ch_idx, ch_name in enumerate(info['ch_names']) if ch_idx in mne.pick_types(info, eeg=True, meg=False) ] info = info.copy().pick_channels(picked_channels) colors = color_cycle(len(data)) def individual_plot(ax, info_idx, names_idx): """ """ ch_name = ch_names[names_idx] for color_idx, (key, psd) in enumerate(sorted(data.items())): if log_transformed: curve = 10 * np.log10(psd[names_idx]) else: curve = psd[names_idx] ax.plot(freqs, curve, color=colors[color_idx], label=key) title = ' '.join([name, ch_name]) ax.figure.canvas.set_window_title(title.replace(' ', '_')) ax.figure.suptitle(title) ax.set_title('') ax.legend() ax.set_xlabel('Frequency (Hz)') ax.set_ylabel('Power ({})'.format( get_power_unit(mne.io.pick.channel_type(info, info_idx), log_transformed))) plt.show() fig = plt.figure() for ax, info_idx, names_idx in iterate_topography(fig, info, ch_names, individual_plot): handles = [] for color_idx, (key, psd) in enumerate(sorted(data.items())): if log_transformed: curve = 10 * np.log10(psd[names_idx]) else: curve = psd[names_idx] handles.append( ax.plot(curve, color=colors[color_idx], linewidth=0.5, label=key)[0]) if not handles: return fig.legend(handles=handles) title = '{0}_{1}'.format(name, ch_type) fig.canvas.set_window_title(title) plt.show()
def test_color_cycle(): colors = color_cycle(30) assert(type(colors) == list) assert(len(colors) == 30) assert(len(set(colors)) == 8)
def plot_topo_fit(subject, report_item): """ Plot topography where by clicking subplots you can check the fit parameters of specific channels """ reports = report_item.content ch_names = report_item.params['ch_names'] freqs = list(reports.values())[0].freqs colors = color_cycle(len(reports)) raw = subject.get_raw() info = raw.info def on_pick(ax, info_idx, names_idx): """ When a subplot representing a specific channel is clicked on the main topography plot, show a new figure containing FOOOF fit plot for every condition """ fig = ax.figure fig.delaxes(ax) for idx, (report_key, report) in enumerate(reports.items()): report_ax = fig.add_subplot(1, len(reports), idx + 1) fooof = report.get_fooof(names_idx) # Use plot function from fooof fooof.plot( ax=report_ax, plot_peaks='dot', add_legend=False, ) # Add information about the fit to the axis title text = ("Condition: " + str(report_key) + "\n" + "R squred: " + format_float(fooof.r_squared_) + "\n" + "Peaks: \n") for peak_params in fooof.peak_params_: text = text + '{0} ({1}, {2})\n'.format( *format_floats(peak_params)) report_ax.set_title(text) fig.tight_layout() # Create a topography where one can inspect fits by clicking subplots fig = plt.figure() for ax, info_idx, names_idx in iterate_topography(fig, info, ch_names, on_pick): handles = [] for color_idx, (key, report) in enumerate(reports.items()): curve = report.power_spectra[names_idx] handles.append( ax.plot(curve, color=colors[color_idx], linewidth=0.5, label=key)[0]) fig.legend(handles=handles) fig.canvas.set_window_title(report_item.name) fig.suptitle(report_item.name) plt.show()