Esempio n. 1
0
def get_clean_idx(raw_file, duration=1, reject=60):
    '''
  Get idx of bad segments of data. Devide the data into epochs of length (duration) and reject segments 
  with peak to peak amplitude that exceeds a predefined threshold. Returns the index of good epochs. 

  Parameters
  ----------
  raw_file : Raw file
  duration : length of epochs (default 1s)
  reject : Threshold to reject (default 60)
  
  Return 
  ------
  list of good index
  '''
    events = make_fixed_length_events(raw_file.copy().crop(tmax=1467),
                                      id=1,
                                      duration=duration)
    epochs = Epochs(raw_file.copy().crop(tmax=1467),
                    events,
                    tmin=0,
                    tmax=duration,
                    reject=dict(eeg=reject),
                    baseline=None)
    epochs.drop_bad()
    return epochs.selection
Esempio n. 2
0
def test_events_long():
    """Test events."""
    data_path = testing.data_path()
    raw_fname = data_path + '/MEG/sample/sample_audvis_trunc_raw.fif'
    raw = read_raw_fif(raw_fname, preload=True)
    raw_tmin, raw_tmax = 0, 90

    tmin, tmax = -0.2, 0.5
    event_id = dict(aud_l=1, vis_l=3)

    # select gradiometers
    picks = pick_types(raw.info, meg='grad', eeg=False, eog=True,
                       stim=True, exclude=raw.info['bads'])

    # load data with usual Epochs for later verification
    raw = concatenate_raws([raw, raw.copy(), raw.copy(), raw.copy(),
                            raw.copy(), raw.copy()])
    assert 110 < raw.times[-1] < 130
    raw_cropped = raw.copy().crop(raw_tmin, raw_tmax)
    events_offline = find_events(raw_cropped)
    epochs_offline = Epochs(raw_cropped, events_offline, event_id=event_id,
                            tmin=tmin, tmax=tmax, picks=picks, decim=1,
                            reject=dict(grad=4000e-13, eog=150e-6),
                            baseline=None)
    epochs_offline.drop_bad()

    # create the mock-client object
    rt_client = MockRtClient(raw)
    rt_epochs = RtEpochs(rt_client, event_id, tmin, tmax, picks=picks, decim=1,
                         reject=dict(grad=4000e-13, eog=150e-6), baseline=None,
                         isi_max=1.)

    rt_epochs.start()
    rt_client.send_data(rt_epochs, picks, tmin=raw_tmin, tmax=raw_tmax,
                        buffer_size=1000)

    expected_events = epochs_offline.events.copy()
    expected_events[:, 0] = expected_events[:, 0] - raw_cropped.first_samp
    assert np.all(expected_events[:, 0] <=
                  (raw_tmax - tmax) * raw.info['sfreq'])
    assert_array_equal(rt_epochs.events, expected_events)
    assert len(rt_epochs) == len(epochs_offline)

    data_picks = pick_types(epochs_offline.info, meg='grad', eeg=False,
                            eog=True,
                            stim=False, exclude=raw.info['bads'])

    for ev_num, ev in enumerate(rt_epochs.iter_evoked()):
        if ev_num == 0:
            X_rt = ev.data[None, data_picks, :]
            y_rt = int(ev.comment)  # comment attribute contains the event_id
        else:
            X_rt = np.concatenate((X_rt, ev.data[None, data_picks, :]), axis=0)
            y_rt = np.append(y_rt, int(ev.comment))

    X_offline = epochs_offline.get_data()[:, data_picks, :]
    y_offline = epochs_offline.events[:, 2]
    assert_array_equal(X_rt, X_offline)
    assert_array_equal(y_rt, y_offline)
Esempio n. 3
0
def test_fieldtrip_rtepochs(free_tcp_port, tmpdir):
    """Test FieldTrip RtEpochs."""
    raw_tmax = 7
    raw = read_raw_fif(raw_fname, preload=True)
    raw.crop(tmin=0, tmax=raw_tmax)
    events_offline = find_events(raw, stim_channel='STI 014')
    event_id = list(np.unique(events_offline[:, 2]))
    tmin, tmax = -0.2, 0.5
    epochs_offline = Epochs(raw, events_offline, event_id=event_id,
                            tmin=tmin, tmax=tmax)
    epochs_offline.drop_bad()
    isi_max = (np.max(np.diff(epochs_offline.events[:, 0])) /
               raw.info['sfreq']) + 1.0

    neuromag2ft_fname = op.realpath(op.join(os.environ['NEUROMAG2FT_ROOT'],
                                            'neuromag2ft'))
    # Works with neuromag2ft-3.0.2
    cmd = (neuromag2ft_fname, '--file', raw_fname, '--speed', '8.0',
           '--bufport', str(free_tcp_port))

    with running_subprocess(cmd, after='terminate', verbose=False):
        data_rt = None
        events_ids_rt = None
        with pytest.warns(RuntimeWarning, match='Trying to guess it'):
            with FieldTripClient(host='localhost', port=free_tcp_port,
                                 tmax=raw_tmax, wait_max=2) as rt_client:
                # get measurement info guessed by MNE-Python
                raw_info = rt_client.get_measurement_info()
                assert ([ch['ch_name'] for ch in raw_info['chs']] ==
                        [ch['ch_name'] for ch in raw.info['chs']])

                # create the real-time epochs object
                epochs_rt = RtEpochs(rt_client, event_id, tmin, tmax,
                                     stim_channel='STI 014', isi_max=isi_max)
                epochs_rt.start()

                time.sleep(0.5)
                for ev_num, ev in enumerate(epochs_rt.iter_evoked()):
                    if ev_num == 0:
                        data_rt = ev.data[None, :, :]
                        events_ids_rt = int(
                            ev.comment)  # comment attribute contains event_id
                    else:
                        data_rt = np.concatenate(
                            (data_rt, ev.data[None, :, :]), axis=0)
                        events_ids_rt = np.append(events_ids_rt,
                                                  int(ev.comment))

                _call_base_epochs_public_api(epochs_rt, tmpdir)
                epochs_rt.stop(stop_receive_thread=True)

        assert_array_equal(events_ids_rt, epochs_rt.events[:, 2])
        assert_array_equal(data_rt, epochs_rt.get_data())
        assert len(epochs_rt) == len(epochs_offline)
        assert_array_equal(events_ids_rt, epochs_offline.events[:, 2])
        assert_allclose(epochs_rt.get_data(), epochs_offline.get_data(),
                        rtol=1.e-5, atol=1.e-8)  # defaults of np.isclose
Esempio n. 4
0
def test_events_sampledata():
    """ based on examples/realtime/plot_compute_rt_decoder.py"""
    data_path = sample.data_path()
    raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
    raw = read_raw_fif(raw_fname, preload=True)
    raw_tmin, raw_tmax = 0, 90

    tmin, tmax = -0.2, 0.5
    event_id = dict(aud_l=1, vis_l=3)

    # select gradiometers
    picks = pick_types(raw.info, meg='grad', eeg=False, eog=True,
                       stim=True, exclude=raw.info['bads'])

    # load data with usual Epochs for later verification
    raw_cropped = raw.copy().crop(raw_tmin, raw_tmax)
    events_offline = find_events(raw_cropped)
    epochs_offline = Epochs(raw_cropped, events_offline, event_id=event_id,
                            tmin=tmin, tmax=tmax, picks=picks, decim=1,
                            reject=dict(grad=4000e-13, eog=150e-6),
                            baseline=None)
    epochs_offline.drop_bad()

    # create the mock-client object
    rt_client = MockRtClient(raw)
    rt_epochs = RtEpochs(rt_client, event_id, tmin, tmax, picks=picks, decim=1,
                         reject=dict(grad=4000e-13, eog=150e-6), baseline=None,
                         isi_max=1.)

    rt_epochs.start()
    rt_client.send_data(rt_epochs, picks, tmin=raw_tmin, tmax=raw_tmax,
                        buffer_size=1000)

    expected_events = epochs_offline.events.copy()
    expected_events[:, 0] = expected_events[:, 0] - raw_cropped.first_samp
    assert np.all(expected_events[:, 0] <=
                  (raw_tmax - tmax) * raw.info['sfreq'])
    assert_array_equal(rt_epochs.events, expected_events)
    assert len(rt_epochs) == len(epochs_offline)

    data_picks = pick_types(epochs_offline.info, meg='grad', eeg=False,
                            eog=True,
                            stim=False, exclude=raw.info['bads'])

    for ev_num, ev in enumerate(rt_epochs.iter_evoked()):
        if ev_num == 0:
            X_rt = ev.data[None, data_picks, :]
            y_rt = int(ev.comment)  # comment attribute contains the event_id
        else:
            X_rt = np.concatenate((X_rt, ev.data[None, data_picks, :]), axis=0)
            y_rt = np.append(y_rt, int(ev.comment))

    X_offline = epochs_offline.get_data()[:, data_picks, :]
    y_offline = epochs_offline.events[:, 2]
    assert_array_equal(X_rt, X_offline)
    assert_array_equal(y_rt, y_offline)
