Exemple #1
0
def test_io_surface(tmp_path):
    """Test reading and writing of Freesurfer surface mesh files."""
    tempdir = str(tmp_path)
    fname_quad = op.join(data_path, 'subjects', 'bert', 'surf',
                         'lh.inflated.nofix')
    fname_tri = op.join(data_path, 'subjects', 'sample', 'bem',
                        'inner_skull.surf')
    for fname in (fname_quad, fname_tri):
        with _record_warnings():  # no volume info
            pts, tri, vol_info = read_surface(fname, read_metadata=True)
        write_surface(op.join(tempdir, 'tmp'), pts, tri, volume_info=vol_info,
                      overwrite=True)
        with _record_warnings():  # no volume info
            c_pts, c_tri, c_vol_info = read_surface(op.join(tempdir, 'tmp'),
                                                    read_metadata=True)
        assert_array_equal(pts, c_pts)
        assert_array_equal(tri, c_tri)
        assert_equal(object_diff(vol_info, c_vol_info), '')
        if fname != fname_tri:  # don't bother testing wavefront for the bigger
            continue

        # Test writing/reading a Wavefront .obj file
        write_surface(op.join(tempdir, 'tmp.obj'), pts, tri, volume_info=None,
                      overwrite=True)
        c_pts, c_tri = read_surface(op.join(tempdir, 'tmp.obj'),
                                    read_metadata=False)
        assert_array_equal(pts, c_pts)
        assert_array_equal(tri, c_tri)

    # reading patches (just a smoke test, let the flatmap viz tests be more
    # complete)
    fname_patch = op.join(
        data_path, 'subjects', 'fsaverage', 'surf', 'rh.cortex.patch.flat')
    _read_patch(fname_patch)
Exemple #2
0
def test_morph():
    """Test inter-subject label morphing."""
    label_orig = read_label(real_label_fname)
    label_orig.subject = 'sample'
    # should work for specifying vertices for both hemis, or just the
    # hemi of the given label
    vals = list()
    for grade in [5, [np.arange(10242), np.arange(10242)], np.arange(10242)]:
        label = label_orig.copy()
        # this should throw an error because the label has all zero values
        pytest.raises(ValueError, label.morph, 'sample', 'fsaverage')
        label.values.fill(1)
        label = label.morph(None, 'fsaverage', 5, grade, subjects_dir, 1)
        label = label.morph('fsaverage', 'sample', 5, None, subjects_dir, 2)
        assert (np.in1d(label_orig.vertices, label.vertices).all())
        assert (len(label.vertices) < 3 * len(label_orig.vertices))
        vals.append(label.vertices)
    assert_array_equal(vals[0], vals[1])
    # make sure label smoothing can run
    assert_equal(label.subject, 'sample')
    verts = [np.arange(10242), np.arange(10242)]
    for hemi in ['lh', 'rh']:
        label.hemi = hemi
        with _record_warnings():  # morph map maybe missing
            label.morph(None, 'fsaverage', 5, verts, subjects_dir, 2)
    pytest.raises(TypeError, label.morph, None, 1, 5, verts, subjects_dir, 2)
    pytest.raises(TypeError, label.morph, None, 'fsaverage', 5.5, verts,
                  subjects_dir, 2)
    with _record_warnings():  # morph map maybe missing
        label.smooth(subjects_dir=subjects_dir)  # make sure this runs
def test_setup_source_space(tmp_path):
    """Test setting up ico, oct, and all source spaces."""
    fname_ico = op.join(data_path, 'subjects', 'fsaverage', 'bem',
                        'fsaverage-ico-5-src.fif')
    # first lets test some input params
    for spacing in ('oct', 'oct6e'):
        with pytest.raises(ValueError, match='subdivision must be an integer'):
            setup_source_space('sample', spacing=spacing,
                               add_dist=False, subjects_dir=subjects_dir)
    for spacing in ('oct0', 'oct-4'):
        with pytest.raises(ValueError, match='oct subdivision must be >= 1'):
            setup_source_space('sample', spacing=spacing,
                               add_dist=False, subjects_dir=subjects_dir)
    with pytest.raises(ValueError, match='ico subdivision must be >= 0'):
        setup_source_space('sample', spacing='ico-4',
                           add_dist=False, subjects_dir=subjects_dir)
    with pytest.raises(ValueError, match='must be a string with values'):
        setup_source_space('sample', spacing='7emm',
                           add_dist=False, subjects_dir=subjects_dir)
    with pytest.raises(ValueError, match='must be a string with values'):
        setup_source_space('sample', spacing='alls',
                           add_dist=False, subjects_dir=subjects_dir)

    # ico 5 (fsaverage) - write to temp file
    src = read_source_spaces(fname_ico)
    with _record_warnings():  # sklearn equiv neighbors
        src_new = setup_source_space('fsaverage', spacing='ico5',
                                     subjects_dir=subjects_dir, add_dist=False)
    _compare_source_spaces(src, src_new, mode='approx')
    assert repr(src).split('~')[0] == repr(src_new).split('~')[0]
    assert repr(src).count('surface (') == 2
    assert_array_equal(src[0]['vertno'], np.arange(10242))
    assert_array_equal(src[1]['vertno'], np.arange(10242))

    # oct-6 (sample) - auto filename + IO
    src = read_source_spaces(fname)
    temp_name = tmp_path / 'temp-src.fif'
    with _record_warnings():  # sklearn equiv neighbors
        src_new = setup_source_space('sample', spacing='oct6',
                                     subjects_dir=subjects_dir, add_dist=False)
        write_source_spaces(temp_name, src_new, overwrite=True)
    assert_equal(src_new[0]['nuse'], 4098)
    _compare_source_spaces(src, src_new, mode='approx', nearest=False)
    src_new = read_source_spaces(temp_name)
    _compare_source_spaces(src, src_new, mode='approx', nearest=False)

    # all source points - no file writing
    src_new = setup_source_space('sample', spacing='all',
                                 subjects_dir=subjects_dir, add_dist=False)
    assert src_new[0]['nuse'] == len(src_new[0]['rr'])
    assert src_new[1]['nuse'] == len(src_new[1]['rr'])

    # dense source space to hit surf['inuse'] lines of _create_surf_spacing
    pytest.raises(RuntimeError, setup_source_space, 'sample',
                  spacing='ico6', subjects_dir=subjects_dir, add_dist=False)
