Beispiel #1
0
def test_epoch_eq():
    """Test epoch count equalization and condition combining
    """
    # equalizing epochs objects
    epochs_1 = Epochs(raw, events, event_id, tmin, tmax, picks=picks)
    epochs_2 = Epochs(raw, events, event_id_2, tmin, tmax, picks=picks)
    assert_true(epochs_1.events.shape[0] != epochs_2.events.shape[0])
    equalize_epoch_counts([epochs_1, epochs_2], method='mintime')
    assert_true(epochs_1.events.shape[0] == epochs_2.events.shape[0])
    epochs_3 = Epochs(raw, events, event_id, tmin, tmax, picks=picks)
    epochs_4 = Epochs(raw, events, event_id_2, tmin, tmax, picks=picks)
    equalize_epoch_counts([epochs_3, epochs_4], method='truncate')
    assert_true(epochs_1.events.shape[0] == epochs_3.events.shape[0])
    assert_true(epochs_3.events.shape[0] == epochs_4.events.shape[0])

    # equalizing conditions
    epochs = Epochs(raw, events, {'a': 1, 'b': 2, 'c': 3, 'd': 4},
                    tmin, tmax, picks=picks)
    old_shapes = [epochs[key].events.shape[0] for key in ['a', 'b', 'c', 'd']]
    epochs.equalize_event_counts(['a', 'b'], copy=False)
    new_shapes = [epochs[key].events.shape[0] for key in ['a', 'b', 'c', 'd']]
    assert_true(new_shapes[0] == new_shapes[1])
    assert_true(new_shapes[2] == new_shapes[2])
    assert_true(new_shapes[3] == new_shapes[3])
    # now with two conditions collapsed
    old_shapes = new_shapes
    epochs.equalize_event_counts([['a', 'b'], 'c'], copy=False)
    new_shapes = [epochs[key].events.shape[0] for key in ['a', 'b', 'c', 'd']]
    assert_true(new_shapes[0] + new_shapes[1] == new_shapes[2])
    assert_true(new_shapes[3] == old_shapes[3])
    assert_raises(KeyError, epochs.equalize_event_counts, [1, 'a'])

    # now let's combine conditions
    old_shapes = new_shapes
    epochs = epochs.equalize_event_counts([['a', 'b'], ['c', 'd']])[0]
    new_shapes = [epochs[key].events.shape[0] for key in ['a', 'b', 'c', 'd']]
    assert_true(old_shapes[0] + old_shapes[1] == new_shapes[0] + new_shapes[1])
    assert_true(new_shapes[0] + new_shapes[1] == new_shapes[2] + new_shapes[3])
    assert_raises(ValueError, combine_event_ids, epochs, ['a', 'b'],
                  {'ab': 1})

    combine_event_ids(epochs, ['a', 'b'], {'ab': 12}, copy=False)
    caught = 0
    for key in ['a', 'b']:
        try:
            epochs[key]
        except KeyError:
            caught += 1
    assert_raises(caught == 2)
    assert_true(not np.any(epochs.events[:, 2] == 1))
    assert_true(not np.any(epochs.events[:, 2] == 2))
    epochs = combine_event_ids(epochs, ['c', 'd'], {'cd': 34})
    assert_true(np.all(np.logical_or(epochs.events[:, 2] == 12,
                                     epochs.events[:, 2] == 34)))
    assert_true(epochs['ab'].events.shape[0] == old_shapes[0] + old_shapes[1])
    assert_true(epochs['ab'].events.shape[0] == epochs['cd'].events.shape[0])
Beispiel #2
0
def test_epoch_combine_ids():
    """Test combining event ids in epochs compared to events
    """
    for preload in [False]:
        epochs = Epochs(
            raw, events, {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5, "f": 32}, tmin, tmax, picks=picks, preload=preload
        )
        events_new = merge_events(events, [1, 2], 12)
        epochs_new = combine_event_ids(epochs, ["a", "b"], {"ab": 12})
        assert_array_equal(events_new, epochs_new.events)