Esempio n. 5
0
def test_fieldtrip_rtepochs(free_tcp_port, tmpdir):
    """Test FieldTrip RtEpochs."""
    raw_tmax = 7
    raw = read_raw_fif(raw_fname, preload=True)
    raw.crop(tmin=0, tmax=raw_tmax)
    events_offline = find_events(raw, stim_channel='STI 014')
    event_id = list(np.unique(events_offline[:, 2]))
    tmin, tmax = -0.2, 0.5
    epochs_offline = Epochs(raw, events_offline, event_id=event_id,
                            tmin=tmin, tmax=tmax)
    epochs_offline.drop_bad()
    isi_max = (np.max(np.diff(epochs_offline.events[:, 0])) /
               raw.info['sfreq']) + 1.0

    kill_signal = _start_buffer_thread(free_tcp_port)

    try:
        data_rt = None
        events_ids_rt = None
        with pytest.warns(RuntimeWarning, match='Trying to guess it'):
            with FieldTripClient(host='localhost', port=free_tcp_port,
                                 tmax=raw_tmax, wait_max=2) as rt_client:
                # get measurement info guessed by MNE-Python
                raw_info = rt_client.get_measurement_info()
                assert ([ch['ch_name'] for ch in raw_info['chs']] ==
                        [ch['ch_name'] for ch in raw.info['chs']])

                # create the real-time epochs object
                epochs_rt = RtEpochs(rt_client, event_id, tmin, tmax,
                                     stim_channel='STI 014', isi_max=isi_max)
                epochs_rt.start()

                time.sleep(0.5)
                for ev_num, ev in enumerate(epochs_rt.iter_evoked()):
                    if ev_num == 0:
                        data_rt = ev.data[None, :, :]
                        events_ids_rt = int(
                            ev.comment)  # comment attribute contains event_id
                    else:
                        data_rt = np.concatenate(
                            (data_rt, ev.data[None, :, :]), axis=0)
                        events_ids_rt = np.append(events_ids_rt,
                                                  int(ev.comment))

                _call_base_epochs_public_api(epochs_rt, tmpdir)
                epochs_rt.stop(stop_receive_thread=True)

        assert_array_equal(events_ids_rt, epochs_rt.events[:, 2])
        assert_array_equal(data_rt, epochs_rt.get_data())
        assert len(epochs_rt) == len(epochs_offline)
        assert_array_equal(events_ids_rt, epochs_offline.events[:, 2])
        assert_allclose(epochs_rt.get_data(), epochs_offline.get_data(),
                        rtol=1.e-5, atol=1.e-8)  # defaults of np.isclose
    finally:
        kill_signal.put(False)  # stop the buffer
def test_fieldtrip_rtepochs(free_tcp_port, tmpdir):
    """Test FieldTrip RtEpochs."""
    raw_tmax = 7
    raw = read_raw_fif(raw_fname, preload=True)
    raw.crop(tmin=0, tmax=raw_tmax)
    events_offline = find_events(raw, stim_channel='STI 014')
    event_id = list(np.unique(events_offline[:, 2]))
    tmin, tmax = -0.2, 0.5
    epochs_offline = Epochs(raw, events_offline, event_id=event_id,
                            tmin=tmin, tmax=tmax)
    epochs_offline.drop_bad()
    isi_max = (np.max(np.diff(epochs_offline.events[:, 0])) /
               raw.info['sfreq']) + 1.0

    kill_signal = _start_buffer_thread(free_tcp_port)

    try:
        data_rt = None
        events_ids_rt = None
        with pytest.warns(RuntimeWarning, match='Trying to guess it'):
            with FieldTripClient(host='localhost', port=free_tcp_port,
                                 tmax=raw_tmax, wait_max=2) as rt_client:
                # get measurement info guessed by MNE-Python
                raw_info = rt_client.get_measurement_info()
                assert ([ch['ch_name'] for ch in raw_info['chs']] ==
                        [ch['ch_name'] for ch in raw.info['chs']])

                # create the real-time epochs object
                epochs_rt = RtEpochs(rt_client, event_id, tmin, tmax,
                                     stim_channel='STI 014', isi_max=isi_max)
                epochs_rt.start()

                time.sleep(0.5)
                for ev_num, ev in enumerate(epochs_rt.iter_evoked()):
                    if ev_num == 0:
                        data_rt = ev.data[None, :, :]
                        events_ids_rt = int(
                            ev.comment)  # comment attribute contains event_id
                    else:
                        data_rt = np.concatenate(
                            (data_rt, ev.data[None, :, :]), axis=0)
                        events_ids_rt = np.append(events_ids_rt,
                                                  int(ev.comment))

                _call_base_epochs_public_api(epochs_rt, tmpdir)
                epochs_rt.stop(stop_receive_thread=True)

        assert_array_equal(events_ids_rt, epochs_rt.events[:, 2])
        assert_array_equal(data_rt, epochs_rt.get_data())
        assert len(epochs_rt) == len(epochs_offline)
        assert_array_equal(events_ids_rt, epochs_offline.events[:, 2])
        assert_allclose(epochs_rt.get_data(), epochs_offline.get_data(),
                        rtol=1.e-5, atol=1.e-8)  # defaults of np.isclose
    finally:
        kill_signal.put(False)  # stop the buffer
Esempio n. 7
0
def _define_epochs(fif_file, t_min, t_max, events_id, events_file='', decim=1):
    """Split raw .fif file into epochs depending on events file.

    Splitted epochs have a length ep_length with rejection criteria.
    """
    raw = read_raw_fif(fif_file)
    reject = _create_reject_dict(raw.info)
    picks = pick_types(raw.info,
                       meg=True,
                       ref_meg=False,
                       eog=True,
                       stim=True,
                       exclude='bads')

    data_path, base, ext = split_filename(fif_file)

    if events_file:
        events_fpath = glob.glob(op.join(data_path, events_file))
        events = read_events(events_fpath[0])
    else:
        events = find_events(raw)

    # TODO -> use autoreject ?
    # reject_tmax = 0.8  # duration we really care about
    epochs = Epochs(raw,
                    events,
                    events_id,
                    t_min,
                    t_max,
                    proj=True,
                    picks=picks,
                    baseline=(None, 0),
                    reject=reject,
                    decim=decim,
                    preload=True)

    epochs.drop_bad(reject=reject)

    good_events_file = os.path.join(data_path, 'good_events.txt')
    np.savetxt(good_events_file, epochs.events)

    # TODO -> decide where to save...
    savename = os.path.abspath(base + '-epo' + ext)
    # savename = os.path.join(data_path, base + '-epo' + ext)
    epochs.save(savename, overwrite=True)
    return savename
Esempio n. 8
0
def test_verbose_method(verbose):
    """Test for gh-8772."""
    # raw
    raw = read_raw_fif(fname_raw, verbose=verbose)
    with catch_logging() as log:
        raw.load_data(verbose=True)
    log = log.getvalue()
    assert 'Reading 0 ... 14399' in log
    with catch_logging() as log:
        raw.load_data(verbose=False)
    log = log.getvalue()
    assert log == ''
    # epochs
    events = np.array([[raw.first_samp + 200, 0, 1]], int)
    epochs = Epochs(raw, events, verbose=verbose)
    with catch_logging() as log:
        epochs.drop_bad(verbose=True)
    log = log.getvalue()
    assert '0 bad epochs dropped' in log
    epochs = Epochs(raw, events, verbose=verbose)
    with catch_logging() as log:
        epochs.drop_bad(verbose=False)
    log = log.getvalue()
    assert log == ''
