Exemple #1
0
def find_adjacency(inst, picks=None):
    '''Find channel adjacency matrix.'''
    from scipy.spatial import Delaunay
    from mne.channels.layout import _find_topomap_coords
    try:
        from mne.source_estimate import spatial_tris_connectivity as adjacency
    except:
        from mne.source_estimate import spatial_tris_adjacency as adjacency

    n_channels = len(inst.ch_names)
    picks = np.arange(n_channels) if picks is None else picks
    ch_names = [inst.info['ch_names'][pick] for pick in picks]
    xy = _find_topomap_coords(inst.info, picks)

    # first on 2x, y
    coords = xy.copy()
    coords[:, 0] *= 2
    tri = Delaunay(coords)
    neighbors1 = adjacency(tri.simplices)

    # then on x, 2y
    coords = xy.copy()
    coords[:, 1] *= 2
    tri = Delaunay(coords)
    neighbors2 = adjacency(tri.simplices)

    adjacency = neighbors1.toarray() | neighbors2.toarray()
    return adjacency, ch_names
Exemple #2
0
def test_open_topo_viewer():
    #~ reader = neo.MicromedIO(filename='File_micromed_1.TRC')
    #~ seg = reader.read_segment()
    #~ n_chan = seg.analogsignals[0].shape[1]
    #~ channel_positions = np.random.randn(n_chan, 2)
    #~ source = NeoAnalogSignalSource(seg.analogsignals[0])

    montage_name = 'standard_1020'

    vhdr_fname = 'small_BrainAmp.vhdr'
    raw = mne.io.read_raw_brainvision(vhdr_fname,
                                      montage=montage_name,
                                      eog=['EOG'],
                                      preload=True)
    raw = raw.pick('eeg')

    source = MneRawSource(raw)

    #~ montage = mne.channels.make_standard_montage('standard_1020')
    #~ print(montage)
    #~ pos = np.array(list(montage._get_ch_pos().values()))
    #~ print(pos)

    #~ n_chan = len(raw.info['ch_names'])
    #~ channel_positions = np.random.randn(n_chan, 2)
    channel_positions = _find_topomap_coords(raw.info, None)

    #you must first create a main Qt application (for event loop)
    app = mkQApp()

    win = MainViewer(show_auto_scale=True)

    view1 = TraceViewer(source=source, name='sigs')
    win.add_view(view1)

    view3 = TopoEegViewer(source=source,
                          name='topo',
                          channel_positions=channel_positions)
    win.add_view(view3)

    #show main window and run Qapp
    win.show()

    app.exec_()
Exemple #3
0
 def plot_correlation_matrix(self):
     "Plot correlation_matrix"
     raw = self.raw.copy()
     if self.apply_montage is True:
         raw.set_montage(self.montage)
     ch_names = raw.info["ch_names"]
     ica_template = mne.preprocessing.read_ica('template-ica.fif')
     common = find_common_channels(ica_template, self.ica)
     components_template, components_ics = extract_common_components(ica_template, self.ica)
     templates = components_template[[0, 7]]
     df = compute_correlation(templates, components_ics)
     raw.rename_channels(tolow)
     raw.reorder_channels(common)
     ch_names = raw.info["ch_names"]
     picks = [i for i in range(len(ch_names)) if ch_names[i].lower() in common]
     pos = _find_topomap_coords(raw.info, picks=picks)
     quality = len(common) / len(ch_names)
     plot_correlation(df, templates, pos, quality)
     return()