Exemple #4
0
def _check_warnings(raw, events, picks=None, count=3):
    """Count warnings."""
    with _record_warnings() as w:
        Epochs(raw, events, dict(aud_l=1, vis_l=3),
               -0.2, 0.5, picks=picks, preload=True, proj=True)
    assert len(w) == count
    assert all('dangerous' in str(ww.message) for ww in w)
Exemple #5
0
def test_multitaper_psd():
    """Test multi-taper PSD computation."""
    import nitime as ni
    for n_times in (100, 101):
        n_channels = 5
        data = np.random.RandomState(0).randn(n_channels, n_times)
        sfreq = 500
        info = create_info(n_channels, sfreq, 'eeg')
        raw = RawArray(data, info)
        pytest.raises(ValueError, psd_multitaper, raw, sfreq,
                      normalization='foo')
        norm = 'full'
        for adaptive, n_jobs in zip((False, True, True), (1, 1, 2)):
            psd, freqs = psd_multitaper(raw, adaptive=adaptive,
                                        n_jobs=n_jobs,
                                        normalization=norm)
            with _record_warnings():  # nitime integers
                freqs_ni, psd_ni, _ = ni.algorithms.spectral.multi_taper_psd(
                    data, sfreq, adaptive=adaptive, jackknife=False)
            assert_array_almost_equal(psd, psd_ni, decimal=4)
            if n_times % 2 == 0:
                # nitime's frequency definitions must be incorrect,
                # they give the same values for 100 and 101 samples
                assert_array_almost_equal(freqs, freqs_ni)
        with pytest.raises(ValueError, match='use a value of at least'):
            psd_multitaper(raw, bandwidth=4.9)
Exemple #6
0
def test_what(tmp_path, verbose_debug):
    """Test mne.what."""
    # ICA
    ica = ICA(max_iter=1)
    raw = RawArray(np.random.RandomState(0).randn(3, 10),
                   create_info(3, 1000., 'eeg'))
    with _record_warnings():  # convergence sometimes
        ica.fit(raw)
    fname = op.join(str(tmp_path), 'x-ica.fif')
    ica.save(fname)
    assert what(fname) == 'ica'
    # test files
    fnames = glob.glob(
        op.join(data_path, 'MEG', 'sample', '*.fif'))
    fnames += glob.glob(
        op.join(data_path, 'subjects', 'sample', 'bem', '*.fif'))
    fnames = sorted(fnames)
    want_dict = dict(eve='events', ave='evoked', cov='cov', inv='inverse',
                     fwd='forward', trans='transform', proj='proj',
                     raw='raw', meg='raw', sol='bem solution',
                     bem='bem surfaces', src='src', dense='bem surfaces',
                     sparse='bem surfaces', head='bem surfaces',
                     fiducials='fiducials')
    for fname in fnames:
        kind = op.splitext(fname)[0].split('-')[-1]
        if len(kind) > 5:
            kind = kind.split('_')[-1]
        this = what(fname)
        assert this == want_dict[kind]
    fname = op.join(data_path, 'MEG', 'sample', 'sample_audvis-ave_xfit.dip')
    assert what(fname) == 'unknown'
Exemple #7
0
def test_localization_bias_free(bias_params_free, reg, pick_ori, weight_norm,
                                use_cov, depth, lower, upper, lower_ori,
                                upper_ori):
    """Test localization bias for free-orientation LCMV."""
    evoked, fwd, noise_cov, data_cov, want = bias_params_free
    if not use_cov:
        evoked.pick_types(meg='grad')
        noise_cov = None
    with _record_warnings():  # rank deficiency of data_cov
        filters = make_lcmv(evoked.info,
                            fwd,
                            data_cov,
                            reg,
                            noise_cov,
                            pick_ori=pick_ori,
                            weight_norm=weight_norm,
                            depth=depth)
    loc = apply_lcmv(evoked, filters).data
    if pick_ori == 'vector':
        ori = loc.copy() / np.linalg.norm(loc, axis=1, keepdims=True)
    else:
        # doesn't make sense for pooled (None) or max-power (can't be all 3)
        ori = None
    loc = np.linalg.norm(loc, axis=1) if pick_ori == 'vector' else np.abs(loc)
    # Compute the percentage of sources for which there is no loc bias:
    max_idx = np.argmax(loc, axis=0)
    perc = (want == max_idx).mean() * 100
    assert lower <= perc <= upper
    _assert_free_ori_match(ori, max_idx, lower_ori, upper_ori)