Esempio n. 9
0
def test_source_psd_epochs():
    """Test multi-taper source PSD computation in label from epochs."""
    raw = read_raw_fif(fname_data)
    inverse_operator = read_inverse_operator(fname_inv)
    label = read_label(fname_label)

    event_id, tmin, tmax = 1, -0.2, 0.5
    lambda2, method = 1. / 9., 'dSPM'
    bandwidth = 8.
    fmin, fmax = 0, 100

    picks = pick_types(raw.info,
                       meg=True,
                       eeg=False,
                       stim=True,
                       ecg=True,
                       eog=True,
                       include=['STI 014'],
                       exclude='bads')
    reject = dict(grad=4000e-13, mag=4e-12, eog=150e-6)

    events = find_events(raw, stim_channel='STI 014')
    epochs = Epochs(raw,
                    events,
                    event_id,
                    tmin,
                    tmax,
                    picks=picks,
                    baseline=(None, 0),
                    reject=reject)

    # only look at one epoch
    epochs.drop_bad()
    one_epochs = epochs[:1]

    inv = prepare_inverse_operator(inverse_operator,
                                   nave=1,
                                   lambda2=1. / 9.,
                                   method="dSPM")
    # return list
    stc_psd = compute_source_psd_epochs(one_epochs,
                                        inv,
                                        lambda2=lambda2,
                                        method=method,
                                        pick_ori="normal",
                                        label=label,
                                        bandwidth=bandwidth,
                                        fmin=fmin,
                                        fmax=fmax,
                                        prepared=True)[0]

    # return generator
    stcs = compute_source_psd_epochs(one_epochs,
                                     inv,
                                     lambda2=lambda2,
                                     method=method,
                                     pick_ori="normal",
                                     label=label,
                                     bandwidth=bandwidth,
                                     fmin=fmin,
                                     fmax=fmax,
                                     return_generator=True,
                                     prepared=True)

    for stc in stcs:
        stc_psd_gen = stc

    assert_array_almost_equal(stc_psd.data, stc_psd_gen.data)

    # compare with direct computation
    stc = apply_inverse_epochs(one_epochs,
                               inv,
                               lambda2=lambda2,
                               method=method,
                               pick_ori="normal",
                               label=label,
                               prepared=True)[0]

    sfreq = epochs.info['sfreq']
    psd, freqs = psd_array_multitaper(stc.data,
                                      sfreq=sfreq,
                                      bandwidth=bandwidth,
                                      fmin=fmin,
                                      fmax=fmax)

    assert_array_almost_equal(psd, stc_psd.data)
    assert_array_almost_equal(freqs, stc_psd.times)

    # Check corner cases caused by tiny bandwidth
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter('always')
        compute_source_psd_epochs(one_epochs,
                                  inv,
                                  lambda2=lambda2,
                                  method=method,
                                  pick_ori="normal",
                                  label=label,
                                  bandwidth=0.01,
                                  low_bias=True,
                                  fmin=fmin,
                                  fmax=fmax,
                                  return_generator=False,
                                  prepared=True)
        compute_source_psd_epochs(one_epochs,
                                  inv,
                                  lambda2=lambda2,
                                  method=method,
                                  pick_ori="normal",
                                  label=label,
                                  bandwidth=0.01,
                                  low_bias=False,
                                  fmin=fmin,
                                  fmax=fmax,
                                  return_generator=False,
                                  prepared=True)
    assert_true(len(w) >= 2)
    assert_true(any('not properly use' in str(ww.message) for ww in w))
    assert_true(any('Bandwidth too small' in str(ww.message) for ww in w))
