예제 #1
0
def plot_ica_overlay_evoked(evoked, evoked_cln, title, show):
    """
    workaround for https://github.com/mne-tools/mne-python/issues/1819
    copied from mne.viz.ica._plot_ica_overlay_evoked()

    Plot evoked after and before ICA cleaning

    Parameters
    ----------
    ica : instance of mne.preprocessing.ICA
        The ICA object.
    epochs : instance of mne.Epochs
        The Epochs to be regarded.
    show : bool
        If True, all open plots will be shown.

    Returns
    -------
    fig : instance of pyplot.Figure
    """
    ch_types_used = [c for c in ['mag', 'grad', 'eeg'] if c in evoked]
    n_rows = len(ch_types_used)
    ch_types_used_cln = [c for c in ['mag', 'grad', 'eeg'] if
                         c in evoked_cln]

    if len(ch_types_used) != len(ch_types_used_cln):
        raise ValueError('Raw and clean evokeds must match. '
                         'Found different channels.')

    fig, axes = plt.subplots(n_rows, 1)
    fig.suptitle('Average signal before (red) and after (black) ICA')
    axes = axes.flatten() if isinstance(axes, np.ndarray) else axes

    evoked.plot(axes=axes, show=False)

    for ax in fig.axes:
        [l.set_color('r') for l in ax.get_lines()]

    fig.canvas.draw()
    evoked_cln.plot(axes=axes, show=show)
    tight_layout(fig=fig)

    if show:
        plt.show()

    fig.subplots_adjust(top=0.90)
    fig.canvas.draw()
예제 #2
0
파일: viz.py 프로젝트: Qi0116/deepthought
def plot_ica_overlay_evoked(evoked, evoked_cln, title, show):
    """
    workaround for https://github.com/mne-tools/mne-python/issues/1819
    copied from mne.viz.ica._plot_ica_overlay_evoked()

    Plot evoked after and before ICA cleaning

    Parameters
    ----------
    ica : instance of mne.preprocessing.ICA
        The ICA object.
    epochs : instance of mne.Epochs
        The Epochs to be regarded.
    show : bool
        If True, all open plots will be shown.

    Returns
    -------
    fig : instance of pyplot.Figure
    """
    ch_types_used = [c for c in ['mag', 'grad', 'eeg'] if c in evoked]
    n_rows = len(ch_types_used)
    ch_types_used_cln = [c for c in ['mag', 'grad', 'eeg'] if
                         c in evoked_cln]

    if len(ch_types_used) != len(ch_types_used_cln):
        raise ValueError('Raw and clean evokeds must match. '
                         'Found different channels.')

    fig, axes = plt.subplots(n_rows, 1)
    fig.suptitle('Average signal before (red) and after (black) ICA')
    axes = axes.flatten() if isinstance(axes, np.ndarray) else axes

    evoked.plot(axes=axes, show=False)

    for ax in fig.axes:
        [l.set_color('r') for l in ax.get_lines()]

    fig.canvas.draw()
    evoked_cln.plot(axes=axes, show=show)
    tight_layout(fig=fig)

    if show:
        plt.show()

    fig.subplots_adjust(top=0.90)
    fig.canvas.draw()