Exemple #4
0
def test_find_topomap_coords():
    """Test mapping of coordinates in 3D space to 2D."""
    info = read_info(fif_fname)
    picks = pick_types(info, meg=False, eeg=True, eog=False, stim=False)

    # Remove extra digitization point, so EEG digitization points match up
    # with the EEG channels
    del info['dig'][85]

    # Use channel locations
    kwargs = dict(ignore_overlap=False,
                  to_sphere=True,
                  sphere=HEAD_SIZE_DEFAULT)
    l0 = _find_topomap_coords(info, picks, **kwargs)

    # Remove electrode position information, use digitization points from now
    # on.
    for ch in info['chs']:
        ch['loc'].fill(np.nan)

    l1 = _find_topomap_coords(info, picks, **kwargs)
    assert_allclose(l1, l0, atol=1e-3)

    for z_pt in ((HEAD_SIZE_DEFAULT, 0., 0.), (0., HEAD_SIZE_DEFAULT, 0.)):
        info['dig'][-1]['r'] = z_pt
        l1 = _find_topomap_coords(info, picks, **kwargs)
        assert_allclose(l1[-1], z_pt[:2], err_msg='Z=0 point moved', atol=1e-6)

    # Test plotting mag topomap without channel locations: it should fail
    mag_picks = pick_types(info, meg='mag')
    with pytest.raises(ValueError, match='Cannot determine location'):
        _find_topomap_coords(info, mag_picks, **kwargs)

    # Test function with too many EEG digitization points: it should fail
    info['dig'].append({'r': [1, 2, 3], 'kind': FIFF.FIFFV_POINT_EEG})
    with pytest.raises(ValueError, match='Number of EEG digitization points'):
        _find_topomap_coords(info, picks, **kwargs)

    # Test function with too little EEG digitization points: it should fail
    info['dig'] = info['dig'][:-2]
    with pytest.raises(ValueError, match='Number of EEG digitization points'):
        _find_topomap_coords(info, picks, **kwargs)

    # Electrode positions must be unique
    info['dig'].append(info['dig'][-1])
    with pytest.raises(ValueError, match='overlapping positions'):
        _find_topomap_coords(info, picks, **kwargs)

    # Test function without EEG digitization points: it should fail
    info['dig'] = [d for d in info['dig'] if d['kind'] != FIFF.FIFFV_POINT_EEG]
    with pytest.raises(RuntimeError, match='Did not find any digitization'):
        _find_topomap_coords(info, picks, **kwargs)

    # Test function without any digitization points, it should fail
    info['dig'] = None
    with pytest.raises(RuntimeError, match='No digitization points found'):
        _find_topomap_coords(info, picks, **kwargs)
    info['dig'] = []
    with pytest.raises(RuntimeError, match='No digitization points found'):
        _find_topomap_coords(info, picks, **kwargs)