Beispiel #3
0
def test_epoch_combine_ids():
    """Test combining event ids in epochs compared to events
    """
    for preload in [False]:
        epochs = Epochs(raw, events, {'a': 1, 'b': 2, 'c': 3,
                                      'd': 4, 'e': 5, 'f': 32},
                        tmin, tmax, picks=picks, preload=preload)
        events_new = merge_events(events, [1, 2], 12)
        epochs_new = combine_event_ids(epochs, ['a', 'b'], {'ab': 12})
        assert_array_equal(events_new, epochs_new.events)
def test_epoch_combine_ids():
    """Test combining event ids in epochs compared to events
    """
    for preload in [False]:
        epochs = Epochs(raw, events, {'a': 1, 'b': 2, 'c': 3,
                                      'd': 4, 'e': 5, 'f': 32},
                        tmin, tmax, picks=picks, preload=preload)
        events_new = merge_events(events, [1, 2], 12)
        epochs_new = combine_event_ids(epochs, ['a', 'b'], {'ab': 12})
        assert_array_equal(events_new, epochs_new.events)
Beispiel #5
0
def eq_trials(epochs, analysis, nn, in_names_match, names, numbers):
    """Equalize trial counts."""
    # Someday we should pass these in...
    off = max(epochs.events[:, 2].max(), nn.max()) + 1
    assert analysis in 'All Individual Oddball AM IDs'.split(), analysis
    if analysis != 'Oddball':
        return  # signal to use default method
    print(f'      Equalizing with sub-condition matching: {in_names_match}')
    assert in_names_match == ['std', 'ba', 'wa'], in_names_match
    assert names == ['standard', 'deviant'], names
    epochs = epochs[in_names_match]
    epochs.equalize_event_counts(['ba', 'wa'])
    for idx, collapse in enumerate((['std'], ['ba', 'wa'])):
        combine_event_ids(epochs,
                          collapse, {names[idx]: numbers[idx] + off},
                          copy=False)
        epochs.events[epochs.events[:, 2] == numbers[idx] + off, 2] -= off
        epochs.event_id[names[idx]] = numbers[idx]
    epochs.equalize_event_counts(['standard', 'deviant'])
    return epochs
def create_epochs_from_intervals(raw: Raw, intervals: List[tuple]) -> Epochs:
    events, _ = events_from_annotations(raw)

    epochs_list = []
    for interval in intervals:
        start_idx = np.where(events[..., 2] == interval[0])[0]
        end_idx = np.where(events[..., 2] == interval[1])[0]

        raw_cropped = raw.copy().crop(
            tmin=events[start_idx[0]][0] / raw.info["sfreq"],
            tmax=events[end_idx[0]][0] / raw.info["sfreq"],
        )

        epochs = create_epochs(raw_cropped)
        combine_event_ids(epochs,
                          list(epochs.event_id.keys()),
                          interval[0],
                          copy=False)

        epochs_list.append(epochs)

    return concatenate_epochs(epochs_list)
