예제 #1
0
def test_read_epochs_bad_events():
    """Test epochs when events are at the beginning or the end of the file
    """
    # Event at the beginning
    epochs = Epochs(
        raw, np.array([[raw.first_samp, 0, event_id]]), event_id, tmin, tmax, picks=picks, baseline=(None, 0)
    )
    with warnings.catch_warnings(record=True):
        evoked = epochs.average()

    epochs = Epochs(
        raw, np.array([[raw.first_samp, 0, event_id]]), event_id, tmin, tmax, picks=picks, baseline=(None, 0)
    )
    epochs.drop_bad_epochs()
    with warnings.catch_warnings(record=True):
        evoked = epochs.average()

    # Event at the end
    epochs = Epochs(
        raw, np.array([[raw.last_samp, 0, event_id]]), event_id, tmin, tmax, picks=picks, baseline=(None, 0)
    )

    with warnings.catch_warnings(record=True):
        evoked = epochs.average()
        assert evoked
    warnings.resetwarnings()
예제 #2
0
def test_read_epochs_bad_events():
    """Test epochs when events are at the beginning or the end of the file
    """
    # Event at the beginning
    epochs = Epochs(raw,
                    np.array([[raw.first_samp, 0, event_id]]),
                    event_id,
                    tmin,
                    tmax,
                    picks=picks,
                    baseline=(None, 0))
    evoked = epochs.average()

    epochs = Epochs(raw,
                    np.array([[raw.first_samp, 0, event_id]]),
                    event_id,
                    tmin,
                    tmax,
                    picks=picks,
                    baseline=(None, 0))
    epochs.drop_bad_epochs()
    evoked = epochs.average()

    # Event at the end
    epochs = Epochs(raw,
                    np.array([[raw.last_samp, 0, event_id]]),
                    event_id,
                    tmin,
                    tmax,
                    picks=picks,
                    baseline=(None, 0))
    evoked = epochs.average()
    assert evoked
예제 #3
0
def test_reject_epochs():
    """Test of epochs rejection
    """
    epochs = Epochs(raw, events, event_id, tmin, tmax, baseline=(None, 0),
                    reject=reject, flat=flat)
    n_events = len(epochs.events)
    data = epochs.get_data()
    n_clean_epochs = len(data)
    # Should match
    # mne_process_raw --raw test_raw.fif --projoff \
    #   --saveavetag -ave --ave test.ave --filteroff
    assert_true(n_events > n_clean_epochs)
    assert_true(n_clean_epochs == 3)
    assert_true(epochs.drop_log == [[], [], [], ['MEG 2443'],
                                    ['MEG 2443'], ['MEG 2443'], ['MEG 2443']])

    # Ensure epochs are not dropped based on a bad channel
    raw_2 = raw.copy()
    raw_2.info['bads'] = ['MEG 2443']
    reject_crazy = dict(grad=1000e-15, mag=4e-15, eeg=80e-9, eog=150e-9)
    epochs = Epochs(raw_2, events, event_id, tmin, tmax, baseline=(None, 0),
                    reject=reject_crazy, flat=flat)
    epochs.drop_bad_epochs()
    assert_true(all(['MEG 2442' in e for e in epochs.drop_log]))
    assert_true(all(['MEG 2443' not in e for e in epochs.drop_log]))

    epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
                    baseline=(None, 0), reject=reject, flat=flat,
                    reject_tmin=0., reject_tmax=.1)
    data = epochs.get_data()
    n_clean_epochs = len(data)
    assert_true(n_clean_epochs == 7)
    assert_true(epochs.times[epochs._reject_time][0] >= 0.)
    assert_true(epochs.times[epochs._reject_time][-1] <= 0.1)
def test_source_psd_epochs():
    """Test multi-taper source PSD computation in label from epochs"""

    raw = fiff.Raw(fname_data)
    inverse_operator = read_inverse_operator(fname_inv)
    label = read_label(fname_label)

    event_id, tmin, tmax = 1, -0.2, 0.5
    lambda2, method = 1. / 9., 'dSPM'
    bandwidth = 8.
    fmin, fmax = 0, 100

    picks = fiff.pick_types(raw.info, meg=True, eeg=False, stim=True,
                            ecg=True, eog=True, include=['STI 014'],
                            exclude='bads')
    reject = dict(grad=4000e-13, mag=4e-12, eog=150e-6)

    events = find_events(raw)
    epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
                    baseline=(None, 0), reject=reject)

    # only look at one epoch
    epochs.drop_bad_epochs()
    one_epochs = epochs[:1]

    # return list
    stc_psd = compute_source_psd_epochs(one_epochs, inverse_operator,
                                        lambda2=lambda2, method=method,
                                        pick_normal=True, label=label,
                                        bandwidth=bandwidth,
                                        fmin=fmin, fmax=fmax)[0]

    # return generator
    stcs = compute_source_psd_epochs(one_epochs, inverse_operator,
                                     lambda2=lambda2, method=method,
                                     pick_normal=True, label=label,
                                     bandwidth=bandwidth,
                                     fmin=fmin, fmax=fmax,
                                     return_generator=True)

    for stc in stcs:
        stc_psd_gen = stc

    assert_array_almost_equal(stc_psd.data, stc_psd_gen.data)

    # compare with direct computation
    stc = apply_inverse_epochs(one_epochs, inverse_operator,
                               lambda2=lambda2, method=method,
                               pick_normal=True, label=label)[0]

    sfreq = epochs.info['sfreq']
    psd, freqs = multitaper_psd(stc.data, sfreq=sfreq, bandwidth=bandwidth,
                                fmin=fmin, fmax=fmax)

    assert_array_almost_equal(psd, stc_psd.data)
    assert_array_almost_equal(freqs, stc_psd.times)
