예제 #1
0
def test_api():
    """Test LCMV/DICS API equivalence."""
    lcmv_names = _get_args(make_lcmv)
    dics_names = _get_args(make_dics)
    dics_names[dics_names.index('csd')] = 'data_cov'
    dics_names[dics_names.index('noise_csd')] = 'noise_cov'
    dics_names.pop(dics_names.index('real_filter'))  # not a thing for LCMV
    assert lcmv_names == dics_names
예제 #2
0
def matplotlib_config():
    """Configure matplotlib for viz tests."""
    import matplotlib
    # "force" should not really be necessary but should not hurt
    kwargs = dict()
    if 'warn' in _get_args(matplotlib.use):
        kwargs['warn'] = False
    matplotlib.use('agg', force=True, **kwargs)  # don't pop up windows
    import matplotlib.pyplot as plt
    assert plt.get_backend() == 'agg'
    # overwrite some params that can horribly slow down tests that
    # users might have changed locally (but should not otherwise affect
    # functionality)
    plt.ioff()
    plt.rcParams['figure.dpi'] = 100
    try:
        from traits.etsconfig.api import ETSConfig
    except Exception:
        pass
    else:
        ETSConfig.toolkit = 'qt4'
    try:
        with warnings.catch_warnings(record=True):  # traits
            from mayavi import mlab
    except Exception:
        pass
    else:
        mlab.options.backend = 'test'
예제 #3
0
def matplotlib_config():
    """Configure matplotlib for viz tests."""
    import matplotlib
    # "force" should not really be necessary but should not hurt
    kwargs = dict()
    if 'warn' in _get_args(matplotlib.use):
        kwargs['warn'] = False
    matplotlib.use('agg', force=True, **kwargs)  # don't pop up windows
    import matplotlib.pyplot as plt
    assert plt.get_backend() == 'agg'
    # overwrite some params that can horribly slow down tests that
    # users might have changed locally (but should not otherwise affect
    # functionality)
    plt.ioff()
    plt.rcParams['figure.dpi'] = 100
    try:
        from traits.etsconfig.api import ETSConfig
    except Exception:
        pass
    else:
        ETSConfig.toolkit = 'qt4'
    try:
        with warnings.catch_warnings(record=True):  # traits
            from mayavi import mlab
    except Exception:
        pass
    else:
        mlab.options.backend = 'test'
예제 #4
0
def check_parameters_match(func, doc=None):
    """Helper to check docstring, returns list of incorrect results"""
    from numpydoc import docscrape
    incorrect = []
    name_ = get_name(func)
    if not name_.startswith('mne.') or name_.startswith('mne.externals'):
        return incorrect
    if inspect.isdatadescriptor(func):
        return incorrect
    args = _get_args(func)
    # drop self
    if len(args) > 0 and args[0] == 'self':
        args = args[1:]

    if doc is None:
        with warnings.catch_warnings(record=True) as w:
            doc = docscrape.FunctionDoc(func)
        if len(w):
            raise RuntimeError('Error for %s:\n%s' % (name_, w[0]))
    # check set
    param_names = [name for name, _, _ in doc['Parameters']]
    # clean up some docscrape output:
    param_names = [name.split(':')[0].strip('` ') for name in param_names]
    param_names = [name for name in param_names if '*' not in name]
    if len(param_names) != len(args):
        bad = str(sorted(list(set(param_names) - set(args)) +
                         list(set(args) - set(param_names))))
        if not any(d in name_ for d in _docstring_ignores) and \
                'deprecation_wrapped' not in func.__code__.co_name:
            incorrect += [name_ + ' arg mismatch: ' + bad]
    else:
        for n1, n2 in zip(param_names, args):
            if n1 != n2:
                incorrect += [name_ + ' ' + n1 + ' != ' + n2]
    return incorrect
