Ejemplo n.º 1
0
def test_si_units():
    """Test that our scalings actually produce SI units."""
    scalings = _handle_default('scalings', None)
    units = _handle_default('units', None)
    # Add a bad one to test that we actually detect it
    assert 'csd_bad' not in scalings
    scalings['csd_bad'] = 1e5
    units['csd_bad'] = 'V/m²'
    assert set(scalings) == set(units)
    known_prefixes = {
        '': 1,
        'm': 1e-3,
        'c': 1e-2,
        'µ': 1e-6,
        'n': 1e-9,
        'f': 1e-15,
    }
    known_SI = {'V', 'T', 'Am', 'm', 'M', 'rad', 'AU',
                'GOF'}  # not really SI but we tolerate them
    powers = '²'

    def _split_si(x):
        if x == 'nAm':
            prefix, si = 'n', 'Am'
        elif x == 'GOF':
            prefix, si = '', 'GOF'
        elif x == 'AU':
            prefix, si = '', 'AU'
        elif x == 'rad':
            prefix, si = '', 'rad'
        elif len(x) == 2:
            if x[1] in powers:
                prefix, si = '', x
            else:
                prefix, si = x
        else:
            assert len(x) in (0, 1), x
            prefix, si = '', x
        return prefix, si

    for key, scale in scalings.items():
        unit = units[key]
        try:
            num, denom = unit.split('/')
        except ValueError:  # not enough to unpack
            num, denom = unit, ''
        # check the numerator and denominator
        num_prefix, num_SI = _split_si(num)
        assert num_prefix in known_prefixes
        assert num_SI in known_SI
        den_prefix, den_SI = _split_si(denom)
        assert den_prefix in known_prefixes
        if not (den_SI == den_prefix == ''):
            assert den_SI.strip(powers) in known_SI
        # reconstruct the scale factor
        want_scale = known_prefixes[den_prefix] / known_prefixes[num_prefix]
        if key == 'csd_bad':
            assert not np.isclose(scale, want_scale, rtol=10)
        else:
            assert_allclose(scale, want_scale, rtol=1e-12)
Ejemplo n.º 2
0
def test_consistency(key):
    """Test defaults consistency."""
    units = set(_handle_default('units'))
    other = set(_handle_default(key))
    au_keys = set('stim exci syst resp ias chpi'.split())
    assert au_keys.intersection(units) == set()
    if key in ('color', 'scalings_plot_raw'):
        assert au_keys.issubset(other)
        other = other.difference(au_keys)
    else:
        assert au_keys.intersection(other) == set()
    assert units == other, key
Ejemplo n.º 3
0
def _plot_histogram(params):
    """Function for plotting histogram of peak-to-peak values."""
    import matplotlib.pyplot as plt
    epochs = params['epochs']
    p2p = np.ptp(epochs.get_data(), axis=2)
    types = list()
    data = list()
    if 'eeg' in params['types']:
        eegs = np.array([p2p.T[i] for i,
                         x in enumerate(params['types']) if x == 'eeg'])
        data.append(eegs.ravel())
        types.append('eeg')
    if 'mag' in params['types']:
        mags = np.array([p2p.T[i] for i,
                         x in enumerate(params['types']) if x == 'mag'])
        data.append(mags.ravel())
        types.append('mag')
    if 'grad' in params['types']:
        grads = np.array([p2p.T[i] for i,
                          x in enumerate(params['types']) if x == 'grad'])
        data.append(grads.ravel())
        types.append('grad')
    params['histogram'] = plt.figure()
    scalings = _handle_default('scalings')
    units = _handle_default('units')
    titles = _handle_default('titles')
    colors = _handle_default('color')
    for idx in range(len(types)):
        ax = plt.subplot(len(types), 1, idx + 1)
        plt.xlabel(units[types[idx]])
        plt.ylabel('count')
        color = colors[types[idx]]
        rej = None
        if epochs.reject is not None and types[idx] in epochs.reject.keys():
                rej = epochs.reject[types[idx]] * scalings[types[idx]]
                rng = [0., rej * 1.1]
        else:
            rng = None
        plt.hist(data[idx] * scalings[types[idx]], bins=100, color=color,
                 range=rng)
        if rej is not None:
            ax.plot((rej, rej), (0, ax.get_ylim()[1]), color='r')
        plt.title(titles[types[idx]])
    params['histogram'].suptitle('Peak-to-peak histogram', y=0.99)
    params['histogram'].subplots_adjust(hspace=0.6)
    try:
        params['histogram'].show(warn=False)
    except:
        pass
    if params['fig_proj'] is not None:
        params['fig_proj'].canvas.draw()
Ejemplo n.º 4
0
def test_handle_default():
    """Test mutable default."""
    x = deepcopy(_handle_default('scalings'))
    y = _handle_default('scalings')
    z = _handle_default('scalings', dict(mag=1, grad=2))
    w = _handle_default('scalings', {})
    assert set(x.keys()) == set(y.keys())
    assert set(x.keys()) == set(z.keys())
    for key in x.keys():
        assert x[key] == y[key]
        assert x[key] == w[key]
        if key in ('mag', 'grad'):
            assert x[key] != z[key]
        else:
            assert x[key] == z[key]
Ejemplo n.º 5
0
def test_handle_default():
    """Test mutable default."""
    x = deepcopy(_handle_default('scalings'))
    y = _handle_default('scalings')
    z = _handle_default('scalings', dict(mag=1, grad=2))
    w = _handle_default('scalings', {})
    assert set(x.keys()) == set(y.keys())
    assert set(x.keys()) == set(z.keys())
    for key in x.keys():
        assert x[key] == y[key]
        assert x[key] == w[key]
        if key in ('mag', 'grad'):
            assert x[key] != z[key]
        else:
            assert x[key] == z[key]
Ejemplo n.º 6
0
def _put_artifact_range(info, evoked, kind):
    """Helper to set artifact stats"""
    ch_scales = _handle_default('scalings')
    for this_picks, ch_type in get_data_picks(evoked, meg_combined=False):
        amp_range = (evoked.data[this_picks].max() -
                     evoked.data[this_picks].min()) * ch_scales[ch_type]
        info.update({'%s_amp_range_%s' % (kind, ch_type): amp_range})
Ejemplo n.º 7
0
def test_handle_default():
    """Test mutable default
    """
    x = deepcopy(_handle_default('scalings'))
    y = _handle_default('scalings')
    z = _handle_default('scalings', dict(mag=1, grad=2))
    w = _handle_default('scalings', {})
    assert_equal(set(x.keys()), set(y.keys()))
    assert_equal(set(x.keys()), set(z.keys()))
    for key in x.keys():
        assert_equal(x[key], y[key])
        assert_equal(x[key], w[key])
        if key in ('mag', 'grad'):
            assert_true(x[key] != z[key])
        else:
            assert_equal(x[key], z[key])