예제 #3
0
def _plot_evoked(evoked, plot_type, colorbar=True, hline=None, ylim=None,
                picks=None, exclude='bads', unit=True, show=True,
                      clim=None, proj=False, xlim='tight', units=None,
                      scalings=None, titles=None, axes=None, cmap='RdBu_r'):
    """Aux function for plot_evoked and plot_evoked_image (cf. docstrings)

    Extra param is:

    plot_type : str, value ('butterfly' | 'image')
        The type of graph to plot: 'butterfly' plots each channel as a line
        (x axis: time, y axis: amplitude). 'image' plots a 2D image where
        color depicts the amplitude of each channel at a given time point
        (x axis: time, y axis: channel). In 'image' mode, the plot is not
        interactive.
    """
    import matplotlib.pyplot as plt
    if axes is not None and proj == 'interactive':
        raise RuntimeError('Currently only single axis figures are supported'
                           ' for interactive SSP selection.')

    scalings = _handle_default('scalings', scalings)
    titles = _handle_default('titles', titles)
    units = _handle_default('units', units)

    channel_types = set(key for d in [scalings, titles, units] for key in d)
    channel_types = sorted(channel_types)  # to guarantee consistent order

    if picks is None:
        picks = list(range(evoked.info['nchan']))

    bad_ch_idx = [evoked.ch_names.index(ch) for ch in evoked.info['bads']
                  if ch in evoked.ch_names]
    if len(exclude) > 0:
        if isinstance(exclude, string_types) and exclude == 'bads':
            exclude = bad_ch_idx
        elif (isinstance(exclude, list)
              and all([isinstance(ch, string_types) for ch in exclude])):
            exclude = [evoked.ch_names.index(ch) for ch in exclude]
        else:
            raise ValueError('exclude has to be a list of channel names or '
                             '"bads"')

        picks = list(set(picks).difference(exclude))

    types = [channel_type(evoked.info, idx) for idx in picks]
    n_channel_types = 0
    ch_types_used = []
    for t in channel_types:
        if t in types:
            n_channel_types += 1
            ch_types_used.append(t)

    axes_init = axes  # remember if axes where given as input

    fig = None
    if axes is None:
        fig, axes = plt.subplots(n_channel_types, 1)

    if isinstance(axes, plt.Axes):
        axes = [axes]
    elif isinstance(axes, np.ndarray):
        axes = list(axes)

    if axes_init is not None:
        fig = axes[0].get_figure()

    if not len(axes) == n_channel_types:
        raise ValueError('Number of axes (%g) must match number of channel '
                         'types (%g)' % (len(axes), n_channel_types))

    # instead of projecting during each iteration let's use the mixin here.
    if proj is True and evoked.proj is not True:
        evoked = evoked.copy()
        evoked.apply_proj()

    times = 1e3 * evoked.times  # time in miliseconds
    for ax, t in zip(axes, ch_types_used):
        ch_unit = units[t]
        this_scaling = scalings[t]
        if unit is False:
            this_scaling = 1.0
            ch_unit = 'NA'  # no unit
        idx = [picks[i] for i in range(len(picks)) if types[i] == t]
        if len(idx) > 0:
            # Parameters for butterfly interactive plots
            if plot_type == 'butterfly':
                if any([i in bad_ch_idx for i in idx]):
                    colors = ['k'] * len(idx)
                    for i in bad_ch_idx:
                        if i in idx:
                            colors[idx.index(i)] = 'r'

                    ax._get_lines.color_cycle = iter(colors)
                else:
                    ax._get_lines.color_cycle = cycle(['k'])
            # Set amplitude scaling
            D = this_scaling * evoked.data[idx, :]
            # plt.axes(ax)
            if plot_type == 'butterfly':
                ax.plot(times, D.T)
            elif plot_type == 'image':
                im = ax.imshow(D, interpolation='nearest', origin='lower',
                               extent=[times[0], times[-1], 0, D.shape[0]],
                               aspect='auto', cmap=cmap)
                if colorbar:
                    cbar = plt.colorbar(im, ax=ax)
                    cbar.ax.set_title(ch_unit)
            elif plot_type == 'mean' :
#                 ax.plot(times, D.mean(axis=0))
                ax.plot(times, np.abs(D).mean(axis=0))
            if xlim is not None:
                if xlim == 'tight':
                    xlim = (times[0], times[-1])
                ax.set_xlim(xlim)
            if ylim is not None and t in ylim:
                if plot_type == 'butterfly' or plot_type == 'mean':
                    ax.set_ylim(ylim[t])
                elif plot_type == 'image':
                    im.set_clim(ylim[t])
            ax.set_title(titles[t] + ' (%d channel%s)' % (
                         len(D), 's' if len(D) > 1 else ''))
            ax.set_xlabel('time (ms)')
            if plot_type == 'butterfly' or plot_type == 'mean':
                ax.set_ylabel('data (%s)' % ch_unit)
            elif plot_type == 'image':
                ax.set_ylabel('channels (%s)' % 'index')
            else:
                raise ValueError("plot_type has to be 'butterfly' or 'image'."
                                 "Got %s." % plot_type)

            if (plot_type == 'butterfly' or plot_type == 'mean') and (hline is not None):
                for h in hline:
                    ax.axhline(h, color='r', linestyle='--', linewidth=2)

    if axes_init is None:
        plt.subplots_adjust(0.175, 0.08, 0.94, 0.94, 0.2, 0.63)

    # if proj == 'interactive':
    #     _check_delayed_ssp(evoked)
    #     params = dict(evoked=evoked, fig=fig, projs=evoked.info['projs'],
    #                   axes=axes, types=types, units=units, scalings=scalings,
    #                   unit=unit, ch_types_used=ch_types_used, picks=picks,
    #                   plot_update_proj_callback=_plot_update_evoked,
    #                   plot_type=plot_type)
    #     _draw_proj_checkbox(None, params)

    if show and plt.get_backend() != 'agg':
        plt.show()
        fig.canvas.draw()  # for axes plots update axes.
    tight_layout(fig=fig)

    return fig