def check_parameters_match(func, doc=None):
    """Helper to check docstring, returns list of incorrect results"""
    incorrect = []
    name_ = get_name(func)
    if not name_.startswith("mne.") or name_.startswith("mne.externals"):
        return incorrect
    if inspect.isdatadescriptor(func):
        return incorrect
    args = _get_args(func)
    # drop self
    if len(args) > 0 and args[0] == "self":
        args = args[1:]

    if doc is None:
        with warnings.catch_warnings(record=True) as w:
            doc = docscrape.FunctionDoc(func)
        if len(w):
            raise RuntimeError("Error for %s:\n%s" % (name_, w[0]))
    # check set
    param_names = [name for name, _, _ in doc["Parameters"]]
    # clean up some docscrape output:
    param_names = [name.split(":")[0].strip("` ") for name in param_names]
    param_names = [name for name in param_names if "*" not in name]
    if len(param_names) != len(args):
        bad = str(sorted(list(set(param_names) - set(args)) + list(set(args) - set(param_names))))
        if not any(d in name_ for d in _docstring_ignores) and "deprecation_wrapped" not in func.__code__.co_name:
            incorrect += [name_ + " arg mismatch: " + bad]
    else:
        for n1, n2 in zip(param_names, args):
            if n1 != n2:
                incorrect += [name_ + " " + n1 + " != " + n2]
    return incorrect
def check_parameters_match(func, doc=None):
    """Helper to check docstring, returns list of incorrect results"""
    from numpydoc import docscrape
    incorrect = []
    name_ = get_name(func)
    if not name_.startswith('mne.') or name_.startswith('mne.externals'):
        return incorrect
    if inspect.isdatadescriptor(func):
        return incorrect
    args = _get_args(func)
    # drop self
    if len(args) > 0 and args[0] == 'self':
        args = args[1:]

    if doc is None:
        with warnings.catch_warnings(record=True) as w:
            doc = docscrape.FunctionDoc(func)
        if len(w):
            raise RuntimeError('Error for %s:\n%s' % (name_, w[0]))
    # check set
    param_names = [name for name, _, _ in doc['Parameters']]
    # clean up some docscrape output:
    param_names = [name.split(':')[0].strip('` ') for name in param_names]
    param_names = [name for name in param_names if '*' not in name]
    if len(param_names) != len(args):
        bad = str(sorted(list(set(param_names) - set(args)) +
                         list(set(args) - set(param_names))))
        if not any(d in name_ for d in _docstring_ignores) and \
                'deprecation_wrapped' not in func.__code__.co_name:
            incorrect += [name_ + ' arg mismatch: ' + bad]
    else:
        for n1, n2 in zip(param_names, args):
            if n1 != n2:
                incorrect += [name_ + ' ' + n1 + ' != ' + n2]
    return incorrect
예제 #7
0
def check_parameters_match(func, doc=None, cls=None):
    """Check docstring, return list of incorrect results."""
    from numpydoc import docscrape
    incorrect = []
    name_ = get_name(func, cls=cls)
    if not name_.startswith('mne.') or name_.startswith('mne.externals'):
        return incorrect
    if inspect.isdatadescriptor(func):
        return incorrect
    args = _get_args(func)
    # drop self
    if len(args) > 0 and args[0] == 'self':
        args = args[1:]

    if doc is None:
        with pytest.warns(None) as w:
            try:
                doc = docscrape.FunctionDoc(func)
            except Exception as exp:
                incorrect += [name_ + ' parsing error: ' + str(exp)]
                return incorrect
        if len(w):
            raise RuntimeError('Error for %s:\n%s' % (name_, w[0]))
    # check set
    parameters = doc['Parameters']
    # clean up some docscrape output:
    parameters = [[p[0].split(':')[0].strip('` '), p[2]] for p in parameters]
    parameters = [p for p in parameters if '*' not in p[0]]
    param_names = [p[0] for p in parameters]
    if len(param_names) != len(args):
        bad = str(
            sorted(
                list(set(param_names) - set(args)) +
                list(set(args) - set(param_names))))
        if not any(re.match(d, name_) for d in docstring_ignores) and \
                'deprecation_wrapped' not in func.__code__.co_name:
            incorrect += [name_ + ' arg mismatch: ' + bad]
    else:
        for n1, n2 in zip(param_names, args):
            if n1 != n2:
                incorrect += [name_ + ' ' + n1 + ' != ' + n2]
        for param_name, desc in parameters:
            desc = '\n'.join(desc)
            full_name = name_ + '::' + param_name
            if full_name in docstring_length_ignores:
                assert len(desc) > char_limit  # assert it actually needs to be
            elif len(desc) > char_limit:
                incorrect += [
                    '%s too long (%d > %d chars)' %
                    (full_name, len(desc), char_limit)
                ]
    return incorrect