def _put_artifact_range(info, evoked, kind):
    """Helper to set artifact stats"""
    ch_scales = _handle_default('scalings')
    for this_picks, ch_type in get_data_picks(evoked, meg_combined=False):
        amp_range = (evoked.data[this_picks].max() -
                     evoked.data[this_picks].min()) * ch_scales[ch_type]
        info.update({'%s_amp_range_%s' % (kind, ch_type): amp_range})
Ejemplo n.º 9
0
def test_si_units():
    """Test that our scalings actually produce SI units."""
    scalings = _handle_default('scalings', None)
    units = _handle_default('units', None)
    # Add a bad one to test that we actually detect it
    assert 'csd_bad' not in scalings
    scalings['csd_bad'] = 1e5
    units['csd_bad'] = 'V/m²'
    assert set(scalings) == set(units)

    for key, scale in scalings.items():
        if key == 'csd_bad':
            with pytest.raises(KeyError, match='is not a channel type'):
                want_scale = _get_scaling(key, units[key])
        else:
            want_scale = _get_scaling(key, units[key])
            assert_allclose(scale, want_scale, rtol=1e-12)
Ejemplo n.º 10
0
def epochs_compute_cnv(epochs, tmin=None, tmax=None):
    """Compute contingent negative variation (CNV)

    Parameters
    ----------
    epochs : instance of Epochs
        The input data.
    tmin : float | None
        The first time point to include, if None, all samples form the first
        sample of the epoch will be used. Defaults to None.
    tmax : float | None
        The last time point to include, if None, all samples up to the last
        sample of the epoch wi  ll be used. Defaults to None.
    return_epochs : bool
        Whether to compute an average or not. If False, data will be
        averaged and put in an Evoked object. Defaults to False.

    Returns
    -------
    cnv : ndarray of float (n_channels, n_epochs) | instance of Evoked
        The regression slopes (betas) represewnting contingent negative
        variation.
    """
    picks = mne.pick_types(epochs.info, meg=True, eeg=True)
    n_epochs = len(epochs.events)
    n_channels = len(picks)
    # we reduce over time samples
    slopes = np.zeros((n_epochs, n_channels))
    intercepts = np.zeros((n_epochs, n_channels))
    if tmax is None:
        tmax = epochs.times[-1]
    if tmin is None:
        tmin = epochs.times[0]

    fit_range = np.where(_time_mask(epochs.times, tmin, tmax))[0]

    # design: intercept + increasing time
    design_matrix = np.c_[np.ones(len(fit_range)),
                          epochs.times[fit_range] - tmin]

    # estimate single trial regression over time samples
    scales = np.zeros(n_channels)
    info_ = pick_info(epochs.info, picks)
    for this_type, this_picks in _picks_by_type(info_):
        scales[this_picks] = _handle_default('scalings')[this_type]

    for ii, epoch in enumerate(epochs):
        y = epoch[picks][:, fit_range].T  # time is samples
        betas, _, _, _ = linalg.lstsq(a=design_matrix, b=y * scales)
        intercepts[ii] = betas[0]
        slopes[ii] = betas[1]

    return slopes, intercepts
Ejemplo n.º 11
0
def _plot_topomap(data, pos, vmin=None, vmax=None, cmap=None, sensors=True,
                  res=64, axes=None, names=None, show_names=False, mask=None,
                  mask_params=None, outlines='head',
                  contours=6, image_interp='bilinear', show=True,
                  head_pos=None, onselect=None, extrapolate='box', border=0):
    import matplotlib.pyplot as plt
    from matplotlib.widgets import RectangleSelector
    data = np.asarray(data)

    if isinstance(pos, Info):  # infer pos from Info object
        picks = _pick_data_channels(pos)  # pick only data channels
        pos = pick_info(pos, picks)

        # check if there is only 1 channel type, and n_chans matches the data
        ch_type = {channel_type(pos, idx)
                   for idx, _ in enumerate(pos["chs"])}
        info_help = ("Pick Info with e.g. mne.pick_info and "
                     "mne.io.pick.channel_indices_by_type.")
        if len(ch_type) > 1:
            raise ValueError("Multiple channel types in Info structure. "
                             + info_help)
        elif len(pos["chs"]) != data.shape[0]:
            raise ValueError("Number of channels in the Info object and "
                             "the data array does not match. " + info_help)
        else:
            ch_type = ch_type.pop()

        if any(type_ in ch_type for type_ in ('planar', 'grad')):
            # deal with grad pairs
            from mne.channels.layout import (_merge_grad_data, find_layout,
                                             _pair_grad_sensors)
            picks, pos = _pair_grad_sensors(pos, find_layout(pos))
            data = _merge_grad_data(data[picks]).reshape(-1)
        else:
            picks = list(range(data.shape[0]))
            pos = _find_topomap_coords(pos, picks=picks)

    if data.ndim > 1:
        raise ValueError("Data needs to be array of shape (n_sensors,); got "
                         "shape %s." % str(data.shape))

    # Give a helpful error message for common mistakes regarding the position
    # matrix.
    pos_help = ("Electrode positions should be specified as a 2D array with "
                "shape (n_channels, 2). Each row in this matrix contains the "
                "(x, y) position of an electrode.")
    if pos.ndim != 2:
        error = ("{ndim}D array supplied as electrode positions, where a 2D "
                 "array was expected").format(ndim=pos.ndim)
        raise ValueError(error + " " + pos_help)
    elif pos.shape[1] == 3:
        error = ("The supplied electrode positions matrix contains 3 columns. "
                 "Are you trying to specify XYZ coordinates? Perhaps the "
                 "mne.channels.create_eeg_layout function is useful for you.")
        raise ValueError(error + " " + pos_help)
    # No error is raised in case of pos.shape[1] == 4. In this case, it is
    # assumed the position matrix contains both (x, y) and (width, height)
    # values, such as Layout.pos.
    elif pos.shape[1] == 1 or pos.shape[1] > 4:
        raise ValueError(pos_help)

    if len(data) != len(pos):
        raise ValueError("Data and pos need to be of same length. Got data of "
                         "length %s, pos of length %s" % (len(data), len(pos)))

    norm = min(data) >= 0
    vmin, vmax = _setup_vmin_vmax(data, vmin, vmax, norm)
    if cmap is None:
        cmap = 'Reds' if norm else 'RdBu_r'

    pos, outlines = _check_outlines(pos, outlines, head_pos)
    assert isinstance(outlines, dict)

    ax = axes if axes else plt.gca()
    _prepare_topomap(pos, ax)

    _use_default_outlines = any(k.startswith('head') for k in outlines)

    if _use_default_outlines:
        # prepare masking
        _autoshrink(outlines, pos, res)

    mask_params = _handle_default('mask_params', mask_params)

    # find mask limits
    xlim = np.inf, -np.inf,
    ylim = np.inf, -np.inf,
    mask_ = np.c_[outlines['mask_pos']]
    xmin, xmax = (np.min(np.r_[xlim[0], mask_[:, 0]]),
                  np.max(np.r_[xlim[1], mask_[:, 0]]))
    ymin, ymax = (np.min(np.r_[ylim[0], mask_[:, 1]]),
                  np.max(np.r_[ylim[1], mask_[:, 1]]))

    # interpolate the data, we multiply clip radius by 1.06 so that pixelated
    # edges of the interpolated image would appear under the mask
    head_radius = (None if extrapolate == 'local' else
                   outlines['clip_radius'][0] * 1.06)
    xi = np.linspace(xmin, xmax, res)
    yi = np.linspace(ymin, ymax, res)
    Xi, Yi = np.meshgrid(xi, yi)
    interp = _GridData(pos, extrapolate, head_radius, border).set_values(data)
    Zi = interp.set_locations(Xi, Yi)()

    # plot outline
    patch_ = None
    if 'patch' in outlines:
        patch_ = outlines['patch']
        patch_ = patch_() if callable(patch_) else patch_
        patch_.set_clip_on(False)
        ax.add_patch(patch_)
        ax.set_transform(ax.transAxes)
        ax.set_clip_path(patch_)
    if _use_default_outlines:
        from matplotlib import patches
        patch_ = patches.Ellipse((0, 0),
                                 2 * outlines['clip_radius'][0],
                                 2 * outlines['clip_radius'][1],
                                 clip_on=True,
                                 transform=ax.transData)

    # plot interpolated map
    im = ax.imshow(Zi, cmap=cmap, vmin=vmin, vmax=vmax, origin='lower',
                   aspect='equal', extent=(xmin, xmax, ymin, ymax),
                   interpolation=image_interp)

    # This tackles an incomprehensible matplotlib bug if no contours are
    # drawn. To avoid rescalings, we will always draw contours.
    # But if no contours are desired we only draw one and make it invisible.
    linewidth = mask_params['markeredgewidth']
    no_contours = False
    if isinstance(contours, (np.ndarray, list)):
        pass  # contours precomputed
    elif contours == 0:
        contours, no_contours = 1, True
    if (Zi == Zi[0, 0]).all():
        cont = None  # can't make contours for constant-valued functions
    else:
        with warnings.catch_warnings(record=True):
            warnings.simplefilter('ignore')
            cont = ax.contour(Xi, Yi, Zi, contours, colors='k',
                              linewidths=linewidth / 2.)
    if no_contours and cont is not None:
        for col in cont.collections:
            col.set_visible(False)

    if patch_ is not None:
        im.set_clip_path(patch_)
        if cont is not None:
            for col in cont.collections:
                col.set_clip_path(patch_)

    pos_x, pos_y = pos.T
    if sensors is not False and mask is None:
        _plot_sensors(pos_x, pos_y, sensors=sensors, ax=ax)
    elif sensors and mask is not None:
        idx = np.where(mask)[0]
        ax.plot(pos_x[idx], pos_y[idx], **mask_params)
        idx = np.where(~mask)[0]
        _plot_sensors(pos_x[idx], pos_y[idx], sensors=sensors, ax=ax)
    elif not sensors and mask is not None:
        idx = np.where(mask)[0]
        ax.plot(pos_x[idx], pos_y[idx], **mask_params)

    if isinstance(outlines, dict):
        _draw_outlines(ax, outlines)

    if show_names:
        if names is None:
            raise ValueError("To show names, a list of names must be provided"
                             " (see `names` keyword).")
        if show_names is True:
            def _show_names(x):
                return x
        else:
            _show_names = show_names
        show_idx = np.arange(len(names)) if mask is None else np.where(mask)[0]
        for ii, (p, ch_id) in enumerate(zip(pos, names)):
            if ii not in show_idx:
                continue
            ch_id = _show_names(ch_id)
            ax.text(p[0], p[1], ch_id, horizontalalignment='center',
                    verticalalignment='center', size='x-small')

    if onselect is not None:
        ax.RS = RectangleSelector(ax, onselect=onselect)
    plt_show(show)
    return im, cont, interp, patch_