Esempio n. 10
0
    # Infer window size based on the frequency being used
    w_size = n_cycles / ((fmax + fmin) / 2.)  # in seconds

    # Apply band-pass filter to isolate the specified frequencies
    raw_filter = raw.copy().filter(fmin, fmax, n_jobs=1, fir_design='firwin')

    # Extract epochs from filtered data, padded by window size
    epochs = Epochs(raw_filter,
                    events,
                    event_id,
                    tmin - w_size,
                    tmax + w_size,
                    proj=False,
                    baseline=None,
                    preload=True)
    epochs.drop_bad()
    y = epochs.events[:, 2] - 2

    # Roll covariance, csp and lda over time
    for t, w_time in enumerate(centered_w_times):

        # Center the min and max of the window
        w_tmin = w_time - w_size / 2.
        w_tmax = w_time + w_size / 2.

        # Crop data into time-window of interest
        X = epochs.copy().crop(w_tmin, w_tmax).get_data()

        # Save mean scores over folds for each frequency and time window
        scores[freq, t] = np.mean(cross_val_score(estimator=clf,
                                                  X=X,
def test_source_psd_epochs():
    """Test multi-taper source PSD computation in label from epochs."""
    raw = read_raw_fif(fname_data)
    inverse_operator = read_inverse_operator(fname_inv)
    label = read_label(fname_label)

    event_id, tmin, tmax = 1, -0.2, 0.5
    lambda2, method = 1. / 9., 'dSPM'
    bandwidth = 8.
    fmin, fmax = 0, 100

    picks = pick_types(raw.info, meg=True, eeg=False, stim=True,
                       ecg=True, eog=True, include=['STI 014'],
                       exclude='bads')
    reject = dict(grad=4000e-13, mag=4e-12, eog=150e-6)

    events = find_events(raw, stim_channel='STI 014')
    epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
                    baseline=(None, 0), reject=reject)

    # only look at one epoch
    epochs.drop_bad()
    one_epochs = epochs[:1]

    inv = prepare_inverse_operator(inverse_operator, nave=1,
                                   lambda2=1. / 9., method="dSPM")
    # return list
    stc_psd = compute_source_psd_epochs(one_epochs, inv,
                                        lambda2=lambda2, method=method,
                                        pick_ori="normal", label=label,
                                        bandwidth=bandwidth,
                                        fmin=fmin, fmax=fmax,
                                        prepared=True)[0]

    # return generator
    stcs = compute_source_psd_epochs(one_epochs, inv,
                                     lambda2=lambda2, method=method,
                                     pick_ori="normal", label=label,
                                     bandwidth=bandwidth,
                                     fmin=fmin, fmax=fmax,
                                     return_generator=True,
                                     prepared=True)

    for stc in stcs:
        stc_psd_gen = stc

    assert_array_almost_equal(stc_psd.data, stc_psd_gen.data)

    # compare with direct computation
    stc = apply_inverse_epochs(one_epochs, inv,
                               lambda2=lambda2, method=method,
                               pick_ori="normal", label=label,
                               prepared=True)[0]

    sfreq = epochs.info['sfreq']
    psd, freqs = _psd_multitaper(stc.data, sfreq=sfreq, bandwidth=bandwidth,
                                 fmin=fmin, fmax=fmax)

    assert_array_almost_equal(psd, stc_psd.data)
    assert_array_almost_equal(freqs, stc_psd.times)

    # Check corner cases caused by tiny bandwidth
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter('always')
        compute_source_psd_epochs(one_epochs, inv,
                                  lambda2=lambda2, method=method,
                                  pick_ori="normal", label=label,
                                  bandwidth=0.01, low_bias=True,
                                  fmin=fmin, fmax=fmax,
                                  return_generator=False,
                                  prepared=True)
        compute_source_psd_epochs(one_epochs, inv,
                                  lambda2=lambda2, method=method,
                                  pick_ori="normal", label=label,
                                  bandwidth=0.01, low_bias=False,
                                  fmin=fmin, fmax=fmax,
                                  return_generator=False,
                                  prepared=True)
    assert_true(len(w) >= 2)
    assert_true(any('not properly use' in str(ww.message) for ww in w))
    assert_true(any('Bandwidth too small' in str(ww.message) for ww in w))
Esempio n. 12
0
def do_preprocessing_combined(p, subjects, run_indices):
    """Do preprocessing on all raw files together.

    Calculates projection vectors to use to clean data.

    Parameters
    ----------
    p : instance of Parameters
        Analysis parameters.
    subjects : list of str
        Subject names to analyze (e.g., ['Eric_SoP_001', ...]).
    run_indices : array-like | None
        Run indices to include.
    """
    drop_logs = list()
    for si, subj in enumerate(subjects):
        proj_nums = _proj_nums(p, subj)
        ecg_channel = _handle_dict(p.ecg_channel, subj)
        flat = _handle_dict(p.flat, subj)
        if p.disp_files:
            print('  Preprocessing subject %g/%g (%s).' %
                  (si + 1, len(subjects), subj))
        pca_dir = _get_pca_dir(p, subj)
        bad_file = get_bad_fname(p, subj, check_exists=False)

        # Create SSP projection vectors after marking bad channels
        raw_names = get_raw_fnames(p, subj, 'sss', False, False,
                                   run_indices[si])
        empty_names = get_raw_fnames(p, subj, 'sss', 'only')
        for r in raw_names + empty_names:
            if not op.isfile(r):
                raise NameError('File not found (' + r + ')')

        fir_kwargs, old_kwargs = _get_fir_kwargs(p.fir_design)
        if isinstance(p.auto_bad, float):
            print('    Creating post SSS bad channel file:\n'
                  '        %s' % bad_file)
            # do autobad
            raw = _raw_LRFCP(raw_names,
                             p.proj_sfreq,
                             None,
                             None,
                             p.n_jobs_fir,
                             p.n_jobs_resample,
                             list(),
                             None,
                             p.disp_files,
                             method='fir',
                             filter_length=p.filter_length,
                             apply_proj=False,
                             force_bads=False,
                             l_trans=p.hp_trans,
                             h_trans=p.lp_trans,
                             phase=p.phase,
                             fir_window=p.fir_window,
                             pick=True,
                             skip_by_annotation='edge',
                             **fir_kwargs)
            events = fixed_len_events(p, raw)
            rtmin = p.reject_tmin \
                if p.reject_tmin is not None else p.tmin
            rtmax = p.reject_tmax \
                if p.reject_tmax is not None else p.tmax
            # do not mark eog channels bad
            meg, eeg = 'meg' in raw, 'eeg' in raw
            picks = pick_types(raw.info,
                               meg=meg,
                               eeg=eeg,
                               eog=False,
                               exclude=[])
            assert p.auto_bad_flat is None or isinstance(p.auto_bad_flat, dict)
            assert p.auto_bad_reject is None or \
                isinstance(p.auto_bad_reject, dict) or \
                p.auto_bad_reject == 'auto'
            if p.auto_bad_reject == 'auto':
                print('    Auto bad channel selection active. '
                      'Will try using Autoreject module to '
                      'compute rejection criterion.')
                try:
                    from autoreject import get_rejection_threshold
                except ImportError:
                    raise ImportError('     Autoreject module not installed.\n'
                                      '     Noisy channel detection parameter '
                                      '     not defined. To use autobad '
                                      '     channel selection either define '
                                      '     rejection criteria or install '
                                      '     Autoreject module.\n')
                print('    Computing thresholds.\n', end='')
                temp_epochs = Epochs(raw,
                                     events,
                                     event_id=None,
                                     tmin=rtmin,
                                     tmax=rtmax,
                                     baseline=_get_baseline(p),
                                     proj=True,
                                     reject=None,
                                     flat=None,
                                     preload=True,
                                     decim=1)
                kwargs = dict()
                if 'verbose' in get_args(get_rejection_threshold):
                    kwargs['verbose'] = False
                reject = get_rejection_threshold(temp_epochs, **kwargs)
                reject = {kk: vv for kk, vv in reject.items()}
            elif p.auto_bad_reject is None and p.auto_bad_flat is None:
                raise RuntimeError('Auto bad channel detection active. Noisy '
                                   'and flat channel detection '
                                   'parameters not defined. '
                                   'At least one criterion must be defined.')
            else:
                reject = p.auto_bad_reject
            if 'eog' in reject.keys():
                reject.pop('eog', None)
            epochs = Epochs(raw,
                            events,
                            None,
                            tmin=rtmin,
                            tmax=rtmax,
                            baseline=_get_baseline(p),
                            picks=picks,
                            reject=reject,
                            flat=p.auto_bad_flat,
                            proj=True,
                            preload=True,
                            decim=1,
                            reject_tmin=rtmin,
                            reject_tmax=rtmax)
            # channel scores from drop log
            drops = Counter([ch for d in epochs.drop_log for ch in d])
            # get rid of non-channel reasons in drop log
            scores = {
                kk: vv
                for kk, vv in drops.items() if kk in epochs.ch_names
            }
            ch_names = np.array(list(scores.keys()))
            # channel scores expressed as percentile and rank ordered
            counts = (100 * np.array([scores[ch] for ch in ch_names], float) /
                      len(epochs.drop_log))
            order = np.argsort(counts)[::-1]
            # boolean array masking out channels with <= % epochs dropped
            mask = counts[order] > p.auto_bad
            badchs = ch_names[order[mask]]
            if len(badchs) > 0:
                # Make sure we didn't get too many bad MEG or EEG channels
                for m, e, thresh in zip(
                    [True, False], [False, True],
                    [p.auto_bad_meg_thresh, p.auto_bad_eeg_thresh]):
                    picks = pick_types(epochs.info, meg=m, eeg=e, exclude=[])
                    if len(picks) > 0:
                        ch_names = [epochs.ch_names[pp] for pp in picks]
                        n_bad_type = sum(ch in ch_names for ch in badchs)
                        if n_bad_type > thresh:
                            stype = 'meg' if m else 'eeg'
                            raise RuntimeError('Too many bad %s channels '
                                               'found: %s > %s' %
                                               (stype, n_bad_type, thresh))

                print('    The following channels resulted in greater than '
                      '{:.0f}% trials dropped:\n'.format(p.auto_bad * 100))
                print(badchs)
                with open(bad_file, 'w') as f:
                    f.write('\n'.join(badchs))
        if not op.isfile(bad_file):
            print('    Clearing bad channels (no file %s)' %
                  op.sep.join(bad_file.split(op.sep)[-3:]))
            bad_file = None

        ecg_t_lims = _handle_dict(p.ecg_t_lims, subj)
        ecg_f_lims = p.ecg_f_lims

        ecg_eve = op.join(pca_dir, 'preproc_ecg-eve.fif')
        ecg_epo = op.join(pca_dir, 'preproc_ecg-epo.fif')
        ecg_proj = op.join(pca_dir, 'preproc_ecg-proj.fif')
        all_proj = op.join(pca_dir, 'preproc_all-proj.fif')

        get_projs_from = _handle_dict(p.get_projs_from, subj)
        if get_projs_from is None:
            get_projs_from = np.arange(len(raw_names))
        pre_list = [
            r for ri, r in enumerate(raw_names) if ri in get_projs_from
        ]

        projs = list()
        raw_orig = _raw_LRFCP(raw_names=pre_list,
                              sfreq=p.proj_sfreq,
                              l_freq=None,
                              h_freq=None,
                              n_jobs=p.n_jobs_fir,
                              n_jobs_resample=p.n_jobs_resample,
                              projs=projs,
                              bad_file=bad_file,
                              disp_files=p.disp_files,
                              method='fir',
                              filter_length=p.filter_length,
                              force_bads=False,
                              l_trans=p.hp_trans,
                              h_trans=p.lp_trans,
                              phase=p.phase,
                              fir_window=p.fir_window,
                              pick=True,
                              skip_by_annotation='edge',
                              **fir_kwargs)

        # Apply any user-supplied extra projectors
        if p.proj_extra is not None:
            if p.disp_files:
                print('    Adding extra projectors from "%s".' % p.proj_extra)
            projs.extend(read_proj(op.join(pca_dir, p.proj_extra)))

        proj_kwargs, p_sl = _get_proj_kwargs(p)
        #
        # Calculate and apply ERM projectors
        #
        if not p.cont_as_esss:
            if any(proj_nums[2]):
                assert proj_nums[2][2] == 0  # no EEG projectors for ERM
                if len(empty_names) == 0:
                    raise RuntimeError('Cannot compute empty-room projectors '
                                       'from continuous raw data')
                if p.disp_files:
                    print('    Computing continuous projectors using ERM.')
                # Use empty room(s), but processed the same way
                projs.extend(_compute_erm_proj(p, subj, projs, 'sss',
                                               bad_file))
            else:
                cont_proj = op.join(pca_dir, 'preproc_cont-proj.fif')
                _safe_remove(cont_proj)

        #
        # Calculate and apply the ECG projectors
        #
        if any(proj_nums[0]):
            if p.disp_files:
                print('    Computing ECG projectors...', end='')
            raw = raw_orig.copy()

            raw.filter(ecg_f_lims[0],
                       ecg_f_lims[1],
                       n_jobs=p.n_jobs_fir,
                       method='fir',
                       filter_length=p.filter_length,
                       l_trans_bandwidth=0.5,
                       h_trans_bandwidth=0.5,
                       phase='zero-double',
                       fir_window='hann',
                       skip_by_annotation='edge',
                       **old_kwargs)
            raw.add_proj(projs)
            raw.apply_proj()
            find_kwargs = dict()
            if 'reject_by_annotation' in get_args(find_ecg_events):
                find_kwargs['reject_by_annotation'] = True
            elif len(raw.annotations) > 0:
                print('    WARNING: ECG event detection will not make use of '
                      'annotations, please update MNE-Python')
            # We've already filtered the data channels above, but this
            # filters the ECG channel
            ecg_events = find_ecg_events(raw,
                                         999,
                                         ecg_channel,
                                         0.,
                                         ecg_f_lims[0],
                                         ecg_f_lims[1],
                                         qrs_threshold='auto',
                                         return_ecg=False,
                                         **find_kwargs)[0]
            use_reject, use_flat = _restrict_reject_flat(
                _handle_dict(p.ssp_ecg_reject, subj), flat, raw)
            ecg_epochs = Epochs(raw,
                                ecg_events,
                                999,
                                ecg_t_lims[0],
                                ecg_t_lims[1],
                                baseline=None,
                                reject=use_reject,
                                flat=use_flat,
                                preload=True)
            print('  obtained %d epochs from %d events.' %
                  (len(ecg_epochs), len(ecg_events)))
            if len(ecg_epochs) >= 20:
                write_events(ecg_eve, ecg_epochs.events)
                ecg_epochs.save(ecg_epo, **_get_epo_kwargs())
                desc_prefix = 'ECG-%s-%s' % tuple(ecg_t_lims)
                pr = compute_proj_wrap(ecg_epochs,
                                       p.proj_ave,
                                       n_grad=proj_nums[0][0],
                                       n_mag=proj_nums[0][1],
                                       n_eeg=proj_nums[0][2],
                                       desc_prefix=desc_prefix,
                                       **proj_kwargs)
                assert len(pr) == np.sum(proj_nums[0][::p_sl])
                write_proj(ecg_proj, pr)
                projs.extend(pr)
            else:
                plot_drop_log(ecg_epochs.drop_log)
                raw.plot(events=ecg_epochs.events)
                raise RuntimeError('Only %d/%d good ECG epochs found' %
                                   (len(ecg_epochs), len(ecg_events)))
            del raw, ecg_epochs, ecg_events
        else:
            _safe_remove([ecg_proj, ecg_eve, ecg_epo])

        #
        # Next calculate and apply the EOG projectors
        #
        for idx, kind in ((1, 'EOG'), (3, 'HEOG'), (4, 'VEOG')):
            _compute_add_eog(p, subj, raw_orig, projs, proj_nums[idx], kind,
                             pca_dir, flat, proj_kwargs, old_kwargs, p_sl)
        del proj_nums

        # save the projectors
        write_proj(all_proj, projs)

        #
        # Look at raw_orig for trial DQs now, it will be quick
        #
        raw_orig.filter(p.hp_cut,
                        p.lp_cut,
                        n_jobs=p.n_jobs_fir,
                        method='fir',
                        filter_length=p.filter_length,
                        l_trans_bandwidth=p.hp_trans,
                        phase=p.phase,
                        h_trans_bandwidth=p.lp_trans,
                        fir_window=p.fir_window,
                        skip_by_annotation='edge',
                        **fir_kwargs)
        raw_orig.add_proj(projs)
        raw_orig.apply_proj()
        # now let's epoch with 1-sec windows to look for DQs
        events = fixed_len_events(p, raw_orig)
        reject = _handle_dict(p.reject, subj)
        use_reject, use_flat = _restrict_reject_flat(reject, flat, raw_orig)
        epochs = Epochs(raw_orig,
                        events,
                        None,
                        p.tmin,
                        p.tmax,
                        preload=False,
                        baseline=_get_baseline(p),
                        reject=use_reject,
                        flat=use_flat,
                        proj=True)
        try:
            epochs.drop_bad()
        except AttributeError:  # old way
            epochs.drop_bad_epochs()
        drop_logs.append(epochs.drop_log)
        del raw_orig
        del epochs
    if p.plot_drop_logs:
        for subj, drop_log in zip(subjects, drop_logs):
            plot_drop_log(drop_log, p.drop_thresh, subject=subj)
Esempio n. 13
0
def test_csd_degenerate(evoked_csd_sphere):
    """Test degenerate conditions."""
    evoked, csd, sphere = evoked_csd_sphere
    warn_evoked = evoked.copy()
    warn_evoked.info['bads'].append(warn_evoked.ch_names[3])
    with pytest.raises(ValueError, match='Either drop.*or interpolate'):
        compute_current_source_density(warn_evoked)

    with pytest.raises(TypeError, match='must be an instance of'):
        compute_current_source_density(None)

    fail_evoked = evoked.copy()
    with pytest.raises(ValueError, match='Zero or infinite position'):
        for ch in fail_evoked.info['chs']:
            ch['loc'][:3] = np.array([0, 0, 0])
        compute_current_source_density(fail_evoked, sphere=sphere)

    with pytest.raises(ValueError, match='Zero or infinite position'):
        fail_evoked.info['chs'][3]['loc'][:3] = np.inf
        compute_current_source_density(fail_evoked, sphere=sphere)

    with pytest.raises(ValueError, match='No EEG channels found.'):
        fail_evoked = evoked.copy()
        fail_evoked.set_channel_types({ch_name: 'ecog' for ch_name in
                                       fail_evoked.ch_names})
        compute_current_source_density(fail_evoked, sphere=sphere)

    with pytest.raises(TypeError, match='lambda2'):
        compute_current_source_density(evoked, lambda2='0', sphere=sphere)

    with pytest.raises(ValueError, match='lambda2 must be between 0 and 1'):
        compute_current_source_density(evoked, lambda2=2, sphere=sphere)

    with pytest.raises(TypeError, match='stiffness must be'):
        compute_current_source_density(evoked, stiffness='0', sphere=sphere)

    with pytest.raises(ValueError, match='stiffness must be non-negative'):
        compute_current_source_density(evoked, stiffness=-2, sphere=sphere)

    with pytest.raises(TypeError, match='n_legendre_terms must be'):
        compute_current_source_density(evoked, n_legendre_terms=0.1,
                                       sphere=sphere)

    with pytest.raises(ValueError, match=('n_legendre_terms must be '
                                          'greater than 0')):
        compute_current_source_density(evoked, n_legendre_terms=0,
                                       sphere=sphere)

    with pytest.raises(ValueError, match='sphere must be'):
        compute_current_source_density(evoked, sphere=-0.1)

    with pytest.raises(ValueError, match=('sphere radius must be '
                                          'greater than 0')):
        compute_current_source_density(evoked, sphere=(-0.1, 0., 0., -1.))

    with pytest.raises(TypeError):
        compute_current_source_density(evoked, copy=2, sphere=sphere)

    # gh-7859
    raw = RawArray(evoked.data, evoked.info)
    epochs = Epochs(
        raw, [[0, 0, 1]], tmin=0, tmax=evoked.times[-1] - evoked.times[0],
        baseline=None, preload=False, proj=False)
    epochs.drop_bad()
    assert len(epochs) == 1
    assert_allclose(epochs.get_data()[0], evoked.data)
    with pytest.raises(RuntimeError, match='Computing CSD requires.*preload'):
        compute_current_source_density(epochs)
    epochs.load_data()
    raw = compute_current_source_density(raw)
    assert not np.allclose(raw.get_data(), evoked.data)
    evoked = compute_current_source_density(evoked)
    assert_allclose(raw.get_data(), evoked.data)
    epochs = compute_current_source_density(epochs)
    assert_allclose(epochs.get_data()[0], evoked.data)
Esempio n. 14
0
def test_source_psd_epochs(method):
    """Test multi-taper source PSD computation in label from epochs."""
    raw = read_raw_fif(fname_data)
    inverse_operator = read_inverse_operator(fname_inv)
    label = read_label(fname_label)

    event_id, tmin, tmax = 1, -0.2, 0.5
    lambda2 = 1. / 9.
    bandwidth = 8.
    fmin, fmax = 0, 100

    picks = pick_types(raw.info, meg=True, eeg=False, stim=True,
                       ecg=True, eog=True, include=['STI 014'],
                       exclude='bads')
    reject = dict(grad=4000e-13, mag=4e-12, eog=150e-6)

    events = find_events(raw, stim_channel='STI 014')
    epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
                    baseline=(None, 0), reject=reject)

    # only look at one epoch
    epochs.drop_bad()
    one_epochs = epochs[:1]

    inv = prepare_inverse_operator(inverse_operator, nave=1,
                                   lambda2=1. / 9., method="dSPM")
    # return list
    stc_psd = compute_source_psd_epochs(one_epochs, inv,
                                        lambda2=lambda2, method=method,
                                        pick_ori="normal", label=label,
                                        bandwidth=bandwidth,
                                        fmin=fmin, fmax=fmax,
                                        prepared=True)[0]

    # return generator
    stcs = compute_source_psd_epochs(one_epochs, inv,
                                     lambda2=lambda2, method=method,
                                     pick_ori="normal", label=label,
                                     bandwidth=bandwidth,
                                     fmin=fmin, fmax=fmax,
                                     return_generator=True,
                                     prepared=True)

    for stc in stcs:
        stc_psd_gen = stc

    assert_allclose(stc_psd.data, stc_psd_gen.data, atol=1e-7)

    # compare with direct computation
    stc = apply_inverse_epochs(one_epochs, inv,
                               lambda2=lambda2, method=method,
                               pick_ori="normal", label=label,
                               prepared=True)[0]

    sfreq = epochs.info['sfreq']
    psd, freqs = psd_array_multitaper(stc.data, sfreq=sfreq,
                                      bandwidth=bandwidth, fmin=fmin,
                                      fmax=fmax)

    assert_allclose(psd, stc_psd.data, atol=1e-7)
    assert_allclose(freqs, stc_psd.times)

    # Check corner cases caused by tiny bandwidth
    with pytest.raises(ValueError, match='use a value of at least'):
        compute_source_psd_epochs(
            one_epochs, inv, lambda2=lambda2, method=method,
            pick_ori="normal", label=label, bandwidth=0.01, low_bias=True,
            fmin=fmin, fmax=fmax, return_generator=False, prepared=True)