Exemple #8
0
def test_add_reorder(n_ref):
    """Test that a reference channel can be added and then data reordered."""
    # gh-8300
    raw = read_raw_fif(raw_fname).crop(0, 0.1).del_proj().pick('eeg')
    assert len(raw.ch_names) == 60
    chs = ['EEG %03d' % (60 + ii) for ii in range(1, n_ref)] + ['EEG 000']
    with pytest.raises(RuntimeError, match='preload'):
        with _record_warnings():  # ignore multiple warning
            add_reference_channels(raw, chs, copy=False)
    raw.load_data()
    if n_ref == 1:
        ctx = nullcontext()
    else:
        assert n_ref == 2
        ctx = pytest.warns(RuntimeWarning, match='locations of multiple')
    with ctx:
        add_reference_channels(raw, chs, copy=False)
    data = raw.get_data()
    assert_array_equal(data[-1], 0.)
    assert raw.ch_names[-n_ref:] == chs
    raw.reorder_channels(raw.ch_names[-1:] + raw.ch_names[:-1])
    assert raw.ch_names == ['EEG %03d' % ii for ii in range(60 + n_ref)]
    data_new = raw.get_data()
    data_new = np.concatenate([data_new[1:], data_new[:1]])
    assert_allclose(data, data_new)
Exemple #9
0
def test_surface_source_morph_round_trip(smooth, lower, upper, n_warn, dtype):
    """Test round-trip morphing yields similar STCs."""
    kwargs = dict(smooth=smooth, warn=True, subjects_dir=subjects_dir)
    stc = mne.read_source_estimate(fname_smorph)
    if dtype is complex:
        stc.data = 1j * stc.data
        assert_array_equal(stc.data.real, 0.)
    if smooth == 'nearest' and not check_version('scipy', '1.3'):
        with pytest.raises(ValueError, match='required to use nearest'):
            morph = compute_source_morph(stc, 'sample', 'fsaverage', **kwargs)
        return
    with _record_warnings() as w:
        morph = compute_source_morph(stc, 'sample', 'fsaverage', **kwargs)
    w = [ww for ww in w if 'vertices not included' in str(ww.message)]
    assert len(w) == n_warn
    assert morph.morph_mat.shape == (20484, len(stc.data))
    stc_fs = morph.apply(stc)
    morph_back = compute_source_morph(stc_fs,
                                      'fsaverage',
                                      'sample',
                                      spacing=stc.vertices,
                                      **kwargs)
    assert morph_back.morph_mat.shape == (len(stc.data), 20484)
    stc_back = morph_back.apply(stc_fs)
    corr = np.corrcoef(stc.data.ravel(), stc_back.data.ravel())[0, 1]
    assert lower <= corr <= upper
    # check the round-trip power
    assert_power_preserved(stc, stc_back)
Exemple #10
0
def test_read_raw_curry(fname, tol, preload, bdf_curry_ref):
    """Test reading CURRY files."""
    with _record_warnings() as wrn:
        raw = read_raw_curry(fname, preload=preload)

    if not check_version('numpy',
                         '1.16') and preload and fname.endswith('ASCII.dat'):
        assert len(wrn) > 0
    else:
        assert len(wrn) == 0

    assert hasattr(raw, '_data') == preload
    assert raw.n_times == bdf_curry_ref.n_times
    assert raw.info['sfreq'] == bdf_curry_ref.info['sfreq']

    for field in ['kind', 'ch_name']:
        assert_array_equal([ch[field] for ch in raw.info['chs']],
                           [ch[field] for ch in bdf_curry_ref.info['chs']])

    assert_allclose(raw.get_data(verbose='error'),
                    bdf_curry_ref.get_data(),
                    atol=tol)

    picks, start, stop = ["C3", "C4"], 200, 800
    assert_allclose(raw.get_data(picks=picks,
                                 start=start,
                                 stop=stop,
                                 verbose='error'),
                    bdf_curry_ref.get_data(picks=picks, start=start,
                                           stop=stop),
                    rtol=tol)
    assert raw.info['dev_head_t'] is None
Exemple #11
0
def test_plot_raw_nan(raw, browser_backend):
    """Test plotting all NaNs."""
    raw._data[:] = np.nan
    # this should (at least) not die, the output should pretty clearly show
    # that there is a problem so probably okay to just plot something blank
    with _record_warnings():
        raw.plot(scalings='auto')
Exemple #12
0
def test_read_raw_edf_stim_channel_input_parameters():
    """Test edf raw reader deprecation."""
    _MSG = "`read_raw_edf` is not supposed to trigger a deprecation warning"
    with _record_warnings() as recwarn:
        read_raw_edf(edf_path)
    assert all([w.category != DeprecationWarning for w in recwarn]), _MSG

    for invalid_stim_parameter in ['EDF Annotations', 'BDF Annotations']:
        with pytest.raises(ValueError, match="stim channel is not supported"):
            read_raw_edf(edf_path, stim_channel=invalid_stim_parameter)
Exemple #13
0
def test_plot_raw_meas_date(raw, browser_backend):
    """Test effect of mismatched meas_date in raw.plot()."""
    raw.set_meas_date(_dt_to_stamp(raw.info['meas_date'])[0])
    annot = Annotations([1 + raw.first_samp / raw.info['sfreq']], [5], ['bad'])
    with pytest.warns(RuntimeWarning, match='outside data range'):
        raw.set_annotations(annot)
    with _record_warnings():  # sometimes projection
        raw.plot(group_by='position', order=np.arange(8))
    fig = raw.plot()
    for key in ['down', 'up', 'escape']:
        fig._fake_keypress(key, fig=fig.mne.fig_selection)