Ejemplo n.º 12
0
def _plot_evoked(evoked, plot_type, colorbar=True, hline=None, ylim=None,
                picks=None, exclude='bads', unit=True, show=True,
                      clim=None, proj=False, xlim='tight', units=None,
                      scalings=None, titles=None, axes=None, cmap='RdBu_r'):
    """Aux function for plot_evoked and plot_evoked_image (cf. docstrings)

    Extra param is:

    plot_type : str, value ('butterfly' | 'image')
        The type of graph to plot: 'butterfly' plots each channel as a line
        (x axis: time, y axis: amplitude). 'image' plots a 2D image where
        color depicts the amplitude of each channel at a given time point
        (x axis: time, y axis: channel). In 'image' mode, the plot is not
        interactive.
    """
    import matplotlib.pyplot as plt
    if axes is not None and proj == 'interactive':
        raise RuntimeError('Currently only single axis figures are supported'
                           ' for interactive SSP selection.')

    scalings = _handle_default('scalings', scalings)
    titles = _handle_default('titles', titles)
    units = _handle_default('units', units)

    channel_types = set(key for d in [scalings, titles, units] for key in d)
    channel_types = sorted(channel_types)  # to guarantee consistent order

    if picks is None:
        picks = list(range(evoked.info['nchan']))

    bad_ch_idx = [evoked.ch_names.index(ch) for ch in evoked.info['bads']
                  if ch in evoked.ch_names]
    if len(exclude) > 0:
        if isinstance(exclude, string_types) and exclude == 'bads':
            exclude = bad_ch_idx
        elif (isinstance(exclude, list)
              and all([isinstance(ch, string_types) for ch in exclude])):
            exclude = [evoked.ch_names.index(ch) for ch in exclude]
        else:
            raise ValueError('exclude has to be a list of channel names or '
                             '"bads"')

        picks = list(set(picks).difference(exclude))

    types = [channel_type(evoked.info, idx) for idx in picks]
    n_channel_types = 0
    ch_types_used = []
    for t in channel_types:
        if t in types:
            n_channel_types += 1
            ch_types_used.append(t)

    axes_init = axes  # remember if axes where given as input

    fig = None
    if axes is None:
        fig, axes = plt.subplots(n_channel_types, 1)

    if isinstance(axes, plt.Axes):
        axes = [axes]
    elif isinstance(axes, np.ndarray):
        axes = list(axes)

    if axes_init is not None:
        fig = axes[0].get_figure()

    if not len(axes) == n_channel_types:
        raise ValueError('Number of axes (%g) must match number of channel '
                         'types (%g)' % (len(axes), n_channel_types))

    # instead of projecting during each iteration let's use the mixin here.
    if proj is True and evoked.proj is not True:
        evoked = evoked.copy()
        evoked.apply_proj()

    times = 1e3 * evoked.times  # time in miliseconds
    for ax, t in zip(axes, ch_types_used):
        ch_unit = units[t]
        this_scaling = scalings[t]
        if unit is False:
            this_scaling = 1.0
            ch_unit = 'NA'  # no unit
        idx = [picks[i] for i in range(len(picks)) if types[i] == t]
        if len(idx) > 0:
            # Parameters for butterfly interactive plots
            if plot_type == 'butterfly':
                if any([i in bad_ch_idx for i in idx]):
                    colors = ['k'] * len(idx)
                    for i in bad_ch_idx:
                        if i in idx:
                            colors[idx.index(i)] = 'r'

                    ax._get_lines.color_cycle = iter(colors)
                else:
                    ax._get_lines.color_cycle = cycle(['k'])
            # Set amplitude scaling
            D = this_scaling * evoked.data[idx, :]
            # plt.axes(ax)
            if plot_type == 'butterfly':
                ax.plot(times, D.T)
            elif plot_type == 'image':
                im = ax.imshow(D, interpolation='nearest', origin='lower',
                               extent=[times[0], times[-1], 0, D.shape[0]],
                               aspect='auto', cmap=cmap)
                if colorbar:
                    cbar = plt.colorbar(im, ax=ax)
                    cbar.ax.set_title(ch_unit)
            elif plot_type == 'mean' :