예제 #5
0
def test_indexing_slicing():
    """Test of indexing and slicing operations
    """
    raw, events, picks = _get_data()
    epochs = Epochs(raw, events[:20], event_id, tmin, tmax, picks=picks,
                    baseline=(None, 0), preload=False,
                    reject=reject, flat=flat)

    data_normal = epochs.get_data()

    n_good_events = data_normal.shape[0]

    # indices for slicing
    start_index = 1
    end_index = n_good_events - 1

    assert((end_index - start_index) > 0)

    for preload in [True, False]:
        epochs2 = Epochs(raw, events[:20], event_id, tmin, tmax,
                         picks=picks, baseline=(None, 0), preload=preload,
                         reject=reject, flat=flat)

        if not preload:
            epochs2.drop_bad_epochs()

        # using slicing
        epochs2_sliced = epochs2[start_index:end_index]

        data_epochs2_sliced = epochs2_sliced.get_data()
        assert_array_equal(data_epochs2_sliced,
                           data_normal[start_index:end_index])

        # using indexing
        pos = 0
        for idx in range(start_index, end_index):
            data = epochs2_sliced[pos].get_data()
            assert_array_equal(data[0], data_normal[idx])
            pos += 1

        # using indexing with an int
        data = epochs2[data_epochs2_sliced.shape[0]].get_data()
        assert_array_equal(data, data_normal[[idx]])

        # using indexing with an array
        idx = np.random.randint(0, data_epochs2_sliced.shape[0], 10)
        data = epochs2[idx].get_data()
        assert_array_equal(data, data_normal[idx])

        # using indexing with a list of indices
        idx = [0]
        data = epochs2[idx].get_data()
        assert_array_equal(data, data_normal[idx])
        idx = [0, 1]
        data = epochs2[idx].get_data()
        assert_array_equal(data, data_normal[idx])
예제 #6
0
def test_indexing_slicing():
    """Test of indexing and slicing operations
    """
    epochs = Epochs(raw, events[:20], event_id, tmin, tmax, picks=picks,
                    baseline=(None, 0), preload=False,
                    reject=reject, flat=flat)

    data_normal = epochs.get_data()

    n_good_events = data_normal.shape[0]

    # indices for slicing
    start_index = 1
    end_index = n_good_events - 1

    assert((end_index - start_index) > 0)

    for preload in [True, False]:
        epochs2 = Epochs(raw, events[:20], event_id, tmin, tmax,
                         picks=picks, baseline=(None, 0), preload=preload,
                         reject=reject, flat=flat)

        if not preload:
            epochs2.drop_bad_epochs()

        # using slicing
        epochs2_sliced = epochs2[start_index:end_index]

        data_epochs2_sliced = epochs2_sliced.get_data()
        assert_array_equal(data_epochs2_sliced,
                           data_normal[start_index:end_index])

        # using indexing
        pos = 0
        for idx in range(start_index, end_index):
            data = epochs2_sliced[pos].get_data()
            assert_array_equal(data[0], data_normal[idx])
            pos += 1

        # using indexing with an int
        data = epochs2[data_epochs2_sliced.shape[0]].get_data()
        assert_array_equal(data, data_normal[[idx]])

        # using indexing with an array
        idx = np.random.randint(0, data_epochs2_sliced.shape[0], 10)
        data = epochs2[idx].get_data()
        assert_array_equal(data, data_normal[idx])

        # using indexing with a list of indices
        idx = [0]
        data = epochs2[idx].get_data()
        assert_array_equal(data, data_normal[idx])
        idx = [0, 1]
        data = epochs2[idx].get_data()
        assert_array_equal(data, data_normal[idx])
예제 #7
0
def test_read_epochs_bad_events():
    """Test epochs when events are at the beginning or the end of the file
    """
    # Event at the beginning
    epochs = Epochs(raw, np.array([[raw.first_samp, 0, event_id]]),
                    event_id, tmin, tmax, picks=picks, baseline=(None, 0))
    evoked = epochs.average()

    epochs = Epochs(raw, np.array([[raw.first_samp, 0, event_id]]),
                    event_id, tmin, tmax, picks=picks, baseline=(None, 0))
    epochs.drop_bad_epochs()
    evoked = epochs.average()

    # Event at the end
    epochs = Epochs(raw, np.array([[raw.last_samp, 0, event_id]]),
                    event_id, tmin, tmax, picks=picks, baseline=(None, 0))
    evoked = epochs.average()