Beispiel #7
0
def save_epochs(p, subjects, in_names, in_numbers, analyses, out_names,
                out_numbers, must_match, decim, run_indices):
    """Generate epochs from raw data based on events

    Can only complete after preprocessing is complete.

    Parameters
    ----------
    p : instance of Parameters
        Analysis parameters.
    subjects : list of str
        Subject names to analyze (e.g., ['Eric_SoP_001', ...]).
    in_names : list of str
        Names of input events.
    in_numbers : list of list of int
        Event numbers (in scored event files) associated with each name.
    analyses : list of str
        Lists of analyses of interest.
    out_names : list of list of str
        Event types to make out of old ones.
    out_numbers : list of list of int
        Event numbers to convert to (e.g., [[1, 1, 2, 3, 3], ...] would create
        three event types, where the first two and last two event types from
        the original list get collapsed over).
    must_match : list of int
        Indices from the original in_names that must match in event counts
        before collapsing. Should eventually be expanded to allow for
        ratio-based collapsing.
    decim : int | list of int
        Amount to decimate.
    run_indices : array-like | None
        Run indices to include.
    """
    in_names = np.asanyarray(in_names)
    old_dict = dict()
    for n, e in zip(in_names, in_numbers):
        old_dict[n] = e

    # let's do some sanity checks
    if len(in_names) != len(in_numbers):
        raise RuntimeError('in_names (%d) must have same length as '
                           'in_numbers (%d)' %
                           (len(in_names), len(in_numbers)))
    if np.any(np.array(in_numbers) <= 0):
        raise ValueError('in_numbers must all be > 0')
    if len(out_names) != len(out_numbers):
        raise RuntimeError('out_names must have same length as out_numbers')
    for name, num in zip(out_names, out_numbers):
        num = np.array(num)
        if len(name) != len(np.unique(num[num > 0])):
            raise RuntimeError('each entry in out_names must have length '
                               'equal to the number of unique elements in the '
                               'corresponding entry in out_numbers:\n%s\n%s' %
                               (name, np.unique(num[num > 0])))
        if len(num) != len(in_names):
            raise RuntimeError('each entry in out_numbers must have the same '
                               'length as in_names')
        if (np.array(num) == 0).any():
            raise ValueError('no element of out_numbers can be zero')

    ch_namess = list()
    drop_logs = list()
    sfreqs = set()
    for si, subj in enumerate(subjects):
        if p.disp_files:
            print('  Loading raw files for subject %s.' % subj)
        epochs_dir = op.join(p.work_dir, subj, p.epochs_dir)
        if not op.isdir(epochs_dir):
            os.mkdir(epochs_dir)
        evoked_dir = op.join(p.work_dir, subj, p.inverse_dir)
        if not op.isdir(evoked_dir):
            os.mkdir(evoked_dir)
        # read in raw files
        raw_names = get_raw_fnames(p, subj, 'pca', False, False,
                                   run_indices[si])
        first_samps = []
        last_samps = []
        for raw_fname in raw_names:
            raw = read_raw_fif(raw_fname, preload=False)
            first_samps.append(raw._first_samps[0])
            last_samps.append(raw._last_samps[-1])
        raw = [read_raw_fif(fname, preload=False) for fname in raw_names]
        _fix_raw_eog_cals(raw)  # EOG epoch scales might be bad!
        raw = concatenate_raws(raw)
        # read in events
        events = _read_events(p, subj, run_indices[si], raw)
        this_decim = _handle_decim(decim[si], raw.info['sfreq'])
        new_sfreq = raw.info['sfreq'] / this_decim
        if p.disp_files:
            print('    Epoching data (decim=%s -> sfreq=%0.1f Hz).' %
                  (this_decim, new_sfreq))
        if new_sfreq not in sfreqs:
            if len(sfreqs) > 0:
                warnings.warn('resulting new sampling frequency %s not equal '
                              'to previous values %s' % (new_sfreq, sfreqs))
            sfreqs.add(new_sfreq)
        epochs_fnames, evoked_fnames = get_epochs_evokeds_fnames(
            p, subj, analyses)
        mat_file, fif_file = epochs_fnames
        if p.autoreject_thresholds:
            assert len(p.autoreject_types) > 0
            assert all(a in ('mag', 'grad', 'eeg', 'ecg', 'eog')
                       for a in p.autoreject_types)
            from autoreject import get_rejection_threshold
            print('    Computing autoreject thresholds', end='')
            rtmin = p.reject_tmin if p.reject_tmin is not None else p.tmin
            rtmax = p.reject_tmax if p.reject_tmax is not None else p.tmax
            temp_epochs = Epochs(raw,
                                 events,
                                 event_id=None,
                                 tmin=rtmin,
                                 tmax=rtmax,
                                 baseline=_get_baseline(p),
                                 proj=True,
                                 reject=None,
                                 flat=None,
                                 preload=True,
                                 decim=this_decim,
                                 reject_by_annotation=p.reject_epochs_by_annot)
            kwargs = dict()
            if 'verbose' in get_args(get_rejection_threshold):
                kwargs['verbose'] = False
            new_dict = get_rejection_threshold(temp_epochs, **kwargs)
            use_reject = dict()
            msgs = list()
            for k in p.autoreject_types:
                msgs.append('%s=%d %s' % (k, DEFAULTS['scalings'][k] *
                                          new_dict[k], DEFAULTS['units'][k]))
                use_reject[k] = new_dict[k]
            print(': ' + ', '.join(msgs))
            hdf5_file = fif_file.replace('-epo.fif', '-reject.h5')
            assert hdf5_file.endswith('.h5')
            write_hdf5(hdf5_file, use_reject, overwrite=True)
        else:
            use_reject = _handle_dict(p.reject, subj)
        # create epochs
        flat = _handle_dict(p.flat, subj)
        use_reject, use_flat = _restrict_reject_flat(use_reject, flat, raw)
        epochs = Epochs(raw,
                        events,
                        event_id=old_dict,
                        tmin=p.tmin,
                        tmax=p.tmax,
                        baseline=_get_baseline(p),
                        reject=use_reject,
                        flat=use_flat,
                        proj=p.epochs_proj,
                        preload=True,
                        decim=this_decim,
                        on_missing=p.on_missing,
                        reject_tmin=p.reject_tmin,
                        reject_tmax=p.reject_tmax,
                        reject_by_annotation=p.reject_epochs_by_annot)
        del raw
        if epochs.events.shape[0] < 1:
            epochs.plot_drop_log()
            raise ValueError('No valid epochs')
        drop_logs.append(epochs.drop_log)
        ch_namess.append(epochs.ch_names)
        # only kept trials that were not dropped
        sfreq = epochs.info['sfreq']
        # now deal with conditions to save evoked
        if p.disp_files:
            print('    Matching trial counts and saving data to disk.')
        for var, name in ((out_names, 'out_names'), (out_numbers,
                                                     'out_numbers'),
                          (must_match, 'must_match'), (evoked_fnames,
                                                       'evoked_fnames')):
            if len(var) != len(analyses):
                raise ValueError('len(%s) (%s) != len(analyses) (%s)' %
                                 (name, len(var), len(analyses)))
        for analysis, names, numbers, match, fn in zip(analyses, out_names,
                                                       out_numbers, must_match,
                                                       evoked_fnames):
            # do matching
            numbers = np.asanyarray(numbers)
            nn = numbers[numbers >= 0]
            new_numbers = []
            for num in numbers:
                if num > 0 and num not in new_numbers:
                    # Eventually we could relax this requirement, but not
                    # having it in place is likely to cause people pain...
                    if any(num < n for n in new_numbers):
                        raise RuntimeError('each list of new_numbers must be '
                                           ' monotonically increasing')
                    new_numbers.append(num)
            new_numbers = np.array(new_numbers)
            in_names_match = in_names[match]
            # use some variables to allow safe name re-use
            offset = max(epochs.events[:, 2].max(), new_numbers.max()) + 1
            safety_str = '__mnefun_copy__'
            assert len(new_numbers) == len(names)  # checked above
            if p.match_fun is None:
                # first, equalize trial counts (this will make a copy)
                e = epochs[list(in_names[numbers > 0])]
                if len(in_names_match) > 1:
                    e.equalize_event_counts(in_names_match)

                # second, collapse relevant types
                for num, name in zip(new_numbers, names):
                    collapse = [
                        x for x in in_names[num == numbers] if x in e.event_id
                    ]
                    combine_event_ids(e,
                                      collapse,
                                      {name + safety_str: num + offset},
                                      copy=False)
                for num, name in zip(new_numbers, names):
                    e.events[e.events[:, 2] == num + offset, 2] -= offset
                    e.event_id[name] = num
                    del e.event_id[name + safety_str]
            else:  # custom matching
                e = p.match_fun(epochs.copy(), analysis, nn, in_names_match,
                                names)

            # now make evoked for each out type
            evokeds = list()
            n_standard = 0
            kinds = ['standard']
            if p.every_other:
                kinds += ['even', 'odd']
            for kind in kinds:
                for name in names:
                    this_e = e[name]
                    if kind == 'even':
                        this_e = this_e[::2]
                    elif kind == 'odd':
                        this_e = this_e[1::2]
                    else:
                        assert kind == 'standard'
                    if len(this_e) > 0:
                        ave = this_e.average(picks='all')
                        stde = this_e.standard_error(picks='all')
                        if kind != 'standard':
                            ave.comment += ' %s' % (kind, )
                            stde.comment += ' %s' % (kind, )
                        evokeds.append(ave)
                        evokeds.append(stde)
                        if kind == 'standard':
                            n_standard += 2
            write_evokeds(fn, evokeds)
            naves = [
                str(n) for n in sorted(
                    set([evoked.nave for evoked in evokeds[:n_standard]]))
            ]
            naves = ', '.join(naves)
            if p.disp_files:
                print('      Analysis "%s": %s epochs / condition' %
                      (analysis, naves))

        if p.disp_files:
            print('    Saving epochs to disk.')
        if 'mat' in p.epochs_type:
            spio.savemat(mat_file,
                         dict(epochs=epochs.get_data(),
                              events=epochs.events,
                              sfreq=sfreq,
                              drop_log=epochs.drop_log),
                         do_compression=True,
                         oned_as='column')
        if 'fif' in p.epochs_type:
            epochs.save(fif_file, **_get_epo_kwargs())

    if p.plot_drop_logs:
        for subj, drop_log in zip(subjects, drop_logs):
            plot_drop_log(drop_log, threshold=p.drop_thresh, subject=subj)