Exemple #14
0
def test_plot_volume_source_estimates(mode, stype, init_t, want_t, init_p,
                                      want_p, bg_img):
    """Test interactive plotting of volume source estimates."""
    forward = read_forward_solution(fwd_fname)
    sample_src = forward['src']
    if init_p is not None:
        init_p = np.array(init_p) / 1000.

    vertices = [s['vertno'] for s in sample_src]
    n_verts = sum(len(v) for v in vertices)
    n_time = 2
    data = np.random.RandomState(0).rand(n_verts, n_time)

    if stype == 'vec':
        stc = VolVectorSourceEstimate(np.tile(data[:, np.newaxis], (1, 3, 1)),
                                      vertices, 1, 1)
    else:
        assert stype == 's'
        stc = VolSourceEstimate(data, vertices, 1, 1)
    # sometimes get scalars/index warning
    with _record_warnings():
        with catch_logging() as log:
            fig = stc.plot(sample_src,
                           subject='sample',
                           subjects_dir=subjects_dir,
                           mode=mode,
                           initial_time=init_t,
                           initial_pos=init_p,
                           bg_img=bg_img,
                           verbose=True)
    log = log.getvalue()
    want_str = 't = %0.3f s' % want_t
    assert want_str in log, (want_str, init_t)
    want_str = '(%0.1f, %0.1f, %0.1f) mm' % want_p
    assert want_str in log, (want_str, init_p)
    for ax_idx in [0, 2, 3, 4]:
        _fake_click(fig, fig.axes[ax_idx], (0.3, 0.5))
    fig.canvas.key_press_event('left')
    fig.canvas.key_press_event('shift+right')
    if bg_img is not None:
        with pytest.raises(FileNotFoundError, match='MRI file .* not found'):
            stc.plot(sample_src,
                     subject='sample',
                     subjects_dir=subjects_dir,
                     mode='stat_map',
                     bg_img='junk.mgz')
    use_ax = None
    for ax in fig.axes:
        if ax.get_xlabel().startswith('Time'):
            use_ax = ax
            break
    assert use_ax is not None
    label = use_ax.get_legend().get_texts()[0].get_text()
    assert re.match('[0-9]*', label) is not None, label
def test_cluster_permutation_t_test(numba_conditional, stat_fun):
    """Test cluster level permutations T-test."""
    condition1_1d, condition2_1d, condition1_2d, condition2_2d = \
        _get_conditions()

    # use a very large sigma to make sure Ts are not independent
    for condition1, p in ((condition1_1d, 0.01), (condition1_2d, 0.01)):
        # these are so significant we can get away with fewer perms
        T_obs, clusters, cluster_p_values, hist =\
            permutation_cluster_1samp_test(condition1, n_permutations=100,
                                           tail=0, seed=1, out_type='mask',
                                           buffer_size=None)
        assert_equal(np.sum(cluster_p_values < 0.05), 1)
        p_min = np.min(cluster_p_values)
        assert_allclose(p_min, p, atol=1e-6)

        T_obs_pos, c_1, cluster_p_values_pos, _ =\
            permutation_cluster_1samp_test(condition1, n_permutations=100,
                                           tail=1, threshold=1.67, seed=1,
                                           stat_fun=stat_fun, out_type='mask',
                                           buffer_size=None)

        T_obs_neg, _, cluster_p_values_neg, _ =\
            permutation_cluster_1samp_test(-condition1, n_permutations=100,
                                           tail=-1, threshold=-1.67,
                                           seed=1, stat_fun=stat_fun,
                                           buffer_size=None, out_type='mask')
        assert_array_equal(T_obs_pos, -T_obs_neg)
        assert_array_equal(cluster_p_values_pos < 0.05,
                           cluster_p_values_neg < 0.05)

        # test with 2 jobs and buffer_size enabled
        buffer_size = condition1.shape[1] // 10
        with _record_warnings():  # sometimes "independently"
            T_obs_neg_buff, _, cluster_p_values_neg_buff, _ = \
                permutation_cluster_1samp_test(
                    -condition1, n_permutations=100, tail=-1, out_type='mask',
                    threshold=-1.67, seed=1, n_jobs=2, stat_fun=stat_fun,
                    buffer_size=buffer_size)

        assert_array_equal(T_obs_neg, T_obs_neg_buff)
        assert_array_equal(cluster_p_values_neg, cluster_p_values_neg_buff)

        # Bad stat_fun
        with pytest.raises(TypeError, match='must be .* ndarray'):
            permutation_cluster_1samp_test(condition1,
                                           threshold=1,
                                           stat_fun=lambda x: None,
                                           out_type='mask')
        with pytest.raises(ValueError, match='not compatible'):
            permutation_cluster_1samp_test(condition1,
                                           threshold=1,
                                           stat_fun=lambda x: stat_fun(x)[:-1],
                                           out_type='mask')
Exemple #16
0
def test_report(tmp_path):
    """Test mne report."""
    check_usage(mne_report)
    tempdir = str(tmp_path)
    use_fname = op.join(tempdir, op.basename(raw_fname))
    shutil.copyfile(raw_fname, use_fname)
    with ArgvSetter(('-p', tempdir, '-i', use_fname, '-d', subjects_dir, '-s',
                     'sample', '--no-browser', '-m', '30')):
        with _record_warnings():  # contour levels
            mne_report.run()
    fnames = glob.glob(op.join(tempdir, '*.html'))
    assert len(fnames) == 1
Exemple #17
0
def test_stockwell_check_input():
    """Test input checker for stockwell."""
    # check for data size equal and unequal to a power of 2

    for last_dim in (127, 128):
        data = np.zeros((2, 10, last_dim))
        with _record_warnings():  # n_fft sometimes
            x_in, n_fft, zero_pad = _check_input_st(data, None)

        assert_equal(x_in.shape, (2, 10, 128))
        assert_equal(n_fft, 128)
        assert_equal(zero_pad, 128 - last_dim)
