def test_si_units(): """Test that our scalings actually produce SI units.""" scalings = _handle_default('scalings', None) units = _handle_default('units', None) # Add a bad one to test that we actually detect it assert 'csd_bad' not in scalings scalings['csd_bad'] = 1e5 units['csd_bad'] = 'V/m²' assert set(scalings) == set(units) known_prefixes = { '': 1, 'm': 1e-3, 'c': 1e-2, 'µ': 1e-6, 'n': 1e-9, 'f': 1e-15, } known_SI = {'V', 'T', 'Am', 'm', 'M', 'rad', 'AU', 'GOF'} # not really SI but we tolerate them powers = '²' def _split_si(x): if x == 'nAm': prefix, si = 'n', 'Am' elif x == 'GOF': prefix, si = '', 'GOF' elif x == 'AU': prefix, si = '', 'AU' elif x == 'rad': prefix, si = '', 'rad' elif len(x) == 2: if x[1] in powers: prefix, si = '', x else: prefix, si = x else: assert len(x) in (0, 1), x prefix, si = '', x return prefix, si for key, scale in scalings.items(): unit = units[key] try: num, denom = unit.split('/') except ValueError: # not enough to unpack num, denom = unit, '' # check the numerator and denominator num_prefix, num_SI = _split_si(num) assert num_prefix in known_prefixes assert num_SI in known_SI den_prefix, den_SI = _split_si(denom) assert den_prefix in known_prefixes if not (den_SI == den_prefix == ''): assert den_SI.strip(powers) in known_SI # reconstruct the scale factor want_scale = known_prefixes[den_prefix] / known_prefixes[num_prefix] if key == 'csd_bad': assert not np.isclose(scale, want_scale, rtol=10) else: assert_allclose(scale, want_scale, rtol=1e-12)
def test_consistency(key): """Test defaults consistency.""" units = set(_handle_default('units')) other = set(_handle_default(key)) au_keys = set('stim exci syst resp ias chpi'.split()) assert au_keys.intersection(units) == set() if key in ('color', 'scalings_plot_raw'): assert au_keys.issubset(other) other = other.difference(au_keys) else: assert au_keys.intersection(other) == set() assert units == other, key
def _plot_histogram(params): """Function for plotting histogram of peak-to-peak values.""" import matplotlib.pyplot as plt epochs = params['epochs'] p2p = np.ptp(epochs.get_data(), axis=2) types = list() data = list() if 'eeg' in params['types']: eegs = np.array([p2p.T[i] for i, x in enumerate(params['types']) if x == 'eeg']) data.append(eegs.ravel()) types.append('eeg') if 'mag' in params['types']: mags = np.array([p2p.T[i] for i, x in enumerate(params['types']) if x == 'mag']) data.append(mags.ravel()) types.append('mag') if 'grad' in params['types']: grads = np.array([p2p.T[i] for i, x in enumerate(params['types']) if x == 'grad']) data.append(grads.ravel()) types.append('grad') params['histogram'] = plt.figure() scalings = _handle_default('scalings') units = _handle_default('units') titles = _handle_default('titles') colors = _handle_default('color') for idx in range(len(types)): ax = plt.subplot(len(types), 1, idx + 1) plt.xlabel(units[types[idx]]) plt.ylabel('count') color = colors[types[idx]] rej = None if epochs.reject is not None and types[idx] in epochs.reject.keys(): rej = epochs.reject[types[idx]] * scalings[types[idx]] rng = [0., rej * 1.1] else: rng = None plt.hist(data[idx] * scalings[types[idx]], bins=100, color=color, range=rng) if rej is not None: ax.plot((rej, rej), (0, ax.get_ylim()[1]), color='r') plt.title(titles[types[idx]]) params['histogram'].suptitle('Peak-to-peak histogram', y=0.99) params['histogram'].subplots_adjust(hspace=0.6) try: params['histogram'].show(warn=False) except: pass if params['fig_proj'] is not None: params['fig_proj'].canvas.draw()
def test_handle_default(): """Test mutable default.""" x = deepcopy(_handle_default('scalings')) y = _handle_default('scalings') z = _handle_default('scalings', dict(mag=1, grad=2)) w = _handle_default('scalings', {}) assert set(x.keys()) == set(y.keys()) assert set(x.keys()) == set(z.keys()) for key in x.keys(): assert x[key] == y[key] assert x[key] == w[key] if key in ('mag', 'grad'): assert x[key] != z[key] else: assert x[key] == z[key]
def _put_artifact_range(info, evoked, kind): """Helper to set artifact stats""" ch_scales = _handle_default('scalings') for this_picks, ch_type in get_data_picks(evoked, meg_combined=False): amp_range = (evoked.data[this_picks].max() - evoked.data[this_picks].min()) * ch_scales[ch_type] info.update({'%s_amp_range_%s' % (kind, ch_type): amp_range})
def test_handle_default(): """Test mutable default """ x = deepcopy(_handle_default('scalings')) y = _handle_default('scalings') z = _handle_default('scalings', dict(mag=1, grad=2)) w = _handle_default('scalings', {}) assert_equal(set(x.keys()), set(y.keys())) assert_equal(set(x.keys()), set(z.keys())) for key in x.keys(): assert_equal(x[key], y[key]) assert_equal(x[key], w[key]) if key in ('mag', 'grad'): assert_true(x[key] != z[key]) else: assert_equal(x[key], z[key])
def test_si_units(): """Test that our scalings actually produce SI units.""" scalings = _handle_default('scalings', None) units = _handle_default('units', None) # Add a bad one to test that we actually detect it assert 'csd_bad' not in scalings scalings['csd_bad'] = 1e5 units['csd_bad'] = 'V/m²' assert set(scalings) == set(units) for key, scale in scalings.items(): if key == 'csd_bad': with pytest.raises(KeyError, match='is not a channel type'): want_scale = _get_scaling(key, units[key]) else: want_scale = _get_scaling(key, units[key]) assert_allclose(scale, want_scale, rtol=1e-12)
def epochs_compute_cnv(epochs, tmin=None, tmax=None): """Compute contingent negative variation (CNV) Parameters ---------- epochs : instance of Epochs The input data. tmin : float | None The first time point to include, if None, all samples form the first sample of the epoch will be used. Defaults to None. tmax : float | None The last time point to include, if None, all samples up to the last sample of the epoch wi ll be used. Defaults to None. return_epochs : bool Whether to compute an average or not. If False, data will be averaged and put in an Evoked object. Defaults to False. Returns ------- cnv : ndarray of float (n_channels, n_epochs) | instance of Evoked The regression slopes (betas) represewnting contingent negative variation. """ picks = mne.pick_types(epochs.info, meg=True, eeg=True) n_epochs = len(epochs.events) n_channels = len(picks) # we reduce over time samples slopes = np.zeros((n_epochs, n_channels)) intercepts = np.zeros((n_epochs, n_channels)) if tmax is None: tmax = epochs.times[-1] if tmin is None: tmin = epochs.times[0] fit_range = np.where(_time_mask(epochs.times, tmin, tmax))[0] # design: intercept + increasing time design_matrix = np.c_[np.ones(len(fit_range)), epochs.times[fit_range] - tmin] # estimate single trial regression over time samples scales = np.zeros(n_channels) info_ = pick_info(epochs.info, picks) for this_type, this_picks in _picks_by_type(info_): scales[this_picks] = _handle_default('scalings')[this_type] for ii, epoch in enumerate(epochs): y = epoch[picks][:, fit_range].T # time is samples betas, _, _, _ = linalg.lstsq(a=design_matrix, b=y * scales) intercepts[ii] = betas[0] slopes[ii] = betas[1] return slopes, intercepts
def _plot_topomap(data, pos, vmin=None, vmax=None, cmap=None, sensors=True, res=64, axes=None, names=None, show_names=False, mask=None, mask_params=None, outlines='head', contours=6, image_interp='bilinear', show=True, head_pos=None, onselect=None, extrapolate='box', border=0): import matplotlib.pyplot as plt from matplotlib.widgets import RectangleSelector data = np.asarray(data) if isinstance(pos, Info): # infer pos from Info object picks = _pick_data_channels(pos) # pick only data channels pos = pick_info(pos, picks) # check if there is only 1 channel type, and n_chans matches the data ch_type = {channel_type(pos, idx) for idx, _ in enumerate(pos["chs"])} info_help = ("Pick Info with e.g. mne.pick_info and " "mne.io.pick.channel_indices_by_type.") if len(ch_type) > 1: raise ValueError("Multiple channel types in Info structure. " + info_help) elif len(pos["chs"]) != data.shape[0]: raise ValueError("Number of channels in the Info object and " "the data array does not match. " + info_help) else: ch_type = ch_type.pop() if any(type_ in ch_type for type_ in ('planar', 'grad')): # deal with grad pairs from mne.channels.layout import (_merge_grad_data, find_layout, _pair_grad_sensors) picks, pos = _pair_grad_sensors(pos, find_layout(pos)) data = _merge_grad_data(data[picks]).reshape(-1) else: picks = list(range(data.shape[0])) pos = _find_topomap_coords(pos, picks=picks) if data.ndim > 1: raise ValueError("Data needs to be array of shape (n_sensors,); got " "shape %s." % str(data.shape)) # Give a helpful error message for common mistakes regarding the position # matrix. pos_help = ("Electrode positions should be specified as a 2D array with " "shape (n_channels, 2). Each row in this matrix contains the " "(x, y) position of an electrode.") if pos.ndim != 2: error = ("{ndim}D array supplied as electrode positions, where a 2D " "array was expected").format(ndim=pos.ndim) raise ValueError(error + " " + pos_help) elif pos.shape[1] == 3: error = ("The supplied electrode positions matrix contains 3 columns. " "Are you trying to specify XYZ coordinates? Perhaps the " "mne.channels.create_eeg_layout function is useful for you.") raise ValueError(error + " " + pos_help) # No error is raised in case of pos.shape[1] == 4. In this case, it is # assumed the position matrix contains both (x, y) and (width, height) # values, such as Layout.pos. elif pos.shape[1] == 1 or pos.shape[1] > 4: raise ValueError(pos_help) if len(data) != len(pos): raise ValueError("Data and pos need to be of same length. Got data of " "length %s, pos of length %s" % (len(data), len(pos))) norm = min(data) >= 0 vmin, vmax = _setup_vmin_vmax(data, vmin, vmax, norm) if cmap is None: cmap = 'Reds' if norm else 'RdBu_r' pos, outlines = _check_outlines(pos, outlines, head_pos) assert isinstance(outlines, dict) ax = axes if axes else plt.gca() _prepare_topomap(pos, ax) _use_default_outlines = any(k.startswith('head') for k in outlines) if _use_default_outlines: # prepare masking _autoshrink(outlines, pos, res) mask_params = _handle_default('mask_params', mask_params) # find mask limits xlim = np.inf, -np.inf, ylim = np.inf, -np.inf, mask_ = np.c_[outlines['mask_pos']] xmin, xmax = (np.min(np.r_[xlim[0], mask_[:, 0]]), np.max(np.r_[xlim[1], mask_[:, 0]])) ymin, ymax = (np.min(np.r_[ylim[0], mask_[:, 1]]), np.max(np.r_[ylim[1], mask_[:, 1]])) # interpolate the data, we multiply clip radius by 1.06 so that pixelated # edges of the interpolated image would appear under the mask head_radius = (None if extrapolate == 'local' else outlines['clip_radius'][0] * 1.06) xi = np.linspace(xmin, xmax, res) yi = np.linspace(ymin, ymax, res) Xi, Yi = np.meshgrid(xi, yi) interp = _GridData(pos, extrapolate, head_radius, border).set_values(data) Zi = interp.set_locations(Xi, Yi)() # plot outline patch_ = None if 'patch' in outlines: patch_ = outlines['patch'] patch_ = patch_() if callable(patch_) else patch_ patch_.set_clip_on(False) ax.add_patch(patch_) ax.set_transform(ax.transAxes) ax.set_clip_path(patch_) if _use_default_outlines: from matplotlib import patches patch_ = patches.Ellipse((0, 0), 2 * outlines['clip_radius'][0], 2 * outlines['clip_radius'][1], clip_on=True, transform=ax.transData) # plot interpolated map im = ax.imshow(Zi, cmap=cmap, vmin=vmin, vmax=vmax, origin='lower', aspect='equal', extent=(xmin, xmax, ymin, ymax), interpolation=image_interp) # This tackles an incomprehensible matplotlib bug if no contours are # drawn. To avoid rescalings, we will always draw contours. # But if no contours are desired we only draw one and make it invisible. linewidth = mask_params['markeredgewidth'] no_contours = False if isinstance(contours, (np.ndarray, list)): pass # contours precomputed elif contours == 0: contours, no_contours = 1, True if (Zi == Zi[0, 0]).all(): cont = None # can't make contours for constant-valued functions else: with warnings.catch_warnings(record=True): warnings.simplefilter('ignore') cont = ax.contour(Xi, Yi, Zi, contours, colors='k', linewidths=linewidth / 2.) if no_contours and cont is not None: for col in cont.collections: col.set_visible(False) if patch_ is not None: im.set_clip_path(patch_) if cont is not None: for col in cont.collections: col.set_clip_path(patch_) pos_x, pos_y = pos.T if sensors is not False and mask is None: _plot_sensors(pos_x, pos_y, sensors=sensors, ax=ax) elif sensors and mask is not None: idx = np.where(mask)[0] ax.plot(pos_x[idx], pos_y[idx], **mask_params) idx = np.where(~mask)[0] _plot_sensors(pos_x[idx], pos_y[idx], sensors=sensors, ax=ax) elif not sensors and mask is not None: idx = np.where(mask)[0] ax.plot(pos_x[idx], pos_y[idx], **mask_params) if isinstance(outlines, dict): _draw_outlines(ax, outlines) if show_names: if names is None: raise ValueError("To show names, a list of names must be provided" " (see `names` keyword).") if show_names is True: def _show_names(x): return x else: _show_names = show_names show_idx = np.arange(len(names)) if mask is None else np.where(mask)[0] for ii, (p, ch_id) in enumerate(zip(pos, names)): if ii not in show_idx: continue ch_id = _show_names(ch_id) ax.text(p[0], p[1], ch_id, horizontalalignment='center', verticalalignment='center', size='x-small') if onselect is not None: ax.RS = RectangleSelector(ax, onselect=onselect) plt_show(show) return im, cont, interp, patch_
def _plot_evoked(evoked, plot_type, colorbar=True, hline=None, ylim=None, picks=None, exclude='bads', unit=True, show=True, clim=None, proj=False, xlim='tight', units=None, scalings=None, titles=None, axes=None, cmap='RdBu_r'): """Aux function for plot_evoked and plot_evoked_image (cf. docstrings) Extra param is: plot_type : str, value ('butterfly' | 'image') The type of graph to plot: 'butterfly' plots each channel as a line (x axis: time, y axis: amplitude). 'image' plots a 2D image where color depicts the amplitude of each channel at a given time point (x axis: time, y axis: channel). In 'image' mode, the plot is not interactive. """ import matplotlib.pyplot as plt if axes is not None and proj == 'interactive': raise RuntimeError('Currently only single axis figures are supported' ' for interactive SSP selection.') scalings = _handle_default('scalings', scalings) titles = _handle_default('titles', titles) units = _handle_default('units', units) channel_types = set(key for d in [scalings, titles, units] for key in d) channel_types = sorted(channel_types) # to guarantee consistent order if picks is None: picks = list(range(evoked.info['nchan'])) bad_ch_idx = [evoked.ch_names.index(ch) for ch in evoked.info['bads'] if ch in evoked.ch_names] if len(exclude) > 0: if isinstance(exclude, string_types) and exclude == 'bads': exclude = bad_ch_idx elif (isinstance(exclude, list) and all([isinstance(ch, string_types) for ch in exclude])): exclude = [evoked.ch_names.index(ch) for ch in exclude] else: raise ValueError('exclude has to be a list of channel names or ' '"bads"') picks = list(set(picks).difference(exclude)) types = [channel_type(evoked.info, idx) for idx in picks] n_channel_types = 0 ch_types_used = [] for t in channel_types: if t in types: n_channel_types += 1 ch_types_used.append(t) axes_init = axes # remember if axes where given as input fig = None if axes is None: fig, axes = plt.subplots(n_channel_types, 1) if isinstance(axes, plt.Axes): axes = [axes] elif isinstance(axes, np.ndarray): axes = list(axes) if axes_init is not None: fig = axes[0].get_figure() if not len(axes) == n_channel_types: raise ValueError('Number of axes (%g) must match number of channel ' 'types (%g)' % (len(axes), n_channel_types)) # instead of projecting during each iteration let's use the mixin here. if proj is True and evoked.proj is not True: evoked = evoked.copy() evoked.apply_proj() times = 1e3 * evoked.times # time in miliseconds for ax, t in zip(axes, ch_types_used): ch_unit = units[t] this_scaling = scalings[t] if unit is False: this_scaling = 1.0 ch_unit = 'NA' # no unit idx = [picks[i] for i in range(len(picks)) if types[i] == t] if len(idx) > 0: # Parameters for butterfly interactive plots if plot_type == 'butterfly': if any([i in bad_ch_idx for i in idx]): colors = ['k'] * len(idx) for i in bad_ch_idx: if i in idx: colors[idx.index(i)] = 'r' ax._get_lines.color_cycle = iter(colors) else: ax._get_lines.color_cycle = cycle(['k']) # Set amplitude scaling D = this_scaling * evoked.data[idx, :] # plt.axes(ax) if plot_type == 'butterfly': ax.plot(times, D.T) elif plot_type == 'image': im = ax.imshow(D, interpolation='nearest', origin='lower', extent=[times[0], times[-1], 0, D.shape[0]], aspect='auto', cmap=cmap) if colorbar: cbar = plt.colorbar(im, ax=ax) cbar.ax.set_title(ch_unit) elif plot_type == 'mean' : # ax.plot(times, D.mean(axis=0)) ax.plot(times, np.abs(D).mean(axis=0)) if xlim is not None: if xlim == 'tight': xlim = (times[0], times[-1]) ax.set_xlim(xlim) if ylim is not None and t in ylim: if plot_type == 'butterfly' or plot_type == 'mean': ax.set_ylim(ylim[t]) elif plot_type == 'image': im.set_clim(ylim[t]) ax.set_title(titles[t] + ' (%d channel%s)' % ( len(D), 's' if len(D) > 1 else '')) ax.set_xlabel('time (ms)') if plot_type == 'butterfly' or plot_type == 'mean': ax.set_ylabel('data (%s)' % ch_unit) elif plot_type == 'image': ax.set_ylabel('channels (%s)' % 'index') else: raise ValueError("plot_type has to be 'butterfly' or 'image'." "Got %s." % plot_type) if (plot_type == 'butterfly' or plot_type == 'mean') and (hline is not None): for h in hline: ax.axhline(h, color='r', linestyle='--', linewidth=2) if axes_init is None: plt.subplots_adjust(0.175, 0.08, 0.94, 0.94, 0.2, 0.63) # if proj == 'interactive': # _check_delayed_ssp(evoked) # params = dict(evoked=evoked, fig=fig, projs=evoked.info['projs'], # axes=axes, types=types, units=units, scalings=scalings, # unit=unit, ch_types_used=ch_types_used, picks=picks, # plot_update_proj_callback=_plot_update_evoked, # plot_type=plot_type) # _draw_proj_checkbox(None, params) if show and plt.get_backend() != 'agg': plt.show() fig.canvas.draw() # for axes plots update axes. tight_layout(fig=fig) return fig
def plot(self, fmin=0, fmax=None, proj=False, picks=None, ax=None, color='black', xscale='linear', area_mode='std', area_alpha=0.33, dB=True, estimate='auto', show=True, n_jobs=1, average=False, line_alpha=None, spatial_colors=True, verbose=None, sphere=None): from mne.viz.utils import _plot_psd, plt_show # set up default vars from packaging import version mne_version = version.parse(mne.__version__) has_new_mne = mne_version >= version.parse('0.22.0') has_20_mne = (mne_version >= version.parse('0.20.0') and mne_version < version.parse('0.22.0')) if has_new_mne: from mne.defaults import _handle_default from mne.io.pick import _picks_to_idx from mne.viz._figure import _split_picks_by_type if ax is None: import matplotlib.pyplot as plt fig, ax = plt.subplots() else: fig = ax.figure ax_list = [ax] units = _handle_default('units', None) picks = _picks_to_idx(self.info, picks) titles = _handle_default('titles', None) scalings = _handle_default('scalings', None) make_label = len(ax_list) == len(fig.axes) xlabels_list = [False] * (len(ax_list) - 1) + [True] (picks_list, units_list, scalings_list, titles_list) = _split_picks_by_type(self, picks, units, scalings, titles) elif has_20_mne: from mne.viz.utils import _set_psd_plot_params fig, picks_list, titles_list, units_list, scalings_list, \ ax_list, make_label, xlabels_list = _set_psd_plot_params( self.info, proj, picks, ax, area_mode) else: from mne.viz.utils import _set_psd_plot_params fig, picks_list, titles_list, units_list, scalings_list, ax_list, \ make_label = _set_psd_plot_params(self.info, proj, picks, ax, area_mode) del ax crop_inst = not (fmin == 0 and fmax is None) fmax = self.freqs[-1] if fmax is None else fmax inst = self.copy() if crop_inst: inst.crop(fmin=fmin, fmax=fmax) inst.average() # create list of psd's (one element for each channel type) psd_list = list() for picks in picks_list: psd_list.append(inst.data[picks]) args = [ inst, fig, inst.freqs, psd_list, picks_list, titles_list, units_list, scalings_list, ax_list, make_label, color, area_mode, area_alpha, dB, estimate, average, spatial_colors, xscale, line_alpha ] if has_20_mne or has_new_mne: args += [sphere, xlabels_list] fig = _plot_psd(*args) plt_show(show) return fig
# create epochs objects ep_low = EpochsArray(X[y == LOW_CONF_EPOCH, ...], info, tmin=-1) ep_high = EpochsArray(X[y == HIGH_CONF_EPOCH, ...], info, tmin=-1) # psds_high, freqs = psd_multitaper(ep_high, tmin=0, tmax=1, fmax=50) # psds_low, freqs = psd_multitaper(ep_low, tmin=0, tmax=1, fmax=50) psds_high, freqs = psd_welch(ep_high, tmin=0.3, tmax=0.9, fmax=50) psds_low, freqs = psd_welch(ep_low, tmin=0.3, tmax=0.9, fmax=50) # normalize # psds_high /= psds_high.mean(axis=2, keepdims=True) # psds_low /= psds_low.mean(axis=2, keepdims=True) psds = (psds_high.mean(axis=0) - psds_low.mean(axis=0)) / psds_low.mean(axis=0) ch_type = _get_ch_type(ep_high, None) units = _handle_default("units", None) unit = units[ch_type] ( picks, pos, merge_channels, names, ch_type, sphere, clip_origin, ) = _prepare_topomap_plot(ep_low, ch_type, sphere=None) outlines = _make_head_outlines(sphere, pos, "head", clip_origin) if merge_channels: psds_merge, names = _merge_ch_data(psds, ch_type, names, method="mean")
def plot_topomap(data, pos, vmin=None, vmax=None, cmap=None, sensors=True, res=64, axes=None, names=None, show_names=False, mask=None, mask_params=None, outlines='head', image_mask=None, contours=6, image_interp='bilinear', show=True, head_pos=None, onselect=None, axis=None): ''' see the docstring for mne.viz.plot_topomap, which i've simply modified to return more objects ''' from matplotlib.widgets import RectangleSelector from mne.io.pick import (channel_type, pick_info, _pick_data_channels) from mne.utils import warn from mne.viz.utils import (_setup_vmin_vmax, plt_show) from mne.defaults import _handle_default from mne.channels.layout import _find_topomap_coords from mne.io.meas_info import Info from mne.viz.topomap import _check_outlines, _prepare_topomap, _griddata, _make_image_mask, _plot_sensors, \ _draw_outlines data = np.asarray(data) if isinstance(pos, Info): # infer pos from Info object picks = _pick_data_channels(pos) # pick only data channels pos = pick_info(pos, picks) # check if there is only 1 channel type, and n_chans matches the data ch_type = set(channel_type(pos, idx) for idx, _ in enumerate(pos["chs"])) info_help = ("Pick Info with e.g. mne.pick_info and " "mne.channels.channel_indices_by_type.") if len(ch_type) > 1: raise ValueError("Multiple channel types in Info structure. " + info_help) elif len(pos["chs"]) != data.shape[0]: raise ValueError("Number of channels in the Info object and " "the data array does not match. " + info_help) else: ch_type = ch_type.pop() if any(type_ in ch_type for type_ in ('planar', 'grad')): # deal with grad pairs from ..channels.layout import (_merge_grad_data, find_layout, _pair_grad_sensors) picks, pos = _pair_grad_sensors(pos, find_layout(pos)) data = _merge_grad_data(data[picks]).reshape(-1) else: picks = list(range(data.shape[0])) pos = _find_topomap_coords(pos, picks=picks) if data.ndim > 1: raise ValueError("Data needs to be array of shape (n_sensors,); got " "shape %s." % str(data.shape)) # Give a helpful error message for common mistakes regarding the position # matrix. pos_help = ("Electrode positions should be specified as a 2D array with " "shape (n_channels, 2). Each row in this matrix contains the " "(x, y) position of an electrode.") if pos.ndim != 2: error = ("{ndim}D array supplied as electrode positions, where a 2D " "array was expected").format(ndim=pos.ndim) raise ValueError(error + " " + pos_help) elif pos.shape[1] == 3: error = ("The supplied electrode positions matrix contains 3 columns. " "Are you trying to specify XYZ coordinates? Perhaps the " "mne.channels.create_eeg_layout function is useful for you.") raise ValueError(error + " " + pos_help) # No error is raised in case of pos.shape[1] == 4. In this case, it is # assumed the position matrix contains both (x, y) and (width, height) # values, such as Layout.pos. elif pos.shape[1] == 1 or pos.shape[1] > 4: raise ValueError(pos_help) if len(data) != len(pos): raise ValueError("Data and pos need to be of same length. Got data of " "length %s, pos of length %s" % (len(data), len(pos))) norm = min(data) >= 0 vmin, vmax = _setup_vmin_vmax(data, vmin, vmax, norm) if cmap is None: cmap = 'Reds' if norm else 'RdBu_r' pos, outlines = _check_outlines(pos, outlines, head_pos) if axis is not None: axes = axis warn('axis parameter is deprecated and will be removed in 0.13. ' 'Use axes instead.', DeprecationWarning) ax = axes if axes else plt.gca() pos_x, pos_y = _prepare_topomap(pos, ax) if outlines is None: xmin, xmax = pos_x.min(), pos_x.max() ymin, ymax = pos_y.min(), pos_y.max() else: xlim = np.inf, -np.inf, ylim = np.inf, -np.inf, mask_ = np.c_[outlines['mask_pos']] xmin, xmax = (np.min(np.r_[xlim[0], mask_[:, 0]]), np.max(np.r_[xlim[1], mask_[:, 0]])) ymin, ymax = (np.min(np.r_[ylim[0], mask_[:, 1]]), np.max(np.r_[ylim[1], mask_[:, 1]])) # interpolate data xi = np.linspace(xmin, xmax, res) yi = np.linspace(ymin, ymax, res) Xi, Yi = np.meshgrid(xi, yi) Zi = _griddata(pos_x, pos_y, data, Xi, Yi) if outlines is None: _is_default_outlines = False elif isinstance(outlines, dict): _is_default_outlines = any(k.startswith('head') for k in outlines) if _is_default_outlines and image_mask is None: # prepare masking image_mask, pos = _make_image_mask(outlines, pos, res) mask_params = _handle_default('mask_params', mask_params) # plot outline linewidth = mask_params['markeredgewidth'] patch = None if 'patch' in outlines: patch = outlines['patch'] patch_ = patch() if callable(patch) else patch patch_.set_clip_on(False) ax.add_patch(patch_) ax.set_transform(ax.transAxes) ax.set_clip_path(patch_) # plot map and countour im = ax.imshow(Zi, cmap=cmap, vmin=vmin, vmax=vmax, origin='lower', aspect='equal', extent=(xmin, xmax, ymin, ymax), interpolation=image_interp) # This tackles an incomprehensible matplotlib bug if no contours are # drawn. To avoid rescalings, we will always draw contours. # But if no contours are desired we only draw one and make it invisible . no_contours = False if contours in (False, None): contours, no_contours = 1, True cont = ax.contour(Xi, Yi, Zi, contours, colors='k', linewidths=linewidth) if no_contours is True: for col in cont.collections: col.set_visible(False) if _is_default_outlines: from matplotlib import patches patch_ = patches.Ellipse((0, 0), 2 * outlines['clip_radius'][0], 2 * outlines['clip_radius'][1], clip_on=True, transform=ax.transData) if _is_default_outlines or patch is not None: im.set_clip_path(patch_) if cont is not None: for col in cont.collections: col.set_clip_path(patch_) if sensors is not False and mask is None: _plot_sensors(pos_x, pos_y, sensors=sensors, ax=ax) elif sensors and mask is not None: idx = np.where(mask)[0] ax.plot(pos_x[idx], pos_y[idx], **mask_params) idx = np.where(~mask)[0] _plot_sensors(pos_x[idx], pos_y[idx], sensors=sensors, ax=ax) elif not sensors and mask is not None: idx = np.where(mask)[0] ax.plot(pos_x[idx], pos_y[idx], **mask_params) if isinstance(outlines, dict): _draw_outlines(ax, outlines) if show_names: if names is None: raise ValueError("To show names, a list of names must be provided" " (see `names` keyword).") if show_names is True: def _show_names(x): return x else: _show_names = show_names show_idx = np.arange(len(names)) if mask is None else np.where(mask)[0] for ii, (p, ch_id) in enumerate(zip(pos, names)): if ii not in show_idx: continue ch_id = _show_names(ch_id) ax.text(p[0], p[1], ch_id, horizontalalignment='center', verticalalignment='center', size='x-small') plt.subplots_adjust(top=.95) if onselect is not None: ax.RS = RectangleSelector(ax, onselect=onselect) plt_show(show) return ax, im, cont, pos_x, pos_y
def _prepare_mne_browse_epochs(params, projs, n_channels, n_epochs, scalings, title, picks, order=None): """Helper for setting up the mne_browse_epochs window.""" import matplotlib.pyplot as plt import matplotlib as mpl from matplotlib.collections import LineCollection from matplotlib.colors import colorConverter epochs = params['epochs'] if picks is None: picks = _handle_picks(epochs) if len(picks) < 1: raise RuntimeError('No appropriate channels found. Please' ' check your picks') picks = sorted(picks) # Reorganize channels inds = list() types = list() for t in ['grad', 'mag']: idxs = pick_types(params['info'], meg=t, ref_meg=False, exclude=[]) if len(idxs) < 1: continue mask = np.in1d(idxs, picks, assume_unique=True) inds.append(idxs[mask]) types += [t] * len(inds[-1]) pick_kwargs = dict(meg=False, ref_meg=False, exclude=[]) if order is None: order = ['eeg', 'seeg', 'ecog', 'eog', 'ecg', 'emg', 'ref_meg', 'stim', 'resp', 'misc', 'chpi', 'syst', 'ias', 'exci'] for ch_type in order: pick_kwargs[ch_type] = True idxs = pick_types(params['info'], **pick_kwargs) if len(idxs) < 1: continue mask = np.in1d(idxs, picks, assume_unique=True) inds.append(idxs[mask]) types += [ch_type] * len(inds[-1]) pick_kwargs[ch_type] = False inds = np.concatenate(inds).astype(int) if not len(inds) == len(picks): raise RuntimeError('Some channels not classified. Please' ' check your picks') ch_names = [params['info']['ch_names'][x] for x in inds] # set up plotting size = get_config('MNE_BROWSE_RAW_SIZE') n_epochs = min(n_epochs, len(epochs.events)) duration = len(epochs.times) * n_epochs n_channels = min(n_channels, len(picks)) if size is not None: size = size.split(',') size = tuple(float(s) for s in size) if title is None: title = epochs.name if epochs.name is None or len(title) == 0: title = '' fig = figure_nobar(facecolor='w', figsize=size, dpi=80) fig.canvas.set_window_title('mne_browse_epochs') ax = plt.subplot2grid((10, 15), (0, 1), colspan=13, rowspan=9) ax.annotate(title, xy=(0.5, 1), xytext=(0, ax.get_ylim()[1] + 15), ha='center', va='bottom', size=12, xycoords='axes fraction', textcoords='offset points') color = _handle_default('color', None) ax.axis([0, duration, 0, 200]) ax2 = ax.twiny() ax2.set_zorder(-1) ax2.axis([0, duration, 0, 200]) ax_hscroll = plt.subplot2grid((10, 15), (9, 1), colspan=13) ax_hscroll.get_yaxis().set_visible(False) ax_hscroll.set_xlabel('Epochs') ax_vscroll = plt.subplot2grid((10, 15), (0, 14), rowspan=9) ax_vscroll.set_axis_off() ax_vscroll.add_patch(mpl.patches.Rectangle((0, 0), 1, len(picks), facecolor='w', zorder=3)) ax_help_button = plt.subplot2grid((10, 15), (9, 0), colspan=1) help_button = mpl.widgets.Button(ax_help_button, 'Help') help_button.on_clicked(partial(_onclick_help, params=params)) # populate vertical and horizontal scrollbars for ci in range(len(picks)): if ch_names[ci] in params['info']['bads']: this_color = params['bad_color'] else: this_color = color[types[ci]] ax_vscroll.add_patch(mpl.patches.Rectangle((0, ci), 1, 1, facecolor=this_color, edgecolor=this_color, zorder=4)) vsel_patch = mpl.patches.Rectangle((0, 0), 1, n_channels, alpha=0.5, edgecolor='w', facecolor='w', zorder=5) ax_vscroll.add_patch(vsel_patch) ax_vscroll.set_ylim(len(types), 0) ax_vscroll.set_title('Ch.') # populate colors list type_colors = [colorConverter.to_rgba(color[c]) for c in types] colors = list() for color_idx in range(len(type_colors)): colors.append([type_colors[color_idx]] * len(epochs.events)) lines = list() n_times = len(epochs.times) for ch_idx in range(n_channels): if len(colors) - 1 < ch_idx: break lc = LineCollection(list(), antialiased=False, linewidths=0.5, zorder=3, picker=3.) ax.add_collection(lc) lines.append(lc) times = epochs.times data = np.zeros((params['info']['nchan'], len(times) * n_epochs)) ylim = (25., 0.) # Hardcoded 25 because butterfly has max 5 rows (5*5=25). # make shells for plotting traces offset = ylim[0] / n_channels offsets = np.arange(n_channels) * offset + (offset / 2.) times = np.arange(len(times) * len(epochs.events)) epoch_times = np.arange(0, len(times), n_times) ax.set_yticks(offsets) ax.set_ylim(ylim) ticks = epoch_times + 0.5 * n_times ax.set_xticks(ticks) ax2.set_xticks(ticks[:n_epochs]) labels = list(range(1, len(ticks) + 1)) # epoch numbers ax.set_xticklabels(labels) ax2.set_xticklabels(labels) xlim = epoch_times[-1] + len(epochs.times) ax_hscroll.set_xlim(0, xlim) vertline_t = ax_hscroll.text(0, 1, '', color='y', va='bottom', ha='right') # fit horizontal scroll bar ticks hscroll_ticks = np.arange(0, xlim, xlim / 7.0) hscroll_ticks = np.append(hscroll_ticks, epoch_times[-1]) hticks = list() for tick in hscroll_ticks: hticks.append(epoch_times.flat[np.abs(epoch_times - tick).argmin()]) hlabels = [x / n_times + 1 for x in hticks] ax_hscroll.set_xticks(hticks) ax_hscroll.set_xticklabels(hlabels) for epoch_idx in range(len(epoch_times)): ax_hscroll.add_patch(mpl.patches.Rectangle((epoch_idx * n_times, 0), n_times, 1, facecolor=(0.8, 0.8, 0.8), edgecolor=(0.8, 0.8, 0.8), alpha=0.5)) hsel_patch = mpl.patches.Rectangle((0, 0), duration, 1, edgecolor='k', facecolor=(0.5, 0.5, 0.5), alpha=0.25, linewidth=1, clip_on=False) ax_hscroll.add_patch(hsel_patch) text = ax.text(0, 0, 'blank', zorder=3, verticalalignment='baseline', ha='left', fontweight='bold') text.set_visible(False) params.update({'fig': fig, 'ax': ax, 'ax2': ax2, 'ax_hscroll': ax_hscroll, 'ax_vscroll': ax_vscroll, 'vsel_patch': vsel_patch, 'hsel_patch': hsel_patch, 'lines': lines, 'projs': projs, 'ch_names': ch_names, 'n_channels': n_channels, 'n_epochs': n_epochs, 'scalings': scalings, 'duration': duration, 'ch_start': 0, 'colors': colors, 'def_colors': type_colors, # don't change at runtime 'picks': picks, 'data': data, 'times': times, 'epoch_times': epoch_times, 'offsets': offsets, 'labels': labels, 'scale_factor': 1.0, 'butterfly_scale': 1.0, 'fig_proj': None, 'types': np.array(types), 'inds': inds, 'vert_lines': list(), 'vertline_t': vertline_t, 'butterfly': False, 'text': text, 'ax_help_button': ax_help_button, # needed for positioning 'help_button': help_button, # reference needed for clicks 'fig_options': None, 'settings': [True, True, True, True], 'image_plot': None}) params['plot_fun'] = partial(_plot_traces, params=params) # callbacks callback_scroll = partial(_plot_onscroll, params=params) fig.canvas.mpl_connect('scroll_event', callback_scroll) callback_click = partial(_mouse_click, params=params) fig.canvas.mpl_connect('button_press_event', callback_click) callback_key = partial(_plot_onkey, params=params) fig.canvas.mpl_connect('key_press_event', callback_key) callback_resize = partial(_resize_event, params=params) fig.canvas.mpl_connect('resize_event', callback_resize) fig.canvas.mpl_connect('pick_event', partial(_onpick, params=params)) params['callback_key'] = callback_key # Draw event lines for the first time. _plot_vert_lines(params) # Plot bad epochs for epoch_idx in params['bads']: params['ax_hscroll'].patches[epoch_idx].set_color((1., 0., 0., 1.)) params['ax_hscroll'].patches[epoch_idx].set_zorder(3) params['ax_hscroll'].patches[epoch_idx].set_edgecolor('w') for ch_idx in range(len(params['ch_names'])): params['colors'][ch_idx][epoch_idx] = (1., 0., 0., 1.) assert params['fix_log'].shape == (len(epochs.events), len(params['ch_names'])) # Plot bad segments if params['fix_log'] is not None: for ch_idx in range(len(params['ch_names'])): for epoch_idx in range(len(epochs.events)): this_log = params['fix_log'][epoch_idx, ch_idx] if epoch_idx in params['bads']: pass else: if this_log == 1: params['colors'][ch_idx][epoch_idx] = (1., 0., 0., 1.) elif this_log == 2: params['colors'][ch_idx][epoch_idx] = (0., 0., 1., 1.) params['plot_fun']()
def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20, n_channels=20, title=None, show=True, block=False, bad_epochs_idx=None, fix_log=None): """ Visualize epochs Bad epochs can be marked with a left click on top of the epoch. Bad channels can be selected by clicking the channel name on the left side of the main axes. Calling this function drops all the selected bad epochs as well as bad epochs marked beforehand with rejection parameters. Parameters ---------- epochs : instance of Epochs The epochs object picks : array-like of int | None Channels to be included. If None only good data channels are used. Defaults to None scalings : dict | 'auto' | None Scaling factors for the traces. If any fields in scalings are 'auto', the scaling factor is set to match the 99.5th percentile of a subset of the corresponding data. If scalings == 'auto', all scalings fields are set to 'auto'. If any fields are 'auto' and data is not preloaded, a subset of epochs up to 100mb will be loaded. If None, defaults to:: dict(mag=1e-12, grad=4e-11, eeg=20e-6, eog=150e-6, ecg=5e-4, emg=1e-3, ref_meg=1e-12, misc=1e-3, stim=1, resp=1, chpi=1e-4) n_epochs : int The number of epochs per view. Defaults to 20. n_channels : int The number of channels per view. Defaults to 20. title : str | None The title of the window. If None, epochs name will be displayed. Defaults to None. show : bool Show figure if True. Defaults to True block : bool Whether to halt program execution until the figure is closed. Useful for rejecting bad trials on the fly by clicking on an epoch. Defaults to False. bad_epochs_idx : array-like | None Indices of bad epochs to show. No bad epochs to visualize if None. fix_log : array, shape (n_channels, n_epochs) | None The bad segments to show in red and the interpolated segments to show in green. Returns ------- fig : Instance of matplotlib.figure.Figure The figure. Notes ----- The arrow keys (up/down/left/right) can be used to navigate between channels and epochs and the scaling can be adjusted with - and + (or =) keys, but this depends on the backend matplotlib is configured to use (e.g., mpl.use(``TkAgg``) should work). Full screen mode can be toggled with f11 key. The amount of epochs and channels per view can be adjusted with home/end and page down/page up keys. Butterfly plot can be toggled with ``b`` key. Right mouse click adds a vertical line to the plot. """ epochs.drop_bad() scalings = _compute_scalings(scalings, epochs) scalings = _handle_default('scalings_plot_raw', scalings) projs = epochs.info['projs'] bads = np.array(list(), dtype=int) if bad_epochs_idx is not None: bads = np.array(bad_epochs_idx).astype(int) params = {'epochs': epochs, 'info': copy.deepcopy(epochs.info), 'bad_color': (0.8, 0.8, 0.8), 't_start': 0, 'histogram': None, 'bads': bads, 'fix_log': fix_log} params['label_click_fun'] = partial(_pick_bad_channels, params=params) _prepare_mne_browse_epochs(params, projs, n_channels, n_epochs, scalings, title, picks) _prepare_projectors(params) _layout_figure(params) callback_close = partial(_close_event, params=params) params['fig'].canvas.mpl_connect('close_event', callback_close) try: plt_show(show, block=block) except TypeError: # not all versions have this plt_show(show) return params['fig']