예제 #8
0
def test_epoch_eq():
    """Test epoch count equalization and condition combining
    """
    # equalizing epochs objects
    epochs_1 = Epochs(raw, events, event_id, tmin, tmax, picks=picks)
    epochs_2 = Epochs(raw, events, event_id_2, tmin, tmax, picks=picks)
    epochs_1.drop_bad_epochs()  # make sure drops are logged
    assert_true(len([l for l in epochs_1.drop_log if not l]) == len(epochs_1.events))
    drop_log1 = epochs_1.drop_log = [[] for _ in range(len(epochs_1.events))]
    drop_log2 = [[] if l == ["EQUALIZED_COUNT"] else l for l in epochs_1.drop_log]
    assert_true(drop_log1 == drop_log2)
    assert_true(len([l for l in epochs_1.drop_log if not l]) == len(epochs_1.events))
    assert_true(epochs_1.events.shape[0] != epochs_2.events.shape[0])
    equalize_epoch_counts([epochs_1, epochs_2], method="mintime")
    assert_true(epochs_1.events.shape[0] == epochs_2.events.shape[0])
    epochs_3 = Epochs(raw, events, event_id, tmin, tmax, picks=picks)
    epochs_4 = Epochs(raw, events, event_id_2, tmin, tmax, picks=picks)
    equalize_epoch_counts([epochs_3, epochs_4], method="truncate")
    assert_true(epochs_1.events.shape[0] == epochs_3.events.shape[0])
    assert_true(epochs_3.events.shape[0] == epochs_4.events.shape[0])

    # equalizing conditions
    epochs = Epochs(raw, events, {"a": 1, "b": 2, "c": 3, "d": 4}, tmin, tmax, picks=picks, reject=reject)
    epochs.drop_bad_epochs()  # make sure drops are logged
    assert_true(len([l for l in epochs.drop_log if not l]) == len(epochs.events))
    drop_log1 = deepcopy(epochs.drop_log)
    old_shapes = [epochs[key].events.shape[0] for key in ["a", "b", "c", "d"]]
    epochs.equalize_event_counts(["a", "b"], copy=False)
    # undo the eq logging
    drop_log2 = [[] if l == ["EQUALIZED_COUNT"] else l for l in epochs.drop_log]
    assert_true(drop_log1 == drop_log2)

    assert_true(len([l for l in epochs.drop_log if not l]) == len(epochs.events))
    new_shapes = [epochs[key].events.shape[0] for key in ["a", "b", "c", "d"]]
    assert_true(new_shapes[0] == new_shapes[1])
    assert_true(new_shapes[2] == new_shapes[2])
    assert_true(new_shapes[3] == new_shapes[3])
    # now with two conditions collapsed
    old_shapes = new_shapes
    epochs.equalize_event_counts([["a", "b"], "c"], copy=False)
    new_shapes = [epochs[key].events.shape[0] for key in ["a", "b", "c", "d"]]
    assert_true(new_shapes[0] + new_shapes[1] == new_shapes[2])
    assert_true(new_shapes[3] == old_shapes[3])
    assert_raises(KeyError, epochs.equalize_event_counts, [1, "a"])

    # now let's combine conditions
    old_shapes = new_shapes
    epochs = epochs.equalize_event_counts([["a", "b"], ["c", "d"]])[0]
    new_shapes = [epochs[key].events.shape[0] for key in ["a", "b", "c", "d"]]
    assert_true(old_shapes[0] + old_shapes[1] == new_shapes[0] + new_shapes[1])
    assert_true(new_shapes[0] + new_shapes[1] == new_shapes[2] + new_shapes[3])
    assert_raises(ValueError, combine_event_ids, epochs, ["a", "b"], {"ab": 1})

    combine_event_ids(epochs, ["a", "b"], {"ab": 12}, copy=False)
    caught = 0
    for key in ["a", "b"]:
        try:
            epochs[key]
        except KeyError:
            caught += 1
    assert_raises(Exception, caught == 2)
    assert_true(not np.any(epochs.events[:, 2] == 1))
    assert_true(not np.any(epochs.events[:, 2] == 2))
    epochs = combine_event_ids(epochs, ["c", "d"], {"cd": 34})
    assert_true(np.all(np.logical_or(epochs.events[:, 2] == 12, epochs.events[:, 2] == 34)))
    assert_true(epochs["ab"].events.shape[0] == old_shapes[0] + old_shapes[1])
    assert_true(epochs["ab"].events.shape[0] == epochs["cd"].events.shape[0])
예제 #9
0
def test_source_psd_epochs():
    """Test multi-taper source PSD computation in label from epochs"""

    raw = io.Raw(fname_data)
    inverse_operator = read_inverse_operator(fname_inv)
    label = read_label(fname_label)

    event_id, tmin, tmax = 1, -0.2, 0.5
    lambda2, method = 1. / 9., 'dSPM'
    bandwidth = 8.
    fmin, fmax = 0, 100

    picks = pick_types(raw.info, meg=True, eeg=False, stim=True,
                       ecg=True, eog=True, include=['STI 014'],
                       exclude='bads')
    reject = dict(grad=4000e-13, mag=4e-12, eog=150e-6)

    events = find_events(raw, stim_channel='STI 014')
    epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
                    baseline=(None, 0), reject=reject)

    # only look at one epoch
    epochs.drop_bad_epochs()
    one_epochs = epochs[:1]

    inv = prepare_inverse_operator(inverse_operator, nave=1,
                                   lambda2=1. / 9., method="dSPM")
    # return list
    stc_psd = compute_source_psd_epochs(one_epochs, inv,
                                        lambda2=lambda2, method=method,
                                        pick_ori="normal", label=label,
                                        bandwidth=bandwidth,
                                        fmin=fmin, fmax=fmax,
                                        prepared=True)[0]

    # return generator
    stcs = compute_source_psd_epochs(one_epochs, inv,
                                     lambda2=lambda2, method=method,
                                     pick_ori="normal", label=label,
                                     bandwidth=bandwidth,
                                     fmin=fmin, fmax=fmax,
                                     return_generator=True,
                                     prepared=True)

    for stc in stcs:
        stc_psd_gen = stc

    assert_array_almost_equal(stc_psd.data, stc_psd_gen.data)

    # compare with direct computation
    stc = apply_inverse_epochs(one_epochs, inv,
                               lambda2=lambda2, method=method,
                               pick_ori="normal", label=label,
                               prepared=True)[0]

    sfreq = epochs.info['sfreq']
    psd, freqs = multitaper_psd(stc.data, sfreq=sfreq, bandwidth=bandwidth,
                                fmin=fmin, fmax=fmax)

    assert_array_almost_equal(psd, stc_psd.data)
    assert_array_almost_equal(freqs, stc_psd.times)

    # Check corner cases caused by tiny bandwidth
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter('always')
        compute_source_psd_epochs(one_epochs, inv,
                                  lambda2=lambda2, method=method,
                                  pick_ori="normal", label=label,
                                  bandwidth=0.01, low_bias=True,
                                  fmin=fmin, fmax=fmax,
                                  return_generator=False,
                                  prepared=True)
        compute_source_psd_epochs(one_epochs, inv,
                                  lambda2=lambda2, method=method,
                                  pick_ori="normal", label=label,
                                  bandwidth=0.01, low_bias=False,
                                  fmin=fmin, fmax=fmax,
                                  return_generator=False,
                                  prepared=True)
    assert_true(len(w) >= 2)
    assert_true(any('not properly use' in str(ww.message) for ww in w))
    assert_true(any('Bandwidth too small' in str(ww.message) for ww in w))