Beispiel #8
0
def test_epoch_eq():
    """Test epoch count equalization and condition combining
    """
    # equalizing epochs objects
    epochs_1 = Epochs(raw, events, event_id, tmin, tmax, picks=picks)
    epochs_2 = Epochs(raw, events, event_id_2, tmin, tmax, picks=picks)
    epochs_1.drop_bad_epochs()  # make sure drops are logged
    assert_true(len([l for l in epochs_1.drop_log if not l]) == len(epochs_1.events))
    drop_log1 = epochs_1.drop_log = [[] for _ in range(len(epochs_1.events))]
    drop_log2 = [[] if l == ["EQUALIZED_COUNT"] else l for l in epochs_1.drop_log]
    assert_true(drop_log1 == drop_log2)
    assert_true(len([l for l in epochs_1.drop_log if not l]) == len(epochs_1.events))
    assert_true(epochs_1.events.shape[0] != epochs_2.events.shape[0])
    equalize_epoch_counts([epochs_1, epochs_2], method="mintime")
    assert_true(epochs_1.events.shape[0] == epochs_2.events.shape[0])
    epochs_3 = Epochs(raw, events, event_id, tmin, tmax, picks=picks)
    epochs_4 = Epochs(raw, events, event_id_2, tmin, tmax, picks=picks)
    equalize_epoch_counts([epochs_3, epochs_4], method="truncate")
    assert_true(epochs_1.events.shape[0] == epochs_3.events.shape[0])
    assert_true(epochs_3.events.shape[0] == epochs_4.events.shape[0])

    # equalizing conditions
    epochs = Epochs(raw, events, {"a": 1, "b": 2, "c": 3, "d": 4}, tmin, tmax, picks=picks, reject=reject)
    epochs.drop_bad_epochs()  # make sure drops are logged
    assert_true(len([l for l in epochs.drop_log if not l]) == len(epochs.events))
    drop_log1 = deepcopy(epochs.drop_log)
    old_shapes = [epochs[key].events.shape[0] for key in ["a", "b", "c", "d"]]
    epochs.equalize_event_counts(["a", "b"], copy=False)
    # undo the eq logging
    drop_log2 = [[] if l == ["EQUALIZED_COUNT"] else l for l in epochs.drop_log]
    assert_true(drop_log1 == drop_log2)

    assert_true(len([l for l in epochs.drop_log if not l]) == len(epochs.events))
    new_shapes = [epochs[key].events.shape[0] for key in ["a", "b", "c", "d"]]
    assert_true(new_shapes[0] == new_shapes[1])
    assert_true(new_shapes[2] == new_shapes[2])
    assert_true(new_shapes[3] == new_shapes[3])
    # now with two conditions collapsed
    old_shapes = new_shapes
    epochs.equalize_event_counts([["a", "b"], "c"], copy=False)
    new_shapes = [epochs[key].events.shape[0] for key in ["a", "b", "c", "d"]]
    assert_true(new_shapes[0] + new_shapes[1] == new_shapes[2])
    assert_true(new_shapes[3] == old_shapes[3])
    assert_raises(KeyError, epochs.equalize_event_counts, [1, "a"])

    # now let's combine conditions
    old_shapes = new_shapes
    epochs = epochs.equalize_event_counts([["a", "b"], ["c", "d"]])[0]
    new_shapes = [epochs[key].events.shape[0] for key in ["a", "b", "c", "d"]]
    assert_true(old_shapes[0] + old_shapes[1] == new_shapes[0] + new_shapes[1])
    assert_true(new_shapes[0] + new_shapes[1] == new_shapes[2] + new_shapes[3])
    assert_raises(ValueError, combine_event_ids, epochs, ["a", "b"], {"ab": 1})

    combine_event_ids(epochs, ["a", "b"], {"ab": 12}, copy=False)
    caught = 0
    for key in ["a", "b"]:
        try:
            epochs[key]
        except KeyError:
            caught += 1
    assert_raises(Exception, caught == 2)
    assert_true(not np.any(epochs.events[:, 2] == 1))
    assert_true(not np.any(epochs.events[:, 2] == 2))
    epochs = combine_event_ids(epochs, ["c", "d"], {"cd": 34})
    assert_true(np.all(np.logical_or(epochs.events[:, 2] == 12, epochs.events[:, 2] == 34)))
    assert_true(epochs["ab"].events.shape[0] == old_shapes[0] + old_shapes[1])
    assert_true(epochs["ab"].events.shape[0] == epochs["cd"].events.shape[0])