예제 #8
0
def test_regularize_cov():
    """Test cov regularization."""
    raw = read_raw_fif(raw_fname)
    raw.info['bads'].append(raw.ch_names[0])  # test with bad channels
    noise_cov = read_cov(cov_fname)
    # Regularize noise cov
    reg_noise_cov = regularize(noise_cov, raw.info,
                               mag=0.1, grad=0.1, eeg=0.1, proj=True,
                               exclude='bads')
    assert noise_cov['dim'] == reg_noise_cov['dim']
    assert noise_cov['data'].shape == reg_noise_cov['data'].shape
    assert np.mean(noise_cov['data'] < reg_noise_cov['data']) < 0.08
    # make sure all args are represented
    assert set(_DATA_CH_TYPES_SPLIT) - set(_get_args(regularize)) == set()
예제 #9
0
def test_regularize_cov():
    """Test cov regularization."""
    raw = read_raw_fif(raw_fname)
    raw.info['bads'].append(raw.ch_names[0])  # test with bad channels
    noise_cov = read_cov(cov_fname)
    # Regularize noise cov
    reg_noise_cov = regularize(noise_cov, raw.info,
                               mag=0.1, grad=0.1, eeg=0.1, proj=True,
                               exclude='bads', rank='full')
    assert noise_cov['dim'] == reg_noise_cov['dim']
    assert noise_cov['data'].shape == reg_noise_cov['data'].shape
    assert np.mean(noise_cov['data'] < reg_noise_cov['data']) < 0.08
    # make sure all args are represented
    assert set(_DATA_CH_TYPES_SPLIT) - set(_get_args(regularize)) == set()
def check_parameters_match(func, doc=None, cls=None):
    """Check docstring, return list of incorrect results."""
    from numpydoc import docscrape
    incorrect = []
    name_ = get_name(func, cls=cls)
    if not name_.startswith('mne.') or name_.startswith('mne.externals'):
        return incorrect
    if inspect.isdatadescriptor(func):
        return incorrect
    args = _get_args(func)
    # drop self
    if len(args) > 0 and args[0] == 'self':
        args = args[1:]

    if doc is None:
        with pytest.warns(None) as w:
            try:
                doc = docscrape.FunctionDoc(func)
            except Exception as exp:
                incorrect += [name_ + ' parsing error: ' + str(exp)]
                return incorrect
        if len(w):
            raise RuntimeError('Error for %s:\n%s' % (name_, w[0]))
    # check set
    parameters = doc['Parameters']
    # clean up some docscrape output:
    parameters = [[p[0].split(':')[0].strip('` '), p[2]]
                  for p in parameters]
    parameters = [p for p in parameters if '*' not in p[0]]
    param_names = [p[0] for p in parameters]
    if len(param_names) != len(args):
        bad = str(sorted(list(set(param_names) - set(args)) +
                         list(set(args) - set(param_names))))
        if not any(re.match(d, name_) for d in docstring_ignores) and \
                'deprecation_wrapped' not in func.__code__.co_name:
            incorrect += [name_ + ' arg mismatch: ' + bad]
    else:
        for n1, n2 in zip(param_names, args):
            if n1 != n2:
                incorrect += [name_ + ' ' + n1 + ' != ' + n2]
        for param_name, desc in parameters:
            desc = '\n'.join(desc)
            full_name = name_ + '::' + param_name
            if full_name in docstring_length_ignores:
                assert len(desc) > char_limit  # assert it actually needs to be
            elif len(desc) > char_limit:
                incorrect += ['%s too long (%d > %d chars)'
                              % (full_name, len(desc), char_limit)]
    return incorrect