Exemple #5
0
def plot_results(fourier_ica_obj, meg_data,
                 W_orig, A_orig, info, picks,
                 cluster_quality=[], fnout=None,
                 show=True, plot_all=True):

    """
    Generate plot containing all results achieved by
    applying FourierICA, i.e., plot activations in
    time- and source-space, as well as fourier
    amplitudes and topoplots.

        Parameters
        ----------
        fourier_ica_obj: FourierICA object
        meg_data: array of data to be decomposed [nchan, ntsl].
        W_orig: estimated de-mixing matrix
        A_orig: estimated mixing matrix
        info: instance of mne.io.meas_info.Info
            Measurement info.
        picks: Channel indices to generate topomap coords for.
        cluster_quality: if set cluster quality is added as text
            info on the plot. Cluster quality is of interest when
            FourierICA combined with ICASSO was applied.
            default: cluster_quality=[]
        fnout: output name for the result image. If not set, the
            image won't be saved. Note, the ending '.png' is
            automatically added
            default: fnout=None
        show: if set plotting results are shown
            default: show=True
        plot_all: if set results for all components are plotted.
            Otherwise only the first 10 components are plotted.
            default: plot_all=True
    """



    # ------------------------------------------
    # import necessary modules
    # ------------------------------------------
    from matplotlib import pyplot as plt
    from matplotlib import gridspec as grd
    from mne.viz import plot_topomap
    from mne.channels.layout import _find_topomap_coords
    import types

    # ------------------------------------------
    # generate sources for plotting
    # ------------------------------------------
    temporal_envelope, pk_values = fourier_ica_obj.get_temporal_envelope(meg_data, W_orig)
    rec_signal_avg, orig_avg = fourier_ica_obj.get_reconstructed_signal(meg_data, W_orig, A_orig)
    fourier_ampl = fourier_ica_obj.get_fourier_ampl(meg_data, W_orig)


    # ------------------------------------------
    # collect some general information
    # ------------------------------------------
    ntsl = int(np.floor(fourier_ica_obj.sfreq*fourier_ica_obj.win_length_sec))
    tpost = fourier_ica_obj.tpre + fourier_ica_obj.win_length_sec
    nchan, ncomp = A_orig.shape
    nbins = fourier_ampl.shape[1]
    sfreq_bins = nbins/(fourier_ica_obj.fhigh - fourier_ica_obj.flow)

    # define axis/positions for plots
    xaxis_time = np.arange(ntsl)/fourier_ica_obj.sfreq + fourier_ica_obj.tpre
    xaxis_fourier = np.arange(nbins)/sfreq_bins + fourier_ica_obj.flow
    ylim_act = [np.min(temporal_envelope), np.max(temporal_envelope)]
    ylim_meg = [np.min(orig_avg), np.max(orig_avg)]
    pos = _find_topomap_coords(info, picks)

    # ------------------------------------------
    # loop over all activations
    # ------------------------------------------
    plt.ioff()
    if plot_all:
        nimg = int(np.ceil(ncomp /10.0))
    else:
        nimg = 1

    if isinstance(A_orig[0, 0], types.ComplexType):
        nplots_per_comp = 8
    else:
        nplots_per_comp = 7


    # loop over all images
    for iimg in range(nimg):

        fig = plt.figure('FourierICA plots', figsize=(18, 14))

        # estimate how many plots on current image
        istart_plot = int(10*iimg)
        nplot = np.min([10*(iimg+1), ncomp])
        gs = grd.GridSpec(10, nplots_per_comp)

        for icomp in range(istart_plot, nplot):

            if icomp == nplot-1:
                spines = ['bottom']
            else:
                spines = []

            # ----------------------------------------------
            # (1.) plot activations in time domain
            # ----------------------------------------------
            p1 = plt.subplot(gs[icomp-istart_plot, 0:2])
            plt.xlim(fourier_ica_obj.tpre, tpost)
            plt.ylim(ylim_act)
            adjust_spines(p1, spines, labelsize=13)
            if icomp == nplot-1:
                plt.xlabel('time [s]')
            elif icomp == istart_plot:
                p1.set_title("activations [time domain]")
            p1.plot(xaxis_time, temporal_envelope[icomp, :])

            # add some information
            txt_info = 'cluster qual.: %0.2f; ' % cluster_quality[icomp] if cluster_quality.any() else ''

            if pk_values.any():
                txt_info += 'pk: %0.2f' % pk_values[icomp]
                p1.text(0.97*fourier_ica_obj.tpre+0.03*tpost, 0.8*ylim_act[1] + 0.1*ylim_act[0],
                        txt_info, color='r')


            IC_number = 'IC#%d' % (icomp+1)
            p1.text(1.1*fourier_ica_obj.tpre-0.1*tpost, 0.4*ylim_act[1] + 0.6*ylim_act[0],
                    IC_number, color='black', rotation=90)

            # ----------------------------------------------
            # (2.) plot back-transformed signals
            # ----------------------------------------------
            p2 = plt.subplot(gs[icomp-istart_plot, 2:4])
            plt.xlim(fourier_ica_obj.tpre, tpost)
            plt.ylim(ylim_meg)
            adjust_spines(p2, spines, labelsize=13)
            if icomp == nplot-1:
                plt.xlabel('time [s]')
            elif icomp == istart_plot:
                p2.set_title("reconstructed MEG-signals")
            p2.plot(xaxis_time, orig_avg.T, 'b', linewidth=0.5)
            p2.plot(xaxis_time, rec_signal_avg[icomp, :, :].T, 'r', linewidth=0.5)

            # ----------------------------------------------
            # (3.) plot Fourier amplitudes
            # ----------------------------------------------
            p3 = plt.subplot(gs[icomp-istart_plot, 4:6])
            plt.xlim(fourier_ica_obj.flow, fourier_ica_obj.fhigh)
            plt.ylim(0.0, 1.0)
            adjust_spines(p3, spines, labelsize=13)
            if icomp == nplot-1:
                plt.xlabel('freq [Hz]')
            elif icomp == istart_plot:
                p3.set_title("Fourier amplitude (arbitrary units)")

            p3.bar(xaxis_fourier, fourier_ampl[icomp, :], 0.8, color='b')

            # ----------------------------------------------
            # (4.) topoplots (magnitude / phase difference)
            # ----------------------------------------------
            if isinstance(A_orig[0, icomp], types.ComplexType):
                magnitude = np.abs(A_orig[:, icomp])
                magnitude = (2 * magnitude/np.max(magnitude)) - 1
                p4 = plt.subplot(gs[icomp-istart_plot, 6])
                im, _ = plot_topomap(magnitude, pos, res=200, vmin=-1, vmax=1, contours=0)
                if icomp == istart_plot:
                    p4.set_title("Magnitude")
                if icomp == nplot-1:
                    cbar = plt.colorbar(im, ticks=[-1, 0, 1], orientation='horizontal', shrink=0.8,
                                        fraction=0.04, pad=0.04)
                    cbar.ax.set_yticklabels(['-1.0', '0.0', '1.0'])

                phase_diff = (np.angle(A_orig[:, icomp]) + np.pi) / (2 * np.pi)
                p5 = plt.subplot(gs[icomp-istart_plot, 7])
                im, _ = plot_topomap(phase_diff, pos, res=200, vmin=0, vmax=1, contours=0)
                if icomp == istart_plot:
                    p5.set_title("Phase differences")
                if icomp == nplot-1:
                    cbar = plt.colorbar(im, ticks=[-1, 0, 1], orientation='horizontal', shrink=0.9,
                                        fraction=0.04, pad=0.04)
                    cbar.ax.set_yticklabels(['0.0', '0.5', '1.0'])

            else:
                from jumeg.jumeg_math import rescale
                p4 = plt.subplot(gs[icomp-istart_plot, 6])
                magnitude = A_orig[:, icomp]
                magnitude = rescale(magnitude, -1, 1)
                im, _ = plot_topomap(magnitude, pos, res=200, vmin=-1, vmax=1, contours=0)
                if icomp == istart_plot:
                    p4.set_title("Magnitude distribution")
                if icomp == nplot-1:
                    cbar = plt.colorbar(im, ticks=[-1, 0, 1], orientation='horizontal', shrink=0.9,
                                        fraction=0.04, pad=0.04)
                    cbar.ax.set_yticklabels(['-1.0', '0.0', '1.0'])

        # save image
        if fnout:
            if plot_all:
                fnout_complete = '%s%2d.png' % (fnout, iimg+1)
            else:
                fnout_complete = '%s.png' % fnout

            plt.savefig(fnout_complete, format='png')

        # show image if requested
        if show:
            plt.show()

        plt.close('FourierICA plots')

    plt.ion()

    return pk_values