#                 ax.plot(times, D.mean(axis=0))
                ax.plot(times, np.abs(D).mean(axis=0))
            if xlim is not None:
                if xlim == 'tight':
                    xlim = (times[0], times[-1])
                ax.set_xlim(xlim)
            if ylim is not None and t in ylim:
                if plot_type == 'butterfly' or plot_type == 'mean':
                    ax.set_ylim(ylim[t])
                elif plot_type == 'image':
                    im.set_clim(ylim[t])
            ax.set_title(titles[t] + ' (%d channel%s)' % (
                         len(D), 's' if len(D) > 1 else ''))
            ax.set_xlabel('time (ms)')
            if plot_type == 'butterfly' or plot_type == 'mean':
                ax.set_ylabel('data (%s)' % ch_unit)
            elif plot_type == 'image':
                ax.set_ylabel('channels (%s)' % 'index')
            else:
                raise ValueError("plot_type has to be 'butterfly' or 'image'."
                                 "Got %s." % plot_type)

            if (plot_type == 'butterfly' or plot_type == 'mean') and (hline is not None):
                for h in hline:
                    ax.axhline(h, color='r', linestyle='--', linewidth=2)

    if axes_init is None:
        plt.subplots_adjust(0.175, 0.08, 0.94, 0.94, 0.2, 0.63)

    # if proj == 'interactive':
    #     _check_delayed_ssp(evoked)
    #     params = dict(evoked=evoked, fig=fig, projs=evoked.info['projs'],
    #                   axes=axes, types=types, units=units, scalings=scalings,
    #                   unit=unit, ch_types_used=ch_types_used, picks=picks,
    #                   plot_update_proj_callback=_plot_update_evoked,
    #                   plot_type=plot_type)
    #     _draw_proj_checkbox(None, params)

    if show and plt.get_backend() != 'agg':
        plt.show()
        fig.canvas.draw()  # for axes plots update axes.
    tight_layout(fig=fig)

    return fig
Ejemplo n.º 13
0
def _plot_evoked(evoked, plot_type, colorbar=True, hline=None, ylim=None,
                picks=None, exclude='bads', unit=True, show=True,
                      clim=None, proj=False, xlim='tight', units=None,
                      scalings=None, titles=None, axes=None, cmap='RdBu_r'):
    """Aux function for plot_evoked and plot_evoked_image (cf. docstrings)

    Extra param is:

    plot_type : str, value ('butterfly' | 'image')
        The type of graph to plot: 'butterfly' plots each channel as a line
        (x axis: time, y axis: amplitude). 'image' plots a 2D image where
        color depicts the amplitude of each channel at a given time point
        (x axis: time, y axis: channel). In 'image' mode, the plot is not
        interactive.
    """
    import matplotlib.pyplot as plt
    if axes is not None and proj == 'interactive':
        raise RuntimeError('Currently only single axis figures are supported'
                           ' for interactive SSP selection.')

    scalings = _handle_default('scalings', scalings)
    titles = _handle_default('titles', titles)
    units = _handle_default('units', units)

    channel_types = set(key for d in [scalings, titles, units] for key in d)
    channel_types = sorted(channel_types)  # to guarantee consistent order

    if picks is None:
        picks = list(range(evoked.info['nchan']))

    bad_ch_idx = [evoked.ch_names.index(ch) for ch in evoked.info['bads']
                  if ch in evoked.ch_names]
    if len(exclude) > 0:
        if isinstance(exclude, string_types) and exclude == 'bads':
            exclude = bad_ch_idx
        elif (isinstance(exclude, list)
              and all([isinstance(ch, string_types) for ch in exclude])):
            exclude = [evoked.ch_names.index(ch) for ch in exclude]
        else:
            raise ValueError('exclude has to be a list of channel names or '
                             '"bads"')

        picks = list(set(picks).difference(exclude))

    types = [channel_type(evoked.info, idx) for idx in picks]
    n_channel_types = 0
    ch_types_used = []
    for t in channel_types:
        if t in types:
            n_channel_types += 1
            ch_types_used.append(t)

    axes_init = axes  # remember if axes where given as input

    fig = None
    if axes is None:
        fig, axes = plt.subplots(n_channel_types, 1)

    if isinstance(axes, plt.Axes):
        axes = [axes]
    elif isinstance(axes, np.ndarray):
        axes = list(axes)

    if axes_init is not None:
        fig = axes[0].get_figure()

    if not len(axes) == n_channel_types:
        raise ValueError('Number of axes (%g) must match number of channel '
                         'types (%g)' % (len(axes), n_channel_types))

    # instead of projecting during each iteration let's use the mixin here.
    if proj is True and evoked.proj is not True:
        evoked = evoked.copy()
        evoked.apply_proj()

    times = 1e3 * evoked.times  # time in miliseconds
    for ax, t in zip(axes, ch_types_used):
        ch_unit = units[t]
        this_scaling = scalings[t]
        if unit is False:
            this_scaling = 1.0
            ch_unit = 'NA'  # no unit
        idx = [picks[i] for i in range(len(picks)) if types[i] == t]
        if len(idx) > 0:
            # Parameters for butterfly interactive plots
            if plot_type == 'butterfly':
                if any([i in bad_ch_idx for i in idx]):
                    colors = ['k'] * len(idx)
                    for i in bad_ch_idx:
                        if i in idx:
                            colors[idx.index(i)] = 'r'

                    ax._get_lines.color_cycle = iter(colors)
                else:
                    ax._get_lines.color_cycle = cycle(['k'])
            # Set amplitude scaling
            D = this_scaling * evoked.data[idx, :]
            # plt.axes(ax)
            if plot_type == 'butterfly':
                ax.plot(times, D.T)
            elif plot_type == 'image':
                im = ax.imshow(D, interpolation='nearest', origin='lower',
                               extent=[times[0], times[-1], 0, D.shape[0]],
                               aspect='auto', cmap=cmap)
                if colorbar:
                    cbar = plt.colorbar(im, ax=ax)
                    cbar.ax.set_title(ch_unit)
            elif plot_type == 'mean' :