Esempio n. 15
0
def extract(filepaths, ica=None, fit_ica=True):
    if ica is None:
        ica = get_ica_transformers("infomax")

    epochs, labels = list(), list()
    for filepath in filepaths:
        print("loading", filepath)
        try:
            gdf = read_raw_gdf(
                filepath,
                eog=["EOG-left", "EOG-central", "EOG-right"],
                exclude=["EOG-left", "EOG-central", "EOG-right"])
            events = events_from_annotations(gdf,
                                             event_id={
                                                 "769": 0,
                                                 "770": 1,
                                                 "771": 2,
                                                 "772": 3
                                             })
            epoch = Epochs(gdf,
                           events[0],
                           event_repeated="drop",
                           reject_by_annotation=True,
                           tmin=-.3,
                           tmax=.7,
                           reject=dict(eeg=1e-4))
            epoch.drop_bad()
        except ValueError:
            print("Error in", filepath)
            continue
        epochs.append(epoch)
        labels.append(epoch.events[:, 2])

    labels = np.concatenate(labels)
    n_epochs, n_channels, n_times = epochs[0].get_data().shape
    ica_vec = [
        epoch.get_data().transpose(1, 0, 2).reshape(n_channels, -1).T
        for epoch in epochs
    ]
    ica_vec = np.concatenate(ica_vec, axis=0)

    if fit_ica:
        ica.fit(ica_vec)

    transformed = ica.transform(ica_vec)
    transformed = ica_vec

    transformed = transformed.T.reshape(n_channels, -1,
                                        n_times).transpose(1, 0, 2)

    features, freqs = psd_array_multitaper(transformed,
                                           250.,
                                           fmin=0,
                                           fmax=20,
                                           bandwidth=2)

    n_epochs, _, _ = features.shape
    #    features = features.reshape(n_epochs, -1)
    features = features.mean(axis=2)
    labels_placeholder = np.zeros((len(labels), 4))

    for i, l in enumerate(labels):
        labels_placeholder[i, l] = 1

    labels = labels_placeholder
    return features, labels, ica