예제 #4
0
def plot_chpi_snr_raw(raw,
                      win_length,
                      n_harmonics=None,
                      show=True,
                      verbose=True):
    """Compute and plot cHPI SNR from raw data

    Parameters
    ----------
    win_length : float
        Length of window to use for SNR estimates (seconds). A longer window
        will naturally include more low frequency power, resulting in lower
        SNR.
    n_harmonics : int or None
        Number of line frequency harmonics to include in the model. If None,
        use all harmonics up to the MEG analog lowpass corner.
    show : bool
        Show figure if True.

    Returns
    -------
    fig : instance of matplotlib.figure.Figure
        cHPI SNR as function of time, residual variance.

    Notes
    -----
    A general linear model including cHPI and line frequencies is fit into
    each data window. The cHPI power obtained from the model is then divided
    by the residual variance (variance of signal unexplained by the model) to
    obtain the SNR.

    The SNR may decrease either due to decrease of cHPI amplitudes (e.g.
    head moving away from the helmet), or due to increase in the residual
    variance. In case of broadband interference that overlaps with the cHPI
    frequencies, the resulting decreased SNR accurately reflects the true
    situation. However, increased narrowband interference outside the cHPI
    and line frequencies would also cause an increase in the residual variance,
    even though it wouldn't necessarily affect estimation of the cHPI
    amplitudes. Thus, this method is intended for a rough overview of cHPI
    signal quality. A more accurate picture of cHPI quality (at an increased
    computational cost) can be obtained by examining the goodness-of-fit of
    the cHPI coil fits.
    """
    import matplotlib.pyplot as plt
    from mne.chpi import _get_hpi_info

    # plotting parameters
    legend_fontsize = 6
    title_fontsize = 10
    tick_fontsize = 10
    label_fontsize = 10

    # get some info from fiff
    sfreq = raw.info['sfreq']
    linefreq = raw.info['line_freq']
    if n_harmonics is not None:
        linefreqs = (np.arange(n_harmonics + 1) + 1) * linefreq
    else:
        linefreqs = np.arange(linefreq, raw.info['lowpass'], linefreq)
    buflen = int(win_length * sfreq)
    if buflen <= 0:
        raise ValueError('Window length should be >0')
    cfreqs = _get_hpi_info(raw.info, verbose=False)[0]
    if verbose:
        print('Nominal cHPI frequencies: %s Hz' % cfreqs)
        print('Sampling frequency: %s Hz' % sfreq)
        print('Using line freqs: %s Hz' % linefreqs)
        print('Using buffers of %s samples = %s seconds\n' %
              (buflen, buflen / sfreq))

    pick_meg = pick_types(raw.info, meg=True, exclude=[])
    pick_mag = pick_types(raw.info, meg='mag', exclude=[])
    pick_grad = pick_types(raw.info, meg='grad', exclude=[])
    nchan = len(pick_meg)
    # grad and mag indices into an array that already has meg channels only
    pick_mag_ = np.in1d(pick_meg, pick_mag).nonzero()[0]
    pick_grad_ = np.in1d(pick_meg, pick_grad).nonzero()[0]

    # create general linear model for the data
    t = np.arange(buflen) / float(sfreq)
    model = np.empty((len(t), 2 + 2 * (len(linefreqs) + len(cfreqs))))
    model[:, 0] = t
    model[:, 1] = np.ones(t.shape)
    # add sine and cosine term for each freq
    allfreqs = np.concatenate([linefreqs, cfreqs])
    model[:, 2::2] = np.cos(2 * np.pi * t[:, np.newaxis] * allfreqs)
    model[:, 3::2] = np.sin(2 * np.pi * t[:, np.newaxis] * allfreqs)
    inv_model = linalg.pinv(model)

    # drop last buffer to avoid overrun
    bufs = np.arange(0, raw.n_times, buflen)[:-1]
    tvec = bufs / sfreq
    snr_avg_grad = np.zeros([len(cfreqs), len(bufs)])
    hpi_pow_grad = np.zeros([len(cfreqs), len(bufs)])
    snr_avg_mag = np.zeros([len(cfreqs), len(bufs)])
    resid_vars = np.zeros([nchan, len(bufs)])
    for ind, buf0 in enumerate(bufs):
        if verbose:
            print('Buffer %s/%s' % (ind + 1, len(bufs)))
        megbuf = raw[pick_meg, buf0:buf0 + buflen][0].T
        coeffs = np.dot(inv_model, megbuf)
        coeffs_hpi = coeffs[2 + 2 * len(linefreqs):]
        resid_vars[:, ind] = np.var(megbuf - np.dot(model, coeffs), 0)
        # get total power by combining sine and cosine terms
        # sinusoidal of amplitude A has power of A**2/2
        hpi_pow = (coeffs_hpi[0::2, :]**2 + coeffs_hpi[1::2, :]**2) / 2
        hpi_pow_grad[:, ind] = hpi_pow[:, pick_grad_].mean(1)
        # divide average HPI power by average variance
        snr_avg_grad[:, ind] = hpi_pow_grad[:, ind] / \
            resid_vars[pick_grad_, ind].mean()
        snr_avg_mag[:, ind] = hpi_pow[:, pick_mag_].mean(1) / \
            resid_vars[pick_mag_, ind].mean()

    cfreqs_legend = ['%s Hz' % fre for fre in cfreqs]
    fig, axs = plt.subplots(4, 1, sharex=True)

    # SNR plots for gradiometers and magnetometers
    ax = axs[0]
    lines1 = ax.plot(tvec, 10 * np.log10(snr_avg_grad.T))
    lines1_med = ax.plot(tvec,
                         10 * np.log10(np.median(snr_avg_grad, axis=0)),
                         lw=2,
                         ls=':',
                         color='k')
    ax.set_xlim([tvec.min(), tvec.max()])
    ax.set(ylabel='SNR (dB)')
    ax.yaxis.label.set_fontsize(label_fontsize)
    ax.set_title('Mean cHPI power / mean residual variance, gradiometers',
                 fontsize=title_fontsize)
    ax.tick_params(axis='both', which='major', labelsize=tick_fontsize)
    ax = axs[1]
    lines2 = ax.plot(tvec, 10 * np.log10(snr_avg_mag.T))
    lines2_med = ax.plot(tvec,
                         10 * np.log10(np.median(snr_avg_mag, axis=0)),
                         lw=2,
                         ls=':',
                         color='k')
    ax.set_xlim([tvec.min(), tvec.max()])
    ax.set(ylabel='SNR (dB)')
    ax.yaxis.label.set_fontsize(label_fontsize)
    ax.set_title('Mean cHPI power / mean residual variance, magnetometers',
                 fontsize=title_fontsize)
    ax.tick_params(axis='both', which='major', labelsize=tick_fontsize)
    ax = axs[2]
    lines3 = ax.plot(tvec, hpi_pow_grad.T)
    lines3_med = ax.plot(tvec,
                         np.median(hpi_pow_grad, axis=0),
                         lw=2,
                         ls=':',
                         color='k')
    ax.set_xlim([tvec.min(), tvec.max()])
    ax.set(ylabel='Power (T/m)$^2$')
    ax.yaxis.label.set_fontsize(label_fontsize)
    ax.set_title('Mean cHPI power, gradiometers', fontsize=title_fontsize)
    ax.tick_params(axis='both', which='major', labelsize=tick_fontsize)
    # residual (unexplained) variance as function of time
    ax = axs[3]
    cls = plt.get_cmap('plasma')(np.linspace(0., 0.7, len(pick_meg)))
    ax.set_prop_cycle(color=cls)
    ax.semilogy(tvec, resid_vars[pick_grad_, :].T, alpha=.4)
    ax.set_xlim([tvec.min(), tvec.max()])
    ax.set(ylabel='Var. (T/m)$^2$', xlabel='Time (s)')
    ax.xaxis.label.set_fontsize(label_fontsize)
    ax.yaxis.label.set_fontsize(label_fontsize)
    ax.set_title('Residual (unexplained) variance, all gradiometer channels',
                 fontsize=title_fontsize)
    ax.tick_params(axis='both', which='major', labelsize=tick_fontsize)
    tight_layout(pad=.5, w_pad=.1, h_pad=.2)  # from mne.viz
    # tight_layout will screw these up
    ax = axs[0]
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
    # order curve legends according to mean of data
    sind = np.argsort(snr_avg_grad.mean(axis=1))[::-1]
    handles = [lines1[i] for i in sind]
    handles.append(lines1_med[0])
    labels = [cfreqs_legend[i] for i in sind]
    labels.append('Median')
    leg_kwargs = dict(
        prop={'size': legend_fontsize},
        bbox_to_anchor=(
            1.02,
            0.5,
        ),
        loc='center left',
        borderpad=1,
        handlelength=1,
    )
    ax.legend(handles, labels, **leg_kwargs)
    ax = axs[1]
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
    sind = np.argsort(snr_avg_mag.mean(axis=1))[::-1]
    handles = [lines2[i] for i in sind]
    handles.append(lines2_med[0])
    labels = [cfreqs_legend[i] for i in sind]
    labels.append('Median')
    ax.legend(handles, labels, **leg_kwargs)
    ax = axs[2]
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
    sind = np.argsort(hpi_pow_grad.mean(axis=1))[::-1]
    handles = [lines3[i] for i in sind]
    handles.append(lines3_med[0])
    labels = [cfreqs_legend[i] for i in sind]
    labels.append('Median')
    ax.legend(handles, labels, **leg_kwargs)
    ax = axs[3]
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
    if show:
        plt.show()

    return fig