Exemple #18
0
def test_read_epochs(cur_system, version, use_info, monkeypatch):
    """Test comparing reading an Epochs object and the FieldTrip version."""
    pandas = _check_pandas_installed(strict=False)
    has_pandas = pandas is not False
    test_data_folder_ft = get_data_paths(cur_system)
    mne_epoched = get_epochs(cur_system)
    if use_info:
        info = get_raw_info(cur_system)
        ctx = nullcontext()
    else:
        info = None
        ctx = pytest.warns(**no_info_warning)

    cur_fname = os.path.join(test_data_folder_ft,
                             'epoched_%s.mat' % (version, ))
    if has_pandas:
        if version == 'v73' and not _has_h5py():
            with pytest.raises(ImportError):
                mne.io.read_epochs_fieldtrip(cur_fname, info)
            return
        with ctx:
            epoched_ft = mne.io.read_epochs_fieldtrip(cur_fname, info)
        assert isinstance(epoched_ft.metadata, pandas.DataFrame)
    else:
        with _record_warnings() as warn_record:
            if version == 'v73' and not _has_h5py():
                with pytest.raises(ImportError):
                    mne.io.read_epochs_fieldtrip(cur_fname, info)
                return
            epoched_ft = mne.io.read_epochs_fieldtrip(cur_fname, info)
            assert epoched_ft.metadata is None
            assert_warning_in_record(pandas_not_found_warning_msg, warn_record)
            if info is None:
                assert_warning_in_record(NOINFO_WARNING, warn_record)

    mne_data = mne_epoched.get_data()[:, :, :-1]
    ft_data = epoched_ft.get_data()

    check_data(mne_data, ft_data, cur_system)
    check_info_fields(mne_epoched, epoched_ft, use_info)

    # weird sfreq
    from mne.externals.pymatreader import read_mat

    def modify_mat(fname, variable_names=None, ignore_fields=None):
        out = read_mat(fname, variable_names, ignore_fields)
        if 'fsample' in out['data']:
            out['data']['fsample'] = np.repeat(out['data']['fsample'], 2)
        return out

    monkeypatch.setattr(mne.externals.pymatreader, 'read_mat', modify_mat)
    with pytest.warns(RuntimeWarning, match='multiple'):
        mne.io.read_epochs_fieldtrip(cur_fname, info)
Exemple #19
0
def test_dpss_windows():
    """Test computation of DPSS windows."""
    import nitime as ni
    N = 1000
    half_nbw = 4
    Kmax = int(2 * half_nbw)

    dpss, eigs = dpss_windows(N, half_nbw, Kmax, low_bias=False)
    with _record_warnings():  # conversions
        dpss_ni, eigs_ni = ni.algorithms.dpss_windows(N, half_nbw, Kmax)

    assert_array_almost_equal(dpss, dpss_ni)
    assert_array_almost_equal(eigs, eigs_ni)

    dpss, eigs = dpss_windows(N, half_nbw, Kmax, interp_from=200,
                              low_bias=False)
    with _record_warnings():  # conversions
        dpss_ni, eigs_ni = ni.algorithms.dpss_windows(N, half_nbw, Kmax,
                                                      interp_from=200)

    assert_array_almost_equal(dpss, dpss_ni)
    assert_array_almost_equal(eigs, eigs_ni)
Exemple #20
0
def test_bdip(fname_dip_, fname_bdip_, tmp_path):
    """Test bdip I/O."""
    # use text as veridical
    with _record_warnings():  # ignored fields
        dip = read_dipole(fname_dip_)
    # read binary
    orig_size = os.stat(fname_bdip_).st_size
    bdip = read_dipole(fname_bdip_)
    # test round-trip by writing and reading, too
    fname = tmp_path / 'test.bdip'
    bdip.save(fname)
    bdip_read = read_dipole(fname)
    write_size = os.stat(str(fname)).st_size
    assert orig_size == write_size
    assert len(dip) == len(bdip) == len(bdip_read) == 17
    dip_has_conf = fname_dip_ == fname_dip_xfit
    for kind, this_bdip in (('orig', bdip), ('read', bdip_read)):
        for key, atol in (('pos', 5e-5), ('ori', 5e-3), ('gof', 0.5e-1),
                          ('times', 5e-5), ('khi2', 1e-2)):
            d = getattr(dip, key)
            b = getattr(this_bdip, key)
            if key == 'khi2' and dip_has_conf:
                if d is not None:
                    assert_allclose(d,
                                    b,
                                    atol=atol,
                                    err_msg='%s: %s' % (kind, key))
                else:
                    assert b is None
        if dip_has_conf:
            # conf
            conf_keys = _BDIP_ERROR_KEYS + ('vol', )
            assert (set(this_bdip.conf.keys()) == set(dip.conf.keys()) ==
                    set(conf_keys))
            for key in conf_keys:
                d = dip.conf[key]
                b = this_bdip.conf[key]
                assert_allclose(
                    d,
                    b,
                    rtol=0.12,  # no so great, text I/O
                    err_msg='%s: %s' % (kind, key))
        # Not stored
        assert this_bdip.name is None
        assert this_bdip.nfree is None

        # Test whether indexing works
        this_bdip0 = this_bdip[0]
        _check_dipole(this_bdip0, 1)