#                 ax.plot(times, D.mean(axis=0))
                ax.plot(times, np.abs(D).mean(axis=0))
            if xlim is not None:
                if xlim == 'tight':
                    xlim = (times[0], times[-1])
                ax.set_xlim(xlim)
            if ylim is not None and t in ylim:
                if plot_type == 'butterfly' or plot_type == 'mean':
                    ax.set_ylim(ylim[t])
                elif plot_type == 'image':
                    im.set_clim(ylim[t])
            ax.set_title(titles[t] + ' (%d channel%s)' % (
                         len(D), 's' if len(D) > 1 else ''))
            ax.set_xlabel('time (ms)')
            if plot_type == 'butterfly' or plot_type == 'mean':
                ax.set_ylabel('data (%s)' % ch_unit)
            elif plot_type == 'image':
                ax.set_ylabel('channels (%s)' % 'index')
            else:
                raise ValueError("plot_type has to be 'butterfly' or 'image'."
                                 "Got %s." % plot_type)

            if (plot_type == 'butterfly' or plot_type == 'mean') and (hline is not None):
                for h in hline:
                    ax.axhline(h, color='r', linestyle='--', linewidth=2)

    if axes_init is None:
        plt.subplots_adjust(0.175, 0.08, 0.94, 0.94, 0.2, 0.63)

    # if proj == 'interactive':
    #     _check_delayed_ssp(evoked)
    #     params = dict(evoked=evoked, fig=fig, projs=evoked.info['projs'],
    #                   axes=axes, types=types, units=units, scalings=scalings,
    #                   unit=unit, ch_types_used=ch_types_used, picks=picks,
    #                   plot_update_proj_callback=_plot_update_evoked,
    #                   plot_type=plot_type)
    #     _draw_proj_checkbox(None, params)

    if show and plt.get_backend() != 'agg':
        plt.show()
        fig.canvas.draw()  # for axes plots update axes.
    tight_layout(fig=fig)

    return fig
Ejemplo n.º 14
0
    def plot(self,
             fmin=0,
             fmax=None,
             proj=False,
             picks=None,
             ax=None,
             color='black',
             xscale='linear',
             area_mode='std',
             area_alpha=0.33,
             dB=True,
             estimate='auto',
             show=True,
             n_jobs=1,
             average=False,
             line_alpha=None,
             spatial_colors=True,
             verbose=None,
             sphere=None):
        from mne.viz.utils import _plot_psd, plt_show

        # set up default vars
        from packaging import version
        mne_version = version.parse(mne.__version__)
        has_new_mne = mne_version >= version.parse('0.22.0')
        has_20_mne = (mne_version >= version.parse('0.20.0')
                      and mne_version < version.parse('0.22.0'))
        if has_new_mne:
            from mne.defaults import _handle_default
            from mne.io.pick import _picks_to_idx
            from mne.viz._figure import _split_picks_by_type

            if ax is None:
                import matplotlib.pyplot as plt
                fig, ax = plt.subplots()
            else:
                fig = ax.figure
            ax_list = [ax]

            units = _handle_default('units', None)
            picks = _picks_to_idx(self.info, picks)
            titles = _handle_default('titles', None)
            scalings = _handle_default('scalings', None)

            make_label = len(ax_list) == len(fig.axes)
            xlabels_list = [False] * (len(ax_list) - 1) + [True]
            (picks_list, units_list, scalings_list,
             titles_list) = _split_picks_by_type(self, picks, units, scalings,
                                                 titles)
        elif has_20_mne:
            from mne.viz.utils import _set_psd_plot_params
            fig, picks_list, titles_list, units_list, scalings_list, \
                ax_list, make_label, xlabels_list = _set_psd_plot_params(
                    self.info, proj, picks, ax, area_mode)
        else:
            from mne.viz.utils import _set_psd_plot_params
            fig, picks_list, titles_list, units_list, scalings_list, ax_list, \
                make_label = _set_psd_plot_params(self.info, proj, picks, ax,
                                                  area_mode)
        del ax

        crop_inst = not (fmin == 0 and fmax is None)
        fmax = self.freqs[-1] if fmax is None else fmax

        inst = self.copy()
        if crop_inst:
            inst.crop(fmin=fmin, fmax=fmax)
        inst.average()

        # create list of psd's (one element for each channel type)
        psd_list = list()
        for picks in picks_list:
            psd_list.append(inst.data[picks])

        args = [
            inst, fig, inst.freqs, psd_list, picks_list, titles_list,
            units_list, scalings_list, ax_list, make_label, color, area_mode,
            area_alpha, dB, estimate, average, spatial_colors, xscale,
            line_alpha
        ]
        if has_20_mne or has_new_mne:
            args += [sphere, xlabels_list]

        fig = _plot_psd(*args)
        plt_show(show)
        return fig
Ejemplo n.º 15
0
# create epochs objects
ep_low = EpochsArray(X[y == LOW_CONF_EPOCH, ...], info, tmin=-1)
ep_high = EpochsArray(X[y == HIGH_CONF_EPOCH, ...], info, tmin=-1)

# psds_high, freqs = psd_multitaper(ep_high, tmin=0, tmax=1, fmax=50)
# psds_low, freqs = psd_multitaper(ep_low, tmin=0, tmax=1, fmax=50)
psds_high, freqs = psd_welch(ep_high, tmin=0.3, tmax=0.9, fmax=50)
psds_low, freqs = psd_welch(ep_low, tmin=0.3, tmax=0.9, fmax=50)

# normalize
# psds_high /= psds_high.mean(axis=2, keepdims=True)
# psds_low /= psds_low.mean(axis=2, keepdims=True)

psds = (psds_high.mean(axis=0) - psds_low.mean(axis=0)) / psds_low.mean(axis=0)
ch_type = _get_ch_type(ep_high, None)
units = _handle_default("units", None)
unit = units[ch_type]

(
    picks,
    pos,
    merge_channels,
    names,
    ch_type,
    sphere,
    clip_origin,
) = _prepare_topomap_plot(ep_low, ch_type, sphere=None)
outlines = _make_head_outlines(sphere, pos, "head", clip_origin)

if merge_channels:
    psds_merge, names = _merge_ch_data(psds, ch_type, names, method="mean")