Esempio n. 16
0
def test_source_psd_epochs():
    """Test multi-taper source PSD computation in label from epochs."""
    raw = read_raw_fif(fname_data)
    inverse_operator = read_inverse_operator(fname_inv)
    label = read_label(fname_label)

    event_id, tmin, tmax = 1, -0.2, 0.5
    lambda2, method = 1. / 9., 'dSPM'
    bandwidth = 8.
    fmin, fmax = 0, 100

    picks = pick_types(raw.info, meg=True, eeg=False, stim=True,
                       ecg=True, eog=True, include=['STI 014'],
                       exclude='bads')
    reject = dict(grad=4000e-13, mag=4e-12, eog=150e-6)

    events = find_events(raw, stim_channel='STI 014')
    epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
                    baseline=(None, 0), reject=reject)

    # only look at one epoch
    epochs.drop_bad()
    one_epochs = epochs[:1]

    inv = prepare_inverse_operator(inverse_operator, nave=1,
                                   lambda2=1. / 9., method="dSPM")
    # return list
    stc_psd = compute_source_psd_epochs(one_epochs, inv,
                                        lambda2=lambda2, method=method,
                                        pick_ori="normal", label=label,
                                        bandwidth=bandwidth,
                                        fmin=fmin, fmax=fmax,
                                        prepared=True)[0]

    # return generator
    stcs = compute_source_psd_epochs(one_epochs, inv,
                                     lambda2=lambda2, method=method,
                                     pick_ori="normal", label=label,
                                     bandwidth=bandwidth,
                                     fmin=fmin, fmax=fmax,
                                     return_generator=True,
                                     prepared=True)

    for stc in stcs:
        stc_psd_gen = stc

    assert_allclose(stc_psd.data, stc_psd_gen.data, atol=1e-7)

    # compare with direct computation
    stc = apply_inverse_epochs(one_epochs, inv,
                               lambda2=lambda2, method=method,
                               pick_ori="normal", label=label,
                               prepared=True)[0]

    sfreq = epochs.info['sfreq']
    psd, freqs = psd_array_multitaper(stc.data, sfreq=sfreq,
                                      bandwidth=bandwidth, fmin=fmin,
                                      fmax=fmax)

    assert_allclose(psd, stc_psd.data, atol=1e-7)
    assert_allclose(freqs, stc_psd.times)

    # Check corner cases caused by tiny bandwidth
    with pytest.raises(ValueError, match='use a value of at least'):
        compute_source_psd_epochs(
            one_epochs, inv, lambda2=lambda2, method=method,
            pick_ori="normal", label=label, bandwidth=0.01, low_bias=True,
            fmin=fmin, fmax=fmax, return_generator=False, prepared=True)
Esempio n. 17
0
def CSP_dec(raw, events, event_id, subject):  
    """Common spatial pattern method is used for frequency filtering. 
    Linear discriminant analysis is used to label data into conditions based on a frequency component created by CSP."""

    event_id = dict(ISI_time_correct = event_id['ISI_time_correct'], ISI_time_control=event_id['ISI_time_control'])  # motor imagery: hands vs feet   
    # Extract information from the raw file
    sfreq = raw.info['sfreq']
    raw.pick_types(meg=False, eeg=True, stim=False, eog=False)
    
    # Assemble the classifier using scikit-learn pipeline
    clf = make_pipeline(CSP(n_components=4, reg=None, log=True),
                        LinearDiscriminantAnalysis())
    n_splits = 5  # how many folds to use for cross-validation
    cv = StratifiedKFold(n_splits=n_splits, shuffle=True)
    
    # Classification & Time-frequency parameters
    tmin, tmax = 1.0, 4.0
    n_cycles = 10.  # how many complete cycles: used to define window size
    min_freq = 2.
    max_freq = 50.
    n_freqs = 24  # how many frequency bins to use
    
    # Assemble list of frequency range tuples
    freqs = np.linspace(min_freq, max_freq, n_freqs)  # assemble frequencies
    freq_ranges = list(zip(freqs[:-1], freqs[1:]))  # make freqs list of tuples
    
    # Infer window spacing from the max freq and number of cycles to avoid gaps
    window_spacing = (n_cycles / np.max(freqs) / 2.)
    centered_w_times = np.arange(tmin, tmax, window_spacing)[1:]
    n_windows = len(centered_w_times)
    
    # Instantiate label encoder
    le = LabelEncoder()
            
            # init scores
    freq_scores = np.zeros((n_freqs - 1,))
    
    # Loop through each frequency range of interest
    for freq, (fmin, fmax) in enumerate(freq_ranges):
    
        # Infer window size based on the frequency being used
        w_size = n_cycles / ((fmax + fmin) / 2.)  # in seconds
    
        # Apply band-pass filter to isolate the specified frequencies
        raw_filter = raw.copy().filter(fmin, fmax, n_jobs=1, fir_design='firwin')
    
        # Extract epochs from filtered data, padded by window size
        epochs = Epochs(raw_filter, events, event_id, tmin - w_size, tmax + w_size,
                        proj=False, baseline=None, preload=True)
        
        e1 = epochs['ISI_time_correct']
        e2 = epochs['ISI_time_control']
        mne.epochs.equalize_epoch_counts([e1,e2])
        epochs = mne.epochs.concatenate_epochs([e1,e2])
        
        epochs.drop_bad()
        y = le.fit_transform(epochs.events[:, 2])
    
        X = epochs.get_data()
    
        # Save mean scores over folds for each frequency and time window
        freq_scores[freq] = np.mean(cross_val_score(estimator=clf, X=X, y=y,
                                                    scoring='roc_auc', cv=cv,
                                                    n_jobs=8), axis=0)
        
    
    fig, axes = plt.subplots(figsize = (20,10))
    axes.bar(left=freqs[:-1], height=freq_scores, width=np.diff(freqs)[0],
            align='edge', edgecolor='black')
    axes.set_xticks([int(round(f)) for f in freqs])
    axes.set_ylim([0, 1])
    axes.axhline(len(epochs['ISI_time_correct']) / len(epochs), color='k', linestyle='--',
                label='chance level')
    axes.legend()
    axes.set_xlabel('Frequency (Hz)')
    axes.set_ylabel('Decoding Scores')                                       
    fig.suptitle('Frequency Decoding Scores ' + subject)
    
    fig.savefig('CSP_decoding/' + subject)
    
    return freq_scores, freqs
            
        