Exemple #21
0
def test_documented():
    """Test that public functions and classes are documented."""
    doc_dir = op.abspath(op.join(op.dirname(__file__), '..', '..', 'doc'))
    doc_file = op.join(doc_dir, 'python_reference.rst')
    if not op.isfile(doc_file):
        pytest.skip('Documentation file not found: %s' % doc_file)
    api_files = ('covariance', 'creating_from_arrays', 'datasets', 'decoding',
                 'events', 'file_io', 'forward', 'inverse', 'logging',
                 'most_used_classes', 'mri', 'preprocessing',
                 'reading_raw_data', 'realtime', 'report', 'sensor_space',
                 'simulation', 'source_space', 'statistics', 'time_frequency',
                 'visualization', 'export')
    known_names = list()
    for api_file in api_files:
        with open(op.join(doc_dir, f'{api_file}.rst'), 'rb') as fid:
            for line in fid:
                line = line.decode('utf-8')
                if not line.startswith('  '):  # at least two spaces
                    continue
                line = line.split()
                if len(line) == 1 and line[0] != ':':
                    known_names.append(line[0].split('.')[-1])
    known_names = set(known_names)

    missing = []
    for name in public_modules:
        with _record_warnings():  # traits warnings
            module = __import__(name, globals())
        for submod in name.split('.')[1:]:
            module = getattr(module, submod)
        classes = inspect.getmembers(module, inspect.isclass)
        functions = inspect.getmembers(module, inspect.isfunction)
        checks = list(classes) + list(functions)
        for name, cf in checks:
            if not name.startswith('_') and name not in known_names:
                from_mod = inspect.getmodule(cf).__name__
                if (from_mod.startswith('mne')
                        and not from_mod.startswith('mne.externals')
                        and not any(
                            from_mod.startswith(x)
                            for x in documented_ignored_mods)
                        and name not in documented_ignored_names
                        and not hasattr(cf, '_deprecated_original')):
                    missing.append('%s (%s.%s)' % (name, from_mod, name))
    if len(missing) > 0:
        raise AssertionError('\n\nFound new public members missing from '
                             'doc/python_reference.rst:\n\n* ' +
                             '\n* '.join(sorted(set(missing))))
Exemple #22
0
def test_plot_head_positions():
    """Test plotting of head positions."""
    info = read_info(evoked_fname)
    pos = np.random.RandomState(0).randn(4, 10)
    pos[:, 0] = np.arange(len(pos))
    destination = (0., 0., 0.04)
    with _record_warnings():  # old MPL will cause a warning
        plot_head_positions(pos)
        plot_head_positions(pos, mode='field', info=info,
                            destination=destination)
        plot_head_positions([pos, pos])  # list support
        pytest.raises(ValueError, plot_head_positions, ['pos'])
        pytest.raises(ValueError, plot_head_positions, pos[:, :9])
    pytest.raises(ValueError, plot_head_positions, pos, 'foo')
    with pytest.raises(ValueError, match='shape'):
        plot_head_positions(pos, axes=1.)
Exemple #23
0
def test_compute_proj_exg(tmp_path, fun):
    """Test mne compute_proj_ecg/eog."""
    check_usage(fun)
    tempdir = str(tmp_path)
    use_fname = op.join(tempdir, op.basename(raw_fname))
    bad_fname = op.join(tempdir, 'bads.txt')
    with open(bad_fname, 'w') as fid:
        fid.write('MEG 2443\n')
    shutil.copyfile(raw_fname, use_fname)
    with ArgvSetter(
        ('-i', use_fname, '--bad=' + bad_fname, '--rej-eeg', '150')):
        with _record_warnings():  # samples, sometimes
            fun.run()
    fnames = glob.glob(op.join(tempdir, '*proj.fif'))
    assert len(fnames) == 1
    fnames = glob.glob(op.join(tempdir, '*-eve.fif'))
    assert len(fnames) == 1
Exemple #24
0
def test_tabs():
    """Test that there are no tabs in our source files."""
    for _, modname, ispkg in walk_packages(mne.__path__, prefix='mne.'):
        # because we don't import e.g. mne.tests w/mne
        if not ispkg and modname not in tab_ignores:
            # mod = importlib.import_module(modname)  # not py26 compatible!
            try:
                with _record_warnings():
                    __import__(modname)
            except Exception:  # can't import properly
                continue
            mod = sys.modules[modname]
            try:
                source = getsource(mod)
            except IOError:  # user probably should have run "make clean"
                continue
            assert '\t' not in source, ('"%s" has tabs, please remove them '
                                        'or add it to the ignore list' %
                                        modname)
Exemple #25
0
def test_make_morph_maps(tmp_path):
    """Test reading and creating morph maps."""
    # make a new fake subjects_dir
    tempdir = str(tmp_path)
    for subject in ('sample', 'sample_ds', 'fsaverage_ds'):
        os.mkdir(op.join(tempdir, subject))
        os.mkdir(op.join(tempdir, subject, 'surf'))
        regs = ('reg',
                'left_right') if subject == 'fsaverage_ds' else ('reg', )
        for hemi in ['lh', 'rh']:
            for reg in regs:
                args = [subject, 'surf', hemi + '.sphere.' + reg]
                copyfile(op.join(subjects_dir, *args), op.join(tempdir, *args))

    for subject_from, subject_to, xhemi in (('fsaverage_ds', 'sample_ds',
                                             False), ('fsaverage_ds',
                                                      'fsaverage_ds', True)):
        # trigger the creation of morph-maps dir and create the map
        with catch_logging() as log:
            mmap = read_morph_map(subject_from,
                                  subject_to,
                                  tempdir,
                                  xhemi=xhemi,
                                  verbose=True)
        log = log.getvalue()
        assert 'does not exist' in log
        assert 'Creating' in log
        mmap2 = read_morph_map(subject_from,
                               subject_to,
                               subjects_dir,
                               xhemi=xhemi)
        assert len(mmap) == len(mmap2)
        for m1, m2 in zip(mmap, mmap2):
            # deal with sparse matrix stuff
            diff = (m1 - m2).data
            assert_allclose(diff, np.zeros_like(diff), atol=1e-3, rtol=0)

    # This will also trigger creation, but it's trivial
    with _record_warnings():
        mmap = read_morph_map('sample', 'sample', subjects_dir=tempdir)
    for mm in mmap:
        assert (mm - sparse.eye(mm.shape[0], mm.shape[0])).sum() == 0