Ejemplo n.º 16
0
def plot_topomap(data, pos, vmin=None, vmax=None, cmap=None, sensors=True,
                 res=64, axes=None, names=None, show_names=False, mask=None,
                 mask_params=None, outlines='head', image_mask=None,
                 contours=6, image_interp='bilinear', show=True,
                 head_pos=None, onselect=None, axis=None):
    ''' see the docstring for mne.viz.plot_topomap,
        which i've simply modified to return more objects '''

    from matplotlib.widgets import RectangleSelector
    from mne.io.pick import (channel_type, pick_info, _pick_data_channels)
    from mne.utils import warn
    from mne.viz.utils import (_setup_vmin_vmax, plt_show)
    from mne.defaults import _handle_default
    from mne.channels.layout import _find_topomap_coords
    from mne.io.meas_info import Info
    from mne.viz.topomap import _check_outlines, _prepare_topomap, _griddata, _make_image_mask, _plot_sensors, \
        _draw_outlines

    data = np.asarray(data)

    if isinstance(pos, Info):  # infer pos from Info object
        picks = _pick_data_channels(pos)  # pick only data channels
        pos = pick_info(pos, picks)

        # check if there is only 1 channel type, and n_chans matches the data
        ch_type = set(channel_type(pos, idx)
                      for idx, _ in enumerate(pos["chs"]))
        info_help = ("Pick Info with e.g. mne.pick_info and "
                     "mne.channels.channel_indices_by_type.")
        if len(ch_type) > 1:
            raise ValueError("Multiple channel types in Info structure. " +
                             info_help)
        elif len(pos["chs"]) != data.shape[0]:
            raise ValueError("Number of channels in the Info object and "
                             "the data array does not match. " + info_help)
        else:
            ch_type = ch_type.pop()

        if any(type_ in ch_type for type_ in ('planar', 'grad')):
            # deal with grad pairs
            from ..channels.layout import (_merge_grad_data, find_layout,
                                           _pair_grad_sensors)
            picks, pos = _pair_grad_sensors(pos, find_layout(pos))
            data = _merge_grad_data(data[picks]).reshape(-1)
        else:
            picks = list(range(data.shape[0]))
            pos = _find_topomap_coords(pos, picks=picks)

    if data.ndim > 1:
        raise ValueError("Data needs to be array of shape (n_sensors,); got "
                         "shape %s." % str(data.shape))

    # Give a helpful error message for common mistakes regarding the position
    # matrix.
    pos_help = ("Electrode positions should be specified as a 2D array with "
                "shape (n_channels, 2). Each row in this matrix contains the "
                "(x, y) position of an electrode.")
    if pos.ndim != 2:
        error = ("{ndim}D array supplied as electrode positions, where a 2D "
                 "array was expected").format(ndim=pos.ndim)
        raise ValueError(error + " " + pos_help)
    elif pos.shape[1] == 3:
        error = ("The supplied electrode positions matrix contains 3 columns. "
                 "Are you trying to specify XYZ coordinates? Perhaps the "
                 "mne.channels.create_eeg_layout function is useful for you.")
        raise ValueError(error + " " + pos_help)
    # No error is raised in case of pos.shape[1] == 4. In this case, it is
    # assumed the position matrix contains both (x, y) and (width, height)
    # values, such as Layout.pos.
    elif pos.shape[1] == 1 or pos.shape[1] > 4:
        raise ValueError(pos_help)

    if len(data) != len(pos):
        raise ValueError("Data and pos need to be of same length. Got data of "
                         "length %s, pos of length %s" % (len(data), len(pos)))

    norm = min(data) >= 0
    vmin, vmax = _setup_vmin_vmax(data, vmin, vmax, norm)
    if cmap is None:
        cmap = 'Reds' if norm else 'RdBu_r'

    pos, outlines = _check_outlines(pos, outlines, head_pos)

    if axis is not None:
        axes = axis
        warn('axis parameter is deprecated and will be removed in 0.13. '
             'Use axes instead.', DeprecationWarning)
    ax = axes if axes else plt.gca()
    pos_x, pos_y = _prepare_topomap(pos, ax)
    if outlines is None:
        xmin, xmax = pos_x.min(), pos_x.max()
        ymin, ymax = pos_y.min(), pos_y.max()
    else:
        xlim = np.inf, -np.inf,
        ylim = np.inf, -np.inf,
        mask_ = np.c_[outlines['mask_pos']]
        xmin, xmax = (np.min(np.r_[xlim[0], mask_[:, 0]]),
                      np.max(np.r_[xlim[1], mask_[:, 0]]))
        ymin, ymax = (np.min(np.r_[ylim[0], mask_[:, 1]]),
                      np.max(np.r_[ylim[1], mask_[:, 1]]))

    # interpolate data
    xi = np.linspace(xmin, xmax, res)
    yi = np.linspace(ymin, ymax, res)
    Xi, Yi = np.meshgrid(xi, yi)
    Zi = _griddata(pos_x, pos_y, data, Xi, Yi)

    if outlines is None:
        _is_default_outlines = False
    elif isinstance(outlines, dict):
        _is_default_outlines = any(k.startswith('head') for k in outlines)

    if _is_default_outlines and image_mask is None:
        # prepare masking
        image_mask, pos = _make_image_mask(outlines, pos, res)

    mask_params = _handle_default('mask_params', mask_params)

    # plot outline
    linewidth = mask_params['markeredgewidth']
    patch = None
    if 'patch' in outlines:
        patch = outlines['patch']
        patch_ = patch() if callable(patch) else patch
        patch_.set_clip_on(False)
        ax.add_patch(patch_)
        ax.set_transform(ax.transAxes)
        ax.set_clip_path(patch_)

    # plot map and countour
    im = ax.imshow(Zi, cmap=cmap, vmin=vmin, vmax=vmax, origin='lower',
                   aspect='equal', extent=(xmin, xmax, ymin, ymax),
                   interpolation=image_interp)

    # This tackles an incomprehensible matplotlib bug if no contours are
    # drawn. To avoid rescalings, we will always draw contours.
    # But if no contours are desired we only draw one and make it invisible .
    no_contours = False
    if contours in (False, None):
        contours, no_contours = 1, True
    cont = ax.contour(Xi, Yi, Zi, contours, colors='k',
                      linewidths=linewidth)
    if no_contours is True:
        for col in cont.collections:
            col.set_visible(False)

    if _is_default_outlines:
        from matplotlib import patches
        patch_ = patches.Ellipse((0, 0),
                                 2 * outlines['clip_radius'][0],
                                 2 * outlines['clip_radius'][1],
                                 clip_on=True,
                                 transform=ax.transData)
    if _is_default_outlines or patch is not None:
        im.set_clip_path(patch_)
        if cont is not None:
            for col in cont.collections:
                col.set_clip_path(patch_)

    if sensors is not False and mask is None:
        _plot_sensors(pos_x, pos_y, sensors=sensors, ax=ax)
    elif sensors and mask is not None:
        idx = np.where(mask)[0]
        ax.plot(pos_x[idx], pos_y[idx], **mask_params)
        idx = np.where(~mask)[0]
        _plot_sensors(pos_x[idx], pos_y[idx], sensors=sensors, ax=ax)
    elif not sensors and mask is not None:
        idx = np.where(mask)[0]
        ax.plot(pos_x[idx], pos_y[idx], **mask_params)

    if isinstance(outlines, dict):
        _draw_outlines(ax, outlines)

    if show_names:
        if names is None:
            raise ValueError("To show names, a list of names must be provided"
                             " (see `names` keyword).")
        if show_names is True:
            def _show_names(x):
                return x
        else:
            _show_names = show_names
        show_idx = np.arange(len(names)) if mask is None else np.where(mask)[0]
        for ii, (p, ch_id) in enumerate(zip(pos, names)):
            if ii not in show_idx:
                continue
            ch_id = _show_names(ch_id)
            ax.text(p[0], p[1], ch_id, horizontalalignment='center',
                    verticalalignment='center', size='x-small')

    plt.subplots_adjust(top=.95)

    if onselect is not None:
        ax.RS = RectangleSelector(ax, onselect=onselect)
    plt_show(show)
    return ax, im, cont, pos_x, pos_y