Exemple #6
0
def ICs_topoplot(A_orig, info, picks, fnout=None, show=True):

    """
    Generate topoplots from the demixing matrix recieved
    by applying FourierICA.

        Parameters
        ----------
        fourier_ica_obj: FourierICA object
        info: instance of mne.io.meas_info.Info
            Measurement info.
        picks: Channel indices to generate topomap coords for.
        fnout: output name for the result image. If not set, the
            image won't be saved. Note, the ending '.png' is
            automatically added
            default: fnout=None
        show: if set plotting results are shown
            default: show=True
    """

    # ------------------------------------------
    # import necessary modules
    # ------------------------------------------
    from matplotlib import pyplot as plt
    from matplotlib import gridspec as grd
    from mne.viz import plot_topomap
    from mne.channels.layout import _find_topomap_coords
    import types

    # ------------------------------------------
    # collect some general information
    # ------------------------------------------
    nchan, ncomp = A_orig.shape

    # define axis/positions for plots
    pos = _find_topomap_coords(info, picks)

    plt.ioff()
    plt.figure('Topoplots', figsize=(5, 14))
    nplot = np.min([10, ncomp])

    if isinstance(A_orig[0, 0], types.ComplexType):
        nplots_per_comp = 2
    else:
        nplots_per_comp = 1

    gs = grd.GridSpec(nplot, nplots_per_comp)

    # ------------------------------------------
    # loop over all activations
    # ------------------------------------------
    for icomp in range(nplot):

        # ----------------------------------------------
        # (topoplots (magnitude / phase difference)
        # ----------------------------------------------
        if isinstance(A_orig[0, icomp], types.ComplexType):
            magnitude = np.abs(A_orig[:, icomp])
            magnitude = (2 * magnitude/np.max(magnitude)) - 1
            p1 = plt.subplot(gs[icomp, 0])
            im, _ = plot_topomap(magnitude, pos, res=200, vmin=-1, vmax=1, contours=0)
            if icomp == 0:
                p1.set_title("Magnitude")
            if icomp == nplot-1:
                cbar = plt.colorbar(im, ticks=[-1, 0, 1], orientation='horizontal', shrink=0.8)
                cbar.ax.set_yticklabels(['-1.0', '0.0', '1.0'])

            phase_diff = (np.angle(A_orig[:, icomp]) + np.pi) / (2 * np.pi)
            p2 = plt.subplot(gs[icomp, 1])
            im, _ = plot_topomap(phase_diff, pos, res=200, vmin=0, vmax=1, contours=0)
            if icomp == 0:
                p2.set_title("Phase differences")
            if icomp == nplot-1:
                cbar = plt.colorbar(im, ticks=[-1, 0, 1], orientation='horizontal', shrink=0.8)
                cbar.ax.set_yticklabels(['0.0', '0.5', '1.0'])

        else:
            p1 = plt.subplot(gs[icomp, 0:2])
            magnitude = A_orig[:, icomp]
            magnitude = (2 * magnitude/np.max(magnitude)) - 1
            plot_topomap(magnitude, pos, res=200, vmin=-1, vmax=1, contours=0)
            if icomp == 0:
                p1.set_title("Magnitude distribution")
            if icomp == nplot-1:
                cbar = plt.colorbar(im, ticks=[-1, 0, 1], orientation='horizontal', shrink=0.8)
                cbar.ax.set_yticklabels(['-1.0', '0.0', '1.0'])

    # save image
    if fnout:
        plt.savefig(fnout + '.png', format='png')

    # show image if requested
    if show:
        plt.show()

    plt.close('Topoplots')
    plt.ion()