예제 #10
0
def do_preprocessing_combined(p, subjects, run_indices):
    """Do preprocessing on all raw files together.

    Calculates projection vectors to use to clean data.

    Parameters
    ----------
    p : instance of Parameters
        Analysis parameters.
    subjects : list of str
        Subject names to analyze (e.g., ['Eric_SoP_001', ...]).
    run_indices : array-like | None
        Run indices to include.
    """
    drop_logs = list()
    for si, subj in enumerate(subjects):
        proj_nums = _proj_nums(p, subj)
        ecg_channel = _handle_dict(p.ecg_channel, subj)
        flat = _handle_dict(p.flat, subj)
        if p.disp_files:
            print('  Preprocessing subject %g/%g (%s).' %
                  (si + 1, len(subjects), subj))
        pca_dir = _get_pca_dir(p, subj)
        bad_file = get_bad_fname(p, subj, check_exists=False)

        # Create SSP projection vectors after marking bad channels
        raw_names = get_raw_fnames(p, subj, 'sss', False, False,
                                   run_indices[si])
        empty_names = get_raw_fnames(p, subj, 'sss', 'only')
        for r in raw_names + empty_names:
            if not op.isfile(r):
                raise NameError('File not found (' + r + ')')

        fir_kwargs, old_kwargs = _get_fir_kwargs(p.fir_design)
        if isinstance(p.auto_bad, float):
            print('    Creating post SSS bad channel file:\n'
                  '        %s' % bad_file)
            # do autobad
            raw = _raw_LRFCP(raw_names,
                             p.proj_sfreq,
                             None,
                             None,
                             p.n_jobs_fir,
                             p.n_jobs_resample,
                             list(),
                             None,
                             p.disp_files,
                             method='fir',
                             filter_length=p.filter_length,
                             apply_proj=False,
                             force_bads=False,
                             l_trans=p.hp_trans,
                             h_trans=p.lp_trans,
                             phase=p.phase,
                             fir_window=p.fir_window,
                             pick=True,
                             skip_by_annotation='edge',
                             **fir_kwargs)
            events = fixed_len_events(p, raw)
            rtmin = p.reject_tmin \
                if p.reject_tmin is not None else p.tmin
            rtmax = p.reject_tmax \
                if p.reject_tmax is not None else p.tmax
            # do not mark eog channels bad
            meg, eeg = 'meg' in raw, 'eeg' in raw
            picks = pick_types(raw.info,
                               meg=meg,
                               eeg=eeg,
                               eog=False,
                               exclude=[])
            assert p.auto_bad_flat is None or isinstance(p.auto_bad_flat, dict)
            assert p.auto_bad_reject is None or \
                isinstance(p.auto_bad_reject, dict) or \
                p.auto_bad_reject == 'auto'
            if p.auto_bad_reject == 'auto':
                print('    Auto bad channel selection active. '
                      'Will try using Autoreject module to '
                      'compute rejection criterion.')
                try:
                    from autoreject import get_rejection_threshold
                except ImportError:
                    raise ImportError('     Autoreject module not installed.\n'
                                      '     Noisy channel detection parameter '
                                      '     not defined. To use autobad '
                                      '     channel selection either define '
                                      '     rejection criteria or install '
                                      '     Autoreject module.\n')
                print('    Computing thresholds.\n', end='')
                temp_epochs = Epochs(raw,
                                     events,
                                     event_id=None,
                                     tmin=rtmin,
                                     tmax=rtmax,
                                     baseline=_get_baseline(p),
                                     proj=True,
                                     reject=None,
                                     flat=None,
                                     preload=True,
                                     decim=1)
                kwargs = dict()
                if 'verbose' in get_args(get_rejection_threshold):
                    kwargs['verbose'] = False
                reject = get_rejection_threshold(temp_epochs, **kwargs)
                reject = {kk: vv for kk, vv in reject.items()}
            elif p.auto_bad_reject is None and p.auto_bad_flat is None:
                raise RuntimeError('Auto bad channel detection active. Noisy '
                                   'and flat channel detection '
                                   'parameters not defined. '
                                   'At least one criterion must be defined.')
            else:
                reject = p.auto_bad_reject
            if 'eog' in reject.keys():
                reject.pop('eog', None)
            epochs = Epochs(raw,
                            events,
                            None,
                            tmin=rtmin,
                            tmax=rtmax,
                            baseline=_get_baseline(p),
                            picks=picks,
                            reject=reject,
                            flat=p.auto_bad_flat,
                            proj=True,
                            preload=True,
                            decim=1,
                            reject_tmin=rtmin,
                            reject_tmax=rtmax)
            # channel scores from drop log
            drops = Counter([ch for d in epochs.drop_log for ch in d])
            # get rid of non-channel reasons in drop log
            scores = {
                kk: vv
                for kk, vv in drops.items() if kk in epochs.ch_names
            }
            ch_names = np.array(list(scores.keys()))
            # channel scores expressed as percentile and rank ordered
            counts = (100 * np.array([scores[ch] for ch in ch_names], float) /
                      len(epochs.drop_log))
            order = np.argsort(counts)[::-1]
            # boolean array masking out channels with <= % epochs dropped
            mask = counts[order] > p.auto_bad
            badchs = ch_names[order[mask]]
            if len(badchs) > 0:
                # Make sure we didn't get too many bad MEG or EEG channels
                for m, e, thresh in zip(
                    [True, False], [False, True],
                    [p.auto_bad_meg_thresh, p.auto_bad_eeg_thresh]):
                    picks = pick_types(epochs.info, meg=m, eeg=e, exclude=[])
                    if len(picks) > 0:
                        ch_names = [epochs.ch_names[pp] for pp in picks]
                        n_bad_type = sum(ch in ch_names for ch in badchs)
                        if n_bad_type > thresh:
                            stype = 'meg' if m else 'eeg'
                            raise RuntimeError('Too many bad %s channels '
                                               'found: %s > %s' %
                                               (stype, n_bad_type, thresh))

                print('    The following channels resulted in greater than '
                      '{:.0f}% trials dropped:\n'.format(p.auto_bad * 100))
                print(badchs)
                with open(bad_file, 'w') as f:
                    f.write('\n'.join(badchs))
        if not op.isfile(bad_file):
            print('    Clearing bad channels (no file %s)' %
                  op.sep.join(bad_file.split(op.sep)[-3:]))
            bad_file = None

        ecg_t_lims = _handle_dict(p.ecg_t_lims, subj)
        ecg_f_lims = p.ecg_f_lims

        ecg_eve = op.join(pca_dir, 'preproc_ecg-eve.fif')
        ecg_epo = op.join(pca_dir, 'preproc_ecg-epo.fif')
        ecg_proj = op.join(pca_dir, 'preproc_ecg-proj.fif')
        all_proj = op.join(pca_dir, 'preproc_all-proj.fif')

        get_projs_from = _handle_dict(p.get_projs_from, subj)
        if get_projs_from is None:
            get_projs_from = np.arange(len(raw_names))
        pre_list = [
            r for ri, r in enumerate(raw_names) if ri in get_projs_from
        ]

        projs = list()
        raw_orig = _raw_LRFCP(raw_names=pre_list,
                              sfreq=p.proj_sfreq,
                              l_freq=None,
                              h_freq=None,
                              n_jobs=p.n_jobs_fir,
                              n_jobs_resample=p.n_jobs_resample,
                              projs=projs,
                              bad_file=bad_file,
                              disp_files=p.disp_files,
                              method='fir',
                              filter_length=p.filter_length,
                              force_bads=False,
                              l_trans=p.hp_trans,
                              h_trans=p.lp_trans,
                              phase=p.phase,
                              fir_window=p.fir_window,
                              pick=True,
                              skip_by_annotation='edge',
                              **fir_kwargs)

        # Apply any user-supplied extra projectors
        if p.proj_extra is not None:
            if p.disp_files:
                print('    Adding extra projectors from "%s".' % p.proj_extra)
            projs.extend(read_proj(op.join(pca_dir, p.proj_extra)))

        proj_kwargs, p_sl = _get_proj_kwargs(p)
        #
        # Calculate and apply ERM projectors
        #
        if not p.cont_as_esss:
            if any(proj_nums[2]):
                assert proj_nums[2][2] == 0  # no EEG projectors for ERM
                if len(empty_names) == 0:
                    raise RuntimeError('Cannot compute empty-room projectors '
                                       'from continuous raw data')
                if p.disp_files:
                    print('    Computing continuous projectors using ERM.')
                # Use empty room(s), but processed the same way
                projs.extend(_compute_erm_proj(p, subj, projs, 'sss',
                                               bad_file))
            else:
                cont_proj = op.join(pca_dir, 'preproc_cont-proj.fif')
                _safe_remove(cont_proj)

        #
        # Calculate and apply the ECG projectors
        #
        if any(proj_nums[0]):
            if p.disp_files:
                print('    Computing ECG projectors...', end='')
            raw = raw_orig.copy()

            raw.filter(ecg_f_lims[0],
                       ecg_f_lims[1],
                       n_jobs=p.n_jobs_fir,
                       method='fir',
                       filter_length=p.filter_length,
                       l_trans_bandwidth=0.5,
                       h_trans_bandwidth=0.5,
                       phase='zero-double',
                       fir_window='hann',
                       skip_by_annotation='edge',
                       **old_kwargs)
            raw.add_proj(projs)
            raw.apply_proj()
            find_kwargs = dict()
            if 'reject_by_annotation' in get_args(find_ecg_events):
                find_kwargs['reject_by_annotation'] = True
            elif len(raw.annotations) > 0:
                print('    WARNING: ECG event detection will not make use of '
                      'annotations, please update MNE-Python')
            # We've already filtered the data channels above, but this
            # filters the ECG channel
            ecg_events = find_ecg_events(raw,
                                         999,
                                         ecg_channel,
                                         0.,
                                         ecg_f_lims[0],
                                         ecg_f_lims[1],
                                         qrs_threshold='auto',
                                         return_ecg=False,
                                         **find_kwargs)[0]
            use_reject, use_flat = _restrict_reject_flat(
                _handle_dict(p.ssp_ecg_reject, subj), flat, raw)
            ecg_epochs = Epochs(raw,
                                ecg_events,
                                999,
                                ecg_t_lims[0],
                                ecg_t_lims[1],
                                baseline=None,
                                reject=use_reject,
                                flat=use_flat,
                                preload=True)
            print('  obtained %d epochs from %d events.' %
                  (len(ecg_epochs), len(ecg_events)))
            if len(ecg_epochs) >= 20:
                write_events(ecg_eve, ecg_epochs.events)
                ecg_epochs.save(ecg_epo, **_get_epo_kwargs())
                desc_prefix = 'ECG-%s-%s' % tuple(ecg_t_lims)
                pr = compute_proj_wrap(ecg_epochs,
                                       p.proj_ave,
                                       n_grad=proj_nums[0][0],
                                       n_mag=proj_nums[0][1],
                                       n_eeg=proj_nums[0][2],
                                       desc_prefix=desc_prefix,
                                       **proj_kwargs)
                assert len(pr) == np.sum(proj_nums[0][::p_sl])
                write_proj(ecg_proj, pr)
                projs.extend(pr)
            else:
                plot_drop_log(ecg_epochs.drop_log)
                raw.plot(events=ecg_epochs.events)
                raise RuntimeError('Only %d/%d good ECG epochs found' %
                                   (len(ecg_epochs), len(ecg_events)))
            del raw, ecg_epochs, ecg_events
        else:
            _safe_remove([ecg_proj, ecg_eve, ecg_epo])

        #
        # Next calculate and apply the EOG projectors
        #
        for idx, kind in ((1, 'EOG'), (3, 'HEOG'), (4, 'VEOG')):
            _compute_add_eog(p, subj, raw_orig, projs, proj_nums[idx], kind,
                             pca_dir, flat, proj_kwargs, old_kwargs, p_sl)
        del proj_nums

        # save the projectors
        write_proj(all_proj, projs)

        #
        # Look at raw_orig for trial DQs now, it will be quick
        #
        raw_orig.filter(p.hp_cut,
                        p.lp_cut,
                        n_jobs=p.n_jobs_fir,
                        method='fir',
                        filter_length=p.filter_length,
                        l_trans_bandwidth=p.hp_trans,
                        phase=p.phase,
                        h_trans_bandwidth=p.lp_trans,
                        fir_window=p.fir_window,
                        skip_by_annotation='edge',
                        **fir_kwargs)
        raw_orig.add_proj(projs)
        raw_orig.apply_proj()
        # now let's epoch with 1-sec windows to look for DQs
        events = fixed_len_events(p, raw_orig)
        reject = _handle_dict(p.reject, subj)
        use_reject, use_flat = _restrict_reject_flat(reject, flat, raw_orig)
        epochs = Epochs(raw_orig,
                        events,
                        None,
                        p.tmin,
                        p.tmax,
                        preload=False,
                        baseline=_get_baseline(p),
                        reject=use_reject,
                        flat=use_flat,
                        proj=True)
        try:
            epochs.drop_bad()
        except AttributeError:  # old way
            epochs.drop_bad_epochs()
        drop_logs.append(epochs.drop_log)
        del raw_orig
        del epochs
    if p.plot_drop_logs:
        for subj, drop_log in zip(subjects, drop_logs):
            plot_drop_log(drop_log, p.drop_thresh, subject=subj)