Ejemplo n.º 17
0
def _prepare_mne_browse_epochs(params, projs, n_channels, n_epochs, scalings,
                               title, picks, order=None):
    """Helper for setting up the mne_browse_epochs window."""
    import matplotlib.pyplot as plt
    import matplotlib as mpl
    from matplotlib.collections import LineCollection
    from matplotlib.colors import colorConverter
    epochs = params['epochs']

    if picks is None:
        picks = _handle_picks(epochs)
    if len(picks) < 1:
        raise RuntimeError('No appropriate channels found. Please'
                           ' check your picks')
    picks = sorted(picks)
    # Reorganize channels
    inds = list()
    types = list()
    for t in ['grad', 'mag']:
        idxs = pick_types(params['info'], meg=t, ref_meg=False, exclude=[])
        if len(idxs) < 1:
            continue
        mask = np.in1d(idxs, picks, assume_unique=True)
        inds.append(idxs[mask])
        types += [t] * len(inds[-1])
    pick_kwargs = dict(meg=False, ref_meg=False, exclude=[])
    if order is None:
        order = ['eeg', 'seeg', 'ecog', 'eog', 'ecg', 'emg', 'ref_meg', 'stim',
                 'resp', 'misc', 'chpi', 'syst', 'ias', 'exci']
    for ch_type in order:
        pick_kwargs[ch_type] = True
        idxs = pick_types(params['info'], **pick_kwargs)
        if len(idxs) < 1:
            continue
        mask = np.in1d(idxs, picks, assume_unique=True)
        inds.append(idxs[mask])
        types += [ch_type] * len(inds[-1])
        pick_kwargs[ch_type] = False
    inds = np.concatenate(inds).astype(int)
    if not len(inds) == len(picks):
        raise RuntimeError('Some channels not classified. Please'
                           ' check your picks')
    ch_names = [params['info']['ch_names'][x] for x in inds]

    # set up plotting
    size = get_config('MNE_BROWSE_RAW_SIZE')
    n_epochs = min(n_epochs, len(epochs.events))
    duration = len(epochs.times) * n_epochs
    n_channels = min(n_channels, len(picks))
    if size is not None:
        size = size.split(',')
        size = tuple(float(s) for s in size)
    if title is None:
        title = epochs.name
        if epochs.name is None or len(title) == 0:
            title = ''
    fig = figure_nobar(facecolor='w', figsize=size, dpi=80)
    fig.canvas.set_window_title('mne_browse_epochs')
    ax = plt.subplot2grid((10, 15), (0, 1), colspan=13, rowspan=9)

    ax.annotate(title, xy=(0.5, 1), xytext=(0, ax.get_ylim()[1] + 15),
                ha='center', va='bottom', size=12, xycoords='axes fraction',
                textcoords='offset points')
    color = _handle_default('color', None)

    ax.axis([0, duration, 0, 200])
    ax2 = ax.twiny()
    ax2.set_zorder(-1)
    ax2.axis([0, duration, 0, 200])
    ax_hscroll = plt.subplot2grid((10, 15), (9, 1), colspan=13)
    ax_hscroll.get_yaxis().set_visible(False)
    ax_hscroll.set_xlabel('Epochs')
    ax_vscroll = plt.subplot2grid((10, 15), (0, 14), rowspan=9)
    ax_vscroll.set_axis_off()
    ax_vscroll.add_patch(mpl.patches.Rectangle((0, 0), 1, len(picks),
                                               facecolor='w', zorder=3))

    ax_help_button = plt.subplot2grid((10, 15), (9, 0), colspan=1)
    help_button = mpl.widgets.Button(ax_help_button, 'Help')
    help_button.on_clicked(partial(_onclick_help, params=params))

    # populate vertical and horizontal scrollbars
    for ci in range(len(picks)):
        if ch_names[ci] in params['info']['bads']:
            this_color = params['bad_color']
        else:
            this_color = color[types[ci]]
        ax_vscroll.add_patch(mpl.patches.Rectangle((0, ci), 1, 1,
                                                   facecolor=this_color,
                                                   edgecolor=this_color,
                                                   zorder=4))

    vsel_patch = mpl.patches.Rectangle((0, 0), 1, n_channels, alpha=0.5,
                                       edgecolor='w', facecolor='w', zorder=5)
    ax_vscroll.add_patch(vsel_patch)

    ax_vscroll.set_ylim(len(types), 0)
    ax_vscroll.set_title('Ch.')

    # populate colors list
    type_colors = [colorConverter.to_rgba(color[c]) for c in types]
    colors = list()
    for color_idx in range(len(type_colors)):
        colors.append([type_colors[color_idx]] * len(epochs.events))
    lines = list()
    n_times = len(epochs.times)

    for ch_idx in range(n_channels):
        if len(colors) - 1 < ch_idx:
            break
        lc = LineCollection(list(), antialiased=False, linewidths=0.5,
                            zorder=3, picker=3.)
        ax.add_collection(lc)
        lines.append(lc)

    times = epochs.times
    data = np.zeros((params['info']['nchan'], len(times) * n_epochs))

    ylim = (25., 0.)  # Hardcoded 25 because butterfly has max 5 rows (5*5=25).
    # make shells for plotting traces
    offset = ylim[0] / n_channels
    offsets = np.arange(n_channels) * offset + (offset / 2.)

    times = np.arange(len(times) * len(epochs.events))
    epoch_times = np.arange(0, len(times), n_times)

    ax.set_yticks(offsets)
    ax.set_ylim(ylim)
    ticks = epoch_times + 0.5 * n_times
    ax.set_xticks(ticks)
    ax2.set_xticks(ticks[:n_epochs])
    labels = list(range(1, len(ticks) + 1))  # epoch numbers
    ax.set_xticklabels(labels)
    ax2.set_xticklabels(labels)
    xlim = epoch_times[-1] + len(epochs.times)
    ax_hscroll.set_xlim(0, xlim)
    vertline_t = ax_hscroll.text(0, 1, '', color='y', va='bottom', ha='right')

    # fit horizontal scroll bar ticks
    hscroll_ticks = np.arange(0, xlim, xlim / 7.0)
    hscroll_ticks = np.append(hscroll_ticks, epoch_times[-1])
    hticks = list()
    for tick in hscroll_ticks:
        hticks.append(epoch_times.flat[np.abs(epoch_times - tick).argmin()])
    hlabels = [x / n_times + 1 for x in hticks]
    ax_hscroll.set_xticks(hticks)
    ax_hscroll.set_xticklabels(hlabels)

    for epoch_idx in range(len(epoch_times)):
        ax_hscroll.add_patch(mpl.patches.Rectangle((epoch_idx * n_times, 0),
                                                   n_times, 1,
                                                   facecolor=(0.8, 0.8, 0.8),
                                                   edgecolor=(0.8, 0.8, 0.8),
                                                   alpha=0.5))
    hsel_patch = mpl.patches.Rectangle((0, 0), duration, 1,
                                       edgecolor='k',
                                       facecolor=(0.5, 0.5, 0.5),
                                       alpha=0.25, linewidth=1, clip_on=False)
    ax_hscroll.add_patch(hsel_patch)
    text = ax.text(0, 0, 'blank', zorder=3, verticalalignment='baseline',
                   ha='left', fontweight='bold')
    text.set_visible(False)

    params.update({'fig': fig,
                   'ax': ax,
                   'ax2': ax2,
                   'ax_hscroll': ax_hscroll,
                   'ax_vscroll': ax_vscroll,
                   'vsel_patch': vsel_patch,
                   'hsel_patch': hsel_patch,
                   'lines': lines,
                   'projs': projs,
                   'ch_names': ch_names,
                   'n_channels': n_channels,
                   'n_epochs': n_epochs,
                   'scalings': scalings,
                   'duration': duration,
                   'ch_start': 0,
                   'colors': colors,
                   'def_colors': type_colors,  # don't change at runtime
                   'picks': picks,
                   'data': data,
                   'times': times,
                   'epoch_times': epoch_times,
                   'offsets': offsets,
                   'labels': labels,
                   'scale_factor': 1.0,
                   'butterfly_scale': 1.0,
                   'fig_proj': None,
                   'types': np.array(types),
                   'inds': inds,
                   'vert_lines': list(),
                   'vertline_t': vertline_t,
                   'butterfly': False,
                   'text': text,
                   'ax_help_button': ax_help_button,  # needed for positioning
                   'help_button': help_button,  # reference needed for clicks
                   'fig_options': None,
                   'settings': [True, True, True, True],
                   'image_plot': None})

    params['plot_fun'] = partial(_plot_traces, params=params)

    # callbacks
    callback_scroll = partial(_plot_onscroll, params=params)
    fig.canvas.mpl_connect('scroll_event', callback_scroll)
    callback_click = partial(_mouse_click, params=params)
    fig.canvas.mpl_connect('button_press_event', callback_click)
    callback_key = partial(_plot_onkey, params=params)
    fig.canvas.mpl_connect('key_press_event', callback_key)
    callback_resize = partial(_resize_event, params=params)
    fig.canvas.mpl_connect('resize_event', callback_resize)
    fig.canvas.mpl_connect('pick_event', partial(_onpick, params=params))
    params['callback_key'] = callback_key

    # Draw event lines for the first time.
    _plot_vert_lines(params)

    # Plot bad epochs
    for epoch_idx in params['bads']:
        params['ax_hscroll'].patches[epoch_idx].set_color((1., 0., 0., 1.))
        params['ax_hscroll'].patches[epoch_idx].set_zorder(3)
        params['ax_hscroll'].patches[epoch_idx].set_edgecolor('w')
        for ch_idx in range(len(params['ch_names'])):
            params['colors'][ch_idx][epoch_idx] = (1., 0., 0., 1.)

    assert params['fix_log'].shape == (len(epochs.events),
                                       len(params['ch_names']))
    # Plot bad segments
    if params['fix_log'] is not None:
        for ch_idx in range(len(params['ch_names'])):
            for epoch_idx in range(len(epochs.events)):
                this_log = params['fix_log'][epoch_idx, ch_idx]
                if epoch_idx in params['bads']:
                    pass
                else:
                    if this_log == 1:
                        params['colors'][ch_idx][epoch_idx] = (1., 0., 0., 1.)
                    elif this_log == 2:
                        params['colors'][ch_idx][epoch_idx] = (0., 0., 1., 1.)

    params['plot_fun']()
