def get_clean_idx(raw_file, duration=1, reject=60): ''' Get idx of bad segments of data. Devide the data into epochs of length (duration) and reject segments with peak to peak amplitude that exceeds a predefined threshold. Returns the index of good epochs. Parameters ---------- raw_file : Raw file duration : length of epochs (default 1s) reject : Threshold to reject (default 60) Return ------ list of good index ''' events = make_fixed_length_events(raw_file.copy().crop(tmax=1467), id=1, duration=duration) epochs = Epochs(raw_file.copy().crop(tmax=1467), events, tmin=0, tmax=duration, reject=dict(eeg=reject), baseline=None) epochs.drop_bad() return epochs.selection
def test_events_long(): """Test events.""" data_path = testing.data_path() raw_fname = data_path + '/MEG/sample/sample_audvis_trunc_raw.fif' raw = read_raw_fif(raw_fname, preload=True) raw_tmin, raw_tmax = 0, 90 tmin, tmax = -0.2, 0.5 event_id = dict(aud_l=1, vis_l=3) # select gradiometers picks = pick_types(raw.info, meg='grad', eeg=False, eog=True, stim=True, exclude=raw.info['bads']) # load data with usual Epochs for later verification raw = concatenate_raws([raw, raw.copy(), raw.copy(), raw.copy(), raw.copy(), raw.copy()]) assert 110 < raw.times[-1] < 130 raw_cropped = raw.copy().crop(raw_tmin, raw_tmax) events_offline = find_events(raw_cropped) epochs_offline = Epochs(raw_cropped, events_offline, event_id=event_id, tmin=tmin, tmax=tmax, picks=picks, decim=1, reject=dict(grad=4000e-13, eog=150e-6), baseline=None) epochs_offline.drop_bad() # create the mock-client object rt_client = MockRtClient(raw) rt_epochs = RtEpochs(rt_client, event_id, tmin, tmax, picks=picks, decim=1, reject=dict(grad=4000e-13, eog=150e-6), baseline=None, isi_max=1.) rt_epochs.start() rt_client.send_data(rt_epochs, picks, tmin=raw_tmin, tmax=raw_tmax, buffer_size=1000) expected_events = epochs_offline.events.copy() expected_events[:, 0] = expected_events[:, 0] - raw_cropped.first_samp assert np.all(expected_events[:, 0] <= (raw_tmax - tmax) * raw.info['sfreq']) assert_array_equal(rt_epochs.events, expected_events) assert len(rt_epochs) == len(epochs_offline) data_picks = pick_types(epochs_offline.info, meg='grad', eeg=False, eog=True, stim=False, exclude=raw.info['bads']) for ev_num, ev in enumerate(rt_epochs.iter_evoked()): if ev_num == 0: X_rt = ev.data[None, data_picks, :] y_rt = int(ev.comment) # comment attribute contains the event_id else: X_rt = np.concatenate((X_rt, ev.data[None, data_picks, :]), axis=0) y_rt = np.append(y_rt, int(ev.comment)) X_offline = epochs_offline.get_data()[:, data_picks, :] y_offline = epochs_offline.events[:, 2] assert_array_equal(X_rt, X_offline) assert_array_equal(y_rt, y_offline)
def test_fieldtrip_rtepochs(free_tcp_port, tmpdir): """Test FieldTrip RtEpochs.""" raw_tmax = 7 raw = read_raw_fif(raw_fname, preload=True) raw.crop(tmin=0, tmax=raw_tmax) events_offline = find_events(raw, stim_channel='STI 014') event_id = list(np.unique(events_offline[:, 2])) tmin, tmax = -0.2, 0.5 epochs_offline = Epochs(raw, events_offline, event_id=event_id, tmin=tmin, tmax=tmax) epochs_offline.drop_bad() isi_max = (np.max(np.diff(epochs_offline.events[:, 0])) / raw.info['sfreq']) + 1.0 neuromag2ft_fname = op.realpath(op.join(os.environ['NEUROMAG2FT_ROOT'], 'neuromag2ft')) # Works with neuromag2ft-3.0.2 cmd = (neuromag2ft_fname, '--file', raw_fname, '--speed', '8.0', '--bufport', str(free_tcp_port)) with running_subprocess(cmd, after='terminate', verbose=False): data_rt = None events_ids_rt = None with pytest.warns(RuntimeWarning, match='Trying to guess it'): with FieldTripClient(host='localhost', port=free_tcp_port, tmax=raw_tmax, wait_max=2) as rt_client: # get measurement info guessed by MNE-Python raw_info = rt_client.get_measurement_info() assert ([ch['ch_name'] for ch in raw_info['chs']] == [ch['ch_name'] for ch in raw.info['chs']]) # create the real-time epochs object epochs_rt = RtEpochs(rt_client, event_id, tmin, tmax, stim_channel='STI 014', isi_max=isi_max) epochs_rt.start() time.sleep(0.5) for ev_num, ev in enumerate(epochs_rt.iter_evoked()): if ev_num == 0: data_rt = ev.data[None, :, :] events_ids_rt = int( ev.comment) # comment attribute contains event_id else: data_rt = np.concatenate( (data_rt, ev.data[None, :, :]), axis=0) events_ids_rt = np.append(events_ids_rt, int(ev.comment)) _call_base_epochs_public_api(epochs_rt, tmpdir) epochs_rt.stop(stop_receive_thread=True) assert_array_equal(events_ids_rt, epochs_rt.events[:, 2]) assert_array_equal(data_rt, epochs_rt.get_data()) assert len(epochs_rt) == len(epochs_offline) assert_array_equal(events_ids_rt, epochs_offline.events[:, 2]) assert_allclose(epochs_rt.get_data(), epochs_offline.get_data(), rtol=1.e-5, atol=1.e-8) # defaults of np.isclose
def test_events_sampledata(): """ based on examples/realtime/plot_compute_rt_decoder.py""" data_path = sample.data_path() raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif' raw = read_raw_fif(raw_fname, preload=True) raw_tmin, raw_tmax = 0, 90 tmin, tmax = -0.2, 0.5 event_id = dict(aud_l=1, vis_l=3) # select gradiometers picks = pick_types(raw.info, meg='grad', eeg=False, eog=True, stim=True, exclude=raw.info['bads']) # load data with usual Epochs for later verification raw_cropped = raw.copy().crop(raw_tmin, raw_tmax) events_offline = find_events(raw_cropped) epochs_offline = Epochs(raw_cropped, events_offline, event_id=event_id, tmin=tmin, tmax=tmax, picks=picks, decim=1, reject=dict(grad=4000e-13, eog=150e-6), baseline=None) epochs_offline.drop_bad() # create the mock-client object rt_client = MockRtClient(raw) rt_epochs = RtEpochs(rt_client, event_id, tmin, tmax, picks=picks, decim=1, reject=dict(grad=4000e-13, eog=150e-6), baseline=None, isi_max=1.) rt_epochs.start() rt_client.send_data(rt_epochs, picks, tmin=raw_tmin, tmax=raw_tmax, buffer_size=1000) expected_events = epochs_offline.events.copy() expected_events[:, 0] = expected_events[:, 0] - raw_cropped.first_samp assert np.all(expected_events[:, 0] <= (raw_tmax - tmax) * raw.info['sfreq']) assert_array_equal(rt_epochs.events, expected_events) assert len(rt_epochs) == len(epochs_offline) data_picks = pick_types(epochs_offline.info, meg='grad', eeg=False, eog=True, stim=False, exclude=raw.info['bads']) for ev_num, ev in enumerate(rt_epochs.iter_evoked()): if ev_num == 0: X_rt = ev.data[None, data_picks, :] y_rt = int(ev.comment) # comment attribute contains the event_id else: X_rt = np.concatenate((X_rt, ev.data[None, data_picks, :]), axis=0) y_rt = np.append(y_rt, int(ev.comment)) X_offline = epochs_offline.get_data()[:, data_picks, :] y_offline = epochs_offline.events[:, 2] assert_array_equal(X_rt, X_offline) assert_array_equal(y_rt, y_offline)
def test_fieldtrip_rtepochs(free_tcp_port, tmpdir): """Test FieldTrip RtEpochs.""" raw_tmax = 7 raw = read_raw_fif(raw_fname, preload=True) raw.crop(tmin=0, tmax=raw_tmax) events_offline = find_events(raw, stim_channel='STI 014') event_id = list(np.unique(events_offline[:, 2])) tmin, tmax = -0.2, 0.5 epochs_offline = Epochs(raw, events_offline, event_id=event_id, tmin=tmin, tmax=tmax) epochs_offline.drop_bad() isi_max = (np.max(np.diff(epochs_offline.events[:, 0])) / raw.info['sfreq']) + 1.0 kill_signal = _start_buffer_thread(free_tcp_port) try: data_rt = None events_ids_rt = None with pytest.warns(RuntimeWarning, match='Trying to guess it'): with FieldTripClient(host='localhost', port=free_tcp_port, tmax=raw_tmax, wait_max=2) as rt_client: # get measurement info guessed by MNE-Python raw_info = rt_client.get_measurement_info() assert ([ch['ch_name'] for ch in raw_info['chs']] == [ch['ch_name'] for ch in raw.info['chs']]) # create the real-time epochs object epochs_rt = RtEpochs(rt_client, event_id, tmin, tmax, stim_channel='STI 014', isi_max=isi_max) epochs_rt.start() time.sleep(0.5) for ev_num, ev in enumerate(epochs_rt.iter_evoked()): if ev_num == 0: data_rt = ev.data[None, :, :] events_ids_rt = int( ev.comment) # comment attribute contains event_id else: data_rt = np.concatenate( (data_rt, ev.data[None, :, :]), axis=0) events_ids_rt = np.append(events_ids_rt, int(ev.comment)) _call_base_epochs_public_api(epochs_rt, tmpdir) epochs_rt.stop(stop_receive_thread=True) assert_array_equal(events_ids_rt, epochs_rt.events[:, 2]) assert_array_equal(data_rt, epochs_rt.get_data()) assert len(epochs_rt) == len(epochs_offline) assert_array_equal(events_ids_rt, epochs_offline.events[:, 2]) assert_allclose(epochs_rt.get_data(), epochs_offline.get_data(), rtol=1.e-5, atol=1.e-8) # defaults of np.isclose finally: kill_signal.put(False) # stop the buffer
def _define_epochs(fif_file, t_min, t_max, events_id, events_file='', decim=1): """Split raw .fif file into epochs depending on events file. Splitted epochs have a length ep_length with rejection criteria. """ raw = read_raw_fif(fif_file) reject = _create_reject_dict(raw.info) picks = pick_types(raw.info, meg=True, ref_meg=False, eog=True, stim=True, exclude='bads') data_path, base, ext = split_filename(fif_file) if events_file: events_fpath = glob.glob(op.join(data_path, events_file)) events = read_events(events_fpath[0]) else: events = find_events(raw) # TODO -> use autoreject ? # reject_tmax = 0.8 # duration we really care about epochs = Epochs(raw, events, events_id, t_min, t_max, proj=True, picks=picks, baseline=(None, 0), reject=reject, decim=decim, preload=True) epochs.drop_bad(reject=reject) good_events_file = os.path.join(data_path, 'good_events.txt') np.savetxt(good_events_file, epochs.events) # TODO -> decide where to save... savename = os.path.abspath(base + '-epo' + ext) # savename = os.path.join(data_path, base + '-epo' + ext) epochs.save(savename, overwrite=True) return savename
def test_verbose_method(verbose): """Test for gh-8772.""" # raw raw = read_raw_fif(fname_raw, verbose=verbose) with catch_logging() as log: raw.load_data(verbose=True) log = log.getvalue() assert 'Reading 0 ... 14399' in log with catch_logging() as log: raw.load_data(verbose=False) log = log.getvalue() assert log == '' # epochs events = np.array([[raw.first_samp + 200, 0, 1]], int) epochs = Epochs(raw, events, verbose=verbose) with catch_logging() as log: epochs.drop_bad(verbose=True) log = log.getvalue() assert '0 bad epochs dropped' in log epochs = Epochs(raw, events, verbose=verbose) with catch_logging() as log: epochs.drop_bad(verbose=False) log = log.getvalue() assert log == ''
def test_source_psd_epochs(): """Test multi-taper source PSD computation in label from epochs.""" raw = read_raw_fif(fname_data) inverse_operator = read_inverse_operator(fname_inv) label = read_label(fname_label) event_id, tmin, tmax = 1, -0.2, 0.5 lambda2, method = 1. / 9., 'dSPM' bandwidth = 8. fmin, fmax = 0, 100 picks = pick_types(raw.info, meg=True, eeg=False, stim=True, ecg=True, eog=True, include=['STI 014'], exclude='bads') reject = dict(grad=4000e-13, mag=4e-12, eog=150e-6) events = find_events(raw, stim_channel='STI 014') epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), reject=reject) # only look at one epoch epochs.drop_bad() one_epochs = epochs[:1] inv = prepare_inverse_operator(inverse_operator, nave=1, lambda2=1. / 9., method="dSPM") # return list stc_psd = compute_source_psd_epochs(one_epochs, inv, lambda2=lambda2, method=method, pick_ori="normal", label=label, bandwidth=bandwidth, fmin=fmin, fmax=fmax, prepared=True)[0] # return generator stcs = compute_source_psd_epochs(one_epochs, inv, lambda2=lambda2, method=method, pick_ori="normal", label=label, bandwidth=bandwidth, fmin=fmin, fmax=fmax, return_generator=True, prepared=True) for stc in stcs: stc_psd_gen = stc assert_array_almost_equal(stc_psd.data, stc_psd_gen.data) # compare with direct computation stc = apply_inverse_epochs(one_epochs, inv, lambda2=lambda2, method=method, pick_ori="normal", label=label, prepared=True)[0] sfreq = epochs.info['sfreq'] psd, freqs = psd_array_multitaper(stc.data, sfreq=sfreq, bandwidth=bandwidth, fmin=fmin, fmax=fmax) assert_array_almost_equal(psd, stc_psd.data) assert_array_almost_equal(freqs, stc_psd.times) # Check corner cases caused by tiny bandwidth with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') compute_source_psd_epochs(one_epochs, inv, lambda2=lambda2, method=method, pick_ori="normal", label=label, bandwidth=0.01, low_bias=True, fmin=fmin, fmax=fmax, return_generator=False, prepared=True) compute_source_psd_epochs(one_epochs, inv, lambda2=lambda2, method=method, pick_ori="normal", label=label, bandwidth=0.01, low_bias=False, fmin=fmin, fmax=fmax, return_generator=False, prepared=True) assert_true(len(w) >= 2) assert_true(any('not properly use' in str(ww.message) for ww in w)) assert_true(any('Bandwidth too small' in str(ww.message) for ww in w))
# Infer window size based on the frequency being used w_size = n_cycles / ((fmax + fmin) / 2.) # in seconds # Apply band-pass filter to isolate the specified frequencies raw_filter = raw.copy().filter(fmin, fmax, n_jobs=1, fir_design='firwin') # Extract epochs from filtered data, padded by window size epochs = Epochs(raw_filter, events, event_id, tmin - w_size, tmax + w_size, proj=False, baseline=None, preload=True) epochs.drop_bad() y = epochs.events[:, 2] - 2 # Roll covariance, csp and lda over time for t, w_time in enumerate(centered_w_times): # Center the min and max of the window w_tmin = w_time - w_size / 2. w_tmax = w_time + w_size / 2. # Crop data into time-window of interest X = epochs.copy().crop(w_tmin, w_tmax).get_data() # Save mean scores over folds for each frequency and time window scores[freq, t] = np.mean(cross_val_score(estimator=clf, X=X,
def test_source_psd_epochs(): """Test multi-taper source PSD computation in label from epochs.""" raw = read_raw_fif(fname_data) inverse_operator = read_inverse_operator(fname_inv) label = read_label(fname_label) event_id, tmin, tmax = 1, -0.2, 0.5 lambda2, method = 1. / 9., 'dSPM' bandwidth = 8. fmin, fmax = 0, 100 picks = pick_types(raw.info, meg=True, eeg=False, stim=True, ecg=True, eog=True, include=['STI 014'], exclude='bads') reject = dict(grad=4000e-13, mag=4e-12, eog=150e-6) events = find_events(raw, stim_channel='STI 014') epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), reject=reject) # only look at one epoch epochs.drop_bad() one_epochs = epochs[:1] inv = prepare_inverse_operator(inverse_operator, nave=1, lambda2=1. / 9., method="dSPM") # return list stc_psd = compute_source_psd_epochs(one_epochs, inv, lambda2=lambda2, method=method, pick_ori="normal", label=label, bandwidth=bandwidth, fmin=fmin, fmax=fmax, prepared=True)[0] # return generator stcs = compute_source_psd_epochs(one_epochs, inv, lambda2=lambda2, method=method, pick_ori="normal", label=label, bandwidth=bandwidth, fmin=fmin, fmax=fmax, return_generator=True, prepared=True) for stc in stcs: stc_psd_gen = stc assert_array_almost_equal(stc_psd.data, stc_psd_gen.data) # compare with direct computation stc = apply_inverse_epochs(one_epochs, inv, lambda2=lambda2, method=method, pick_ori="normal", label=label, prepared=True)[0] sfreq = epochs.info['sfreq'] psd, freqs = _psd_multitaper(stc.data, sfreq=sfreq, bandwidth=bandwidth, fmin=fmin, fmax=fmax) assert_array_almost_equal(psd, stc_psd.data) assert_array_almost_equal(freqs, stc_psd.times) # Check corner cases caused by tiny bandwidth with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') compute_source_psd_epochs(one_epochs, inv, lambda2=lambda2, method=method, pick_ori="normal", label=label, bandwidth=0.01, low_bias=True, fmin=fmin, fmax=fmax, return_generator=False, prepared=True) compute_source_psd_epochs(one_epochs, inv, lambda2=lambda2, method=method, pick_ori="normal", label=label, bandwidth=0.01, low_bias=False, fmin=fmin, fmax=fmax, return_generator=False, prepared=True) assert_true(len(w) >= 2) assert_true(any('not properly use' in str(ww.message) for ww in w)) assert_true(any('Bandwidth too small' in str(ww.message) for ww in w))
def do_preprocessing_combined(p, subjects, run_indices): """Do preprocessing on all raw files together. Calculates projection vectors to use to clean data. Parameters ---------- p : instance of Parameters Analysis parameters. subjects : list of str Subject names to analyze (e.g., ['Eric_SoP_001', ...]). run_indices : array-like | None Run indices to include. """ drop_logs = list() for si, subj in enumerate(subjects): proj_nums = _proj_nums(p, subj) ecg_channel = _handle_dict(p.ecg_channel, subj) flat = _handle_dict(p.flat, subj) if p.disp_files: print(' Preprocessing subject %g/%g (%s).' % (si + 1, len(subjects), subj)) pca_dir = _get_pca_dir(p, subj) bad_file = get_bad_fname(p, subj, check_exists=False) # Create SSP projection vectors after marking bad channels raw_names = get_raw_fnames(p, subj, 'sss', False, False, run_indices[si]) empty_names = get_raw_fnames(p, subj, 'sss', 'only') for r in raw_names + empty_names: if not op.isfile(r): raise NameError('File not found (' + r + ')') fir_kwargs, old_kwargs = _get_fir_kwargs(p.fir_design) if isinstance(p.auto_bad, float): print(' Creating post SSS bad channel file:\n' ' %s' % bad_file) # do autobad raw = _raw_LRFCP(raw_names, p.proj_sfreq, None, None, p.n_jobs_fir, p.n_jobs_resample, list(), None, p.disp_files, method='fir', filter_length=p.filter_length, apply_proj=False, force_bads=False, l_trans=p.hp_trans, h_trans=p.lp_trans, phase=p.phase, fir_window=p.fir_window, pick=True, skip_by_annotation='edge', **fir_kwargs) events = fixed_len_events(p, raw) rtmin = p.reject_tmin \ if p.reject_tmin is not None else p.tmin rtmax = p.reject_tmax \ if p.reject_tmax is not None else p.tmax # do not mark eog channels bad meg, eeg = 'meg' in raw, 'eeg' in raw picks = pick_types(raw.info, meg=meg, eeg=eeg, eog=False, exclude=[]) assert p.auto_bad_flat is None or isinstance(p.auto_bad_flat, dict) assert p.auto_bad_reject is None or \ isinstance(p.auto_bad_reject, dict) or \ p.auto_bad_reject == 'auto' if p.auto_bad_reject == 'auto': print(' Auto bad channel selection active. ' 'Will try using Autoreject module to ' 'compute rejection criterion.') try: from autoreject import get_rejection_threshold except ImportError: raise ImportError(' Autoreject module not installed.\n' ' Noisy channel detection parameter ' ' not defined. To use autobad ' ' channel selection either define ' ' rejection criteria or install ' ' Autoreject module.\n') print(' Computing thresholds.\n', end='') temp_epochs = Epochs(raw, events, event_id=None, tmin=rtmin, tmax=rtmax, baseline=_get_baseline(p), proj=True, reject=None, flat=None, preload=True, decim=1) kwargs = dict() if 'verbose' in get_args(get_rejection_threshold): kwargs['verbose'] = False reject = get_rejection_threshold(temp_epochs, **kwargs) reject = {kk: vv for kk, vv in reject.items()} elif p.auto_bad_reject is None and p.auto_bad_flat is None: raise RuntimeError('Auto bad channel detection active. Noisy ' 'and flat channel detection ' 'parameters not defined. ' 'At least one criterion must be defined.') else: reject = p.auto_bad_reject if 'eog' in reject.keys(): reject.pop('eog', None) epochs = Epochs(raw, events, None, tmin=rtmin, tmax=rtmax, baseline=_get_baseline(p), picks=picks, reject=reject, flat=p.auto_bad_flat, proj=True, preload=True, decim=1, reject_tmin=rtmin, reject_tmax=rtmax) # channel scores from drop log drops = Counter([ch for d in epochs.drop_log for ch in d]) # get rid of non-channel reasons in drop log scores = { kk: vv for kk, vv in drops.items() if kk in epochs.ch_names } ch_names = np.array(list(scores.keys())) # channel scores expressed as percentile and rank ordered counts = (100 * np.array([scores[ch] for ch in ch_names], float) / len(epochs.drop_log)) order = np.argsort(counts)[::-1] # boolean array masking out channels with <= % epochs dropped mask = counts[order] > p.auto_bad badchs = ch_names[order[mask]] if len(badchs) > 0: # Make sure we didn't get too many bad MEG or EEG channels for m, e, thresh in zip( [True, False], [False, True], [p.auto_bad_meg_thresh, p.auto_bad_eeg_thresh]): picks = pick_types(epochs.info, meg=m, eeg=e, exclude=[]) if len(picks) > 0: ch_names = [epochs.ch_names[pp] for pp in picks] n_bad_type = sum(ch in ch_names for ch in badchs) if n_bad_type > thresh: stype = 'meg' if m else 'eeg' raise RuntimeError('Too many bad %s channels ' 'found: %s > %s' % (stype, n_bad_type, thresh)) print(' The following channels resulted in greater than ' '{:.0f}% trials dropped:\n'.format(p.auto_bad * 100)) print(badchs) with open(bad_file, 'w') as f: f.write('\n'.join(badchs)) if not op.isfile(bad_file): print(' Clearing bad channels (no file %s)' % op.sep.join(bad_file.split(op.sep)[-3:])) bad_file = None ecg_t_lims = _handle_dict(p.ecg_t_lims, subj) ecg_f_lims = p.ecg_f_lims ecg_eve = op.join(pca_dir, 'preproc_ecg-eve.fif') ecg_epo = op.join(pca_dir, 'preproc_ecg-epo.fif') ecg_proj = op.join(pca_dir, 'preproc_ecg-proj.fif') all_proj = op.join(pca_dir, 'preproc_all-proj.fif') get_projs_from = _handle_dict(p.get_projs_from, subj) if get_projs_from is None: get_projs_from = np.arange(len(raw_names)) pre_list = [ r for ri, r in enumerate(raw_names) if ri in get_projs_from ] projs = list() raw_orig = _raw_LRFCP(raw_names=pre_list, sfreq=p.proj_sfreq, l_freq=None, h_freq=None, n_jobs=p.n_jobs_fir, n_jobs_resample=p.n_jobs_resample, projs=projs, bad_file=bad_file, disp_files=p.disp_files, method='fir', filter_length=p.filter_length, force_bads=False, l_trans=p.hp_trans, h_trans=p.lp_trans, phase=p.phase, fir_window=p.fir_window, pick=True, skip_by_annotation='edge', **fir_kwargs) # Apply any user-supplied extra projectors if p.proj_extra is not None: if p.disp_files: print(' Adding extra projectors from "%s".' % p.proj_extra) projs.extend(read_proj(op.join(pca_dir, p.proj_extra))) proj_kwargs, p_sl = _get_proj_kwargs(p) # # Calculate and apply ERM projectors # if not p.cont_as_esss: if any(proj_nums[2]): assert proj_nums[2][2] == 0 # no EEG projectors for ERM if len(empty_names) == 0: raise RuntimeError('Cannot compute empty-room projectors ' 'from continuous raw data') if p.disp_files: print(' Computing continuous projectors using ERM.') # Use empty room(s), but processed the same way projs.extend(_compute_erm_proj(p, subj, projs, 'sss', bad_file)) else: cont_proj = op.join(pca_dir, 'preproc_cont-proj.fif') _safe_remove(cont_proj) # # Calculate and apply the ECG projectors # if any(proj_nums[0]): if p.disp_files: print(' Computing ECG projectors...', end='') raw = raw_orig.copy() raw.filter(ecg_f_lims[0], ecg_f_lims[1], n_jobs=p.n_jobs_fir, method='fir', filter_length=p.filter_length, l_trans_bandwidth=0.5, h_trans_bandwidth=0.5, phase='zero-double', fir_window='hann', skip_by_annotation='edge', **old_kwargs) raw.add_proj(projs) raw.apply_proj() find_kwargs = dict() if 'reject_by_annotation' in get_args(find_ecg_events): find_kwargs['reject_by_annotation'] = True elif len(raw.annotations) > 0: print(' WARNING: ECG event detection will not make use of ' 'annotations, please update MNE-Python') # We've already filtered the data channels above, but this # filters the ECG channel ecg_events = find_ecg_events(raw, 999, ecg_channel, 0., ecg_f_lims[0], ecg_f_lims[1], qrs_threshold='auto', return_ecg=False, **find_kwargs)[0] use_reject, use_flat = _restrict_reject_flat( _handle_dict(p.ssp_ecg_reject, subj), flat, raw) ecg_epochs = Epochs(raw, ecg_events, 999, ecg_t_lims[0], ecg_t_lims[1], baseline=None, reject=use_reject, flat=use_flat, preload=True) print(' obtained %d epochs from %d events.' % (len(ecg_epochs), len(ecg_events))) if len(ecg_epochs) >= 20: write_events(ecg_eve, ecg_epochs.events) ecg_epochs.save(ecg_epo, **_get_epo_kwargs()) desc_prefix = 'ECG-%s-%s' % tuple(ecg_t_lims) pr = compute_proj_wrap(ecg_epochs, p.proj_ave, n_grad=proj_nums[0][0], n_mag=proj_nums[0][1], n_eeg=proj_nums[0][2], desc_prefix=desc_prefix, **proj_kwargs) assert len(pr) == np.sum(proj_nums[0][::p_sl]) write_proj(ecg_proj, pr) projs.extend(pr) else: plot_drop_log(ecg_epochs.drop_log) raw.plot(events=ecg_epochs.events) raise RuntimeError('Only %d/%d good ECG epochs found' % (len(ecg_epochs), len(ecg_events))) del raw, ecg_epochs, ecg_events else: _safe_remove([ecg_proj, ecg_eve, ecg_epo]) # # Next calculate and apply the EOG projectors # for idx, kind in ((1, 'EOG'), (3, 'HEOG'), (4, 'VEOG')): _compute_add_eog(p, subj, raw_orig, projs, proj_nums[idx], kind, pca_dir, flat, proj_kwargs, old_kwargs, p_sl) del proj_nums # save the projectors write_proj(all_proj, projs) # # Look at raw_orig for trial DQs now, it will be quick # raw_orig.filter(p.hp_cut, p.lp_cut, n_jobs=p.n_jobs_fir, method='fir', filter_length=p.filter_length, l_trans_bandwidth=p.hp_trans, phase=p.phase, h_trans_bandwidth=p.lp_trans, fir_window=p.fir_window, skip_by_annotation='edge', **fir_kwargs) raw_orig.add_proj(projs) raw_orig.apply_proj() # now let's epoch with 1-sec windows to look for DQs events = fixed_len_events(p, raw_orig) reject = _handle_dict(p.reject, subj) use_reject, use_flat = _restrict_reject_flat(reject, flat, raw_orig) epochs = Epochs(raw_orig, events, None, p.tmin, p.tmax, preload=False, baseline=_get_baseline(p), reject=use_reject, flat=use_flat, proj=True) try: epochs.drop_bad() except AttributeError: # old way epochs.drop_bad_epochs() drop_logs.append(epochs.drop_log) del raw_orig del epochs if p.plot_drop_logs: for subj, drop_log in zip(subjects, drop_logs): plot_drop_log(drop_log, p.drop_thresh, subject=subj)
def test_csd_degenerate(evoked_csd_sphere): """Test degenerate conditions.""" evoked, csd, sphere = evoked_csd_sphere warn_evoked = evoked.copy() warn_evoked.info['bads'].append(warn_evoked.ch_names[3]) with pytest.raises(ValueError, match='Either drop.*or interpolate'): compute_current_source_density(warn_evoked) with pytest.raises(TypeError, match='must be an instance of'): compute_current_source_density(None) fail_evoked = evoked.copy() with pytest.raises(ValueError, match='Zero or infinite position'): for ch in fail_evoked.info['chs']: ch['loc'][:3] = np.array([0, 0, 0]) compute_current_source_density(fail_evoked, sphere=sphere) with pytest.raises(ValueError, match='Zero or infinite position'): fail_evoked.info['chs'][3]['loc'][:3] = np.inf compute_current_source_density(fail_evoked, sphere=sphere) with pytest.raises(ValueError, match='No EEG channels found.'): fail_evoked = evoked.copy() fail_evoked.set_channel_types({ch_name: 'ecog' for ch_name in fail_evoked.ch_names}) compute_current_source_density(fail_evoked, sphere=sphere) with pytest.raises(TypeError, match='lambda2'): compute_current_source_density(evoked, lambda2='0', sphere=sphere) with pytest.raises(ValueError, match='lambda2 must be between 0 and 1'): compute_current_source_density(evoked, lambda2=2, sphere=sphere) with pytest.raises(TypeError, match='stiffness must be'): compute_current_source_density(evoked, stiffness='0', sphere=sphere) with pytest.raises(ValueError, match='stiffness must be non-negative'): compute_current_source_density(evoked, stiffness=-2, sphere=sphere) with pytest.raises(TypeError, match='n_legendre_terms must be'): compute_current_source_density(evoked, n_legendre_terms=0.1, sphere=sphere) with pytest.raises(ValueError, match=('n_legendre_terms must be ' 'greater than 0')): compute_current_source_density(evoked, n_legendre_terms=0, sphere=sphere) with pytest.raises(ValueError, match='sphere must be'): compute_current_source_density(evoked, sphere=-0.1) with pytest.raises(ValueError, match=('sphere radius must be ' 'greater than 0')): compute_current_source_density(evoked, sphere=(-0.1, 0., 0., -1.)) with pytest.raises(TypeError): compute_current_source_density(evoked, copy=2, sphere=sphere) # gh-7859 raw = RawArray(evoked.data, evoked.info) epochs = Epochs( raw, [[0, 0, 1]], tmin=0, tmax=evoked.times[-1] - evoked.times[0], baseline=None, preload=False, proj=False) epochs.drop_bad() assert len(epochs) == 1 assert_allclose(epochs.get_data()[0], evoked.data) with pytest.raises(RuntimeError, match='Computing CSD requires.*preload'): compute_current_source_density(epochs) epochs.load_data() raw = compute_current_source_density(raw) assert not np.allclose(raw.get_data(), evoked.data) evoked = compute_current_source_density(evoked) assert_allclose(raw.get_data(), evoked.data) epochs = compute_current_source_density(epochs) assert_allclose(epochs.get_data()[0], evoked.data)
def test_source_psd_epochs(method): """Test multi-taper source PSD computation in label from epochs.""" raw = read_raw_fif(fname_data) inverse_operator = read_inverse_operator(fname_inv) label = read_label(fname_label) event_id, tmin, tmax = 1, -0.2, 0.5 lambda2 = 1. / 9. bandwidth = 8. fmin, fmax = 0, 100 picks = pick_types(raw.info, meg=True, eeg=False, stim=True, ecg=True, eog=True, include=['STI 014'], exclude='bads') reject = dict(grad=4000e-13, mag=4e-12, eog=150e-6) events = find_events(raw, stim_channel='STI 014') epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), reject=reject) # only look at one epoch epochs.drop_bad() one_epochs = epochs[:1] inv = prepare_inverse_operator(inverse_operator, nave=1, lambda2=1. / 9., method="dSPM") # return list stc_psd = compute_source_psd_epochs(one_epochs, inv, lambda2=lambda2, method=method, pick_ori="normal", label=label, bandwidth=bandwidth, fmin=fmin, fmax=fmax, prepared=True)[0] # return generator stcs = compute_source_psd_epochs(one_epochs, inv, lambda2=lambda2, method=method, pick_ori="normal", label=label, bandwidth=bandwidth, fmin=fmin, fmax=fmax, return_generator=True, prepared=True) for stc in stcs: stc_psd_gen = stc assert_allclose(stc_psd.data, stc_psd_gen.data, atol=1e-7) # compare with direct computation stc = apply_inverse_epochs(one_epochs, inv, lambda2=lambda2, method=method, pick_ori="normal", label=label, prepared=True)[0] sfreq = epochs.info['sfreq'] psd, freqs = psd_array_multitaper(stc.data, sfreq=sfreq, bandwidth=bandwidth, fmin=fmin, fmax=fmax) assert_allclose(psd, stc_psd.data, atol=1e-7) assert_allclose(freqs, stc_psd.times) # Check corner cases caused by tiny bandwidth with pytest.raises(ValueError, match='use a value of at least'): compute_source_psd_epochs( one_epochs, inv, lambda2=lambda2, method=method, pick_ori="normal", label=label, bandwidth=0.01, low_bias=True, fmin=fmin, fmax=fmax, return_generator=False, prepared=True)
def extract(filepaths, ica=None, fit_ica=True): if ica is None: ica = get_ica_transformers("infomax") epochs, labels = list(), list() for filepath in filepaths: print("loading", filepath) try: gdf = read_raw_gdf( filepath, eog=["EOG-left", "EOG-central", "EOG-right"], exclude=["EOG-left", "EOG-central", "EOG-right"]) events = events_from_annotations(gdf, event_id={ "769": 0, "770": 1, "771": 2, "772": 3 }) epoch = Epochs(gdf, events[0], event_repeated="drop", reject_by_annotation=True, tmin=-.3, tmax=.7, reject=dict(eeg=1e-4)) epoch.drop_bad() except ValueError: print("Error in", filepath) continue epochs.append(epoch) labels.append(epoch.events[:, 2]) labels = np.concatenate(labels) n_epochs, n_channels, n_times = epochs[0].get_data().shape ica_vec = [ epoch.get_data().transpose(1, 0, 2).reshape(n_channels, -1).T for epoch in epochs ] ica_vec = np.concatenate(ica_vec, axis=0) if fit_ica: ica.fit(ica_vec) transformed = ica.transform(ica_vec) transformed = ica_vec transformed = transformed.T.reshape(n_channels, -1, n_times).transpose(1, 0, 2) features, freqs = psd_array_multitaper(transformed, 250., fmin=0, fmax=20, bandwidth=2) n_epochs, _, _ = features.shape # features = features.reshape(n_epochs, -1) features = features.mean(axis=2) labels_placeholder = np.zeros((len(labels), 4)) for i, l in enumerate(labels): labels_placeholder[i, l] = 1 labels = labels_placeholder return features, labels, ica
def test_source_psd_epochs(): """Test multi-taper source PSD computation in label from epochs.""" raw = read_raw_fif(fname_data) inverse_operator = read_inverse_operator(fname_inv) label = read_label(fname_label) event_id, tmin, tmax = 1, -0.2, 0.5 lambda2, method = 1. / 9., 'dSPM' bandwidth = 8. fmin, fmax = 0, 100 picks = pick_types(raw.info, meg=True, eeg=False, stim=True, ecg=True, eog=True, include=['STI 014'], exclude='bads') reject = dict(grad=4000e-13, mag=4e-12, eog=150e-6) events = find_events(raw, stim_channel='STI 014') epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), reject=reject) # only look at one epoch epochs.drop_bad() one_epochs = epochs[:1] inv = prepare_inverse_operator(inverse_operator, nave=1, lambda2=1. / 9., method="dSPM") # return list stc_psd = compute_source_psd_epochs(one_epochs, inv, lambda2=lambda2, method=method, pick_ori="normal", label=label, bandwidth=bandwidth, fmin=fmin, fmax=fmax, prepared=True)[0] # return generator stcs = compute_source_psd_epochs(one_epochs, inv, lambda2=lambda2, method=method, pick_ori="normal", label=label, bandwidth=bandwidth, fmin=fmin, fmax=fmax, return_generator=True, prepared=True) for stc in stcs: stc_psd_gen = stc assert_allclose(stc_psd.data, stc_psd_gen.data, atol=1e-7) # compare with direct computation stc = apply_inverse_epochs(one_epochs, inv, lambda2=lambda2, method=method, pick_ori="normal", label=label, prepared=True)[0] sfreq = epochs.info['sfreq'] psd, freqs = psd_array_multitaper(stc.data, sfreq=sfreq, bandwidth=bandwidth, fmin=fmin, fmax=fmax) assert_allclose(psd, stc_psd.data, atol=1e-7) assert_allclose(freqs, stc_psd.times) # Check corner cases caused by tiny bandwidth with pytest.raises(ValueError, match='use a value of at least'): compute_source_psd_epochs( one_epochs, inv, lambda2=lambda2, method=method, pick_ori="normal", label=label, bandwidth=0.01, low_bias=True, fmin=fmin, fmax=fmax, return_generator=False, prepared=True)
def CSP_dec(raw, events, event_id, subject): """Common spatial pattern method is used for frequency filtering. Linear discriminant analysis is used to label data into conditions based on a frequency component created by CSP.""" event_id = dict(ISI_time_correct = event_id['ISI_time_correct'], ISI_time_control=event_id['ISI_time_control']) # motor imagery: hands vs feet # Extract information from the raw file sfreq = raw.info['sfreq'] raw.pick_types(meg=False, eeg=True, stim=False, eog=False) # Assemble the classifier using scikit-learn pipeline clf = make_pipeline(CSP(n_components=4, reg=None, log=True), LinearDiscriminantAnalysis()) n_splits = 5 # how many folds to use for cross-validation cv = StratifiedKFold(n_splits=n_splits, shuffle=True) # Classification & Time-frequency parameters tmin, tmax = 1.0, 4.0 n_cycles = 10. # how many complete cycles: used to define window size min_freq = 2. max_freq = 50. n_freqs = 24 # how many frequency bins to use # Assemble list of frequency range tuples freqs = np.linspace(min_freq, max_freq, n_freqs) # assemble frequencies freq_ranges = list(zip(freqs[:-1], freqs[1:])) # make freqs list of tuples # Infer window spacing from the max freq and number of cycles to avoid gaps window_spacing = (n_cycles / np.max(freqs) / 2.) centered_w_times = np.arange(tmin, tmax, window_spacing)[1:] n_windows = len(centered_w_times) # Instantiate label encoder le = LabelEncoder() # init scores freq_scores = np.zeros((n_freqs - 1,)) # Loop through each frequency range of interest for freq, (fmin, fmax) in enumerate(freq_ranges): # Infer window size based on the frequency being used w_size = n_cycles / ((fmax + fmin) / 2.) # in seconds # Apply band-pass filter to isolate the specified frequencies raw_filter = raw.copy().filter(fmin, fmax, n_jobs=1, fir_design='firwin') # Extract epochs from filtered data, padded by window size epochs = Epochs(raw_filter, events, event_id, tmin - w_size, tmax + w_size, proj=False, baseline=None, preload=True) e1 = epochs['ISI_time_correct'] e2 = epochs['ISI_time_control'] mne.epochs.equalize_epoch_counts([e1,e2]) epochs = mne.epochs.concatenate_epochs([e1,e2]) epochs.drop_bad() y = le.fit_transform(epochs.events[:, 2]) X = epochs.get_data() # Save mean scores over folds for each frequency and time window freq_scores[freq] = np.mean(cross_val_score(estimator=clf, X=X, y=y, scoring='roc_auc', cv=cv, n_jobs=8), axis=0) fig, axes = plt.subplots(figsize = (20,10)) axes.bar(left=freqs[:-1], height=freq_scores, width=np.diff(freqs)[0], align='edge', edgecolor='black') axes.set_xticks([int(round(f)) for f in freqs]) axes.set_ylim([0, 1]) axes.axhline(len(epochs['ISI_time_correct']) / len(epochs), color='k', linestyle='--', label='chance level') axes.legend() axes.set_xlabel('Frequency (Hz)') axes.set_ylabel('Decoding Scores') fig.suptitle('Frequency Decoding Scores ' + subject) fig.savefig('CSP_decoding/' + subject) return freq_scores, freqs
freq_scores = np.zeros((n_freqs - 1,)) # Loop through each frequency range of interest for freq, (fmin, fmax) in enumerate(freq_ranges): # Infer window size based on the frequency being used w_size = n_cycles / ((fmax + fmin) / 2.) # in seconds # Apply band-pass filter to isolate the specified frequencies raw_filter = raw.copy().filter(fmin, fmax, n_jobs=1, fir_design='firwin', skip_by_annotation='edge') # Extract epochs from filtered data, padded by window size epochs = Epochs(raw_filter, events, event_id, tmin - w_size, tmax + w_size, proj=False, baseline=None, preload=True) epochs.drop_bad() y = le.fit_transform(epochs.events[:, 2]) X = epochs.get_data() # Save mean scores over folds for each frequency and time window freq_scores[freq] = np.mean(cross_val_score(estimator=clf, X=X, y=y, scoring='roc_auc', cv=cv, n_jobs=1), axis=0) ############################################################################### # Plot frequency results plt.bar(freqs[:-1], freq_scores, width=np.diff(freqs)[0], align='edge', edgecolor='black') plt.xticks(freqs)
class SSVEP_Analysis_Offline: ''' 本类用于ssvep数据的分析,主要包含5个功能,分别为:数据加载,数据预处理,特征提取,分类器构建,离线分类与结果统计。 ''' def __init__(self): self.all_acc = [] self.all_itr = [] pass def load_data( self, filename='aaaa', data_format='eeg', trial_list={ 'Stimulus/S 1': 1, 'Stimulus/S 2': 2, 'Stimulus/S 3': 3, 'Stimulus/S 4': 4 }, tmin=0.0, tmax=8.0, fmin=5.5, fmax=35.0): ''' :param filename: 脑电数据文件名 :param data_format: 数据格式,支持.eeg与.fif两种脑电数据格式。 :param trial_list: 需要分析的mark列表 :param tmin: 分析时间段的起始时间 :param tmax: 分析时间段的结束时间 :param fmin: 分析频段的起始频率 :param fmax: 分析频段的截止频率 :return: ''' if data_format == 'eeg': self.raw_data = read_raw_brainvision(filename, preload=True, verbose=False) elif data_format == 'fif': self.raw_data = read_raw_fif(filename, preload=True, verbose=False) else: print('当前没有添加读取该文件格式的函数,欢迎补充') self.sfreq = self.raw_data.info['sfreq'] # 采样率 self.trial_list = trial_list # mark列表 # self.trial_list = {'Start': 1, 'End': 2} self.tmin, self.tmax = tmin, tmax # mark起止时间 self.fmin, self.fmax = fmin, fmax # 滤波频段 def data_preprocess(self, window_size=2., window_step=0.1, data_augmentation=True): ''' :param window_size: 滑动时间窗窗长 :param window_step: 滑动时间窗步长 :param data_augmentation: 是否需要进行脑电数据样本扩增 :return: ''' self.window_size = window_size events, _ = events_from_annotations(self.raw_data, event_id=self.trial_list, verbose=False) flag = self.tmin events_step = np.zeros_like(events) events_step[:, 0] += int(window_step * self.sfreq) event_augmentation = events events_temp = events while flag < self.tmax - window_size: events_temp = events_temp + events_step event_augmentation = np.concatenate( (event_augmentation, events_temp), axis=0) flag += window_step if data_augmentation == True: all_event = event_augmentation all_event = all_event[np.argsort(all_event[:, 0])] else: all_event = events # 提取epoch self.epochs = Epochs(self.raw_data, events=all_event, event_id=self.trial_list, tmin=0, tmax=window_size, proj=False, baseline=None, preload=True, verbose=False) self.epochs.drop_bad() # 滤波 self.epochs_filter = self.epochs.filter(self.fmin, self.fmax, n_jobs=-1, fir_design='firwin', verbose=False) # 对标签进行编码 le = LabelEncoder() self.all_label = le.fit_transform(self.epochs_filter.events[:, 2]) def feature_extract(self, method=None, plot=True): ''' :param method: 特征提取方法,目前仅支持PSD方法,后续会加入时域特征,频域特征,熵,以及组合特征。 由于目前结果已经很好了,就先这样吧。 :param plot: 是否对特征进行可视化。 :return: ''' self.all_feature, self.frequencies = psd_welch(self.epochs_filter, fmin=self.fmin, fmax=self.fmax, verbose=False, n_per_seg=128) self.all_feature = self.all_feature.reshape(self.all_feature.shape[0], -1) # 将各个通道的脑电特征拉平为一维 # self.all_feature = np.mean(self.all_feature, axis=1) # 对通道这个维度求平均,降维 if plot == True: for label in np.unique(self.all_label): _, ax = plt.subplots(1, 1) ax.set_title(list(self.trial_list.keys())[label]) # plt.show() self.epochs_filter.copy().drop(indices=(self.all_label != label)).\ plot_psd(dB=False, fmin=0.5, fmax=32., color='blue', ax=ax) def classifier_building(self, scaler_form='StandardScaler', train_with_GridSearchCV=False): ''' :param scaler_form: 对输入数据X标准化进行标准化的类型,一共三种:StandardScaler,MinMaxScaler, Normalizer。 默认为StandardScaler。 :param train_with_GridSearchCV: 是否使用网格搜索进行参数寻优。由于目前分类效果已经不错, 所以该功能的优先级放在了最后,目前该功能还在完善中。 :return: ''' if scaler_form == 'StandardScaler': scaler = StandardScaler() elif scaler_form == 'MinMaxScaler': scaler = MinMaxScaler() else: scaler = Normalizer() # scaler = StandardScaler() # scaler.fit(X) # print(scaler.mean_, scaler.var_) # X = scaler.transform(X) # 创建文件夹,存储分类器model if not os.path.exists('./record_model'): os.mkdir('./record_model') # 准备数据集 X = self.all_feature y = self.all_label # from sklearn.utils import shuffle # X, y = shuffle(X, y, random_state=0) # 将数据打乱 # X, X_test, y, y_test = train_test_split(X, y, test_size=0.2, random_state=0) # 划分为训练集与测试集 print(X.shape) X, X_test, y, y_test = X[0:int(len(y)*3/4), :], X[int(len(y)*3/4):, :], \ y[0:int(len(y)*3/4)], y[int(len(y)*3/4):] # 由于脑电信号在时域上差异较大,所以不建议打乱重排后再划分训练集。 # (用前面的数据训练分类器来预测后面数据的类别,更显公允。) # 构建分类器 lr_clf = LogisticRegression(multi_class='auto', solver="liblinear", random_state=42) lda_clf = LinearDiscriminantAnalysis() gn_clf = GaussianNB() svm_clf_1 = SVC(kernel='linear', gamma="auto", random_state=42) svm_clf_2 = SVC(kernel='poly', gamma="auto", random_state=42) svm_clf_3 = SVC(kernel='rbf', gamma="auto", random_state=42) svm_clf_4 = SVC(kernel='sigmoid', gamma="auto", random_state=42) knn_clf = KNeighborsClassifier() rf_clf = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1) gb_clf = GradientBoostingClassifier() ada_clf = AdaBoostClassifier(DecisionTreeClassifier(max_depth=1), n_estimators=1000, algorithm="SAMME.R", learning_rate=0.5, random_state=42) bag_clf_rf = BaggingClassifier(n_jobs=-1, random_state=42) bag_clf_knn = BaggingClassifier(KNeighborsClassifier(), max_samples=0.5, max_features=0.5) xgb_clf = XGBClassifier() mlp_clf = MLPClassifier(solver='sgd', activation='logistic', alpha=1e-4, hidden_layer_sizes=(30, 10), random_state=42, max_iter=1000, verbose=False, learning_rate_init=.1) # voting_clf = VotingClassifier(estimators=[('lr_clf', lr_clf), ('lda_clf', lda_clf), ('gn_clf', gn_clf), # ('svm_clf_1', svm_clf_1), ('svm_clf_2', svm_clf_2), # ('svm_clf_3', svm_clf_3), ('svm_clf_4', svm_clf_4), # ('knn_clf', knn_clf), ('rf_clf', rf_clf), ('gb_clf', gb_clf), # ('ada_clf', ada_clf), ('bag_clf_rf', bag_clf_rf), # ('bag_clf_knn', bag_clf_knn), ('xgb_clf', xgb_clf), # ('mlp_clf', mlp_clf)], # voting='soft', n_jobs=-1) # 将以上的分类器进行投票,构造投票分类器。可以为这些分类器添加不同的权重,权重系数为weights=None voting_clf = VotingClassifier( estimators=[ ('lr_clf', lr_clf), ('lda_clf', lda_clf), ('gn_clf', gn_clf), # ('svm_clf_1', svm_clf_1), ('svm_clf_2', svm_clf_2), # ('svm_clf_3', svm_clf_3), ('svm_clf_4', svm_clf_4), ('knn_clf', knn_clf), ('rf_clf', rf_clf), ('gb_clf', gb_clf), ('ada_clf', ada_clf), ('bag_clf_rf', bag_clf_rf), ('bag_clf_knn', bag_clf_knn), ('xgb_clf', xgb_clf), ('mlp_clf', mlp_clf) ], voting='soft', n_jobs=-1, weights=None) # 将以上所有分类器组合成一个列表表示 listing_clf = [ lr_clf, lda_clf, gn_clf, svm_clf_1, svm_clf_2, svm_clf_3, svm_clf_4, knn_clf, rf_clf, gb_clf, ada_clf, bag_clf_rf, bag_clf_knn, xgb_clf, mlp_clf, voting_clf ] self.listing_clf_name = [ 'lr_clf', 'lda_clf', 'gn_clf', 'svm_clf_1', 'svm_clf_2', 'svm_clf_3', 'svm_clf_4', 'knn_clf', 'rf_clf', 'gb_clf', 'ada_clf', 'bag_clf_rf', 'bag_clf_knn', 'xgb_clf', 'mlp_clf', 'voting_clf' ] # 进行训练(不进行参数寻优) if train_with_GridSearchCV == False: # 开始训练 cv = StratifiedKFold(n_splits=5, shuffle=True) cv_scores = np.zeros((len(listing_clf))) # 数组,分类结果准确率 for i, classify in enumerate(listing_clf): models = Pipeline( memory=None, steps=[ ('Scaler', scaler), # 数据标准化 (self.listing_clf_name[i], classify), # 分类器 ]) models.fit(X=X, y=y) joblib.dump( classify, './record_model/' + str(i) + '-' + self.listing_clf_name[i] + '.pkl') print('第', i, '个分类器:', self.listing_clf_name[i]) # new_svm = joblib.load('svm.pkl') y_pred = models.predict(X_test) # 计算acc 与 itr acc = accuracy_score(y_pred, y_test) self.all_acc.append(acc) itr = self.cal_itr(len(self.trial_list), acc, self.window_size) self.all_itr.append(itr) print('正确率:', acc, ' itr: ', itr) # # 使用K-fold交叉验证进行评估 # cv_scores[i] = np.mean(cross_val_score(estimator=classify, X=X, y=y, # scoring='accuracy', cv=cv, n_jobs=-1), axis=0) # print(cv_scores) # 进行训练(参数寻优,随机搜索) else: print('sss') param_dist = { 'n_estimators': range(80, 200, 4), 'max_depth': range(2, 15, 1), 'learning_rate': np.linspace(0.01, 2, 20), 'subsample': np.linspace(0.7, 0.9, 20), 'colsample_bytree': np.linspace(0.5, 0.98, 10), 'min_child_weight': range(1, 9, 1) } grid_search = RandomizedSearchCV(xgb_clf, param_dist, n_iter=300, cv=5, scoring='accuracy', n_jobs=-1) print('sss') grid_search.fit(X, y) print(grid_search.best_estimator_.feature_importances_) print(grid_search.best_params_) print(grid_search.best_estimator_) xgb_clf_final = grid_search.best_estimator_ y_pred = xgb_clf_final.predict(X_test) print('正确率:', accuracy_score(y_pred, y_test)) pass def classify_offline(self, model_file='./record_model/15-voting_clf.pkl'): ''' :param model_file: 文件名 :return: ''' # 准备数据集 X_test = self.all_feature y_test = self.all_label scaler = StandardScaler() scaler.fit(X_test) print(scaler.mean_, scaler.var_) X_test = scaler.transform(X_test) # 加载模型文件 model_clf = joblib.load(model_file) # 进行预测 y_pred = model_clf.predict(X_test) print('正确率:', accuracy_score(y_pred, y_test)) def result_statistics(self): ''' 待补充 :return: ''' return self.all_acc, self.all_itr pass def cal_itr(self, q, p, t): ''' BCI性能衡量指标丨信息传输速率 Information Transfer Rate itr的计算方法为理想ITR计算,即平均试次时间不包含模拟休息时长。 :param q: 目标个数 :param p: 识别正确率 :param t: 平均试次时长,单位为s :return: itr: 信息传输速率 ''' if p == 1: itr = np.log2(q) * 60 / t else: itr = 60 / t * (np.log2(q) + p * np.log2(p) + (1 - p) * np.log2( (1 - p) / (q - 1))) return itr