示例#1
0
def test_manifest_check_download(tmpdir, n_have, monkeypatch):
    """Test our manifest downloader."""
    monkeypatch.setattr(datasets.utils, '_fetch_file', _fake_zip_fetch)
    destination = op.join(str(tmpdir), 'empty')
    manifest_path = op.join(str(tmpdir), 'manifest.txt')
    with open(manifest_path, 'w') as fid:
        for fname in _zip_fnames:
            fid.write('%s\n' % fname)
    assert n_have in range(len(_zip_fnames) + 1)
    assert not op.isdir(destination)
    if n_have > 0:
        os.makedirs(op.join(destination, 'foo'))
        assert op.isdir(op.join(destination, 'foo'))
    for fname in _zip_fnames:
        assert not op.isfile(op.join(destination, fname))
    for fname in _zip_fnames[:n_have]:
        with open(op.join(destination, fname), 'w'):
            pass
    with catch_logging() as log:
        with use_log_level(True):
            url = hash_ = ''  # we mock the _fetch_file so these are not used
            _manifest_check_download(manifest_path, destination, url, hash_)
    log = log.getvalue()
    n_missing = 3 - n_have
    assert ('%d file%s missing from' % (n_missing, _pl(n_missing))) in log
    for want in ('Extracting missing', 'Successfully '):
        if n_missing > 0:
            assert want in log
        else:
            assert want not in log
    assert op.isdir(destination)
    for fname in _zip_fnames:
        assert op.isfile(op.join(destination, fname))
示例#2
0
def test_manifest_check_download(tmpdir, n_have, monkeypatch):
    """Test our manifest downloader."""
    monkeypatch.setattr(datasets.utils, '_fetch_file', _fake_zip_fetch)
    destination = op.join(str(tmpdir), 'empty')
    manifest_path = op.join(str(tmpdir), 'manifest.txt')
    with open(manifest_path, 'w') as fid:
        for fname in _zip_fnames:
            fid.write('%s\n' % fname)
    assert n_have in range(len(_zip_fnames) + 1)
    assert not op.isdir(destination)
    if n_have > 0:
        os.makedirs(op.join(destination, 'foo'))
        assert op.isdir(op.join(destination, 'foo'))
    for fname in _zip_fnames:
        assert not op.isfile(op.join(destination, fname))
    for fname in _zip_fnames[:n_have]:
        with open(op.join(destination, fname), 'w'):
            pass
    with catch_logging() as log:
        with use_log_level(True):
            url = hash_ = ''  # we mock the _fetch_file so these are not used
            _manifest_check_download(manifest_path, destination, url, hash_)
    log = log.getvalue()
    n_missing = 3 - n_have
    assert ('%d file%s missing from' % (n_missing, _pl(n_missing))) in log
    for want in ('Extracting missing', 'Successfully '):
        if n_missing > 0:
            assert want in log
        else:
            assert want not in log
    assert op.isdir(destination)
    for fname in _zip_fnames:
        assert op.isfile(op.join(destination, fname))
def test_docstring_parameters():
    """Test module docstring formatting."""
    # skip modules that require mayavi if mayavi is not installed
    public_modules_ = public_modules[:]
    try:
        import mayavi  # noqa: F401 analysis:ignore
        public_modules_.append('mne.gui')
    except ImportError:
        pass

    incorrect = []
    for name in public_modules_:
        # Assert that by default we import all public names with `import mne`
        if name not in ('mne', 'mne.gui'):
            extra = name.split('.')[1]
            assert hasattr(mne, extra)
        with pytest.warns(None):  # traits warnings
            module = __import__(name, globals())
        for submod in name.split('.')[1:]:
            module = getattr(module, submod)
        classes = inspect.getmembers(module, inspect.isclass)
        for cname, cls in classes:
            if cname.startswith('_'):
                continue
            incorrect += check_parameters_match(cls)
        functions = inspect.getmembers(module, inspect.isfunction)
        for fname, func in functions:
            if fname.startswith('_'):
                continue
            incorrect += check_parameters_match(func)
    incorrect = sorted(list(set(incorrect)))
    msg = '\n' + '\n'.join(incorrect)
    msg += '\n%d error%s' % (len(incorrect), _pl(incorrect))
    if len(incorrect) > 0:
        raise AssertionError(msg)
def test_docstring_parameters():
    """Test module docstring formatting."""
    from numpydoc import docscrape
    # skip modules that require mayavi if mayavi is not installed
    public_modules_ = public_modules[:]
    try:
        import mayavi  # noqa: F401 analysis:ignore
        public_modules_.append('mne.gui')
    except ImportError:
        pass

    incorrect = []
    for name in public_modules_:
        # Assert that by default we import all public names with `import mne`
        if name not in ('mne', 'mne.gui'):
            extra = name.split('.')[1]
            assert hasattr(mne, extra)
        with pytest.warns(None):  # traits warnings
            module = __import__(name, globals())
        for submod in name.split('.')[1:]:
            module = getattr(module, submod)
        classes = inspect.getmembers(module, inspect.isclass)
        # XXX eventually this should be public
        if name == 'mne.viz':
            from mne.viz._brain import _Brain
            classes.append(('Brain', _Brain))
        for cname, cls in classes:
            if cname.startswith('_'):
                continue
            incorrect += check_parameters_match(cls)
            cdoc = docscrape.ClassDoc(cls)
            for method_name in cdoc.methods:
                method = getattr(cls, method_name)
                incorrect += check_parameters_match(method, cls=cls)
            if hasattr(cls, '__call__') and \
                    'of type object' not in str(cls.__call__):
                incorrect += check_parameters_match(cls.__call__, cls)
        functions = inspect.getmembers(module, inspect.isfunction)
        for fname, func in functions:
            if fname.startswith('_'):
                continue
            incorrect += check_parameters_match(func)
    incorrect = sorted(list(set(incorrect)))
    msg = '\n' + '\n'.join(incorrect)
    msg += '\n%d error%s' % (len(incorrect), _pl(incorrect))
    if len(incorrect) > 0:
        raise AssertionError(msg)