Ejemplo n.º 18
0
def plot_epochs(epochs, picks=None, scalings=None, n_epochs=20,
                n_channels=20, title=None, show=True, block=False,
                bad_epochs_idx=None, fix_log=None):
    """ Visualize epochs

    Bad epochs can be marked with a left click on top of the epoch. Bad
    channels can be selected by clicking the channel name on the left side of
    the main axes. Calling this function drops all the selected bad epochs as
    well as bad epochs marked beforehand with rejection parameters.

    Parameters
    ----------

    epochs : instance of Epochs
        The epochs object
    picks : array-like of int | None
        Channels to be included. If None only good data channels are used.
        Defaults to None
    scalings : dict | 'auto' | None
        Scaling factors for the traces. If any fields in scalings are 'auto',
        the scaling factor is set to match the 99.5th percentile of a subset of
        the corresponding data. If scalings == 'auto', all scalings fields are
        set to 'auto'. If any fields are 'auto' and data is not preloaded,
        a subset of epochs up to 100mb will be loaded. If None, defaults to::

            dict(mag=1e-12, grad=4e-11, eeg=20e-6, eog=150e-6, ecg=5e-4,
                 emg=1e-3, ref_meg=1e-12, misc=1e-3, stim=1, resp=1, chpi=1e-4)

    n_epochs : int
        The number of epochs per view. Defaults to 20.
    n_channels : int
        The number of channels per view. Defaults to 20.
    title : str | None
        The title of the window. If None, epochs name will be displayed.
        Defaults to None.
    show : bool
        Show figure if True. Defaults to True
    block : bool
        Whether to halt program execution until the figure is closed.
        Useful for rejecting bad trials on the fly by clicking on an epoch.
        Defaults to False.
    bad_epochs_idx : array-like | None
        Indices of bad epochs to show. No bad epochs to visualize if None.
    fix_log : array, shape (n_channels, n_epochs) | None
        The bad segments to show in red and the interpolated segments
        to show in green.

    Returns
    -------
    fig : Instance of matplotlib.figure.Figure
        The figure.

    Notes
    -----
    The arrow keys (up/down/left/right) can be used to navigate between
    channels and epochs and the scaling can be adjusted with - and + (or =)
    keys, but this depends on the backend matplotlib is configured to use
    (e.g., mpl.use(``TkAgg``) should work). Full screen mode can be toggled
    with f11 key. The amount of epochs and channels per view can be adjusted
    with home/end and page down/page up keys. Butterfly plot can be toggled
    with ``b`` key. Right mouse click adds a vertical line to the plot.
    """
    epochs.drop_bad()
    scalings = _compute_scalings(scalings, epochs)
    scalings = _handle_default('scalings_plot_raw', scalings)

    projs = epochs.info['projs']

    bads = np.array(list(), dtype=int)
    if bad_epochs_idx is not None:
        bads = np.array(bad_epochs_idx).astype(int)

    params = {'epochs': epochs,
              'info': copy.deepcopy(epochs.info),
              'bad_color': (0.8, 0.8, 0.8),
              't_start': 0,
              'histogram': None,
              'bads': bads,
              'fix_log': fix_log}
    params['label_click_fun'] = partial(_pick_bad_channels, params=params)
    _prepare_mne_browse_epochs(params, projs, n_channels, n_epochs, scalings,
                               title, picks)
    _prepare_projectors(params)
    _layout_figure(params)

    callback_close = partial(_close_event, params=params)
    params['fig'].canvas.mpl_connect('close_event', callback_close)
    try:
        plt_show(show, block=block)
    except TypeError:  # not all versions have this
        plt_show(show)

    return params['fig']