Exemple #7
0
def plot_topomap(data, pos, vmin=None, vmax=None, cmap=None, sensors=True,
                 res=64, axes=None, names=None, show_names=False, mask=None,
                 mask_params=None, outlines='head', image_mask=None,
                 contours=6, image_interp='bilinear', show=True,
                 head_pos=None, onselect=None, axis=None):
    ''' see the docstring for mne.viz.plot_topomap,
        which i've simply modified to return more objects '''

    from matplotlib.widgets import RectangleSelector
    from mne.io.pick import (channel_type, pick_info, _pick_data_channels)
    from mne.utils import warn
    from mne.viz.utils import (_setup_vmin_vmax, plt_show)
    from mne.defaults import _handle_default
    from mne.channels.layout import _find_topomap_coords
    from mne.io.meas_info import Info
    from mne.viz.topomap import _check_outlines, _prepare_topomap, _griddata, _make_image_mask, _plot_sensors, \
        _draw_outlines

    data = np.asarray(data)

    if isinstance(pos, Info):  # infer pos from Info object
        picks = _pick_data_channels(pos)  # pick only data channels
        pos = pick_info(pos, picks)

        # check if there is only 1 channel type, and n_chans matches the data
        ch_type = set(channel_type(pos, idx)
                      for idx, _ in enumerate(pos["chs"]))
        info_help = ("Pick Info with e.g. mne.pick_info and "
                     "mne.channels.channel_indices_by_type.")
        if len(ch_type) > 1:
            raise ValueError("Multiple channel types in Info structure. " +
                             info_help)
        elif len(pos["chs"]) != data.shape[0]:
            raise ValueError("Number of channels in the Info object and "
                             "the data array does not match. " + info_help)
        else:
            ch_type = ch_type.pop()

        if any(type_ in ch_type for type_ in ('planar', 'grad')):
            # deal with grad pairs
            from ..channels.layout import (_merge_grad_data, find_layout,
                                           _pair_grad_sensors)
            picks, pos = _pair_grad_sensors(pos, find_layout(pos))
            data = _merge_grad_data(data[picks]).reshape(-1)
        else:
            picks = list(range(data.shape[0]))
            pos = _find_topomap_coords(pos, picks=picks)

    if data.ndim > 1:
        raise ValueError("Data needs to be array of shape (n_sensors,); got "
                         "shape %s." % str(data.shape))

    # Give a helpful error message for common mistakes regarding the position
    # matrix.
    pos_help = ("Electrode positions should be specified as a 2D array with "
                "shape (n_channels, 2). Each row in this matrix contains the "
                "(x, y) position of an electrode.")
    if pos.ndim != 2:
        error = ("{ndim}D array supplied as electrode positions, where a 2D "
                 "array was expected").format(ndim=pos.ndim)
        raise ValueError(error + " " + pos_help)
    elif pos.shape[1] == 3:
        error = ("The supplied electrode positions matrix contains 3 columns. "
                 "Are you trying to specify XYZ coordinates? Perhaps the "
                 "mne.channels.create_eeg_layout function is useful for you.")
        raise ValueError(error + " " + pos_help)
    # No error is raised in case of pos.shape[1] == 4. In this case, it is
    # assumed the position matrix contains both (x, y) and (width, height)
    # values, such as Layout.pos.
    elif pos.shape[1] == 1 or pos.shape[1] > 4:
        raise ValueError(pos_help)

    if len(data) != len(pos):
        raise ValueError("Data and pos need to be of same length. Got data of "
                         "length %s, pos of length %s" % (len(data), len(pos)))

    norm = min(data) >= 0
    vmin, vmax = _setup_vmin_vmax(data, vmin, vmax, norm)
    if cmap is None:
        cmap = 'Reds' if norm else 'RdBu_r'

    pos, outlines = _check_outlines(pos, outlines, head_pos)

    if axis is not None:
        axes = axis
        warn('axis parameter is deprecated and will be removed in 0.13. '
             'Use axes instead.', DeprecationWarning)
    ax = axes if axes else plt.gca()
    pos_x, pos_y = _prepare_topomap(pos, ax)
    if outlines is None:
        xmin, xmax = pos_x.min(), pos_x.max()
        ymin, ymax = pos_y.min(), pos_y.max()
    else:
        xlim = np.inf, -np.inf,
        ylim = np.inf, -np.inf,
        mask_ = np.c_[outlines['mask_pos']]
        xmin, xmax = (np.min(np.r_[xlim[0], mask_[:, 0]]),
                      np.max(np.r_[xlim[1], mask_[:, 0]]))
        ymin, ymax = (np.min(np.r_[ylim[0], mask_[:, 1]]),
                      np.max(np.r_[ylim[1], mask_[:, 1]]))

    # interpolate data
    xi = np.linspace(xmin, xmax, res)
    yi = np.linspace(ymin, ymax, res)
    Xi, Yi = np.meshgrid(xi, yi)
    Zi = _griddata(pos_x, pos_y, data, Xi, Yi)

    if outlines is None:
        _is_default_outlines = False
    elif isinstance(outlines, dict):
        _is_default_outlines = any(k.startswith('head') for k in outlines)

    if _is_default_outlines and image_mask is None:
        # prepare masking
        image_mask, pos = _make_image_mask(outlines, pos, res)

    mask_params = _handle_default('mask_params', mask_params)

    # plot outline
    linewidth = mask_params['markeredgewidth']
    patch = None
    if 'patch' in outlines:
        patch = outlines['patch']
        patch_ = patch() if callable(patch) else patch
        patch_.set_clip_on(False)
        ax.add_patch(patch_)
        ax.set_transform(ax.transAxes)
        ax.set_clip_path(patch_)

    # plot map and countour
    im = ax.imshow(Zi, cmap=cmap, vmin=vmin, vmax=vmax, origin='lower',
                   aspect='equal', extent=(xmin, xmax, ymin, ymax),
                   interpolation=image_interp)

    # This tackles an incomprehensible matplotlib bug if no contours are
    # drawn. To avoid rescalings, we will always draw contours.
    # But if no contours are desired we only draw one and make it invisible .
    no_contours = False
    if contours in (False, None):
        contours, no_contours = 1, True
    cont = ax.contour(Xi, Yi, Zi, contours, colors='k',
                      linewidths=linewidth)
    if no_contours is True:
        for col in cont.collections:
            col.set_visible(False)

    if _is_default_outlines:
        from matplotlib import patches
        patch_ = patches.Ellipse((0, 0),
                                 2 * outlines['clip_radius'][0],
                                 2 * outlines['clip_radius'][1],
                                 clip_on=True,
                                 transform=ax.transData)
    if _is_default_outlines or patch is not None:
        im.set_clip_path(patch_)
        if cont is not None:
            for col in cont.collections:
                col.set_clip_path(patch_)

    if sensors is not False and mask is None:
        _plot_sensors(pos_x, pos_y, sensors=sensors, ax=ax)
    elif sensors and mask is not None:
        idx = np.where(mask)[0]
        ax.plot(pos_x[idx], pos_y[idx], **mask_params)
        idx = np.where(~mask)[0]
        _plot_sensors(pos_x[idx], pos_y[idx], sensors=sensors, ax=ax)
    elif not sensors and mask is not None:
        idx = np.where(mask)[0]
        ax.plot(pos_x[idx], pos_y[idx], **mask_params)

    if isinstance(outlines, dict):
        _draw_outlines(ax, outlines)

    if show_names:
        if names is None:
            raise ValueError("To show names, a list of names must be provided"
                             " (see `names` keyword).")
        if show_names is True:
            def _show_names(x):
                return x
        else:
            _show_names = show_names
        show_idx = np.arange(len(names)) if mask is None else np.where(mask)[0]
        for ii, (p, ch_id) in enumerate(zip(pos, names)):
            if ii not in show_idx:
                continue
            ch_id = _show_names(ch_id)
            ax.text(p[0], p[1], ch_id, horizontalalignment='center',
                    verticalalignment='center', size='x-small')

    plt.subplots_adjust(top=.95)

    if onselect is not None:
        ax.RS = RectangleSelector(ax, onselect=onselect)
    plt_show(show)
    return ax, im, cont, pos_x, pos_y