예제 #11
0
def matplotlib_config():
    """Configure matplotlib for viz tests."""
    import matplotlib
    # "force" should not really be necessary but should not hurt
    kwargs = dict()
    if 'warn' in _get_args(matplotlib.use):
        kwargs['warn'] = False
    matplotlib.use('agg', force=True, **kwargs)  # don't pop up windows
    import matplotlib.pyplot as plt
    assert plt.get_backend() == 'agg'
    # overwrite some params that can horribly slow down tests that
    # users might have changed locally (but should not otherwise affect
    # functionality)
    plt.ioff()
    plt.rcParams['figure.dpi'] = 100
예제 #12
0
def verbose(function):
    """Verbose decorator to allow functions to override log-level.

    Parameters
    ----------
    function : callable
        Function to be decorated by setting the verbosity level.

    Returns
    -------
    dec : callable
        The decorated function.
    """
    arg_names = _get_args(function)

    def wrapper(*args, **kwargs):
        default_level = verbose_level = None
        if len(arg_names) > 0 and arg_names[0] == 'self':
            default_level = getattr(args[0], 'verbose', None)
        if 'verbose' in kwargs:
            verbose_level = kwargs.pop('verbose')
        else:
            try:
                verbose_level = args[arg_names.index('verbose')]
            except (IndexError, ValueError):
                pass

        # This ensures that object.method(verbose=None) will use object.verbose
        if verbose_level is None:
            verbose_level = default_level
        if verbose_level is not None:
            # set it back if we get an exception
            with use_log_level(verbose_level):
                return function(*args, **kwargs)
        return function(*args, **kwargs)
    return FunctionMaker.create(
        function, 'return decfunc(%(signature)s)',
        dict(decfunc=wrapper), __wrapped__=function,
        __qualname__=function.__qualname__)
예제 #13
0
파일: _utils.py 프로젝트: mdclarke/mnefun
def _get_epo_kwargs():
    from mne.fixes import _get_args
    epo_kwargs = dict(verbose=False)
    if 'overwrite' in _get_args(Epochs.save):
        epo_kwargs['overwrite'] = True
    return epo_kwargs
예제 #14
0
def test_sourcemorph_consistency():
    """Test SourceMorph class consistency."""
    assert _get_args(SourceMorph.__init__)[1:] == \
        mne.morph._SOURCE_MORPH_ATTRIBUTES
예제 #15
0
def test_sourcemorph_consistency():
    """Test SourceMorph class consistency."""
    assert _get_args(SourceMorph.__init__)[1:] == \
        mne.morph._SOURCE_MORPH_ATTRIBUTES