예제 #5
0
파일: viz.py 프로젝트: Qi0116/deepthought
def _plot_evoked(evoked, plot_type, colorbar=True, hline=None, ylim=None,
                picks=None, exclude='bads', unit=True, show=True,
                      clim=None, proj=False, xlim='tight', units=None,
                      scalings=None, titles=None, axes=None, cmap='RdBu_r'):
    """Aux function for plot_evoked and plot_evoked_image (cf. docstrings)

    Extra param is:

    plot_type : str, value ('butterfly' | 'image')
        The type of graph to plot: 'butterfly' plots each channel as a line
        (x axis: time, y axis: amplitude). 'image' plots a 2D image where
        color depicts the amplitude of each channel at a given time point
        (x axis: time, y axis: channel). In 'image' mode, the plot is not
        interactive.
    """
    import matplotlib.pyplot as plt
    if axes is not None and proj == 'interactive':
        raise RuntimeError('Currently only single axis figures are supported'
                           ' for interactive SSP selection.')

    scalings = _handle_default('scalings', scalings)
    titles = _handle_default('titles', titles)
    units = _handle_default('units', units)

    channel_types = set(key for d in [scalings, titles, units] for key in d)
    channel_types = sorted(channel_types)  # to guarantee consistent order

    if picks is None:
        picks = list(range(evoked.info['nchan']))

    bad_ch_idx = [evoked.ch_names.index(ch) for ch in evoked.info['bads']
                  if ch in evoked.ch_names]
    if len(exclude) > 0:
        if isinstance(exclude, string_types) and exclude == 'bads':
            exclude = bad_ch_idx
        elif (isinstance(exclude, list)
              and all([isinstance(ch, string_types) for ch in exclude])):
            exclude = [evoked.ch_names.index(ch) for ch in exclude]
        else:
            raise ValueError('exclude has to be a list of channel names or '
                             '"bads"')

        picks = list(set(picks).difference(exclude))

    types = [channel_type(evoked.info, idx) for idx in picks]
    n_channel_types = 0
    ch_types_used = []
    for t in channel_types:
        if t in types:
            n_channel_types += 1
            ch_types_used.append(t)

    axes_init = axes  # remember if axes where given as input

    fig = None
    if axes is None:
        fig, axes = plt.subplots(n_channel_types, 1)

    if isinstance(axes, plt.Axes):
        axes = [axes]
    elif isinstance(axes, np.ndarray):
        axes = list(axes)

    if axes_init is not None:
        fig = axes[0].get_figure()

    if not len(axes) == n_channel_types:
        raise ValueError('Number of axes (%g) must match number of channel '
                         'types (%g)' % (len(axes), n_channel_types))

    # instead of projecting during each iteration let's use the mixin here.
    if proj is True and evoked.proj is not True:
        evoked = evoked.copy()
        evoked.apply_proj()

    times = 1e3 * evoked.times  # time in miliseconds
    for ax, t in zip(axes, ch_types_used):
        ch_unit = units[t]
        this_scaling = scalings[t]
        if unit is False:
            this_scaling = 1.0
            ch_unit = 'NA'  # no unit
        idx = [picks[i] for i in range(len(picks)) if types[i] == t]
        if len(idx) > 0:
            # Parameters for butterfly interactive plots
            if plot_type == 'butterfly':
                if any([i in bad_ch_idx for i in idx]):
                    colors = ['k'] * len(idx)
                    for i in bad_ch_idx:
                        if i in idx:
                            colors[idx.index(i)] = 'r'

                    ax._get_lines.color_cycle = iter(colors)
                else:
                    ax._get_lines.color_cycle = cycle(['k'])
            # Set amplitude scaling
            D = this_scaling * evoked.data[idx, :]
            # plt.axes(ax)
            if plot_type == 'butterfly':
                ax.plot(times, D.T)
            elif plot_type == 'image':
                im = ax.imshow(D, interpolation='nearest', origin='lower',
                               extent=[times[0], times[-1], 0, D.shape[0]],
                               aspect='auto', cmap=cmap)
                if colorbar:
                    cbar = plt.colorbar(im, ax=ax)
                    cbar.ax.set_title(ch_unit)
            elif plot_type == 'mean' :