Exemple #8
0
def make_topoplot(
    values,
    info,
    ax,
    vmin=None,
    vmax=None,
    plot_head=True,
    cmap="RdBu_r",
    size=30,
    hemisphere="all",
    picks=None,
    pick_color=["#2d004f", "#254f00", "#000000"],
):
    """Makes an ECoG topo plot with electrodes circles, without interpolation.
    Modified from MNE plot_topomap to plot electrodes without interpolation.

    Parameters
    ----------
    values : array, 1-D
        Values to plot as color-coded circles.
    info : instance of Info
        The x/y-coordinates of the electrodes will be infered from this object.
    ax : instance of Axes
        The axes to plot to.
    vmin : float | None
         Lower bound of the color range. If None: - maximum absolute value.
    vmax : float | None
        Upper bounds of the color range. If None: maximum absolute value.
    plot_head : True | False
        Whether to plot the outline for the head.
    cmap : matplotlib colormap | None
        Colormap to use for values, if None, defaults to RdBu_r.
    size : int
        Size of electrode circles.
    picks : list | None
        Which electrodes should be highlighted with by drawing a thicker edge.
    pick_color : list
        Edgecolor for highlighted electrodes.
    hemisphere : string ("left", "right", "all")
        Restrict which hemisphere of head outlines coordinates to plot.

    Returns
    -------
    sc : matplotlib PathCollection
        The colored electrode circles.
    """

    pos = _find_topomap_coords(info, picks=None)
    sphere = np.array([0.0, 0.0, 0.0, 0.095])
    outlines = _make_head_outlines(
        sphere=sphere, pos=pos, outlines="head", clip_origin=(0.0, 0.0)
    )

    if plot_head:
        outlines_ = {
            k: v for k, v in outlines.items() if k not in ["patch", "mask_pos"]
        }
        for key, (x_coord, y_coord) in outlines_.items():
            if hemisphere == "left":
                if type(x_coord) == np.ndarray:
                    idx = x_coord <= 0
                    x_coord = x_coord[idx]
                    y_coord = y_coord[idx]
                ax.plot(x_coord, y_coord, color="k", linewidth=1, clip_on=False)
            elif hemisphere == "right":
                if type(x_coord) == np.ndarray:
                    idx = x_coord >= 0
                    x_coord = x_coord[idx]
                    y_coord = y_coord[idx]
                ax.plot(x_coord, y_coord, color="k", linewidth=1, clip_on=False)
            else:
                ax.plot(x_coord, y_coord, color="k", linewidth=1, clip_on=False)

    if not (vmin) and not (vmax):
        vmin = -values.max()
        vmax = values.max()

    sc = ax.scatter(
        pos[:, 0],
        pos[:, 1],
        s=size,
        edgecolors="grey",
        c=values,
        vmin=vmin,
        vmax=vmax,
        cmap=plt.get_cmap(cmap),
    )

    if np.any(picks):
        picks = np.array(picks)
        if picks.ndim > 0:
            if len(pick_color) == 1:
                pick_color = [pick_color] * len(picks)
            for i, idxx in enumerate(picks):
                ax.scatter(
                    pos[idxx, 0],
                    pos[idxx, 1],
                    s=size,
                    edgecolors=pick_color[i],
                    facecolors="None",
                    linewidths=1.5,
                    c=None,
                    vmin=vmin,
                    vmax=vmax,
                    cmap=plt.get_cmap(cmap),
                )

        if picks.ndim == 2:
            if len(pick_color) == 1:
                pick_color = [pick_color] * len(picks)
            for i, idxx in enumerate(picks):
                ax.plot(
                    pos[idxx, 0],
                    pos[idxx, 1],
                    linestyle="-",
                    color=pick_color[i],
                    linewidth=1.5,
                )

    ax.axis("square")
    ax.axis("off")
    ax.xaxis.set_visible(False)
    ax.yaxis.set_visible(False)

    return sc