freq_scores = np.zeros((n_freqs - 1,))

# Loop through each frequency range of interest
for freq, (fmin, fmax) in enumerate(freq_ranges):

    # Infer window size based on the frequency being used
    w_size = n_cycles / ((fmax + fmin) / 2.)  # in seconds

    # Apply band-pass filter to isolate the specified frequencies
    raw_filter = raw.copy().filter(fmin, fmax, n_jobs=1, fir_design='firwin',
                                   skip_by_annotation='edge')

    # Extract epochs from filtered data, padded by window size
    epochs = Epochs(raw_filter, events, event_id, tmin - w_size, tmax + w_size,
                    proj=False, baseline=None, preload=True)
    epochs.drop_bad()
    y = le.fit_transform(epochs.events[:, 2])

    X = epochs.get_data()

    # Save mean scores over folds for each frequency and time window
    freq_scores[freq] = np.mean(cross_val_score(estimator=clf, X=X, y=y,
                                                scoring='roc_auc', cv=cv,
                                                n_jobs=1), axis=0)

###############################################################################
# Plot frequency results

plt.bar(freqs[:-1], freq_scores, width=np.diff(freqs)[0],
        align='edge', edgecolor='black')
plt.xticks(freqs)
Esempio n. 19
0
class SSVEP_Analysis_Offline:
    '''
    本类用于ssvep数据的分析,主要包含5个功能,分别为:数据加载,数据预处理,特征提取,分类器构建,离线分类与结果统计。
    '''
    def __init__(self):
        self.all_acc = []
        self.all_itr = []
        pass

    def load_data(
            self,
            filename='aaaa',
            data_format='eeg',
            trial_list={
                'Stimulus/S  1': 1,
                'Stimulus/S  2': 2,
                'Stimulus/S  3': 3,
                'Stimulus/S  4': 4
            },
            tmin=0.0,
            tmax=8.0,
            fmin=5.5,
            fmax=35.0):
        '''

        :param filename: 脑电数据文件名
        :param data_format: 数据格式,支持.eeg与.fif两种脑电数据格式。
        :param trial_list: 需要分析的mark列表
        :param tmin: 分析时间段的起始时间
        :param tmax: 分析时间段的结束时间
        :param fmin: 分析频段的起始频率
        :param fmax: 分析频段的截止频率
        :return:
        '''
        if data_format == 'eeg':
            self.raw_data = read_raw_brainvision(filename,
                                                 preload=True,
                                                 verbose=False)
        elif data_format == 'fif':
            self.raw_data = read_raw_fif(filename, preload=True, verbose=False)
        else:
            print('当前没有添加读取该文件格式的函数,欢迎补充')

        self.sfreq = self.raw_data.info['sfreq']  # 采样率

        self.trial_list = trial_list  # mark列表
        # self.trial_list = {'Start': 1, 'End': 2}

        self.tmin, self.tmax = tmin, tmax  # mark起止时间
        self.fmin, self.fmax = fmin, fmax  # 滤波频段

    def data_preprocess(self,
                        window_size=2.,
                        window_step=0.1,
                        data_augmentation=True):
        '''

        :param window_size: 滑动时间窗窗长
        :param window_step: 滑动时间窗步长
        :param data_augmentation: 是否需要进行脑电数据样本扩增
        :return:
        '''
        self.window_size = window_size
        events, _ = events_from_annotations(self.raw_data,
                                            event_id=self.trial_list,
                                            verbose=False)

        flag = self.tmin
        events_step = np.zeros_like(events)
        events_step[:, 0] += int(window_step * self.sfreq)
        event_augmentation = events
        events_temp = events
        while flag < self.tmax - window_size:
            events_temp = events_temp + events_step
            event_augmentation = np.concatenate(
                (event_augmentation, events_temp), axis=0)
            flag += window_step

        if data_augmentation == True:
            all_event = event_augmentation
            all_event = all_event[np.argsort(all_event[:, 0])]
        else:
            all_event = events

        # 提取epoch
        self.epochs = Epochs(self.raw_data,
                             events=all_event,
                             event_id=self.trial_list,
                             tmin=0,
                             tmax=window_size,
                             proj=False,
                             baseline=None,
                             preload=True,
                             verbose=False)
        self.epochs.drop_bad()

        # 滤波
        self.epochs_filter = self.epochs.filter(self.fmin,
                                                self.fmax,
                                                n_jobs=-1,
                                                fir_design='firwin',
                                                verbose=False)

        # 对标签进行编码
        le = LabelEncoder()
        self.all_label = le.fit_transform(self.epochs_filter.events[:, 2])

    def feature_extract(self, method=None, plot=True):
        '''

        :param method: 特征提取方法,目前仅支持PSD方法,后续会加入时域特征,频域特征,熵,以及组合特征。
                       由于目前结果已经很好了,就先这样吧。
        :param plot: 是否对特征进行可视化。
        :return:
        '''
        self.all_feature, self.frequencies = psd_welch(self.epochs_filter,
                                                       fmin=self.fmin,
                                                       fmax=self.fmax,
                                                       verbose=False,
                                                       n_per_seg=128)
        self.all_feature = self.all_feature.reshape(self.all_feature.shape[0],
                                                    -1)  # 将各个通道的脑电特征拉平为一维
        # self.all_feature = np.mean(self.all_feature, axis=1)  # 对通道这个维度求平均,降维

        if plot == True:

            for label in np.unique(self.all_label):
                _, ax = plt.subplots(1, 1)
                ax.set_title(list(self.trial_list.keys())[label])
                # plt.show()
                self.epochs_filter.copy().drop(indices=(self.all_label != label)).\
                    plot_psd(dB=False, fmin=0.5, fmax=32., color='blue', ax=ax)

    def classifier_building(self,
                            scaler_form='StandardScaler',
                            train_with_GridSearchCV=False):
        '''

        :param scaler_form: 对输入数据X标准化进行标准化的类型,一共三种:StandardScaler,MinMaxScaler, Normalizer。
                            默认为StandardScaler。
        :param train_with_GridSearchCV: 是否使用网格搜索进行参数寻优。由于目前分类效果已经不错,
                                        所以该功能的优先级放在了最后,目前该功能还在完善中。
        :return:
        '''

        if scaler_form == 'StandardScaler':
            scaler = StandardScaler()
        elif scaler_form == 'MinMaxScaler':
            scaler = MinMaxScaler()
        else:
            scaler = Normalizer()

        # scaler = StandardScaler()
        # scaler.fit(X)
        # print(scaler.mean_, scaler.var_)
        # X = scaler.transform(X)

        # 创建文件夹,存储分类器model
        if not os.path.exists('./record_model'):
            os.mkdir('./record_model')

        # 准备数据集
        X = self.all_feature
        y = self.all_label

        # from sklearn.utils import shuffle
        # X, y = shuffle(X, y, random_state=0)  # 将数据打乱

        # X, X_test, y, y_test = train_test_split(X, y, test_size=0.2, random_state=0)  # 划分为训练集与测试集
        print(X.shape)

        X, X_test, y, y_test = X[0:int(len(y)*3/4), :], X[int(len(y)*3/4):, :], \
                               y[0:int(len(y)*3/4)], y[int(len(y)*3/4):]
        # 由于脑电信号在时域上差异较大,所以不建议打乱重排后再划分训练集。
        # (用前面的数据训练分类器来预测后面数据的类别,更显公允。)

        # 构建分类器
        lr_clf = LogisticRegression(multi_class='auto',
                                    solver="liblinear",
                                    random_state=42)
        lda_clf = LinearDiscriminantAnalysis()
        gn_clf = GaussianNB()
        svm_clf_1 = SVC(kernel='linear', gamma="auto", random_state=42)
        svm_clf_2 = SVC(kernel='poly', gamma="auto", random_state=42)
        svm_clf_3 = SVC(kernel='rbf', gamma="auto", random_state=42)
        svm_clf_4 = SVC(kernel='sigmoid', gamma="auto", random_state=42)
        knn_clf = KNeighborsClassifier()
        rf_clf = RandomForestClassifier(n_estimators=100,
                                        random_state=42,
                                        n_jobs=-1)
        gb_clf = GradientBoostingClassifier()
        ada_clf = AdaBoostClassifier(DecisionTreeClassifier(max_depth=1),
                                     n_estimators=1000,
                                     algorithm="SAMME.R",
                                     learning_rate=0.5,
                                     random_state=42)
        bag_clf_rf = BaggingClassifier(n_jobs=-1, random_state=42)
        bag_clf_knn = BaggingClassifier(KNeighborsClassifier(),
                                        max_samples=0.5,
                                        max_features=0.5)
        xgb_clf = XGBClassifier()
        mlp_clf = MLPClassifier(solver='sgd',
                                activation='logistic',
                                alpha=1e-4,
                                hidden_layer_sizes=(30, 10),
                                random_state=42,
                                max_iter=1000,
                                verbose=False,
                                learning_rate_init=.1)

        # voting_clf = VotingClassifier(estimators=[('lr_clf', lr_clf), ('lda_clf', lda_clf), ('gn_clf', gn_clf),
        #                                           ('svm_clf_1', svm_clf_1), ('svm_clf_2', svm_clf_2),
        #                                           ('svm_clf_3', svm_clf_3), ('svm_clf_4', svm_clf_4),
        #                                           ('knn_clf', knn_clf), ('rf_clf', rf_clf), ('gb_clf', gb_clf),
        #                                           ('ada_clf', ada_clf), ('bag_clf_rf', bag_clf_rf),
        #                                           ('bag_clf_knn', bag_clf_knn), ('xgb_clf', xgb_clf),
        #                                           ('mlp_clf', mlp_clf)],
        #                               voting='soft', n_jobs=-1)

        # 将以上的分类器进行投票,构造投票分类器。可以为这些分类器添加不同的权重,权重系数为weights=None
        voting_clf = VotingClassifier(
            estimators=[
                ('lr_clf', lr_clf),
                ('lda_clf', lda_clf),
                ('gn_clf', gn_clf),
                # ('svm_clf_1', svm_clf_1), ('svm_clf_2', svm_clf_2),
                # ('svm_clf_3', svm_clf_3), ('svm_clf_4', svm_clf_4),
                ('knn_clf', knn_clf),
                ('rf_clf', rf_clf),
                ('gb_clf', gb_clf),
                ('ada_clf', ada_clf),
                ('bag_clf_rf', bag_clf_rf),
                ('bag_clf_knn', bag_clf_knn),
                ('xgb_clf', xgb_clf),
                ('mlp_clf', mlp_clf)
            ],
            voting='soft',
            n_jobs=-1,
            weights=None)

        # 将以上所有分类器组合成一个列表表示
        listing_clf = [
            lr_clf, lda_clf, gn_clf, svm_clf_1, svm_clf_2, svm_clf_3,
            svm_clf_4, knn_clf, rf_clf, gb_clf, ada_clf, bag_clf_rf,
            bag_clf_knn, xgb_clf, mlp_clf, voting_clf
        ]
        self.listing_clf_name = [
            'lr_clf', 'lda_clf', 'gn_clf', 'svm_clf_1', 'svm_clf_2',
            'svm_clf_3', 'svm_clf_4', 'knn_clf', 'rf_clf', 'gb_clf', 'ada_clf',
            'bag_clf_rf', 'bag_clf_knn', 'xgb_clf', 'mlp_clf', 'voting_clf'
        ]

        # 进行训练(不进行参数寻优)
        if train_with_GridSearchCV == False:
            # 开始训练
            cv = StratifiedKFold(n_splits=5, shuffle=True)
            cv_scores = np.zeros((len(listing_clf)))  # 数组,分类结果准确率
            for i, classify in enumerate(listing_clf):
                models = Pipeline(
                    memory=None,
                    steps=[
                        ('Scaler', scaler),  # 数据标准化
                        (self.listing_clf_name[i], classify),  # 分类器
                    ])

                models.fit(X=X, y=y)
                joblib.dump(
                    classify, './record_model/' + str(i) + '-' +
                    self.listing_clf_name[i] + '.pkl')
                print('第', i, '个分类器:', self.listing_clf_name[i])
                # new_svm = joblib.load('svm.pkl')

                y_pred = models.predict(X_test)

                # 计算acc 与 itr
                acc = accuracy_score(y_pred, y_test)
                self.all_acc.append(acc)
                itr = self.cal_itr(len(self.trial_list), acc, self.window_size)
                self.all_itr.append(itr)
                print('正确率:', acc, '   itr: ', itr)

                # # 使用K-fold交叉验证进行评估
                # cv_scores[i] = np.mean(cross_val_score(estimator=classify, X=X, y=y,
                #                                        scoring='accuracy', cv=cv, n_jobs=-1), axis=0)

            # print(cv_scores)

        # 进行训练(参数寻优,随机搜索)
        else:
            print('sss')

            param_dist = {
                'n_estimators': range(80, 200, 4),
                'max_depth': range(2, 15, 1),
                'learning_rate': np.linspace(0.01, 2, 20),
                'subsample': np.linspace(0.7, 0.9, 20),
                'colsample_bytree': np.linspace(0.5, 0.98, 10),
                'min_child_weight': range(1, 9, 1)
            }
            grid_search = RandomizedSearchCV(xgb_clf,
                                             param_dist,
                                             n_iter=300,
                                             cv=5,
                                             scoring='accuracy',
                                             n_jobs=-1)
            print('sss')
            grid_search.fit(X, y)

            print(grid_search.best_estimator_.feature_importances_)
            print(grid_search.best_params_)
            print(grid_search.best_estimator_)

            xgb_clf_final = grid_search.best_estimator_
            y_pred = xgb_clf_final.predict(X_test)
            print('正确率:', accuracy_score(y_pred, y_test))

            pass

    def classify_offline(self, model_file='./record_model/15-voting_clf.pkl'):
        '''

        :param model_file: 文件名
        :return:
        '''
        # 准备数据集
        X_test = self.all_feature
        y_test = self.all_label

        scaler = StandardScaler()
        scaler.fit(X_test)
        print(scaler.mean_, scaler.var_)
        X_test = scaler.transform(X_test)

        # 加载模型文件
        model_clf = joblib.load(model_file)

        # 进行预测
        y_pred = model_clf.predict(X_test)
        print('正确率:', accuracy_score(y_pred, y_test))

    def result_statistics(self):
        '''
        待补充
        :return:
        '''
        return self.all_acc, self.all_itr
        pass

    def cal_itr(self, q, p, t):
        '''
        BCI性能衡量指标丨信息传输速率 Information Transfer Rate
        itr的计算方法为理想ITR计算,即平均试次时间不包含模拟休息时长。

        :param q: 目标个数
        :param p: 识别正确率
        :param t: 平均试次时长,单位为s
        :return: itr: 信息传输速率
        '''

        if p == 1:
            itr = np.log2(q) * 60 / t
        else:
            itr = 60 / t * (np.log2(q) + p * np.log2(p) + (1 - p) * np.log2(
                (1 - p) / (q - 1)))

        return itr