def plot_ica_overlay_evoked(evoked, evoked_cln, title, show): """ workaround for https://github.com/mne-tools/mne-python/issues/1819 copied from mne.viz.ica._plot_ica_overlay_evoked() Plot evoked after and before ICA cleaning Parameters ---------- ica : instance of mne.preprocessing.ICA The ICA object. epochs : instance of mne.Epochs The Epochs to be regarded. show : bool If True, all open plots will be shown. Returns ------- fig : instance of pyplot.Figure """ ch_types_used = [c for c in ['mag', 'grad', 'eeg'] if c in evoked] n_rows = len(ch_types_used) ch_types_used_cln = [c for c in ['mag', 'grad', 'eeg'] if c in evoked_cln] if len(ch_types_used) != len(ch_types_used_cln): raise ValueError('Raw and clean evokeds must match. ' 'Found different channels.') fig, axes = plt.subplots(n_rows, 1) fig.suptitle('Average signal before (red) and after (black) ICA') axes = axes.flatten() if isinstance(axes, np.ndarray) else axes evoked.plot(axes=axes, show=False) for ax in fig.axes: [l.set_color('r') for l in ax.get_lines()] fig.canvas.draw() evoked_cln.plot(axes=axes, show=show) tight_layout(fig=fig) if show: plt.show() fig.subplots_adjust(top=0.90) fig.canvas.draw()
def _plot_evoked(evoked, plot_type, colorbar=True, hline=None, ylim=None, picks=None, exclude='bads', unit=True, show=True, clim=None, proj=False, xlim='tight', units=None, scalings=None, titles=None, axes=None, cmap='RdBu_r'): """Aux function for plot_evoked and plot_evoked_image (cf. docstrings) Extra param is: plot_type : str, value ('butterfly' | 'image') The type of graph to plot: 'butterfly' plots each channel as a line (x axis: time, y axis: amplitude). 'image' plots a 2D image where color depicts the amplitude of each channel at a given time point (x axis: time, y axis: channel). In 'image' mode, the plot is not interactive. """ import matplotlib.pyplot as plt if axes is not None and proj == 'interactive': raise RuntimeError('Currently only single axis figures are supported' ' for interactive SSP selection.') scalings = _handle_default('scalings', scalings) titles = _handle_default('titles', titles) units = _handle_default('units', units) channel_types = set(key for d in [scalings, titles, units] for key in d) channel_types = sorted(channel_types) # to guarantee consistent order if picks is None: picks = list(range(evoked.info['nchan'])) bad_ch_idx = [evoked.ch_names.index(ch) for ch in evoked.info['bads'] if ch in evoked.ch_names] if len(exclude) > 0: if isinstance(exclude, string_types) and exclude == 'bads': exclude = bad_ch_idx elif (isinstance(exclude, list) and all([isinstance(ch, string_types) for ch in exclude])): exclude = [evoked.ch_names.index(ch) for ch in exclude] else: raise ValueError('exclude has to be a list of channel names or ' '"bads"') picks = list(set(picks).difference(exclude)) types = [channel_type(evoked.info, idx) for idx in picks] n_channel_types = 0 ch_types_used = [] for t in channel_types: if t in types: n_channel_types += 1 ch_types_used.append(t) axes_init = axes # remember if axes where given as input fig = None if axes is None: fig, axes = plt.subplots(n_channel_types, 1) if isinstance(axes, plt.Axes): axes = [axes] elif isinstance(axes, np.ndarray): axes = list(axes) if axes_init is not None: fig = axes[0].get_figure() if not len(axes) == n_channel_types: raise ValueError('Number of axes (%g) must match number of channel ' 'types (%g)' % (len(axes), n_channel_types)) # instead of projecting during each iteration let's use the mixin here. if proj is True and evoked.proj is not True: evoked = evoked.copy() evoked.apply_proj() times = 1e3 * evoked.times # time in miliseconds for ax, t in zip(axes, ch_types_used): ch_unit = units[t] this_scaling = scalings[t] if unit is False: this_scaling = 1.0 ch_unit = 'NA' # no unit idx = [picks[i] for i in range(len(picks)) if types[i] == t] if len(idx) > 0: # Parameters for butterfly interactive plots if plot_type == 'butterfly': if any([i in bad_ch_idx for i in idx]): colors = ['k'] * len(idx) for i in bad_ch_idx: if i in idx: colors[idx.index(i)] = 'r' ax._get_lines.color_cycle = iter(colors) else: ax._get_lines.color_cycle = cycle(['k']) # Set amplitude scaling D = this_scaling * evoked.data[idx, :] # plt.axes(ax) if plot_type == 'butterfly': ax.plot(times, D.T) elif plot_type == 'image': im = ax.imshow(D, interpolation='nearest', origin='lower', extent=[times[0], times[-1], 0, D.shape[0]], aspect='auto', cmap=cmap) if colorbar: cbar = plt.colorbar(im, ax=ax) cbar.ax.set_title(ch_unit) elif plot_type == 'mean' : # ax.plot(times, D.mean(axis=0)) ax.plot(times, np.abs(D).mean(axis=0)) if xlim is not None: if xlim == 'tight': xlim = (times[0], times[-1]) ax.set_xlim(xlim) if ylim is not None and t in ylim: if plot_type == 'butterfly' or plot_type == 'mean': ax.set_ylim(ylim[t]) elif plot_type == 'image': im.set_clim(ylim[t]) ax.set_title(titles[t] + ' (%d channel%s)' % ( len(D), 's' if len(D) > 1 else '')) ax.set_xlabel('time (ms)') if plot_type == 'butterfly' or plot_type == 'mean': ax.set_ylabel('data (%s)' % ch_unit) elif plot_type == 'image': ax.set_ylabel('channels (%s)' % 'index') else: raise ValueError("plot_type has to be 'butterfly' or 'image'." "Got %s." % plot_type) if (plot_type == 'butterfly' or plot_type == 'mean') and (hline is not None): for h in hline: ax.axhline(h, color='r', linestyle='--', linewidth=2) if axes_init is None: plt.subplots_adjust(0.175, 0.08, 0.94, 0.94, 0.2, 0.63) # if proj == 'interactive': # _check_delayed_ssp(evoked) # params = dict(evoked=evoked, fig=fig, projs=evoked.info['projs'], # axes=axes, types=types, units=units, scalings=scalings, # unit=unit, ch_types_used=ch_types_used, picks=picks, # plot_update_proj_callback=_plot_update_evoked, # plot_type=plot_type) # _draw_proj_checkbox(None, params) if show and plt.get_backend() != 'agg': plt.show() fig.canvas.draw() # for axes plots update axes. tight_layout(fig=fig) return fig
def plot_chpi_snr_raw(raw, win_length, n_harmonics=None, show=True, verbose=True): """Compute and plot cHPI SNR from raw data Parameters ---------- win_length : float Length of window to use for SNR estimates (seconds). A longer window will naturally include more low frequency power, resulting in lower SNR. n_harmonics : int or None Number of line frequency harmonics to include in the model. If None, use all harmonics up to the MEG analog lowpass corner. show : bool Show figure if True. Returns ------- fig : instance of matplotlib.figure.Figure cHPI SNR as function of time, residual variance. Notes ----- A general linear model including cHPI and line frequencies is fit into each data window. The cHPI power obtained from the model is then divided by the residual variance (variance of signal unexplained by the model) to obtain the SNR. The SNR may decrease either due to decrease of cHPI amplitudes (e.g. head moving away from the helmet), or due to increase in the residual variance. In case of broadband interference that overlaps with the cHPI frequencies, the resulting decreased SNR accurately reflects the true situation. However, increased narrowband interference outside the cHPI and line frequencies would also cause an increase in the residual variance, even though it wouldn't necessarily affect estimation of the cHPI amplitudes. Thus, this method is intended for a rough overview of cHPI signal quality. A more accurate picture of cHPI quality (at an increased computational cost) can be obtained by examining the goodness-of-fit of the cHPI coil fits. """ import matplotlib.pyplot as plt from mne.chpi import _get_hpi_info # plotting parameters legend_fontsize = 6 title_fontsize = 10 tick_fontsize = 10 label_fontsize = 10 # get some info from fiff sfreq = raw.info['sfreq'] linefreq = raw.info['line_freq'] if n_harmonics is not None: linefreqs = (np.arange(n_harmonics + 1) + 1) * linefreq else: linefreqs = np.arange(linefreq, raw.info['lowpass'], linefreq) buflen = int(win_length * sfreq) if buflen <= 0: raise ValueError('Window length should be >0') cfreqs = _get_hpi_info(raw.info, verbose=False)[0] if verbose: print('Nominal cHPI frequencies: %s Hz' % cfreqs) print('Sampling frequency: %s Hz' % sfreq) print('Using line freqs: %s Hz' % linefreqs) print('Using buffers of %s samples = %s seconds\n' % (buflen, buflen / sfreq)) pick_meg = pick_types(raw.info, meg=True, exclude=[]) pick_mag = pick_types(raw.info, meg='mag', exclude=[]) pick_grad = pick_types(raw.info, meg='grad', exclude=[]) nchan = len(pick_meg) # grad and mag indices into an array that already has meg channels only pick_mag_ = np.in1d(pick_meg, pick_mag).nonzero()[0] pick_grad_ = np.in1d(pick_meg, pick_grad).nonzero()[0] # create general linear model for the data t = np.arange(buflen) / float(sfreq) model = np.empty((len(t), 2 + 2 * (len(linefreqs) + len(cfreqs)))) model[:, 0] = t model[:, 1] = np.ones(t.shape) # add sine and cosine term for each freq allfreqs = np.concatenate([linefreqs, cfreqs]) model[:, 2::2] = np.cos(2 * np.pi * t[:, np.newaxis] * allfreqs) model[:, 3::2] = np.sin(2 * np.pi * t[:, np.newaxis] * allfreqs) inv_model = linalg.pinv(model) # drop last buffer to avoid overrun bufs = np.arange(0, raw.n_times, buflen)[:-1] tvec = bufs / sfreq snr_avg_grad = np.zeros([len(cfreqs), len(bufs)]) hpi_pow_grad = np.zeros([len(cfreqs), len(bufs)]) snr_avg_mag = np.zeros([len(cfreqs), len(bufs)]) resid_vars = np.zeros([nchan, len(bufs)]) for ind, buf0 in enumerate(bufs): if verbose: print('Buffer %s/%s' % (ind + 1, len(bufs))) megbuf = raw[pick_meg, buf0:buf0 + buflen][0].T coeffs = np.dot(inv_model, megbuf) coeffs_hpi = coeffs[2 + 2 * len(linefreqs):] resid_vars[:, ind] = np.var(megbuf - np.dot(model, coeffs), 0) # get total power by combining sine and cosine terms # sinusoidal of amplitude A has power of A**2/2 hpi_pow = (coeffs_hpi[0::2, :]**2 + coeffs_hpi[1::2, :]**2) / 2 hpi_pow_grad[:, ind] = hpi_pow[:, pick_grad_].mean(1) # divide average HPI power by average variance snr_avg_grad[:, ind] = hpi_pow_grad[:, ind] / \ resid_vars[pick_grad_, ind].mean() snr_avg_mag[:, ind] = hpi_pow[:, pick_mag_].mean(1) / \ resid_vars[pick_mag_, ind].mean() cfreqs_legend = ['%s Hz' % fre for fre in cfreqs] fig, axs = plt.subplots(4, 1, sharex=True) # SNR plots for gradiometers and magnetometers ax = axs[0] lines1 = ax.plot(tvec, 10 * np.log10(snr_avg_grad.T)) lines1_med = ax.plot(tvec, 10 * np.log10(np.median(snr_avg_grad, axis=0)), lw=2, ls=':', color='k') ax.set_xlim([tvec.min(), tvec.max()]) ax.set(ylabel='SNR (dB)') ax.yaxis.label.set_fontsize(label_fontsize) ax.set_title('Mean cHPI power / mean residual variance, gradiometers', fontsize=title_fontsize) ax.tick_params(axis='both', which='major', labelsize=tick_fontsize) ax = axs[1] lines2 = ax.plot(tvec, 10 * np.log10(snr_avg_mag.T)) lines2_med = ax.plot(tvec, 10 * np.log10(np.median(snr_avg_mag, axis=0)), lw=2, ls=':', color='k') ax.set_xlim([tvec.min(), tvec.max()]) ax.set(ylabel='SNR (dB)') ax.yaxis.label.set_fontsize(label_fontsize) ax.set_title('Mean cHPI power / mean residual variance, magnetometers', fontsize=title_fontsize) ax.tick_params(axis='both', which='major', labelsize=tick_fontsize) ax = axs[2] lines3 = ax.plot(tvec, hpi_pow_grad.T) lines3_med = ax.plot(tvec, np.median(hpi_pow_grad, axis=0), lw=2, ls=':', color='k') ax.set_xlim([tvec.min(), tvec.max()]) ax.set(ylabel='Power (T/m)$^2$') ax.yaxis.label.set_fontsize(label_fontsize) ax.set_title('Mean cHPI power, gradiometers', fontsize=title_fontsize) ax.tick_params(axis='both', which='major', labelsize=tick_fontsize) # residual (unexplained) variance as function of time ax = axs[3] cls = plt.get_cmap('plasma')(np.linspace(0., 0.7, len(pick_meg))) ax.set_prop_cycle(color=cls) ax.semilogy(tvec, resid_vars[pick_grad_, :].T, alpha=.4) ax.set_xlim([tvec.min(), tvec.max()]) ax.set(ylabel='Var. (T/m)$^2$', xlabel='Time (s)') ax.xaxis.label.set_fontsize(label_fontsize) ax.yaxis.label.set_fontsize(label_fontsize) ax.set_title('Residual (unexplained) variance, all gradiometer channels', fontsize=title_fontsize) ax.tick_params(axis='both', which='major', labelsize=tick_fontsize) tight_layout(pad=.5, w_pad=.1, h_pad=.2) # from mne.viz # tight_layout will screw these up ax = axs[0] box = ax.get_position() ax.set_position([box.x0, box.y0, box.width * 0.8, box.height]) # order curve legends according to mean of data sind = np.argsort(snr_avg_grad.mean(axis=1))[::-1] handles = [lines1[i] for i in sind] handles.append(lines1_med[0]) labels = [cfreqs_legend[i] for i in sind] labels.append('Median') leg_kwargs = dict( prop={'size': legend_fontsize}, bbox_to_anchor=( 1.02, 0.5, ), loc='center left', borderpad=1, handlelength=1, ) ax.legend(handles, labels, **leg_kwargs) ax = axs[1] box = ax.get_position() ax.set_position([box.x0, box.y0, box.width * 0.8, box.height]) sind = np.argsort(snr_avg_mag.mean(axis=1))[::-1] handles = [lines2[i] for i in sind] handles.append(lines2_med[0]) labels = [cfreqs_legend[i] for i in sind] labels.append('Median') ax.legend(handles, labels, **leg_kwargs) ax = axs[2] box = ax.get_position() ax.set_position([box.x0, box.y0, box.width * 0.8, box.height]) sind = np.argsort(hpi_pow_grad.mean(axis=1))[::-1] handles = [lines3[i] for i in sind] handles.append(lines3_med[0]) labels = [cfreqs_legend[i] for i in sind] labels.append('Median') ax.legend(handles, labels, **leg_kwargs) ax = axs[3] box = ax.get_position() ax.set_position([box.x0, box.y0, box.width * 0.8, box.height]) if show: plt.show() return fig