#                 ax.plot(times, D.mean(axis=0))
                ax.plot(times, np.abs(D).mean(axis=0))
            if xlim is not None:
                if xlim == 'tight':
                    xlim = (times[0], times[-1])
                ax.set_xlim(xlim)
            if ylim is not None and t in ylim:
                if plot_type == 'butterfly' or plot_type == 'mean':
                    ax.set_ylim(ylim[t])
                elif plot_type == 'image':
                    im.set_clim(ylim[t])
            ax.set_title(titles[t] + ' (%d channel%s)' % (
                         len(D), 's' if len(D) > 1 else ''))
            ax.set_xlabel('time (ms)')
            if plot_type == 'butterfly' or plot_type == 'mean':
                ax.set_ylabel('data (%s)' % ch_unit)
            elif plot_type == 'image':
                ax.set_ylabel('channels (%s)' % 'index')
            else:
                raise ValueError("plot_type has to be 'butterfly' or 'image'."
                                 "Got %s." % plot_type)

            if (plot_type == 'butterfly' or plot_type == 'mean') and (hline is not None):
                for h in hline:
                    ax.axhline(h, color='r', linestyle='--', linewidth=2)

    if axes_init is None:
        plt.subplots_adjust(0.175, 0.08, 0.94, 0.94, 0.2, 0.63)

    # if proj == 'interactive':
    #     _check_delayed_ssp(evoked)
    #     params = dict(evoked=evoked, fig=fig, projs=evoked.info['projs'],
    #                   axes=axes, types=types, units=units, scalings=scalings,
    #                   unit=unit, ch_types_used=ch_types_used, picks=picks,
    #                   plot_update_proj_callback=_plot_update_evoked,
    #                   plot_type=plot_type)
    #     _draw_proj_checkbox(None, params)

    if show and plt.get_backend() != 'agg':
        plt.show()
        fig.canvas.draw()  # for axes plots update axes.
    tight_layout(fig=fig)

    return fig