示例#5
0
def test_docstring_parameters():
    """Test module docstring formatting."""
    from numpydoc import docscrape

    incorrect = []
    for name in public_modules:
        # Assert that by default we import all public names with `import mne`
        if name not in ('mne', 'mne.gui'):
            extra = name.split('.')[1]
            assert hasattr(mne, extra)
        with _record_warnings():  # traits warnings
            module = __import__(name, globals())
        for submod in name.split('.')[1:]:
            module = getattr(module, submod)
        classes = inspect.getmembers(module, inspect.isclass)
        for cname, cls in classes:
            if cname.startswith('_'):
                continue
            incorrect += check_parameters_match(cls)
            cdoc = docscrape.ClassDoc(cls)
            for method_name in cdoc.methods:
                method = getattr(cls, method_name)
                incorrect += check_parameters_match(method, cls=cls)
            if hasattr(cls, '__call__') and \
                    'of type object' not in str(cls.__call__) and \
                    'of ABCMeta object' not in str(cls.__call__):
                incorrect += check_parameters_match(cls.__call__, cls)
        functions = inspect.getmembers(module, inspect.isfunction)
        for fname, func in functions:
            if fname.startswith('_'):
                continue
            incorrect += check_parameters_match(func)
    incorrect = sorted(list(set(incorrect)))
    msg = '\n' + '\n'.join(incorrect)
    msg += '\n%d error%s' % (len(incorrect), _pl(incorrect))
    if len(incorrect) > 0:
        raise AssertionError(msg)
示例#6
0
文件: _report.py 项目: LABSN/mnefun
def _proj_fig(fname, info, proj_nums, proj_meg, kind):
    import matplotlib.pyplot as plt
    proj_nums = np.array(proj_nums, int)
    assert proj_nums.shape == (3,)
    projs = read_proj(fname)
    epochs = fname.replace('-proj.fif', '-epo.fif')
    n_col = proj_nums.max()
    rs_topo = 3
    if op.isfile(epochs):
        epochs = mne.read_epochs(epochs)
        evoked = epochs.average()
        rs_trace = 2
    else:
        rs_trace = 0
    n_row = proj_nums.astype(bool).sum() * (rs_topo + rs_trace)
    shape = (n_row, n_col)
    fig = plt.figure(figsize=(n_col * 2, n_row * 0.75))
    used = np.zeros(len(projs), int)
    ri = 0
    for count, ch_type in zip(proj_nums, ('grad', 'mag', 'eeg')):
        if count == 0:
            continue
        if ch_type == 'eeg':
            meg, eeg = False, True
        else:
            meg, eeg = ch_type, False
        ch_names = [info['ch_names'][pick]
                    for pick in mne.pick_types(info, meg=meg, eeg=eeg)]
        idx = np.where([np.in1d(ch_names, proj['data']['col_names']).all()
                        for proj in projs])[0]
        if len(idx) != count:
            raise RuntimeError('Expected %d %s projector%s for channel type '
                               '%s based on proj_nums but got %d in %s'
                               % (count, kind, _pl(count), ch_type, len(idx),
                                  fname))
        if proj_meg == 'separate':
            assert not used[idx].any()
        else:
            assert (used[idx] <= 1).all()
        used[idx] += 1
        these_projs = [deepcopy(projs[ii]) for ii in idx]
        for proj in these_projs:
            sub_idx = [proj['data']['col_names'].index(name)
                       for name in ch_names]
            proj['data']['data'] = proj['data']['data'][:, sub_idx]
            proj['data']['col_names'] = ch_names
        topo_axes = [plt.subplot2grid(
            shape, (ri * (rs_topo + rs_trace), ci),
            rowspan=rs_topo) for ci in range(count)]
        # topomaps
        with warnings.catch_warnings(record=True):
            plot_projs_topomap(these_projs, info=info, show=False,
                               axes=topo_axes)
        plt.setp(topo_axes, title='', xlabel='')
        topo_axes[0].set(ylabel=ch_type)
        if rs_trace:
            trace_axes = [plt.subplot2grid(
                shape, (ri * (rs_topo + rs_trace) + rs_topo, ci),
                rowspan=rs_trace) for ci in range(count)]
            for proj, ax in zip(these_projs, trace_axes):
                this_evoked = evoked.copy().pick_channels(ch_names)
                p = proj['data']['data']
                assert p.shape == (1, len(this_evoked.data))
                with warnings.catch_warnings(record=True):  # tight_layout
                    this_evoked.plot(
                        picks=np.arange(len(this_evoked.data)), axes=[ax])
                ax.texts = []
                trace = np.dot(p, this_evoked.data)[0]
                trace *= 0.8 * (np.abs(ax.get_ylim()).max() /
                                np.abs(trace).max())
                ax.plot(this_evoked.times, trace, color='#9467bd')
                ax.set(title='', ylabel='', xlabel='')
        ri += 1
    assert used.all() and (used <= 2).all()
    fig.subplots_adjust(0.1, 0.1, 0.95, 1, 0.3, 0.3)
    return fig