def test_read_epochs_bad_events(): """Test epochs when events are at the beginning or the end of the file """ # Event at the beginning epochs = Epochs( raw, np.array([[raw.first_samp, 0, event_id]]), event_id, tmin, tmax, picks=picks, baseline=(None, 0) ) with warnings.catch_warnings(record=True): evoked = epochs.average() epochs = Epochs( raw, np.array([[raw.first_samp, 0, event_id]]), event_id, tmin, tmax, picks=picks, baseline=(None, 0) ) epochs.drop_bad_epochs() with warnings.catch_warnings(record=True): evoked = epochs.average() # Event at the end epochs = Epochs( raw, np.array([[raw.last_samp, 0, event_id]]), event_id, tmin, tmax, picks=picks, baseline=(None, 0) ) with warnings.catch_warnings(record=True): evoked = epochs.average() assert evoked warnings.resetwarnings()
def test_read_epochs_bad_events(): """Test epochs when events are at the beginning or the end of the file """ # Event at the beginning epochs = Epochs(raw, np.array([[raw.first_samp, 0, event_id]]), event_id, tmin, tmax, picks=picks, baseline=(None, 0)) evoked = epochs.average() epochs = Epochs(raw, np.array([[raw.first_samp, 0, event_id]]), event_id, tmin, tmax, picks=picks, baseline=(None, 0)) epochs.drop_bad_epochs() evoked = epochs.average() # Event at the end epochs = Epochs(raw, np.array([[raw.last_samp, 0, event_id]]), event_id, tmin, tmax, picks=picks, baseline=(None, 0)) evoked = epochs.average() assert evoked
def test_reject_epochs(): """Test of epochs rejection """ epochs = Epochs(raw, events, event_id, tmin, tmax, baseline=(None, 0), reject=reject, flat=flat) n_events = len(epochs.events) data = epochs.get_data() n_clean_epochs = len(data) # Should match # mne_process_raw --raw test_raw.fif --projoff \ # --saveavetag -ave --ave test.ave --filteroff assert_true(n_events > n_clean_epochs) assert_true(n_clean_epochs == 3) assert_true(epochs.drop_log == [[], [], [], ['MEG 2443'], ['MEG 2443'], ['MEG 2443'], ['MEG 2443']]) # Ensure epochs are not dropped based on a bad channel raw_2 = raw.copy() raw_2.info['bads'] = ['MEG 2443'] reject_crazy = dict(grad=1000e-15, mag=4e-15, eeg=80e-9, eog=150e-9) epochs = Epochs(raw_2, events, event_id, tmin, tmax, baseline=(None, 0), reject=reject_crazy, flat=flat) epochs.drop_bad_epochs() assert_true(all(['MEG 2442' in e for e in epochs.drop_log])) assert_true(all(['MEG 2443' not in e for e in epochs.drop_log])) epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), reject=reject, flat=flat, reject_tmin=0., reject_tmax=.1) data = epochs.get_data() n_clean_epochs = len(data) assert_true(n_clean_epochs == 7) assert_true(epochs.times[epochs._reject_time][0] >= 0.) assert_true(epochs.times[epochs._reject_time][-1] <= 0.1)
def test_source_psd_epochs(): """Test multi-taper source PSD computation in label from epochs""" raw = fiff.Raw(fname_data) inverse_operator = read_inverse_operator(fname_inv) label = read_label(fname_label) event_id, tmin, tmax = 1, -0.2, 0.5 lambda2, method = 1. / 9., 'dSPM' bandwidth = 8. fmin, fmax = 0, 100 picks = fiff.pick_types(raw.info, meg=True, eeg=False, stim=True, ecg=True, eog=True, include=['STI 014'], exclude='bads') reject = dict(grad=4000e-13, mag=4e-12, eog=150e-6) events = find_events(raw) epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), reject=reject) # only look at one epoch epochs.drop_bad_epochs() one_epochs = epochs[:1] # return list stc_psd = compute_source_psd_epochs(one_epochs, inverse_operator, lambda2=lambda2, method=method, pick_normal=True, label=label, bandwidth=bandwidth, fmin=fmin, fmax=fmax)[0] # return generator stcs = compute_source_psd_epochs(one_epochs, inverse_operator, lambda2=lambda2, method=method, pick_normal=True, label=label, bandwidth=bandwidth, fmin=fmin, fmax=fmax, return_generator=True) for stc in stcs: stc_psd_gen = stc assert_array_almost_equal(stc_psd.data, stc_psd_gen.data) # compare with direct computation stc = apply_inverse_epochs(one_epochs, inverse_operator, lambda2=lambda2, method=method, pick_normal=True, label=label)[0] sfreq = epochs.info['sfreq'] psd, freqs = multitaper_psd(stc.data, sfreq=sfreq, bandwidth=bandwidth, fmin=fmin, fmax=fmax) assert_array_almost_equal(psd, stc_psd.data) assert_array_almost_equal(freqs, stc_psd.times)
def test_indexing_slicing(): """Test of indexing and slicing operations """ raw, events, picks = _get_data() epochs = Epochs(raw, events[:20], event_id, tmin, tmax, picks=picks, baseline=(None, 0), preload=False, reject=reject, flat=flat) data_normal = epochs.get_data() n_good_events = data_normal.shape[0] # indices for slicing start_index = 1 end_index = n_good_events - 1 assert((end_index - start_index) > 0) for preload in [True, False]: epochs2 = Epochs(raw, events[:20], event_id, tmin, tmax, picks=picks, baseline=(None, 0), preload=preload, reject=reject, flat=flat) if not preload: epochs2.drop_bad_epochs() # using slicing epochs2_sliced = epochs2[start_index:end_index] data_epochs2_sliced = epochs2_sliced.get_data() assert_array_equal(data_epochs2_sliced, data_normal[start_index:end_index]) # using indexing pos = 0 for idx in range(start_index, end_index): data = epochs2_sliced[pos].get_data() assert_array_equal(data[0], data_normal[idx]) pos += 1 # using indexing with an int data = epochs2[data_epochs2_sliced.shape[0]].get_data() assert_array_equal(data, data_normal[[idx]]) # using indexing with an array idx = np.random.randint(0, data_epochs2_sliced.shape[0], 10) data = epochs2[idx].get_data() assert_array_equal(data, data_normal[idx]) # using indexing with a list of indices idx = [0] data = epochs2[idx].get_data() assert_array_equal(data, data_normal[idx]) idx = [0, 1] data = epochs2[idx].get_data() assert_array_equal(data, data_normal[idx])
def test_indexing_slicing(): """Test of indexing and slicing operations """ epochs = Epochs(raw, events[:20], event_id, tmin, tmax, picks=picks, baseline=(None, 0), preload=False, reject=reject, flat=flat) data_normal = epochs.get_data() n_good_events = data_normal.shape[0] # indices for slicing start_index = 1 end_index = n_good_events - 1 assert((end_index - start_index) > 0) for preload in [True, False]: epochs2 = Epochs(raw, events[:20], event_id, tmin, tmax, picks=picks, baseline=(None, 0), preload=preload, reject=reject, flat=flat) if not preload: epochs2.drop_bad_epochs() # using slicing epochs2_sliced = epochs2[start_index:end_index] data_epochs2_sliced = epochs2_sliced.get_data() assert_array_equal(data_epochs2_sliced, data_normal[start_index:end_index]) # using indexing pos = 0 for idx in range(start_index, end_index): data = epochs2_sliced[pos].get_data() assert_array_equal(data[0], data_normal[idx]) pos += 1 # using indexing with an int data = epochs2[data_epochs2_sliced.shape[0]].get_data() assert_array_equal(data, data_normal[[idx]]) # using indexing with an array idx = np.random.randint(0, data_epochs2_sliced.shape[0], 10) data = epochs2[idx].get_data() assert_array_equal(data, data_normal[idx]) # using indexing with a list of indices idx = [0] data = epochs2[idx].get_data() assert_array_equal(data, data_normal[idx]) idx = [0, 1] data = epochs2[idx].get_data() assert_array_equal(data, data_normal[idx])
def test_read_epochs_bad_events(): """Test epochs when events are at the beginning or the end of the file """ # Event at the beginning epochs = Epochs(raw, np.array([[raw.first_samp, 0, event_id]]), event_id, tmin, tmax, picks=picks, baseline=(None, 0)) evoked = epochs.average() epochs = Epochs(raw, np.array([[raw.first_samp, 0, event_id]]), event_id, tmin, tmax, picks=picks, baseline=(None, 0)) epochs.drop_bad_epochs() evoked = epochs.average() # Event at the end epochs = Epochs(raw, np.array([[raw.last_samp, 0, event_id]]), event_id, tmin, tmax, picks=picks, baseline=(None, 0)) evoked = epochs.average()
def test_epoch_eq(): """Test epoch count equalization and condition combining """ # equalizing epochs objects epochs_1 = Epochs(raw, events, event_id, tmin, tmax, picks=picks) epochs_2 = Epochs(raw, events, event_id_2, tmin, tmax, picks=picks) epochs_1.drop_bad_epochs() # make sure drops are logged assert_true(len([l for l in epochs_1.drop_log if not l]) == len(epochs_1.events)) drop_log1 = epochs_1.drop_log = [[] for _ in range(len(epochs_1.events))] drop_log2 = [[] if l == ["EQUALIZED_COUNT"] else l for l in epochs_1.drop_log] assert_true(drop_log1 == drop_log2) assert_true(len([l for l in epochs_1.drop_log if not l]) == len(epochs_1.events)) assert_true(epochs_1.events.shape[0] != epochs_2.events.shape[0]) equalize_epoch_counts([epochs_1, epochs_2], method="mintime") assert_true(epochs_1.events.shape[0] == epochs_2.events.shape[0]) epochs_3 = Epochs(raw, events, event_id, tmin, tmax, picks=picks) epochs_4 = Epochs(raw, events, event_id_2, tmin, tmax, picks=picks) equalize_epoch_counts([epochs_3, epochs_4], method="truncate") assert_true(epochs_1.events.shape[0] == epochs_3.events.shape[0]) assert_true(epochs_3.events.shape[0] == epochs_4.events.shape[0]) # equalizing conditions epochs = Epochs(raw, events, {"a": 1, "b": 2, "c": 3, "d": 4}, tmin, tmax, picks=picks, reject=reject) epochs.drop_bad_epochs() # make sure drops are logged assert_true(len([l for l in epochs.drop_log if not l]) == len(epochs.events)) drop_log1 = deepcopy(epochs.drop_log) old_shapes = [epochs[key].events.shape[0] for key in ["a", "b", "c", "d"]] epochs.equalize_event_counts(["a", "b"], copy=False) # undo the eq logging drop_log2 = [[] if l == ["EQUALIZED_COUNT"] else l for l in epochs.drop_log] assert_true(drop_log1 == drop_log2) assert_true(len([l for l in epochs.drop_log if not l]) == len(epochs.events)) new_shapes = [epochs[key].events.shape[0] for key in ["a", "b", "c", "d"]] assert_true(new_shapes[0] == new_shapes[1]) assert_true(new_shapes[2] == new_shapes[2]) assert_true(new_shapes[3] == new_shapes[3]) # now with two conditions collapsed old_shapes = new_shapes epochs.equalize_event_counts([["a", "b"], "c"], copy=False) new_shapes = [epochs[key].events.shape[0] for key in ["a", "b", "c", "d"]] assert_true(new_shapes[0] + new_shapes[1] == new_shapes[2]) assert_true(new_shapes[3] == old_shapes[3]) assert_raises(KeyError, epochs.equalize_event_counts, [1, "a"]) # now let's combine conditions old_shapes = new_shapes epochs = epochs.equalize_event_counts([["a", "b"], ["c", "d"]])[0] new_shapes = [epochs[key].events.shape[0] for key in ["a", "b", "c", "d"]] assert_true(old_shapes[0] + old_shapes[1] == new_shapes[0] + new_shapes[1]) assert_true(new_shapes[0] + new_shapes[1] == new_shapes[2] + new_shapes[3]) assert_raises(ValueError, combine_event_ids, epochs, ["a", "b"], {"ab": 1}) combine_event_ids(epochs, ["a", "b"], {"ab": 12}, copy=False) caught = 0 for key in ["a", "b"]: try: epochs[key] except KeyError: caught += 1 assert_raises(Exception, caught == 2) assert_true(not np.any(epochs.events[:, 2] == 1)) assert_true(not np.any(epochs.events[:, 2] == 2)) epochs = combine_event_ids(epochs, ["c", "d"], {"cd": 34}) assert_true(np.all(np.logical_or(epochs.events[:, 2] == 12, epochs.events[:, 2] == 34))) assert_true(epochs["ab"].events.shape[0] == old_shapes[0] + old_shapes[1]) assert_true(epochs["ab"].events.shape[0] == epochs["cd"].events.shape[0])
def test_source_psd_epochs(): """Test multi-taper source PSD computation in label from epochs""" raw = io.Raw(fname_data) inverse_operator = read_inverse_operator(fname_inv) label = read_label(fname_label) event_id, tmin, tmax = 1, -0.2, 0.5 lambda2, method = 1. / 9., 'dSPM' bandwidth = 8. fmin, fmax = 0, 100 picks = pick_types(raw.info, meg=True, eeg=False, stim=True, ecg=True, eog=True, include=['STI 014'], exclude='bads') reject = dict(grad=4000e-13, mag=4e-12, eog=150e-6) events = find_events(raw, stim_channel='STI 014') epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), reject=reject) # only look at one epoch epochs.drop_bad_epochs() one_epochs = epochs[:1] inv = prepare_inverse_operator(inverse_operator, nave=1, lambda2=1. / 9., method="dSPM") # return list stc_psd = compute_source_psd_epochs(one_epochs, inv, lambda2=lambda2, method=method, pick_ori="normal", label=label, bandwidth=bandwidth, fmin=fmin, fmax=fmax, prepared=True)[0] # return generator stcs = compute_source_psd_epochs(one_epochs, inv, lambda2=lambda2, method=method, pick_ori="normal", label=label, bandwidth=bandwidth, fmin=fmin, fmax=fmax, return_generator=True, prepared=True) for stc in stcs: stc_psd_gen = stc assert_array_almost_equal(stc_psd.data, stc_psd_gen.data) # compare with direct computation stc = apply_inverse_epochs(one_epochs, inv, lambda2=lambda2, method=method, pick_ori="normal", label=label, prepared=True)[0] sfreq = epochs.info['sfreq'] psd, freqs = multitaper_psd(stc.data, sfreq=sfreq, bandwidth=bandwidth, fmin=fmin, fmax=fmax) assert_array_almost_equal(psd, stc_psd.data) assert_array_almost_equal(freqs, stc_psd.times) # Check corner cases caused by tiny bandwidth with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') compute_source_psd_epochs(one_epochs, inv, lambda2=lambda2, method=method, pick_ori="normal", label=label, bandwidth=0.01, low_bias=True, fmin=fmin, fmax=fmax, return_generator=False, prepared=True) compute_source_psd_epochs(one_epochs, inv, lambda2=lambda2, method=method, pick_ori="normal", label=label, bandwidth=0.01, low_bias=False, fmin=fmin, fmax=fmax, return_generator=False, prepared=True) assert_true(len(w) >= 2) assert_true(any('not properly use' in str(ww.message) for ww in w)) assert_true(any('Bandwidth too small' in str(ww.message) for ww in w))
def do_preprocessing_combined(p, subjects, run_indices): """Do preprocessing on all raw files together. Calculates projection vectors to use to clean data. Parameters ---------- p : instance of Parameters Analysis parameters. subjects : list of str Subject names to analyze (e.g., ['Eric_SoP_001', ...]). run_indices : array-like | None Run indices to include. """ drop_logs = list() for si, subj in enumerate(subjects): proj_nums = _proj_nums(p, subj) ecg_channel = _handle_dict(p.ecg_channel, subj) flat = _handle_dict(p.flat, subj) if p.disp_files: print(' Preprocessing subject %g/%g (%s).' % (si + 1, len(subjects), subj)) pca_dir = _get_pca_dir(p, subj) bad_file = get_bad_fname(p, subj, check_exists=False) # Create SSP projection vectors after marking bad channels raw_names = get_raw_fnames(p, subj, 'sss', False, False, run_indices[si]) empty_names = get_raw_fnames(p, subj, 'sss', 'only') for r in raw_names + empty_names: if not op.isfile(r): raise NameError('File not found (' + r + ')') fir_kwargs, old_kwargs = _get_fir_kwargs(p.fir_design) if isinstance(p.auto_bad, float): print(' Creating post SSS bad channel file:\n' ' %s' % bad_file) # do autobad raw = _raw_LRFCP(raw_names, p.proj_sfreq, None, None, p.n_jobs_fir, p.n_jobs_resample, list(), None, p.disp_files, method='fir', filter_length=p.filter_length, apply_proj=False, force_bads=False, l_trans=p.hp_trans, h_trans=p.lp_trans, phase=p.phase, fir_window=p.fir_window, pick=True, skip_by_annotation='edge', **fir_kwargs) events = fixed_len_events(p, raw) rtmin = p.reject_tmin \ if p.reject_tmin is not None else p.tmin rtmax = p.reject_tmax \ if p.reject_tmax is not None else p.tmax # do not mark eog channels bad meg, eeg = 'meg' in raw, 'eeg' in raw picks = pick_types(raw.info, meg=meg, eeg=eeg, eog=False, exclude=[]) assert p.auto_bad_flat is None or isinstance(p.auto_bad_flat, dict) assert p.auto_bad_reject is None or \ isinstance(p.auto_bad_reject, dict) or \ p.auto_bad_reject == 'auto' if p.auto_bad_reject == 'auto': print(' Auto bad channel selection active. ' 'Will try using Autoreject module to ' 'compute rejection criterion.') try: from autoreject import get_rejection_threshold except ImportError: raise ImportError(' Autoreject module not installed.\n' ' Noisy channel detection parameter ' ' not defined. To use autobad ' ' channel selection either define ' ' rejection criteria or install ' ' Autoreject module.\n') print(' Computing thresholds.\n', end='') temp_epochs = Epochs(raw, events, event_id=None, tmin=rtmin, tmax=rtmax, baseline=_get_baseline(p), proj=True, reject=None, flat=None, preload=True, decim=1) kwargs = dict() if 'verbose' in get_args(get_rejection_threshold): kwargs['verbose'] = False reject = get_rejection_threshold(temp_epochs, **kwargs) reject = {kk: vv for kk, vv in reject.items()} elif p.auto_bad_reject is None and p.auto_bad_flat is None: raise RuntimeError('Auto bad channel detection active. Noisy ' 'and flat channel detection ' 'parameters not defined. ' 'At least one criterion must be defined.') else: reject = p.auto_bad_reject if 'eog' in reject.keys(): reject.pop('eog', None) epochs = Epochs(raw, events, None, tmin=rtmin, tmax=rtmax, baseline=_get_baseline(p), picks=picks, reject=reject, flat=p.auto_bad_flat, proj=True, preload=True, decim=1, reject_tmin=rtmin, reject_tmax=rtmax) # channel scores from drop log drops = Counter([ch for d in epochs.drop_log for ch in d]) # get rid of non-channel reasons in drop log scores = { kk: vv for kk, vv in drops.items() if kk in epochs.ch_names } ch_names = np.array(list(scores.keys())) # channel scores expressed as percentile and rank ordered counts = (100 * np.array([scores[ch] for ch in ch_names], float) / len(epochs.drop_log)) order = np.argsort(counts)[::-1] # boolean array masking out channels with <= % epochs dropped mask = counts[order] > p.auto_bad badchs = ch_names[order[mask]] if len(badchs) > 0: # Make sure we didn't get too many bad MEG or EEG channels for m, e, thresh in zip( [True, False], [False, True], [p.auto_bad_meg_thresh, p.auto_bad_eeg_thresh]): picks = pick_types(epochs.info, meg=m, eeg=e, exclude=[]) if len(picks) > 0: ch_names = [epochs.ch_names[pp] for pp in picks] n_bad_type = sum(ch in ch_names for ch in badchs) if n_bad_type > thresh: stype = 'meg' if m else 'eeg' raise RuntimeError('Too many bad %s channels ' 'found: %s > %s' % (stype, n_bad_type, thresh)) print(' The following channels resulted in greater than ' '{:.0f}% trials dropped:\n'.format(p.auto_bad * 100)) print(badchs) with open(bad_file, 'w') as f: f.write('\n'.join(badchs)) if not op.isfile(bad_file): print(' Clearing bad channels (no file %s)' % op.sep.join(bad_file.split(op.sep)[-3:])) bad_file = None ecg_t_lims = _handle_dict(p.ecg_t_lims, subj) ecg_f_lims = p.ecg_f_lims ecg_eve = op.join(pca_dir, 'preproc_ecg-eve.fif') ecg_epo = op.join(pca_dir, 'preproc_ecg-epo.fif') ecg_proj = op.join(pca_dir, 'preproc_ecg-proj.fif') all_proj = op.join(pca_dir, 'preproc_all-proj.fif') get_projs_from = _handle_dict(p.get_projs_from, subj) if get_projs_from is None: get_projs_from = np.arange(len(raw_names)) pre_list = [ r for ri, r in enumerate(raw_names) if ri in get_projs_from ] projs = list() raw_orig = _raw_LRFCP(raw_names=pre_list, sfreq=p.proj_sfreq, l_freq=None, h_freq=None, n_jobs=p.n_jobs_fir, n_jobs_resample=p.n_jobs_resample, projs=projs, bad_file=bad_file, disp_files=p.disp_files, method='fir', filter_length=p.filter_length, force_bads=False, l_trans=p.hp_trans, h_trans=p.lp_trans, phase=p.phase, fir_window=p.fir_window, pick=True, skip_by_annotation='edge', **fir_kwargs) # Apply any user-supplied extra projectors if p.proj_extra is not None: if p.disp_files: print(' Adding extra projectors from "%s".' % p.proj_extra) projs.extend(read_proj(op.join(pca_dir, p.proj_extra))) proj_kwargs, p_sl = _get_proj_kwargs(p) # # Calculate and apply ERM projectors # if not p.cont_as_esss: if any(proj_nums[2]): assert proj_nums[2][2] == 0 # no EEG projectors for ERM if len(empty_names) == 0: raise RuntimeError('Cannot compute empty-room projectors ' 'from continuous raw data') if p.disp_files: print(' Computing continuous projectors using ERM.') # Use empty room(s), but processed the same way projs.extend(_compute_erm_proj(p, subj, projs, 'sss', bad_file)) else: cont_proj = op.join(pca_dir, 'preproc_cont-proj.fif') _safe_remove(cont_proj) # # Calculate and apply the ECG projectors # if any(proj_nums[0]): if p.disp_files: print(' Computing ECG projectors...', end='') raw = raw_orig.copy() raw.filter(ecg_f_lims[0], ecg_f_lims[1], n_jobs=p.n_jobs_fir, method='fir', filter_length=p.filter_length, l_trans_bandwidth=0.5, h_trans_bandwidth=0.5, phase='zero-double', fir_window='hann', skip_by_annotation='edge', **old_kwargs) raw.add_proj(projs) raw.apply_proj() find_kwargs = dict() if 'reject_by_annotation' in get_args(find_ecg_events): find_kwargs['reject_by_annotation'] = True elif len(raw.annotations) > 0: print(' WARNING: ECG event detection will not make use of ' 'annotations, please update MNE-Python') # We've already filtered the data channels above, but this # filters the ECG channel ecg_events = find_ecg_events(raw, 999, ecg_channel, 0., ecg_f_lims[0], ecg_f_lims[1], qrs_threshold='auto', return_ecg=False, **find_kwargs)[0] use_reject, use_flat = _restrict_reject_flat( _handle_dict(p.ssp_ecg_reject, subj), flat, raw) ecg_epochs = Epochs(raw, ecg_events, 999, ecg_t_lims[0], ecg_t_lims[1], baseline=None, reject=use_reject, flat=use_flat, preload=True) print(' obtained %d epochs from %d events.' % (len(ecg_epochs), len(ecg_events))) if len(ecg_epochs) >= 20: write_events(ecg_eve, ecg_epochs.events) ecg_epochs.save(ecg_epo, **_get_epo_kwargs()) desc_prefix = 'ECG-%s-%s' % tuple(ecg_t_lims) pr = compute_proj_wrap(ecg_epochs, p.proj_ave, n_grad=proj_nums[0][0], n_mag=proj_nums[0][1], n_eeg=proj_nums[0][2], desc_prefix=desc_prefix, **proj_kwargs) assert len(pr) == np.sum(proj_nums[0][::p_sl]) write_proj(ecg_proj, pr) projs.extend(pr) else: plot_drop_log(ecg_epochs.drop_log) raw.plot(events=ecg_epochs.events) raise RuntimeError('Only %d/%d good ECG epochs found' % (len(ecg_epochs), len(ecg_events))) del raw, ecg_epochs, ecg_events else: _safe_remove([ecg_proj, ecg_eve, ecg_epo]) # # Next calculate and apply the EOG projectors # for idx, kind in ((1, 'EOG'), (3, 'HEOG'), (4, 'VEOG')): _compute_add_eog(p, subj, raw_orig, projs, proj_nums[idx], kind, pca_dir, flat, proj_kwargs, old_kwargs, p_sl) del proj_nums # save the projectors write_proj(all_proj, projs) # # Look at raw_orig for trial DQs now, it will be quick # raw_orig.filter(p.hp_cut, p.lp_cut, n_jobs=p.n_jobs_fir, method='fir', filter_length=p.filter_length, l_trans_bandwidth=p.hp_trans, phase=p.phase, h_trans_bandwidth=p.lp_trans, fir_window=p.fir_window, skip_by_annotation='edge', **fir_kwargs) raw_orig.add_proj(projs) raw_orig.apply_proj() # now let's epoch with 1-sec windows to look for DQs events = fixed_len_events(p, raw_orig) reject = _handle_dict(p.reject, subj) use_reject, use_flat = _restrict_reject_flat(reject, flat, raw_orig) epochs = Epochs(raw_orig, events, None, p.tmin, p.tmax, preload=False, baseline=_get_baseline(p), reject=use_reject, flat=use_flat, proj=True) try: epochs.drop_bad() except AttributeError: # old way epochs.drop_bad_epochs() drop_logs.append(epochs.drop_log) del raw_orig del epochs if p.plot_drop_logs: for subj, drop_log in zip(subjects, drop_logs): plot_drop_log(drop_log, p.drop_thresh, subject=subj)
event_fname = data_path + ('/MEG/sample/sample_audvis_filt-0-40_raw-' 'eve.fif') event_id = {'Auditory/Left': 1} tmin, tmax = -0.2, 0.5 events = mne.read_events(event_fname) # pick EEG and MEG channels raw.info['bads'] = [] picks = mne.pick_types(raw.info, meg=True, eeg=False, stim=False, eog=False, include=[], exclude=[]) n_events = events.shape[0] epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), reject=None, verbose=False, detrend=0, preload=True) epochs.drop_bad_epochs() prefix = 'MEG data' err_cons = grid_search(epochs, n_interpolates, consensus_percs, prefix=prefix, n_folds=n_folds) # try the best consensus perc to get clean evoked now best_idx, best_jdx = np.unravel_index(err_cons.mean(axis=-1).argmin(), err_cons.shape[:2]) consensus_perc = consensus_percs[best_idx] n_interpolate = n_interpolates[best_jdx] auto_reject = ConsensusAutoReject(compute_threshes, consensus_perc, n_interpolate=n_interpolate) epochs_transformed = auto_reject.fit_transform(epochs) evoked = epochs.average()
def test_epoch_eq(): """Test epoch count equalization and condition combining """ # equalizing epochs objects epochs_1 = Epochs(raw, events, event_id, tmin, tmax, picks=picks) epochs_2 = Epochs(raw, events, event_id_2, tmin, tmax, picks=picks) epochs_1.drop_bad_epochs() # make sure drops are logged assert_true( len([l for l in epochs_1.drop_log if not l]) == len(epochs_1.events)) drop_log1 = epochs_1.drop_log = [[] for _ in range(len(epochs_1.events))] drop_log2 = [[] if l == ['EQUALIZED_COUNT'] else l for l in epochs_1.drop_log] assert_true(drop_log1 == drop_log2) assert_true( len([l for l in epochs_1.drop_log if not l]) == len(epochs_1.events)) assert_true(epochs_1.events.shape[0] != epochs_2.events.shape[0]) equalize_epoch_counts([epochs_1, epochs_2], method='mintime') assert_true(epochs_1.events.shape[0] == epochs_2.events.shape[0]) epochs_3 = Epochs(raw, events, event_id, tmin, tmax, picks=picks) epochs_4 = Epochs(raw, events, event_id_2, tmin, tmax, picks=picks) equalize_epoch_counts([epochs_3, epochs_4], method='truncate') assert_true(epochs_1.events.shape[0] == epochs_3.events.shape[0]) assert_true(epochs_3.events.shape[0] == epochs_4.events.shape[0]) # equalizing conditions epochs = Epochs(raw, events, { 'a': 1, 'b': 2, 'c': 3, 'd': 4 }, tmin, tmax, picks=picks, reject=reject) epochs.drop_bad_epochs() # make sure drops are logged assert_true( len([l for l in epochs.drop_log if not l]) == len(epochs.events)) drop_log1 = deepcopy(epochs.drop_log) old_shapes = [epochs[key].events.shape[0] for key in ['a', 'b', 'c', 'd']] epochs.equalize_event_counts(['a', 'b'], copy=False) # undo the eq logging drop_log2 = [[] if l == ['EQUALIZED_COUNT'] else l for l in epochs.drop_log] assert_true(drop_log1 == drop_log2) assert_true( len([l for l in epochs.drop_log if not l]) == len(epochs.events)) new_shapes = [epochs[key].events.shape[0] for key in ['a', 'b', 'c', 'd']] assert_true(new_shapes[0] == new_shapes[1]) assert_true(new_shapes[2] == new_shapes[2]) assert_true(new_shapes[3] == new_shapes[3]) # now with two conditions collapsed old_shapes = new_shapes epochs.equalize_event_counts([['a', 'b'], 'c'], copy=False) new_shapes = [epochs[key].events.shape[0] for key in ['a', 'b', 'c', 'd']] assert_true(new_shapes[0] + new_shapes[1] == new_shapes[2]) assert_true(new_shapes[3] == old_shapes[3]) assert_raises(KeyError, epochs.equalize_event_counts, [1, 'a']) # now let's combine conditions old_shapes = new_shapes epochs = epochs.equalize_event_counts([['a', 'b'], ['c', 'd']])[0] new_shapes = [epochs[key].events.shape[0] for key in ['a', 'b', 'c', 'd']] assert_true(old_shapes[0] + old_shapes[1] == new_shapes[0] + new_shapes[1]) assert_true(new_shapes[0] + new_shapes[1] == new_shapes[2] + new_shapes[3]) assert_raises(ValueError, combine_event_ids, epochs, ['a', 'b'], {'ab': 1}) combine_event_ids(epochs, ['a', 'b'], {'ab': 12}, copy=False) caught = 0 for key in ['a', 'b']: try: epochs[key] except KeyError: caught += 1 assert_raises(Exception, caught == 2) assert_true(not np.any(epochs.events[:, 2] == 1)) assert_true(not np.any(epochs.events[:, 2] == 2)) epochs = combine_event_ids(epochs, ['c', 'd'], {'cd': 34}) assert_true( np.all( np.logical_or(epochs.events[:, 2] == 12, epochs.events[:, 2] == 34))) assert_true(epochs['ab'].events.shape[0] == old_shapes[0] + old_shapes[1]) assert_true(epochs['ab'].events.shape[0] == epochs['cd'].events.shape[0])
def test_source_psd_epochs(): """Test multi-taper source PSD computation in label from epochs""" raw = io.Raw(fname_data) inverse_operator = read_inverse_operator(fname_inv) label = read_label(fname_label) event_id, tmin, tmax = 1, -0.2, 0.5 lambda2, method = 1. / 9., 'dSPM' bandwidth = 8. fmin, fmax = 0, 100 picks = pick_types(raw.info, meg=True, eeg=False, stim=True, ecg=True, eog=True, include=['STI 014'], exclude='bads') reject = dict(grad=4000e-13, mag=4e-12, eog=150e-6) events = find_events(raw, stim_channel='STI 014') epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), reject=reject) # only look at one epoch epochs.drop_bad_epochs() one_epochs = epochs[:1] inv = prepare_inverse_operator(inverse_operator, nave=1, lambda2=1. / 9., method="dSPM") # return list stc_psd = compute_source_psd_epochs(one_epochs, inv, lambda2=lambda2, method=method, pick_ori="normal", label=label, bandwidth=bandwidth, fmin=fmin, fmax=fmax, prepared=True)[0] # return generator stcs = compute_source_psd_epochs(one_epochs, inv, lambda2=lambda2, method=method, pick_ori="normal", label=label, bandwidth=bandwidth, fmin=fmin, fmax=fmax, return_generator=True, prepared=True) for stc in stcs: stc_psd_gen = stc assert_array_almost_equal(stc_psd.data, stc_psd_gen.data) # compare with direct computation stc = apply_inverse_epochs(one_epochs, inv, lambda2=lambda2, method=method, pick_ori="normal", label=label, prepared=True)[0] sfreq = epochs.info['sfreq'] psd, freqs = _psd_multitaper(stc.data, sfreq=sfreq, bandwidth=bandwidth, fmin=fmin, fmax=fmax) assert_array_almost_equal(psd, stc_psd.data) assert_array_almost_equal(freqs, stc_psd.times) # Check corner cases caused by tiny bandwidth with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') compute_source_psd_epochs(one_epochs, inv, lambda2=lambda2, method=method, pick_ori="normal", label=label, bandwidth=0.01, low_bias=True, fmin=fmin, fmax=fmax, return_generator=False, prepared=True) compute_source_psd_epochs(one_epochs, inv, lambda2=lambda2, method=method, pick_ori="normal", label=label, bandwidth=0.01, low_bias=False, fmin=fmin, fmax=fmax, return_generator=False, prepared=True) assert_true(len(w) >= 2) assert_true(any('not properly use' in str(ww.message) for ww in w)) assert_true(any('Bandwidth too small' in str(ww.message) for ww in w))
def test_source_psd_epochs(): """Test multi-taper source PSD computation in label from epochs""" raw = fiff.Raw(fname_data) inverse_operator = read_inverse_operator(fname_inv) label = read_label(fname_label) event_id, tmin, tmax = 1, -0.2, 0.5 lambda2, method = 1. / 9., 'dSPM' bandwidth = 8. fmin, fmax = 0, 100 picks = fiff.pick_types(raw.info, meg=True, eeg=False, stim=True, ecg=True, eog=True, include=['STI 014'], exclude='bads') reject = dict(grad=4000e-13, mag=4e-12, eog=150e-6) events = find_events(raw, stim_channel='STI 014') epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), reject=reject) # only look at one epoch epochs.drop_bad_epochs() one_epochs = epochs[:1] # return list stc_psd = compute_source_psd_epochs(one_epochs, inverse_operator, lambda2=lambda2, method=method, pick_ori="normal", label=label, bandwidth=bandwidth, fmin=fmin, fmax=fmax)[0] # return generator stcs = compute_source_psd_epochs(one_epochs, inverse_operator, lambda2=lambda2, method=method, pick_ori="normal", label=label, bandwidth=bandwidth, fmin=fmin, fmax=fmax, return_generator=True) for stc in stcs: stc_psd_gen = stc assert_array_almost_equal(stc_psd.data, stc_psd_gen.data) # compare with direct computation stc = apply_inverse_epochs(one_epochs, inverse_operator, lambda2=lambda2, method=method, pick_ori="normal", label=label)[0] sfreq = epochs.info['sfreq'] psd, freqs = multitaper_psd(stc.data, sfreq=sfreq, bandwidth=bandwidth, fmin=fmin, fmax=fmax) assert_array_almost_equal(psd, stc_psd.data) assert_array_almost_equal(freqs, stc_psd.times)
include=[], exclude=[]) n_events = events.shape[0] epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), reject=None, verbose=False, detrend=0, preload=True) epochs.drop_bad_epochs() prefix = 'MEG data' err_cons = grid_search(epochs, n_interpolates, consensus_percs, prefix=prefix, n_folds=n_folds) # try the best consensus perc to get clean evoked now best_idx, best_jdx = np.unravel_index( err_cons.mean(axis=-1).argmin(), err_cons.shape[:2]) consensus_perc = consensus_percs[best_idx] n_interpolate = n_interpolates[best_jdx] auto_reject = ConsensusAutoReject(compute_threshes, consensus_perc,