Beispiel #9
0
events = mne.read_events(event_fname)

#   Set up pick list: EEG + STI 014 - bad channels (modify to your needs)
include = []  # or stim channels ['STI 014']
raw.info['bads'] += ['EEG 053']  # bads + 1 more

# pick EEG channels
picks = fiff.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=True,
                        include=include, exclude='bads')
# Read epochs
epochs = mne.Epochs(raw, events, event_ids, tmin, tmax, picks=picks,
                    baseline=(None, 0), reject=dict(eeg=80e-6, eog=150e-6))
# Let's equalize the trial counts in each condition
epochs.equalize_event_counts(['AudL', 'AudR', 'VisL', 'VisR'], copy=False)
# Now let's combine some conditions
combine_event_ids(epochs, ['AudL', 'AudR'], {'Auditory': 12}, copy=False)
combine_event_ids(epochs, ['VisL', 'VisR'], {'Visual': 34}, copy=False)

# average epochs and get Evoked datasets
evokeds = [epochs[cond].average() for cond in ['Auditory', 'Visual']]

# save evoked data to disk
fiff.write_evoked('sample_auditory_and_visual_eeg-ave.fif', evokeds)

###############################################################################
# View evoked response
import matplotlib.pyplot as plt
plt.clf()
ax = plt.subplot(2, 1, 1)
evokeds[0].plot(axes=ax)
plt.title('EEG evoked potential, auditory trials')
Beispiel #10
0
def test_epoch_eq():
    """Test epoch count equalization and condition combining
    """
    # equalizing epochs objects
    epochs_1 = Epochs(raw, events, event_id, tmin, tmax, picks=picks)
    epochs_2 = Epochs(raw, events, event_id_2, tmin, tmax, picks=picks)
    epochs_1.drop_bad_epochs()  # make sure drops are logged
    assert_true(
        len([l for l in epochs_1.drop_log if not l]) == len(epochs_1.events))
    drop_log1 = epochs_1.drop_log = [[] for _ in range(len(epochs_1.events))]
    drop_log2 = [[] if l == ['EQUALIZED_COUNT'] else l
                 for l in epochs_1.drop_log]
    assert_true(drop_log1 == drop_log2)
    assert_true(
        len([l for l in epochs_1.drop_log if not l]) == len(epochs_1.events))
    assert_true(epochs_1.events.shape[0] != epochs_2.events.shape[0])
    equalize_epoch_counts([epochs_1, epochs_2], method='mintime')
    assert_true(epochs_1.events.shape[0] == epochs_2.events.shape[0])
    epochs_3 = Epochs(raw, events, event_id, tmin, tmax, picks=picks)
    epochs_4 = Epochs(raw, events, event_id_2, tmin, tmax, picks=picks)
    equalize_epoch_counts([epochs_3, epochs_4], method='truncate')
    assert_true(epochs_1.events.shape[0] == epochs_3.events.shape[0])
    assert_true(epochs_3.events.shape[0] == epochs_4.events.shape[0])

    # equalizing conditions
    epochs = Epochs(raw,
                    events, {
                        'a': 1,
                        'b': 2,
                        'c': 3,
                        'd': 4
                    },
                    tmin,
                    tmax,
                    picks=picks,
                    reject=reject)
    epochs.drop_bad_epochs()  # make sure drops are logged
    assert_true(
        len([l for l in epochs.drop_log if not l]) == len(epochs.events))
    drop_log1 = deepcopy(epochs.drop_log)
    old_shapes = [epochs[key].events.shape[0] for key in ['a', 'b', 'c', 'd']]
    epochs.equalize_event_counts(['a', 'b'], copy=False)
    # undo the eq logging
    drop_log2 = [[] if l == ['EQUALIZED_COUNT'] else l
                 for l in epochs.drop_log]
    assert_true(drop_log1 == drop_log2)

    assert_true(
        len([l for l in epochs.drop_log if not l]) == len(epochs.events))
    new_shapes = [epochs[key].events.shape[0] for key in ['a', 'b', 'c', 'd']]
    assert_true(new_shapes[0] == new_shapes[1])
    assert_true(new_shapes[2] == new_shapes[2])
    assert_true(new_shapes[3] == new_shapes[3])
    # now with two conditions collapsed
    old_shapes = new_shapes
    epochs.equalize_event_counts([['a', 'b'], 'c'], copy=False)
    new_shapes = [epochs[key].events.shape[0] for key in ['a', 'b', 'c', 'd']]
    assert_true(new_shapes[0] + new_shapes[1] == new_shapes[2])
    assert_true(new_shapes[3] == old_shapes[3])
    assert_raises(KeyError, epochs.equalize_event_counts, [1, 'a'])

    # now let's combine conditions
    old_shapes = new_shapes
    epochs = epochs.equalize_event_counts([['a', 'b'], ['c', 'd']])[0]
    new_shapes = [epochs[key].events.shape[0] for key in ['a', 'b', 'c', 'd']]
    assert_true(old_shapes[0] + old_shapes[1] == new_shapes[0] + new_shapes[1])
    assert_true(new_shapes[0] + new_shapes[1] == new_shapes[2] + new_shapes[3])
    assert_raises(ValueError, combine_event_ids, epochs, ['a', 'b'], {'ab': 1})

    combine_event_ids(epochs, ['a', 'b'], {'ab': 12}, copy=False)
    caught = 0
    for key in ['a', 'b']:
        try:
            epochs[key]
        except KeyError:
            caught += 1
    assert_raises(Exception, caught == 2)
    assert_true(not np.any(epochs.events[:, 2] == 1))
    assert_true(not np.any(epochs.events[:, 2] == 2))
    epochs = combine_event_ids(epochs, ['c', 'd'], {'cd': 34})
    assert_true(
        np.all(
            np.logical_or(epochs.events[:, 2] == 12, epochs.events[:,
                                                                   2] == 34)))
    assert_true(epochs['ab'].events.shape[0] == old_shapes[0] + old_shapes[1])
    assert_true(epochs['ab'].events.shape[0] == epochs['cd'].events.shape[0])