예제 #16
0
def test_search_light():
    """Test SlidingEstimator."""
    from sklearn.linear_model import Ridge, LogisticRegression
    from sklearn.pipeline import make_pipeline
    from sklearn.metrics import roc_auc_score, make_scorer
    with pytest.warns(None):  # NumPy module import
        from sklearn.ensemble import BaggingClassifier
    from sklearn.base import is_classifier

    logreg = LogisticRegression(solver='liblinear',
                                multi_class='ovr',
                                random_state=0)

    X, y = make_data()
    n_epochs, _, n_time = X.shape
    # init
    pytest.raises(ValueError, SlidingEstimator, 'foo')
    sl = SlidingEstimator(Ridge())
    assert (not is_classifier(sl))
    sl = SlidingEstimator(LogisticRegression(solver='liblinear'))
    assert (is_classifier(sl))
    # fit
    assert_equal(sl.__repr__()[:18], '<SlidingEstimator(')
    sl.fit(X, y)
    assert_equal(sl.__repr__()[-28:], ', fitted with 10 estimators>')
    pytest.raises(ValueError, sl.fit, X[1:], y)
    pytest.raises(ValueError, sl.fit, X[:, :, 0], y)
    sl.fit(X, y, sample_weight=np.ones_like(y))

    # transforms
    pytest.raises(ValueError, sl.predict, X[:, :, :2])
    y_pred = sl.predict(X)
    assert (y_pred.dtype == int)
    assert_array_equal(y_pred.shape, [n_epochs, n_time])
    y_proba = sl.predict_proba(X)
    assert (y_proba.dtype == float)
    assert_array_equal(y_proba.shape, [n_epochs, n_time, 2])

    # score
    score = sl.score(X, y)
    assert_array_equal(score.shape, [n_time])
    assert (np.sum(np.abs(score)) != 0)
    assert (score.dtype == float)

    sl = SlidingEstimator(logreg)
    assert_equal(sl.scoring, None)

    # Scoring method
    for scoring in ['foo', 999]:
        sl = SlidingEstimator(logreg, scoring=scoring)
        sl.fit(X, y)
        pytest.raises((ValueError, TypeError), sl.score, X, y)

    # Check sklearn's roc_auc fix: scikit-learn/scikit-learn#6874
    # -- 3 class problem
    sl = SlidingEstimator(logreg, scoring='roc_auc')
    y = np.arange(len(X)) % 3
    sl.fit(X, y)
    with pytest.raises(ValueError, match='for two-class'):
        sl.score(X, y)
    # But check that valid ones should work with new enough sklearn
    if 'multi_class' in _get_args(roc_auc_score):
        scoring = make_scorer(roc_auc_score,
                              needs_proba=True,
                              multi_class='ovo')
        sl = SlidingEstimator(logreg, scoring=scoring)
        sl.fit(X, y)
        sl.score(X, y)  # smoke test

    # -- 2 class problem not in [0, 1]
    y = np.arange(len(X)) % 2 + 1
    sl.fit(X, y)
    score = sl.score(X, y)
    assert_array_equal(score, [
        roc_auc_score(y - 1, _y_pred - 1)
        for _y_pred in sl.decision_function(X).T
    ])
    y = np.arange(len(X)) % 2

    # Cannot pass a metric as a scoring parameter
    sl1 = SlidingEstimator(logreg, scoring=roc_auc_score)
    sl1.fit(X, y)
    pytest.raises(ValueError, sl1.score, X, y)

    # Now use string as scoring
    sl1 = SlidingEstimator(logreg, scoring='roc_auc')
    sl1.fit(X, y)
    rng = np.random.RandomState(0)
    X = rng.randn(*X.shape)  # randomize X to avoid AUCs in [0, 1]
    score_sl = sl1.score(X, y)
    assert_array_equal(score_sl.shape, [n_time])
    assert (score_sl.dtype == float)

    # Check that scoring was applied adequately
    scoring = make_scorer(roc_auc_score, needs_threshold=True)
    score_manual = [
        scoring(est, x, y)
        for est, x in zip(sl1.estimators_, X.transpose(2, 0, 1))
    ]
    assert_array_equal(score_manual, score_sl)

    # n_jobs
    sl = SlidingEstimator(logreg, n_jobs=1, scoring='roc_auc')
    score_1job = sl.fit(X, y).score(X, y)
    sl.n_jobs = 2
    score_njobs = sl.fit(X, y).score(X, y)
    assert_array_equal(score_1job, score_njobs)
    sl.predict(X)

    # n_jobs > n_estimators
    sl.fit(X[..., [0]], y)
    sl.predict(X[..., [0]])

    # pipeline

    class _LogRegTransformer(LogisticRegression):
        # XXX needs transformer in pipeline to get first proba only
        def __init__(self):
            super(_LogRegTransformer, self).__init__()
            self.multi_class = 'ovr'
            self.random_state = 0
            self.solver = 'liblinear'

        def transform(self, X):
            return super(_LogRegTransformer, self).predict_proba(X)[..., 1]

    pipe = make_pipeline(SlidingEstimator(_LogRegTransformer()), logreg)
    pipe.fit(X, y)
    pipe.predict(X)

    # n-dimensional feature space
    X = np.random.rand(10, 3, 4, 2)
    y = np.arange(10) % 2
    y_preds = list()
    for n_jobs in [1, 2]:
        pipe = SlidingEstimator(make_pipeline(Vectorizer(), logreg),
                                n_jobs=n_jobs)
        y_preds.append(pipe.fit(X, y).predict(X))
        features_shape = pipe.estimators_[0].steps[0][1].features_shape_
        assert_array_equal(features_shape, [3, 4])
    assert_array_equal(y_preds[0], y_preds[1])

    # Bagging classifiers
    X = np.random.rand(10, 3, 4)
    for n_jobs in (1, 2):
        pipe = SlidingEstimator(BaggingClassifier(None, 2), n_jobs=n_jobs)
        pipe.fit(X, y)
        pipe.score(X, y)
        assert (isinstance(pipe.estimators_[0], BaggingClassifier))