예제 #11
0
event_fname = data_path + ('/MEG/sample/sample_audvis_filt-0-40_raw-'
                           'eve.fif')
event_id = {'Auditory/Left': 1}
tmin, tmax = -0.2, 0.5
events = mne.read_events(event_fname)

# pick EEG and MEG channels
raw.info['bads'] = []
picks = mne.pick_types(raw.info, meg=True, eeg=False, stim=False, eog=False,
                       include=[], exclude=[])

n_events = events.shape[0]
epochs = Epochs(raw, events, event_id, tmin, tmax,
                picks=picks, baseline=(None, 0), reject=None,
                verbose=False, detrend=0, preload=True)
epochs.drop_bad_epochs()

prefix = 'MEG data'
err_cons = grid_search(epochs, n_interpolates, consensus_percs,
                       prefix=prefix, n_folds=n_folds)

# try the best consensus perc to get clean evoked now
best_idx, best_jdx = np.unravel_index(err_cons.mean(axis=-1).argmin(),
                                      err_cons.shape[:2])
consensus_perc = consensus_percs[best_idx]
n_interpolate = n_interpolates[best_jdx]
auto_reject = ConsensusAutoReject(compute_threshes, consensus_perc,
                                  n_interpolate=n_interpolate)
epochs_transformed = auto_reject.fit_transform(epochs)