Beispiel #11
0
    fname = op.join(out_dir, f"INTERVALS_{subject}.h5")
    if op.isfile(fname):
        data = mne.externals.h5io.read_hdf5(fname)
        assert events == data["events"]
        assert intervals == data["intervals"]
        aucs[si] = data["auc"]
        continue
    print("Fitting estimator for %s: " % subject)
    ep_fname = op.join(workdir, subject, "epochs",
                       "All_%d-sss_%s-epo.fif" % (lp, subject))

    epochs = read_epochs(ep_fname)
    epochs.crop(-0.2, 0.6).apply_baseline()  #
    epochs = epochs["std", "ba", "wa"]
    if combine:
        combine_event_ids(epochs, ["ba", "wa"], {"dev": 23}, copy=False)
    epochs.equalize_event_counts(epochs.event_id.keys())
    epochs.pick_types(meg="grad")  # get rid of trigger channels
    epochs.drop_bad()
    epochs.filter(None, 25).decimate(3)  # lowpass and !decimate
    # epochs.crop(*win)
    for ci, cs in enumerate(events):
        for ii, interval in enumerate(intervals):
            eps = epochs[cs]
            ix = eps.time_as_index(win[0])[0], eps.time_as_index(win[1])[0]
            if "First" in interval:
                sl = slice(None, len(eps) // 3)
            elif "Second" in interval:
                sl = slice(len(eps) // 3, 2 * len(eps) // 3)
            elif "Third" in interval:
                sl = slice(2 * len(eps) // 3, None)