예제 #17
0
def read_ica(fname):
    """Restore ICA solution from fif file.
    Parameters
    ----------
    fname : str
        Absolute path to fif file containing ICA matrices.
        The file name should end with -ica.fif or -ica.fif.gz.
    Returns
    -------
    ica : instance of ICA
        The ICA estimator.
    """
    # TODO: do I actually need this?

    from mne.utils import logger, check_fname
    from mne.io.open import fiff_open
    from mne.io.meas_info import read_meas_info
    from mne.io.tree import dir_tree_find
    from mne.io.constants import FIFF
    from mne.io.tag import read_tag
    from mne.preprocessing.ica import _deserialize
    from mne import Covariance
    from scipy import linalg
    from mne.fixes import _get_args

    check_fname(fname, 'ICA', ('-ica.fif', '-ica.fif.gz'))

    logger.info('Reading %s ...' % fname)
    fid, tree, _ = fiff_open(fname)

    try:
        # we used to store bads that weren't part of the info...
        info, meas = read_meas_info(fid, tree, clean_bads=True)
    except ValueError:
        logger.info('Could not find the measurement info. \n'
                    'Functionality requiring the info won\'t be'
                    ' available.')
        info = None

    ica_data = dir_tree_find(tree, FIFF.FIFFB_MNE_ICA)
    if len(ica_data) == 0:
        ica_data = dir_tree_find(tree, 123)  # Constant 123 Used before v 0.11
        if len(ica_data) == 0:
            fid.close()
            raise ValueError('Could not find ICA data')

    my_ica_data = ica_data[0]
    for d in my_ica_data['directory']:
        kind = d.kind
        pos = d.pos
        if kind == FIFF.FIFF_MNE_ICA_INTERFACE_PARAMS:
            tag = read_tag(fid, pos)
            ica_init = tag.data
        elif kind == FIFF.FIFF_MNE_ROW_NAMES:
            tag = read_tag(fid, pos)
            ch_names = tag.data
        elif kind == FIFF.FIFF_MNE_ICA_WHITENER:
            tag = read_tag(fid, pos)
            pre_whitener = tag.data
        elif kind == FIFF.FIFF_MNE_ICA_PCA_COMPONENTS:
            tag = read_tag(fid, pos)
            pca_components = tag.data
        elif kind == FIFF.FIFF_MNE_ICA_PCA_EXPLAINED_VAR:
            tag = read_tag(fid, pos)
            pca_explained_variance = tag.data
        elif kind == FIFF.FIFF_MNE_ICA_PCA_MEAN:
            tag = read_tag(fid, pos)
            pca_mean = tag.data
        elif kind == FIFF.FIFF_MNE_ICA_MATRIX:
            tag = read_tag(fid, pos)
            unmixing_matrix = tag.data
        elif kind == FIFF.FIFF_MNE_ICA_BADS:
            tag = read_tag(fid, pos)
            exclude = tag.data
        elif kind == FIFF.FIFF_MNE_ICA_MISC_PARAMS:
            tag = read_tag(fid, pos)
            ica_misc = tag.data

    fid.close()

    ica_init, ica_misc = [_deserialize(k) for k in (ica_init, ica_misc)]
    current_fit = ica_init.pop('current_fit')
    if ica_init['noise_cov'] == Covariance.__name__:
        logger.info('Reading whitener drawn from noise covariance ...')

    logger.info('Now restoring ICA solution ...')

    # make sure dtypes are np.float64 to satisfy fast_dot
    def f(x):
        return x.astype(np.float64)

    ica_init = dict(
        (k, v) for k, v in ica_init.items() if k in _get_args(ICA.__init__))
    ica = ICA(**ica_init)
    ica.current_fit = current_fit
    ica.ch_names = ch_names.split(':')
    ica._pre_whitener = f(pre_whitener)
    ica.pca_mean_ = f(pca_mean)
    ica.pca_components_ = f(pca_components)
    ica.n_components_ = unmixing_matrix.shape[0]
    ica._update_ica_names()
    ica.pca_explained_variance_ = f(pca_explained_variance)
    ica.unmixing_matrix_ = f(unmixing_matrix)
    ica.mixing_matrix_ = linalg.pinv(ica.unmixing_matrix_)
    ica.exclude = [] if exclude is None else list(exclude)
    ica.info = info
    if 'n_samples_' in ica_misc:
        ica.n_samples_ = ica_misc['n_samples_']
    if 'labels_' in ica_misc:
        labels_ = ica_misc['labels_']
        if labels_ is not None:
            ica.labels_ = labels_
    if 'method' in ica_misc:
        ica.method = ica_misc['method']

    logger.info('Ready.')

    return ica