evoked = epochs.average()
예제 #12
0
def test_epoch_eq():
    """Test epoch count equalization and condition combining
    """
    # equalizing epochs objects
    epochs_1 = Epochs(raw, events, event_id, tmin, tmax, picks=picks)
    epochs_2 = Epochs(raw, events, event_id_2, tmin, tmax, picks=picks)
    epochs_1.drop_bad_epochs()  # make sure drops are logged
    assert_true(
        len([l for l in epochs_1.drop_log if not l]) == len(epochs_1.events))
    drop_log1 = epochs_1.drop_log = [[] for _ in range(len(epochs_1.events))]
    drop_log2 = [[] if l == ['EQUALIZED_COUNT'] else l
                 for l in epochs_1.drop_log]
    assert_true(drop_log1 == drop_log2)
    assert_true(
        len([l for l in epochs_1.drop_log if not l]) == len(epochs_1.events))
    assert_true(epochs_1.events.shape[0] != epochs_2.events.shape[0])
    equalize_epoch_counts([epochs_1, epochs_2], method='mintime')
    assert_true(epochs_1.events.shape[0] == epochs_2.events.shape[0])
    epochs_3 = Epochs(raw, events, event_id, tmin, tmax, picks=picks)
    epochs_4 = Epochs(raw, events, event_id_2, tmin, tmax, picks=picks)
    equalize_epoch_counts([epochs_3, epochs_4], method='truncate')
    assert_true(epochs_1.events.shape[0] == epochs_3.events.shape[0])
    assert_true(epochs_3.events.shape[0] == epochs_4.events.shape[0])

    # equalizing conditions
    epochs = Epochs(raw,
                    events, {
                        'a': 1,
                        'b': 2,
                        'c': 3,
                        'd': 4
                    },
                    tmin,
                    tmax,
                    picks=picks,
                    reject=reject)
    epochs.drop_bad_epochs()  # make sure drops are logged
    assert_true(
        len([l for l in epochs.drop_log if not l]) == len(epochs.events))
    drop_log1 = deepcopy(epochs.drop_log)
    old_shapes = [epochs[key].events.shape[0] for key in ['a', 'b', 'c', 'd']]
    epochs.equalize_event_counts(['a', 'b'], copy=False)
    # undo the eq logging
    drop_log2 = [[] if l == ['EQUALIZED_COUNT'] else l
                 for l in epochs.drop_log]
    assert_true(drop_log1 == drop_log2)

    assert_true(
        len([l for l in epochs.drop_log if not l]) == len(epochs.events))
    new_shapes = [epochs[key].events.shape[0] for key in ['a', 'b', 'c', 'd']]
    assert_true(new_shapes[0] == new_shapes[1])
    assert_true(new_shapes[2] == new_shapes[2])
    assert_true(new_shapes[3] == new_shapes[3])
    # now with two conditions collapsed
    old_shapes = new_shapes
    epochs.equalize_event_counts([['a', 'b'], 'c'], copy=False)
    new_shapes = [epochs[key].events.shape[0] for key in ['a', 'b', 'c', 'd']]
    assert_true(new_shapes[0] + new_shapes[1] == new_shapes[2])
    assert_true(new_shapes[3] == old_shapes[3])
    assert_raises(KeyError, epochs.equalize_event_counts, [1, 'a'])

    # now let's combine conditions
    old_shapes = new_shapes
    epochs = epochs.equalize_event_counts([['a', 'b'], ['c', 'd']])[0]
    new_shapes = [epochs[key].events.shape[0] for key in ['a', 'b', 'c', 'd']]
    assert_true(old_shapes[0] + old_shapes[1] == new_shapes[0] + new_shapes[1])
    assert_true(new_shapes[0] + new_shapes[1] == new_shapes[2] + new_shapes[3])
    assert_raises(ValueError, combine_event_ids, epochs, ['a', 'b'], {'ab': 1})

    combine_event_ids(epochs, ['a', 'b'], {'ab': 12}, copy=False)
    caught = 0
    for key in ['a', 'b']:
        try:
            epochs[key]
        except KeyError:
            caught += 1
    assert_raises(Exception, caught == 2)
    assert_true(not np.any(epochs.events[:, 2] == 1))
    assert_true(not np.any(epochs.events[:, 2] == 2))
    epochs = combine_event_ids(epochs, ['c', 'd'], {'cd': 34})
    assert_true(
        np.all(
            np.logical_or(epochs.events[:, 2] == 12, epochs.events[:,
                                                                   2] == 34)))
    assert_true(epochs['ab'].events.shape[0] == old_shapes[0] + old_shapes[1])
    assert_true(epochs['ab'].events.shape[0] == epochs['cd'].events.shape[0])