Exemple #26
0
def test_plot_volume_source_estimates_morph():
    """Test interactive plotting of volume source estimates with morph."""
    forward = read_forward_solution(fwd_fname)
    sample_src = forward['src']
    vertices = [s['vertno'] for s in sample_src]
    n_verts = sum(len(v) for v in vertices)
    n_time = 2
    data = np.random.RandomState(0).rand(n_verts, n_time)
    stc = VolSourceEstimate(data, vertices, 1, 1)
    sample_src[0]['subject_his_id'] = 'sample'  # old src
    morph = compute_source_morph(sample_src,
                                 'sample',
                                 'fsaverage',
                                 zooms=5,
                                 subjects_dir=subjects_dir)
    initial_pos = (-0.05, -0.01, -0.006)
    # sometimes get scalars/index warning
    with _record_warnings():
        with catch_logging() as log:
            stc.plot(morph,
                     subjects_dir=subjects_dir,
                     mode='glass_brain',
                     initial_pos=initial_pos,
                     verbose=True)
    log = log.getvalue()
    assert 't = 1.000 s' in log
    assert '(-52.0, -8.0, -7.0) mm' in log

    with pytest.raises(ValueError, match='Allowed values are'):
        stc.plot(sample_src, 'sample', subjects_dir, mode='abcd')
    vertices.append([])
    surface_stc = SourceEstimate(data, vertices, 1, 1)
    with pytest.raises(TypeError, match='an instance of VolSourceEstimate'):
        plot_volume_source_estimates(surface_stc, sample_src, 'sample',
                                     subjects_dir)
    with pytest.raises(ValueError, match='Negative colormap limits'):
        stc.plot(sample_src,
                 'sample',
                 subjects_dir,
                 clim=dict(lims=[-1, 2, 3], kind='value'))
Exemple #27
0
def test_tf_mxne():
    """Test convergence of TF-MxNE solver."""
    alpha_space = 10.
    alpha_time = 5.

    M, G, active_set = _generate_tf_data()

    with _record_warnings():  # CD
        X_hat_tf, active_set_hat_tf, E, gap_tfmxne = tf_mixed_norm_solver(
            M,
            G,
            alpha_space,
            alpha_time,
            maxit=200,
            tol=1e-8,
            verbose=True,
            n_orient=1,
            tstep=4,
            wsize=32,
            return_gap=True)
    assert_array_less(gap_tfmxne, 1e-8)
    assert_array_equal(np.where(active_set_hat_tf)[0], active_set)
Exemple #28
0
def test_label_fill_restrict(fname):
    """Test label in fill and restrict."""
    src = read_source_spaces(src_fname)
    label = read_label(fname)

    # construct label from source space vertices
    label_src = label.restrict(src)
    vert_in_src = label_src.vertices
    values_in_src = label_src.values
    if check_version('scipy', '1.3') and fname == real_label_fname:
        # Check that we can auto-fill patch info quickly for one condition
        for s in src:
            s['nearest'] = None
        with _record_warnings():
            label_src = label_src.fill(src)
    else:
        label_src = label_src.fill(src)
    assert src[0]['nearest'] is not None

    # check label vertices
    vertices_status = np.in1d(src[0]['nearest'], label.vertices)
    vertices_in = np.nonzero(vertices_status)[0]
    vertices_out = np.nonzero(np.logical_not(vertices_status))[0]
    assert_array_equal(label_src.vertices, vertices_in)
    assert_array_equal(np.in1d(vertices_out, label_src.vertices), False)

    # check values
    value_idx = np.digitize(src[0]['nearest'][vertices_in], vert_in_src, True)
    assert_array_equal(label_src.values, values_in_src[value_idx])

    # test exception
    vertices = np.append([-1], vert_in_src)
    with pytest.raises(ValueError, match='does not contain all of the label'):
        Label(vertices, hemi='lh').fill(src)

    # test filling empty label
    label = Label([], hemi='lh')
    label.fill(src)
    assert_array_equal(label.vertices, np.array([], int))
Exemple #29
0
def test_docstring_parameters():
    """Test module docstring formatting."""
    from numpydoc import docscrape

    incorrect = []
    for name in public_modules:
        # Assert that by default we import all public names with `import mne`
        if name not in ('mne', 'mne.gui'):
            extra = name.split('.')[1]
            assert hasattr(mne, extra)
        with _record_warnings():  # traits warnings
            module = __import__(name, globals())
        for submod in name.split('.')[1:]:
            module = getattr(module, submod)
        classes = inspect.getmembers(module, inspect.isclass)
        for cname, cls in classes:
            if cname.startswith('_'):
                continue
            incorrect += check_parameters_match(cls)
            cdoc = docscrape.ClassDoc(cls)
            for method_name in cdoc.methods:
                method = getattr(cls, method_name)
                incorrect += check_parameters_match(method, cls=cls)
            if hasattr(cls, '__call__') and \
                    'of type object' not in str(cls.__call__) and \
                    'of ABCMeta object' not in str(cls.__call__):
                incorrect += check_parameters_match(cls.__call__, cls)
        functions = inspect.getmembers(module, inspect.isfunction)
        for fname, func in functions:
            if fname.startswith('_'):
                continue
            incorrect += check_parameters_match(func)
    incorrect = sorted(list(set(incorrect)))
    msg = '\n' + '\n'.join(incorrect)
    msg += '\n%d error%s' % (len(incorrect), _pl(incorrect))
    if len(incorrect) > 0:
        raise AssertionError(msg)