예제 #18
0
def plot_3d_montage(info,
                    view_map,
                    *,
                    src_det_names='auto',
                    ch_names='numbered',
                    subject='fsaverage',
                    trans='fsaverage',
                    surface='pial',
                    subjects_dir=None,
                    verbose=None):
    """
    Plot a 3D sensor montage.

    Parameters
    ----------
    info : instance of Info
        Measurement info.
    view_map : dict
        Dict of view (key) to channel-pair-numbers (value) to use when
        plotting. Note that, because these get plotted as 1-based channel
        *numbers*, the values should be 1-based rather than 0-based.
        The keys are of the form:

        ``'{side}-{view}'``
            For views like ``'left-lat'`` or ``'right-frontal'`` where the side
            matters.
        ``'{view}'``
            For views like ``'caudal'`` that are along the midline.

        See :meth:`mne.viz.Brain.show_view` for ``view`` options, and the
        Examples section below for usage examples.
    src_det_names : None | dict | str
        Source and detector names to use. "auto" (default) will see if the
        channel locations correspond to standard 10-20 locations and will
        use those if they do (otherwise will act like None). None will use
        S1, S2, ..., D1, D2, ..., etc. Can also be an explicit dict mapping,
        for example::

            src_det_names=dict(S1='Fz', D1='FCz', ...)
    ch_names : str | dict | None
        If ``'numbered'`` (default), use ``['1', '2', ...]`` for the channel
        names, or ``None`` to use ``['S1_D2', 'S2_D1', ...]``. Can also be a
        dict to provide a mapping from the ``'S1_D2'``-style names (keys) to
        other names, e.g., ``defaultdict(lambda: '')`` will prevent showing
        the names altogether.

        .. versionadded:: 0.3
    subject : str
        The subject.
    trans : str | Transform
        The subjects head<->MRI transform.
    surface : str
        The FreeSurfer surface name (e.g., 'pial', 'white').
    subjects_dir : str
        The subjects directory.
    %(verbose)s

    Returns
    -------
    figure : matplotlib.figure.Figure
        The matplotlib figimage.

    Examples
    --------
    For a Hitachi system with two sets of 12 source-detector arrangements,
    one on each side of the head, showing 1-12 on the left and 13-24 on the
    right can be accomplished using the following ``view_map``::

        >>> view_map = {
        ...     'left-lat': np.arange(1, 13),
        ...     'right-lat': np.arange(13, 25),
        ... }

    NIRx typically involves more complicated arrangements. See
    :ref:`the 3D tutorial <tut-fnirs-vis-brain-plot-3d-montage>` for
    an advanced example that incorporates the ``'caudal'`` view as well.
    """  # noqa: E501
    import matplotlib.pyplot as plt
    from scipy.spatial.distance import cdist
    _validate_type(info, Info, 'info')
    _validate_type(view_map, dict, 'views')
    _validate_type(src_det_names, (None, dict, str), 'src_det_names')
    _validate_type(ch_names, (dict, str, None), 'ch_names')
    info = pick_info(info, pick_types(info, fnirs=True, exclude=())[::2])
    if isinstance(ch_names, str):
        _check_option('ch_names', ch_names, ('numbered', ), extra='when str')
        ch_names = {
            name.split()[0]: str(ni)
            for ni, name in enumerate(info['ch_names'], 1)
        }
    info['bads'] = []
    if isinstance(src_det_names, str):
        _check_option('src_det_names',
                      src_det_names, ('auto', ),
                      extra='when str')
        # Decide if we can map to 10-20 locations
        names, pos = zip(
            *transform_to_head(make_standard_montage(
                'standard_1020')).get_positions()['ch_pos'].items())
        pos = np.array(pos, float)
        locs = dict()
        bad = False
        for ch in info['chs']:
            name = ch['ch_name']
            s_name, d_name = name.split()[0].split('_')
            for name, loc in [(s_name, ch['loc'][3:6]),
                              (d_name, ch['loc'][6:9])]:
                if name in locs:
                    continue
                # see if it's close enough
                idx = np.where(cdist(loc[np.newaxis], pos)[0] < 1e-3)[0]
                if len(idx) < 1:
                    bad = True
                    break
                # Some are duplicated (e.g., T7+T3) but we can rely on the
                # first one being the canonical one
                locs[name] = names[idx[0]]
            if bad:
                break
        if bad:
            src_det_names = None
            logger.info('Could not automatically map source/detector names to '
                        '10-20 locations.')
        else:
            src_det_names = locs
            logger.info('Source-detector names automatically mapped to 10-20 '
                        'locations')

    head_mri_t = _get_trans(trans, 'head', 'mri')[0]
    del trans
    views = list()
    for key, num in view_map.items():
        _validate_type(key, str, f'view_map key {repr(key)}')
        _validate_type(num, np.ndarray, f'view_map[{repr(key)}]')
        if '-' in key:
            hemi, v = key.split('-', maxsplit=1)
            hemi = dict(left='lh', right='rh')[hemi]
            views.append((hemi, v, num))
        else:
            views.append(('lh', key, num))
    del view_map
    size = (400 * len(views), 400)
    brain = Brain(subject,
                  'both',
                  surface,
                  views=['lat'] * len(views),
                  size=size,
                  background='w',
                  units='m',
                  view_layout='horizontal',
                  subjects_dir=subjects_dir)
    with _safe_brain_close(brain):
        brain.add_head(dense=False, alpha=0.1)
        brain.add_sensors(info,
                          trans=head_mri_t,
                          fnirs=['channels', 'pairs', 'sources', 'detectors'])
        add_text_kwargs = dict()
        if 'render' in _get_args(brain.plotter.add_text):
            add_text_kwargs['render'] = False
        for col, view in enumerate(views):
            plotted = set()
            brain.show_view(view[1],
                            hemi=view[0],
                            focalpoint=(0, -0.02, 0.02),
                            distance=0.4,
                            row=0,
                            col=col)
            brain.plotter.subplot(0, col)
            vp = brain.plotter.renderer
            for ci in view[2]:  # figure out what we need to plot
                this_ch = info['chs'][ci - 1]
                ch_name = this_ch['ch_name'].split()[0]
                s_name, d_name = ch_name.split('_')
                needed = [
                    (ch_names, 'ch_names', ch_name, this_ch['loc'][:3], 12,
                     'Centered'),
                    (src_det_names, 'src_det_names', s_name,
                     this_ch['loc'][3:6], 8, 'Bottom'),
                    (src_det_names, 'src_det_names', d_name,
                     this_ch['loc'][6:9], 8, 'Bottom'),
                ]
                for lookup, lname, name, ch_pos, font_size, va in needed:
                    if name in plotted:
                        continue
                    plotted.add(name)
                    orig_name = name
                    if lookup is not None:
                        name = lookup[name]
                    _validate_type(name, str, f'{lname}[{repr(orig_name)}]')
                    ch_pos = apply_trans(head_mri_t, ch_pos)
                    vp.SetWorldPoint(np.r_[ch_pos, 1.])
                    vp.WorldToDisplay()
                    ch_pos = (np.array(vp.GetDisplayPoint()[:2]) -
                              np.array(vp.GetOrigin()))
                    actor = brain.plotter.add_text(name,
                                                   ch_pos,
                                                   font_size=font_size,
                                                   color=(0., 0., 0.),
                                                   **add_text_kwargs)
                    prop = actor.GetTextProperty()
                    getattr(prop, f'SetVerticalJustificationTo{va}')()
                    prop.SetJustificationToCentered()
                    actor.SetTextProperty(prop)
                    prop.SetBold(True)
        img = brain.screenshot()
    return plt.figimage(img, resize=True).figure