예제 #13
0
def test_source_psd_epochs():
    """Test multi-taper source PSD computation in label from epochs"""

    raw = io.Raw(fname_data)
    inverse_operator = read_inverse_operator(fname_inv)
    label = read_label(fname_label)

    event_id, tmin, tmax = 1, -0.2, 0.5
    lambda2, method = 1. / 9., 'dSPM'
    bandwidth = 8.
    fmin, fmax = 0, 100

    picks = pick_types(raw.info,
                       meg=True,
                       eeg=False,
                       stim=True,
                       ecg=True,
                       eog=True,
                       include=['STI 014'],
                       exclude='bads')
    reject = dict(grad=4000e-13, mag=4e-12, eog=150e-6)

    events = find_events(raw, stim_channel='STI 014')
    epochs = Epochs(raw,
                    events,
                    event_id,
                    tmin,
                    tmax,
                    picks=picks,
                    baseline=(None, 0),
                    reject=reject)

    # only look at one epoch
    epochs.drop_bad_epochs()
    one_epochs = epochs[:1]

    inv = prepare_inverse_operator(inverse_operator,
                                   nave=1,
                                   lambda2=1. / 9.,
                                   method="dSPM")
    # return list
    stc_psd = compute_source_psd_epochs(one_epochs,
                                        inv,
                                        lambda2=lambda2,
                                        method=method,
                                        pick_ori="normal",
                                        label=label,
                                        bandwidth=bandwidth,
                                        fmin=fmin,
                                        fmax=fmax,
                                        prepared=True)[0]

    # return generator
    stcs = compute_source_psd_epochs(one_epochs,
                                     inv,
                                     lambda2=lambda2,
                                     method=method,
                                     pick_ori="normal",
                                     label=label,
                                     bandwidth=bandwidth,
                                     fmin=fmin,
                                     fmax=fmax,
                                     return_generator=True,
                                     prepared=True)

    for stc in stcs:
        stc_psd_gen = stc

    assert_array_almost_equal(stc_psd.data, stc_psd_gen.data)

    # compare with direct computation
    stc = apply_inverse_epochs(one_epochs,
                               inv,
                               lambda2=lambda2,
                               method=method,
                               pick_ori="normal",
                               label=label,
                               prepared=True)[0]

    sfreq = epochs.info['sfreq']
    psd, freqs = _psd_multitaper(stc.data,
                                 sfreq=sfreq,
                                 bandwidth=bandwidth,
                                 fmin=fmin,
                                 fmax=fmax)

    assert_array_almost_equal(psd, stc_psd.data)
    assert_array_almost_equal(freqs, stc_psd.times)

    # Check corner cases caused by tiny bandwidth
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter('always')
        compute_source_psd_epochs(one_epochs,
                                  inv,
                                  lambda2=lambda2,
                                  method=method,
                                  pick_ori="normal",
                                  label=label,
                                  bandwidth=0.01,
                                  low_bias=True,
                                  fmin=fmin,
                                  fmax=fmax,
                                  return_generator=False,
                                  prepared=True)
        compute_source_psd_epochs(one_epochs,
                                  inv,
                                  lambda2=lambda2,
                                  method=method,
                                  pick_ori="normal",
                                  label=label,
                                  bandwidth=0.01,
                                  low_bias=False,
                                  fmin=fmin,
                                  fmax=fmax,
                                  return_generator=False,
                                  prepared=True)
    assert_true(len(w) >= 2)
    assert_true(any('not properly use' in str(ww.message) for ww in w))
    assert_true(any('Bandwidth too small' in str(ww.message) for ww in w))