Exemple #30
0
def test_interpolation_eeg(offset, avg_proj, ctol, atol, method):
    """Test interpolation of EEG channels."""
    raw, epochs_eeg = _load_data('eeg')
    epochs_eeg = epochs_eeg.copy()
    assert not _has_eeg_average_ref_proj(epochs_eeg.info['projs'])
    # Offsetting the coordinate frame should have no effect on the output
    for inst in (raw, epochs_eeg):
        for ch in inst.info['chs']:
            if ch['kind'] == io.constants.FIFF.FIFFV_EEG_CH:
                ch['loc'][:3] += offset
                ch['loc'][3:6] += offset
        for d in inst.info['dig']:
            d['r'] += offset

    # check that interpolation does nothing if no bads are marked
    epochs_eeg.info['bads'] = []
    evoked_eeg = epochs_eeg.average()
    kw = dict(method=method)
    with pytest.warns(RuntimeWarning, match='Doing nothing'):
        evoked_eeg.interpolate_bads(**kw)

    # create good and bad channels for EEG
    epochs_eeg.info['bads'] = []
    goods_idx = np.ones(len(epochs_eeg.ch_names), dtype=bool)
    goods_idx[epochs_eeg.ch_names.index('EEG 012')] = False
    bads_idx = ~goods_idx
    pos = epochs_eeg._get_channel_positions()

    evoked_eeg = epochs_eeg.average()
    if avg_proj:
        evoked_eeg.set_eeg_reference(projection=True).apply_proj()
        assert_allclose(evoked_eeg.data.mean(0), 0., atol=1e-20)
    ave_before = evoked_eeg.data[bads_idx]

    # interpolate bad channels for EEG
    epochs_eeg.info['bads'] = ['EEG 012']
    evoked_eeg = epochs_eeg.average()
    if avg_proj:
        evoked_eeg.set_eeg_reference(projection=True).apply_proj()
        good_picks = pick_types(evoked_eeg.info, meg=False, eeg=True)
        assert_allclose(evoked_eeg.data[good_picks].mean(0), 0., atol=1e-20)
    evoked_eeg_bad = evoked_eeg.copy()
    bads_picks = pick_channels(epochs_eeg.ch_names,
                               include=epochs_eeg.info['bads'],
                               ordered=True)
    evoked_eeg_bad.data[bads_picks, :] = 1e10

    # Test first the exclude parameter
    evoked_eeg_2_bads = evoked_eeg_bad.copy()
    evoked_eeg_2_bads.info['bads'] = ['EEG 004', 'EEG 012']
    evoked_eeg_2_bads.data[pick_channels(evoked_eeg_bad.ch_names,
                                         ['EEG 004', 'EEG 012'])] = 1e10
    evoked_eeg_interp = evoked_eeg_2_bads.interpolate_bads(origin=(0., 0., 0.),
                                                           exclude=['EEG 004'],
                                                           **kw)
    assert evoked_eeg_interp.info['bads'] == ['EEG 004']
    assert np.all(evoked_eeg_interp.get_data('EEG 004') == 1e10)
    assert np.all(evoked_eeg_interp.get_data('EEG 012') != 1e10)

    # Now test without exclude parameter
    evoked_eeg_bad.info['bads'] = ['EEG 012']
    evoked_eeg_interp = evoked_eeg_bad.copy().interpolate_bads(origin=(0., 0.,
                                                                       0.),
                                                               **kw)
    if avg_proj:
        assert_allclose(evoked_eeg_interp.data.mean(0), 0., atol=1e-6)
    interp_zero = evoked_eeg_interp.data[bads_idx]
    if method is None:  # using
        pos_good = pos[goods_idx]
        pos_bad = pos[bads_idx]
        interpolation = _make_interpolation_matrix(pos_good, pos_bad)
        assert interpolation.shape == (1, len(epochs_eeg.ch_names) - 1)
        interp_manual = np.dot(interpolation, evoked_eeg_bad.data[goods_idx])
        assert_array_equal(interp_manual, interp_zero)
        del interp_manual, interpolation, pos, pos_good, pos_bad
    assert_allclose(ave_before, interp_zero, atol=atol)
    assert ctol[0] < np.corrcoef(ave_before, interp_zero)[0, 1] < ctol[1]
    interp_fit = evoked_eeg_bad.copy().interpolate_bads(**kw).data[bads_idx]
    assert_allclose(ave_before, interp_fit, atol=2.5e-6)
    assert ctol[1] < np.corrcoef(ave_before, interp_fit)[0, 1]  # better

    # check that interpolation fails when preload is False
    epochs_eeg.preload = False
    with pytest.raises(RuntimeError, match='requires epochs data to be loade'):
        epochs_eeg.interpolate_bads(**kw)
    epochs_eeg.preload = True

    # check that interpolation changes the data in raw
    raw_eeg = io.RawArray(data=epochs_eeg._data[0], info=epochs_eeg.info)
    raw_before = raw_eeg._data[bads_idx]
    raw_after = raw_eeg.interpolate_bads(**kw)._data[bads_idx]
    assert not np.all(raw_before == raw_after)

    # check that interpolation fails when preload is False
    for inst in [raw, epochs_eeg]:
        assert hasattr(inst, 'preload')
        inst.preload = False
        inst.info['bads'] = [inst.ch_names[1]]
        with pytest.raises(RuntimeError, match='requires.*data to be loaded'):
            inst.interpolate_bads(**kw)

    # check that interpolation works with few channels
    raw_few = raw.copy().crop(0, 0.1).load_data()
    raw_few.pick_channels(raw_few.ch_names[:1] + raw_few.ch_names[3:4])
    assert len(raw_few.ch_names) == 2
    raw_few.del_proj()
    raw_few.info['bads'] = [raw_few.ch_names[-1]]
    orig_data = raw_few[1][0]
    with _record_warnings() as w:
        raw_few.interpolate_bads(reset_bads=False, **kw)
    assert len([ww for ww in w if 'more than' not in str(ww.message)]) == 0
    new_data = raw_few[1][0]
    assert (new_data == 0).mean() < 0.5
    assert np.corrcoef(new_data, orig_data)[0, 1] > 0.2