예제 #14
0
def test_source_psd_epochs():
    """Test multi-taper source PSD computation in label from epochs"""

    raw = fiff.Raw(fname_data)
    inverse_operator = read_inverse_operator(fname_inv)
    label = read_label(fname_label)

    event_id, tmin, tmax = 1, -0.2, 0.5
    lambda2, method = 1. / 9., 'dSPM'
    bandwidth = 8.
    fmin, fmax = 0, 100

    picks = fiff.pick_types(raw.info,
                            meg=True,
                            eeg=False,
                            stim=True,
                            ecg=True,
                            eog=True,
                            include=['STI 014'],
                            exclude='bads')
    reject = dict(grad=4000e-13, mag=4e-12, eog=150e-6)

    events = find_events(raw, stim_channel='STI 014')
    epochs = Epochs(raw,
                    events,
                    event_id,
                    tmin,
                    tmax,
                    picks=picks,
                    baseline=(None, 0),
                    reject=reject)

    # only look at one epoch
    epochs.drop_bad_epochs()
    one_epochs = epochs[:1]

    # return list
    stc_psd = compute_source_psd_epochs(one_epochs,
                                        inverse_operator,
                                        lambda2=lambda2,
                                        method=method,
                                        pick_ori="normal",
                                        label=label,
                                        bandwidth=bandwidth,
                                        fmin=fmin,
                                        fmax=fmax)[0]

    # return generator
    stcs = compute_source_psd_epochs(one_epochs,
                                     inverse_operator,
                                     lambda2=lambda2,
                                     method=method,
                                     pick_ori="normal",
                                     label=label,
                                     bandwidth=bandwidth,
                                     fmin=fmin,
                                     fmax=fmax,
                                     return_generator=True)

    for stc in stcs:
        stc_psd_gen = stc

    assert_array_almost_equal(stc_psd.data, stc_psd_gen.data)

    # compare with direct computation
    stc = apply_inverse_epochs(one_epochs,
                               inverse_operator,
                               lambda2=lambda2,
                               method=method,
                               pick_ori="normal",
                               label=label)[0]

    sfreq = epochs.info['sfreq']
    psd, freqs = multitaper_psd(stc.data,
                                sfreq=sfreq,
                                bandwidth=bandwidth,
                                fmin=fmin,
                                fmax=fmax)

    assert_array_almost_equal(psd, stc_psd.data)
    assert_array_almost_equal(freqs, stc_psd.times)
예제 #15
0
                       include=[],
                       exclude=[])

n_events = events.shape[0]
epochs = Epochs(raw,
                events,
                event_id,
                tmin,
                tmax,
                picks=picks,
                baseline=(None, 0),
                reject=None,
                verbose=False,
                detrend=0,
                preload=True)
epochs.drop_bad_epochs()

prefix = 'MEG data'
err_cons = grid_search(epochs,
                       n_interpolates,
                       consensus_percs,
                       prefix=prefix,
                       n_folds=n_folds)

# try the best consensus perc to get clean evoked now
best_idx, best_jdx = np.unravel_index(
    err_cons.mean(axis=-1).argmin(), err_cons.shape[:2])
consensus_perc = consensus_percs[best_idx]
n_interpolate = n_interpolates[best_jdx]
auto_reject = ConsensusAutoReject(compute_threshes,
                                  consensus_perc,