Example #1
0
def assert_indexing(info, picks_by_type, ref_meg=False, all_data=True):
    """Assert our indexing functions work properly."""
    # First that our old and new channel typing functions are equivalent
    _assert_channel_types(info)
    # Next that channel_indices_by_type works
    if not ref_meg:
        idx = channel_indices_by_type(info)
        for key in idx:
            for p in picks_by_type:
                if key == p[0]:
                    assert_array_equal(idx[key], p[1])
                    break
            else:
                assert len(idx[key]) == 0
    # Finally, picks_by_type (if relevant)
    if not all_data:
        picks_by_type = [p for p in picks_by_type
                         if p[0] in _DATA_CH_TYPES_SPLIT]
    picks_by_type = [(p[0], np.array(p[1], int)) for p in picks_by_type]
    actual = _picks_by_type(info, ref_meg=ref_meg)
    assert_object_equal(actual, picks_by_type)
    if not ref_meg and idx['hbo']:  # our old code had a bug
        with pytest.raises(TypeError, match='unexpected keyword argument'):
            _picks_by_type_old(info, ref_meg=ref_meg)
    else:
        old = _picks_by_type_old(info, ref_meg=ref_meg)
        assert_object_equal(old, picks_by_type)
    # test bads
    info = info.copy()
    info['bads'] = [info['chs'][picks_by_type[0][1][0]]['ch_name']]
    picks_by_type = deepcopy(picks_by_type)
    picks_by_type[0] = (picks_by_type[0][0], picks_by_type[0][1][1:])
    actual = _picks_by_type(info, ref_meg=ref_meg)
    assert_object_equal(actual, picks_by_type)
Example #2
0
def test_pick_forward_seeg():
    """Test picking forward with SEEG
    """
    fwd = read_forward_solution(test_forward.fname_meeg)
    counts = channel_indices_by_type(fwd["info"])
    for key in counts.keys():
        counts[key] = len(counts[key])
    counts["meg"] = counts["mag"] + counts["grad"]
    fwd_ = pick_types_forward(fwd, meg=True, eeg=False, seeg=False)
    _check_fwd_n_chan_consistent(fwd_, counts["meg"])
    fwd_ = pick_types_forward(fwd, meg=False, eeg=True, seeg=False)
    _check_fwd_n_chan_consistent(fwd_, counts["eeg"])
    # should raise exception related to emptiness
    assert_raises(ValueError, pick_types_forward, fwd, meg=False, eeg=False, seeg=True)
    # change last chan from EEG to sEEG
    seeg_name = "OTp1"
    rename_channels(fwd["info"], {"EEG 060": seeg_name})
    for ch in fwd["info"]["chs"]:
        if ch["ch_name"] == seeg_name:
            ch["kind"] = FIFF.FIFFV_SEEG_CH
            ch["coil_type"] = FIFF.FIFFV_COIL_EEG
    fwd["sol"]["row_names"][-1] = fwd["info"]["chs"][-1]["ch_name"]
    counts["eeg"] -= 1
    counts["seeg"] += 1
    # repick & check
    fwd_seeg = pick_types_forward(fwd, meg=False, eeg=False, seeg=True)
    assert_equal(fwd_seeg["sol"]["row_names"], [seeg_name])
    assert_equal(fwd_seeg["info"]["ch_names"], [seeg_name])
    # should work fine
    fwd_ = pick_types_forward(fwd, meg=True, eeg=False, seeg=False)
    _check_fwd_n_chan_consistent(fwd_, counts["meg"])
    fwd_ = pick_types_forward(fwd, meg=False, eeg=True, seeg=False)
    _check_fwd_n_chan_consistent(fwd_, counts["eeg"])
    fwd_ = pick_types_forward(fwd, meg=False, eeg=False, seeg=True)
    _check_fwd_n_chan_consistent(fwd_, counts["seeg"])
Example #3
0
def test_pick_seeg_ecog():
    """Test picking with sEEG and ECoG
    """
    names = 'A1 A2 Fz O OTp1 OTp2 E1 OTp3 E2 E3'.split()
    types = 'mag mag eeg eeg seeg seeg ecog seeg ecog ecog'.split()
    info = create_info(names, 1024., types)
    idx = channel_indices_by_type(info)
    assert_array_equal(idx['mag'], [0, 1])
    assert_array_equal(idx['eeg'], [2, 3])
    assert_array_equal(idx['seeg'], [4, 5, 7])
    assert_array_equal(idx['ecog'], [6, 8, 9])
    assert_array_equal(pick_types(info, meg=False, seeg=True), [4, 5, 7])
    for i, t in enumerate(types):
        assert_equal(channel_type(info, i), types[i])
    raw = RawArray(np.zeros((len(names), 10)), info)
    events = np.array([[1, 0, 0], [2, 0, 0]])
    epochs = Epochs(raw, events, {'event': 0}, -1e-5, 1e-5, add_eeg_ref=False)
    evoked = epochs.average(pick_types(epochs.info, meg=True, seeg=True))
    e_seeg = evoked.copy().pick_types(meg=False, seeg=True)
    for l, r in zip(e_seeg.ch_names, [names[4], names[5], names[7]]):
        assert_equal(l, r)
    # Deal with constant debacle
    raw = read_raw_fif(op.join(io_dir, 'tests', 'data',
                               'test_chpi_raw_sss.fif'), add_eeg_ref=False)
    assert_equal(len(pick_types(raw.info, meg=False, seeg=True, ecog=True)), 0)
Example #4
0
def test_pick_forward_seeg():
    fwd = read_forward_solution(test_forward.fname_meeg)
    counts = channel_indices_by_type(fwd['info'])
    for key in counts.keys():
        counts[key] = len(counts[key])
    counts['meg'] = counts['mag'] + counts['grad']
    fwd_ = pick_types_forward(fwd, meg=True, eeg=False, seeg=False)
    _check_fwd_n_chan_consistent(fwd_, counts['meg'])
    fwd_ = pick_types_forward(fwd, meg=False, eeg=True, seeg=False)
    _check_fwd_n_chan_consistent(fwd_, counts['eeg'])
    # should raise exception related to emptiness
    assert_raises(ValueError, pick_types_forward, fwd, meg=False, eeg=False,
                  seeg=True)
    # change last chan from EEG to sEEG
    seeg_name = 'OTp1'
    rename_channels(fwd['info'], {'EEG 060': (seeg_name, 'seeg')})
    fwd['sol']['row_names'][-1] = fwd['info']['chs'][-1]['ch_name']
    counts['eeg'] -= 1
    counts['seeg'] += 1
    # repick & check
    fwd_seeg = pick_types_forward(fwd, meg=False, eeg=False, seeg=True)
    assert_equal(fwd_seeg['sol']['row_names'], [seeg_name])
    assert_equal(fwd_seeg['info']['ch_names'], [seeg_name])
    # should work fine
    fwd_ = pick_types_forward(fwd, meg=True, eeg=False, seeg=False)
    _check_fwd_n_chan_consistent(fwd_, counts['meg'])
    fwd_ = pick_types_forward(fwd, meg=False, eeg=True, seeg=False)
    _check_fwd_n_chan_consistent(fwd_, counts['eeg'])
    fwd_ = pick_types_forward(fwd, meg=False, eeg=False, seeg=True)
    _check_fwd_n_chan_consistent(fwd_, counts['seeg'])
Example #5
0
def test_pick_seeg_ecog():
    """Test picking with sEEG and ECoG
    """
    names = 'A1 A2 Fz O OTp1 OTp2 E1 OTp3 E2 E3'.split()
    types = 'mag mag eeg eeg seeg seeg ecog seeg ecog ecog'.split()
    info = create_info(names, 1024., types)
    idx = channel_indices_by_type(info)
    assert_array_equal(idx['mag'], [0, 1])
    assert_array_equal(idx['eeg'], [2, 3])
    assert_array_equal(idx['seeg'], [4, 5, 7])
    assert_array_equal(idx['ecog'], [6, 8, 9])
    assert_array_equal(pick_types(info, meg=False, seeg=True), [4, 5, 7])
    for i, t in enumerate(types):
        assert_equal(channel_type(info, i), types[i])
    raw = RawArray(np.zeros((len(names), 10)), info)
    events = np.array([[1, 0, 0], [2, 0, 0]])
    epochs = Epochs(raw, events, {'event': 0}, -1e-5, 1e-5)
    evoked = epochs.average(pick_types(epochs.info, meg=True, seeg=True))
    e_seeg = evoked.copy().pick_types(meg=False, seeg=True)
    for l, r in zip(e_seeg.ch_names, [names[4], names[5], names[7]]):
        assert_equal(l, r)
    # Deal with constant debacle
    raw = read_raw_fif(
        op.join(io_dir, 'tests', 'data', 'test_chpi_raw_sss.fif'))
    assert_equal(len(pick_types(raw.info, meg=False, seeg=True, ecog=True)), 0)
Example #6
0
def assert_indexing(info, picks_by_type, ref_meg=False, all_data=True):
    """Assert our indexing functions work properly."""
    # First that our old and new channel typing functions are equivalent
    _assert_channel_types(info)
    # Next that channel_indices_by_type works
    if not ref_meg:
        idx = channel_indices_by_type(info)
        for key in idx:
            for p in picks_by_type:
                if key == p[0]:
                    assert_array_equal(idx[key], p[1])
                    break
            else:
                assert len(idx[key]) == 0
    # Finally, picks_by_type (if relevant)
    if not all_data:
        picks_by_type = [p for p in picks_by_type
                         if p[0] in _DATA_CH_TYPES_SPLIT]
    picks_by_type = [(p[0], np.array(p[1], int)) for p in picks_by_type]
    actual = _picks_by_type(info, ref_meg=ref_meg)
    assert_object_equal(actual, picks_by_type)
    if not ref_meg and idx['hbo']:  # our old code had a bug
        with pytest.raises(TypeError, match='unexpected keyword argument'):
            _picks_by_type_old(info, ref_meg=ref_meg)
    else:
        old = _picks_by_type_old(info, ref_meg=ref_meg)
        assert_object_equal(old, picks_by_type)
    # test bads
    info = info.copy()
    info['bads'] = [info['chs'][picks_by_type[0][1][0]]['ch_name']]
    picks_by_type = deepcopy(picks_by_type)
    picks_by_type[0] = (picks_by_type[0][0], picks_by_type[0][1][1:])
    actual = _picks_by_type(info, ref_meg=ref_meg)
    assert_object_equal(actual, picks_by_type)
Example #7
0
def test_pick_forward_seeg():
    fwd = read_forward_solution(test_forward.fname_meeg)
    counts = channel_indices_by_type(fwd['info'])
    for key in counts.keys():
        counts[key] = len(counts[key])
    counts['meg'] = counts['mag'] + counts['grad']
    fwd_ = pick_types_forward(fwd, meg=True, eeg=False, seeg=False)
    _check_fwd_n_chan_consistent(fwd_, counts['meg'])
    fwd_ = pick_types_forward(fwd, meg=False, eeg=True, seeg=False)
    _check_fwd_n_chan_consistent(fwd_, counts['eeg'])
    # should raise exception related to emptiness
    assert_raises(ValueError, pick_types_forward, fwd, meg=False, eeg=False, 
                  seeg=True)
    # change last chan from EEG to sEEG
    seeg_name = 'OTp1'
    rename_channels(fwd['info'], {'EEG 060': (seeg_name, 'seeg')})
    fwd['sol']['row_names'][-1] = fwd['info']['chs'][-1]['ch_name']
    counts['eeg'] -= 1
    counts['seeg'] += 1
    # repick & check
    fwd_seeg = pick_types_forward(fwd, meg=False, eeg=False, seeg=True)
    assert_equal(fwd_seeg['sol']['row_names'], [seeg_name])
    assert_equal(fwd_seeg['info']['ch_names'], [seeg_name])
    # should work fine
    fwd_ = pick_types_forward(fwd, meg=True, eeg=False, seeg=False)
    _check_fwd_n_chan_consistent(fwd_, counts['meg'])
    fwd_ = pick_types_forward(fwd, meg=False, eeg=True, seeg=False)
    _check_fwd_n_chan_consistent(fwd_, counts['eeg'])
    fwd_ = pick_types_forward(fwd, meg=False, eeg=False, seeg=True)
    _check_fwd_n_chan_consistent(fwd_, counts['seeg'])
def test_pick_bio():
    """Test picking BIO channels."""
    names = 'A1 A2 Fz O BIO1 BIO2 BIO3'.split()
    types = 'mag mag eeg eeg bio bio bio'.split()
    info = create_info(names, 1024., types)
    idx = channel_indices_by_type(info)
    assert_array_equal(idx['mag'], [0, 1])
    assert_array_equal(idx['eeg'], [2, 3])
    assert_array_equal(idx['bio'], [4, 5, 6])
def test_pick_fnirs():
    """Test picking fNIRS channels."""
    names = 'A1 A2 Fz O hbo1 hbo2 hbr1'.split()
    types = 'mag mag eeg eeg hbo hbo hbr'.split()
    info = create_info(names, 1024., types)
    idx = channel_indices_by_type(info)
    assert_array_equal(idx['mag'], [0, 1])
    assert_array_equal(idx['eeg'], [2, 3])
    assert_array_equal(idx['hbo'], [4, 5])
    assert_array_equal(idx['hbr'], [6])
Example #10
0
def test_plot_tfr_topomap():
    """Test plotting of TFR data."""
    import matplotlib as mpl
    import matplotlib.pyplot as plt
    raw = read_raw_fif(raw_fname)
    times = np.linspace(-0.1, 0.1, 200)
    res = 8
    n_freqs = 3
    nave = 1
    rng = np.random.RandomState(42)
    picks = [93, 94, 96, 97, 21, 22, 24, 25, 129, 130, 315, 316, 2, 5, 8, 11]
    info = pick_info(raw.info, picks)
    data = rng.randn(len(picks), n_freqs, len(times))
    tfr = AverageTFR(info, data, times, np.arange(n_freqs), nave)
    tfr.plot_topomap(ch_type='mag',
                     tmin=0.05,
                     tmax=0.150,
                     fmin=0,
                     fmax=10,
                     res=res,
                     contours=0)

    eclick = mpl.backend_bases.MouseEvent('button_press_event',
                                          plt.gcf().canvas, 0, 0, 1)
    eclick.xdata = eclick.ydata = 0.1
    eclick.inaxes = plt.gca()
    erelease = mpl.backend_bases.MouseEvent('button_release_event',
                                            plt.gcf().canvas, 0.9, 0.9, 1)
    erelease.xdata = 0.3
    erelease.ydata = 0.2
    pos = [[0.11, 0.11], [0.25, 0.5], [0.0, 0.2], [0.2, 0.39]]
    _onselect(eclick, erelease, tfr, pos, 'grad', 1, 3, 1, 3, 'RdBu_r', list())
    _onselect(eclick, erelease, tfr, pos, 'mag', 1, 3, 1, 3, 'RdBu_r', list())
    eclick.xdata = eclick.ydata = 0.
    erelease.xdata = erelease.ydata = 0.9
    tfr._onselect(eclick, erelease, None, 'mean', None)
    plt.close('all')

    # test plot_psds_topomap
    info = raw.info.copy()
    chan_inds = channel_indices_by_type(info)
    info = pick_info(info, chan_inds['grad'][:4])

    fig, axes = plt.subplots()
    freqs = np.arange(3., 9.5)
    bands = [(4, 8, 'Theta')]
    psd = np.random.rand(len(info['ch_names']), freqs.shape[0])
    plot_psds_topomap(psd, freqs, info, bands=bands, axes=[axes])
Example #11
0
def test_load_generator(fname, recwarn):
    """Test IO of annotations from edf and bdf files with raw info."""
    raw = read_raw_edf(fname)
    assert len(raw.annotations.onset) == 2
    found_types = [k for k, v in
                   channel_indices_by_type(raw.info, picks=None).items()
                   if v]
    assert len(found_types) == 1
    events, event_id = events_from_annotations(raw)
    ch_names = ['squarewave', 'ramp', 'pulse', 'ECG', 'noise', 'sine 1 Hz',
                'sine 8 Hz', 'sine 8.5 Hz', 'sine 15 Hz', 'sine 17 Hz',
                'sine 50 Hz']
    assert raw.get_data().shape == (11, 120000)
    assert raw.ch_names == ch_names
    assert event_id == {'RECORD START': 1, 'REC STOP': 2}
    assert_array_equal(events, [[0, 0, 1], [120000, 0, 2]])
Example #12
0
def test_load_generator(fname, recwarn):
    """Test IO of annotations from edf and bdf files with raw info."""
    raw = read_raw_edf(fname)
    assert len(raw.annotations.onset) == 2
    found_types = [k for k, v in
                   channel_indices_by_type(raw.info, picks=None).items()
                   if v]
    assert len(found_types) == 1
    events, event_id = events_from_annotations(raw)
    ch_names = ['squarewave', 'ramp', 'pulse', 'ECG', 'noise', 'sine 1 Hz',
                'sine 8 Hz', 'sine 8.5 Hz', 'sine 15 Hz', 'sine 17 Hz',
                'sine 50 Hz']
    assert raw.get_data().shape == (11, 120000)
    assert raw.ch_names == ch_names
    assert event_id == {'RECORD START': 1, 'REC STOP': 2}
    assert_array_equal(events, [[0, 0, 1], [120000, 0, 2]])
Example #13
0
def _get_channel_type(epochs):
    idx = channel_indices_by_type(epochs.info)
    invalid_ch_types_present = [
        key for key in idx.keys()
        if key not in ['mag', 'grad', 'eeg'] and key in epochs
    ]
    if len(invalid_ch_types_present) > 0:
        raise ValueError('Invalid channel types present in epochs.'
                         ' Expected ONLY `meg` or ONLY `eeg`. Got %s' %
                         ', '.join(invalid_ch_types_present))
    if 'meg' in epochs and 'eeg' in epochs:
        raise ValueError('Got mixed channel types. Pick either eeg or meg'
                         ' but not both')
    if 'eeg' in epochs:
        return 'eeg'
    elif 'meg' in epochs:
        return 'meg'
Example #14
0
def test_pick_forward_seeg_ecog():
    """Test picking forward with SEEG and ECoG
    """
    fwd = read_forward_solution(fname_meeg)
    counts = channel_indices_by_type(fwd['info'])
    for key in counts.keys():
        counts[key] = len(counts[key])
    counts['meg'] = counts['mag'] + counts['grad']
    fwd_ = pick_types_forward(fwd, meg=True)
    _check_fwd_n_chan_consistent(fwd_, counts['meg'])
    fwd_ = pick_types_forward(fwd, meg=False, eeg=True)
    _check_fwd_n_chan_consistent(fwd_, counts['eeg'])
    # should raise exception related to emptiness
    assert_raises(ValueError, pick_types_forward, fwd, meg=False, seeg=True)
    assert_raises(ValueError, pick_types_forward, fwd, meg=False, ecog=True)
    # change last chan from EEG to sEEG, second-to-last to ECoG
    ecog_name = 'E1'
    seeg_name = 'OTp1'
    rename_channels(fwd['info'], {'EEG 059': ecog_name})
    rename_channels(fwd['info'], {'EEG 060': seeg_name})
    for ch in fwd['info']['chs']:
        if ch['ch_name'] == seeg_name:
            ch['kind'] = FIFF.FIFFV_SEEG_CH
            ch['coil_type'] = FIFF.FIFFV_COIL_EEG
        elif ch['ch_name'] == ecog_name:
            ch['kind'] = FIFF.FIFFV_ECOG_CH
            ch['coil_type'] = FIFF.FIFFV_COIL_EEG
    fwd['sol']['row_names'][-1] = fwd['info']['chs'][-1]['ch_name']
    fwd['sol']['row_names'][-2] = fwd['info']['chs'][-2]['ch_name']
    counts['eeg'] -= 2
    counts['seeg'] += 1
    counts['ecog'] += 1
    # repick & check
    fwd_seeg = pick_types_forward(fwd, meg=False, seeg=True)
    assert_equal(fwd_seeg['sol']['row_names'], [seeg_name])
    assert_equal(fwd_seeg['info']['ch_names'], [seeg_name])
    # should work fine
    fwd_ = pick_types_forward(fwd, meg=True)
    _check_fwd_n_chan_consistent(fwd_, counts['meg'])
    fwd_ = pick_types_forward(fwd, meg=False, eeg=True)
    _check_fwd_n_chan_consistent(fwd_, counts['eeg'])
    fwd_ = pick_types_forward(fwd, meg=False, seeg=True)
    _check_fwd_n_chan_consistent(fwd_, counts['seeg'])
    fwd_ = pick_types_forward(fwd, meg=False, ecog=True)
    _check_fwd_n_chan_consistent(fwd_, counts['ecog'])
Example #15
0
def test_pick_forward_seeg_ecog():
    """Test picking forward with SEEG and ECoG
    """
    fwd = read_forward_solution(fname_meeg)
    counts = channel_indices_by_type(fwd['info'])
    for key in counts.keys():
        counts[key] = len(counts[key])
    counts['meg'] = counts['mag'] + counts['grad']
    fwd_ = pick_types_forward(fwd, meg=True)
    _check_fwd_n_chan_consistent(fwd_, counts['meg'])
    fwd_ = pick_types_forward(fwd, meg=False, eeg=True)
    _check_fwd_n_chan_consistent(fwd_, counts['eeg'])
    # should raise exception related to emptiness
    assert_raises(ValueError, pick_types_forward, fwd, meg=False, seeg=True)
    assert_raises(ValueError, pick_types_forward, fwd, meg=False, ecog=True)
    # change last chan from EEG to sEEG, second-to-last to ECoG
    ecog_name = 'E1'
    seeg_name = 'OTp1'
    rename_channels(fwd['info'], {'EEG 059': ecog_name})
    rename_channels(fwd['info'], {'EEG 060': seeg_name})
    for ch in fwd['info']['chs']:
        if ch['ch_name'] == seeg_name:
            ch['kind'] = FIFF.FIFFV_SEEG_CH
            ch['coil_type'] = FIFF.FIFFV_COIL_EEG
        elif ch['ch_name'] == ecog_name:
            ch['kind'] = FIFF.FIFFV_ECOG_CH
            ch['coil_type'] = FIFF.FIFFV_COIL_EEG
    fwd['sol']['row_names'][-1] = fwd['info']['chs'][-1]['ch_name']
    fwd['sol']['row_names'][-2] = fwd['info']['chs'][-2]['ch_name']
    counts['eeg'] -= 2
    counts['seeg'] += 1
    counts['ecog'] += 1
    # repick & check
    fwd_seeg = pick_types_forward(fwd, meg=False, seeg=True)
    assert_equal(fwd_seeg['sol']['row_names'], [seeg_name])
    assert_equal(fwd_seeg['info']['ch_names'], [seeg_name])
    # should work fine
    fwd_ = pick_types_forward(fwd, meg=True)
    _check_fwd_n_chan_consistent(fwd_, counts['meg'])
    fwd_ = pick_types_forward(fwd, meg=False, eeg=True)
    _check_fwd_n_chan_consistent(fwd_, counts['eeg'])
    fwd_ = pick_types_forward(fwd, meg=False, seeg=True)
    _check_fwd_n_chan_consistent(fwd_, counts['seeg'])
    fwd_ = pick_types_forward(fwd, meg=False, ecog=True)
    _check_fwd_n_chan_consistent(fwd_, counts['ecog'])
Example #16
0
def test_pick_seeg():
    names = 'A1 A2 Fz O OTp1 OTp2 OTp3'.split()
    types = 'mag mag eeg eeg seeg seeg seeg'.split()
    info = create_info(names, 1024., types)
    idx = channel_indices_by_type(info)
    assert_array_equal(idx['mag'], [0, 1])
    assert_array_equal(idx['eeg'], [2, 3])
    assert_array_equal(idx['seeg'], [4, 5, 6])
    assert_array_equal(pick_types(info, meg=False, seeg=True), [4, 5, 6])
    for i, t in enumerate(types):
        assert_equal(channel_type(info, i), types[i])
    raw = RawArray(zeros((len(names), 10)), info)
    events = array([[1, 0, 0], [2, 0, 0]]).astype('d')
    epochs = Epochs(raw, events, {'event': 0}, -1e-5, 1e-5)
    evoked = epochs.average(pick_types(epochs.info, meg=True, seeg=True))
    e_seeg = pick_types_evoked(evoked, meg=False, seeg=True)
    for l, r in zip(e_seeg.ch_names, names[4:]):
        assert_equal(l, r)
Example #17
0
def test_pick_seeg():
    names = 'A1 A2 Fz O OTp1 OTp2 OTp3'.split()
    types = 'mag mag eeg eeg seeg seeg seeg'.split()
    info = create_info(names, 1024., types)
    idx = channel_indices_by_type(info)
    assert_array_equal(idx['mag'], [0, 1])
    assert_array_equal(idx['eeg'], [2, 3])
    assert_array_equal(idx['seeg'], [4, 5, 6])
    assert_array_equal(pick_types(info, meg=False, seeg=True), [4, 5, 6])
    for i, t in enumerate(types):
        assert_equal(channel_type(info, i), types[i])
    raw = RawArray(zeros((len(names), 10)), info)
    events = array([[1, 0, 0], [2, 0, 0]]).astype('d')
    epochs = Epochs(raw, events, {'event': 0}, -1e-5, 1e-5)
    evoked = epochs.average(pick_types(epochs.info, meg=True, seeg=True))
    e_seeg = pick_types_evoked(evoked, meg=False, seeg=True)
    for l, r in zip(e_seeg.ch_names, names[4:]):
        assert_equal(l, r)
Example #18
0
def test_plot_tfr_topomap():
    """Test plotting of TFR data."""
    import matplotlib as mpl
    import matplotlib.pyplot as plt
    raw = read_raw_fif(raw_fname)
    times = np.linspace(-0.1, 0.1, 200)
    res = 8
    n_freqs = 3
    nave = 1
    rng = np.random.RandomState(42)
    picks = [93, 94, 96, 97, 21, 22, 24, 25, 129, 130, 315, 316, 2, 5, 8, 11]
    info = pick_info(raw.info, picks)
    data = rng.randn(len(picks), n_freqs, len(times))
    tfr = AverageTFR(info, data, times, np.arange(n_freqs), nave)
    tfr.plot_topomap(ch_type='mag', tmin=0.05, tmax=0.150, fmin=0, fmax=10,
                     res=res, contours=0)

    eclick = mpl.backend_bases.MouseEvent('button_press_event',
                                          plt.gcf().canvas, 0, 0, 1)
    eclick.xdata = eclick.ydata = 0.1
    eclick.inaxes = plt.gca()
    erelease = mpl.backend_bases.MouseEvent('button_release_event',
                                            plt.gcf().canvas, 0.9, 0.9, 1)
    erelease.xdata = 0.3
    erelease.ydata = 0.2
    pos = [[0.11, 0.11], [0.25, 0.5], [0.0, 0.2], [0.2, 0.39]]
    _onselect(eclick, erelease, tfr, pos, 'grad', 1, 3, 1, 3, 'RdBu_r', list())
    _onselect(eclick, erelease, tfr, pos, 'mag', 1, 3, 1, 3, 'RdBu_r', list())
    eclick.xdata = eclick.ydata = 0.
    erelease.xdata = erelease.ydata = 0.9
    tfr._onselect(eclick, erelease, None, 'mean', None)
    plt.close('all')

    # test plot_psds_topomap
    info = raw.info.copy()
    chan_inds = channel_indices_by_type(info)
    info = pick_info(info, chan_inds['grad'][:4])

    fig, axes = plt.subplots()
    freqs = np.arange(3., 9.5)
    bands = [(4, 8, 'Theta')]
    psd = np.random.rand(len(info['ch_names']), freqs.shape[0])
    plot_psds_topomap(psd, freqs, info, bands=bands, axes=[axes])
Example #19
0
def test_pick_seeg():
    """Test picking with SEEG
    """
    names = "A1 A2 Fz O OTp1 OTp2 OTp3".split()
    types = "mag mag eeg eeg seeg seeg seeg".split()
    info = create_info(names, 1024.0, types)
    idx = channel_indices_by_type(info)
    assert_array_equal(idx["mag"], [0, 1])
    assert_array_equal(idx["eeg"], [2, 3])
    assert_array_equal(idx["seeg"], [4, 5, 6])
    assert_array_equal(pick_types(info, meg=False, seeg=True), [4, 5, 6])
    for i, t in enumerate(types):
        assert_equal(channel_type(info, i), types[i])
    raw = RawArray(np.zeros((len(names), 10)), info)
    events = np.array([[1, 0, 0], [2, 0, 0]]).astype("d")
    epochs = Epochs(raw, events, {"event": 0}, -1e-5, 1e-5)
    evoked = epochs.average(pick_types(epochs.info, meg=True, seeg=True))
    e_seeg = evoked.pick_types(meg=False, seeg=True, copy=True)
    for l, r in zip(e_seeg.ch_names, names[4:]):
        assert_equal(l, r)
Example #20
0
def write_info(fname, info, overwrite=False):
    """Save Info object to ``.hdf5`` file.

    Parameters
    ----------
    fname : str
        Name of the file.
    info : mne.Info
        Info object to save.
    """
    from .channels import get_ch_pos
    from mne.utils import _validate_type
    from mne.externals import h5io
    from mne.io.pick import channel_indices_by_type

    # make sure the types are correct
    _validate_type(fname, 'str', item_name='fname')
    _validate_type(info, 'info', item_name='info')

    # extract type info
    tps = channel_indices_by_type(info)

    # remove empty dict keys
    for k in list(tps.keys()):
        if len(tps[k]) == 0:
            tps.pop(k)

    has_types = list(tps.keys())
    ch_type = has_types[0] if len(has_types) == 1 else tps

    # save to .hdf5
    data_dict = {
        'ch_names': info['ch_names'],
        'sfreq': info['sfreq'],
        'ch_type': ch_type,
        'pos': get_ch_pos(info)
    }
    h5io.write_hdf5(fname, data_dict, overwrite=overwrite)
Example #21
0
def test_pick_seeg():
    """Test picking with SEEG
    """
    names = 'A1 A2 Fz O OTp1 OTp2 OTp3'.split()
    types = 'mag mag eeg eeg seeg seeg seeg'.split()
    info = create_info(names, 1024., types)
    idx = channel_indices_by_type(info)
    assert_array_equal(idx['mag'], [0, 1])
    assert_array_equal(idx['eeg'], [2, 3])
    assert_array_equal(idx['seeg'], [4, 5, 6])
    assert_array_equal(pick_types(info, meg=False, seeg=True), [4, 5, 6])
    for i, t in enumerate(types):
        assert_equal(channel_type(info, i), types[i])
    raw = RawArray(np.zeros((len(names), 10)), info)
    events = np.array([[1, 0, 0], [2, 0, 0]])
    epochs = Epochs(raw, events, {'event': 0}, -1e-5, 1e-5)
    evoked = epochs.average(pick_types(epochs.info, meg=True, seeg=True))
    e_seeg = evoked.pick_types(meg=False, seeg=True, copy=True)
    for l, r in zip(e_seeg.ch_names, names[4:]):
        assert_equal(l, r)
    # Deal with constant debacle
    raw = Raw(fname_mc)
    assert_equal(len(pick_types(raw.info, meg=False, seeg=True)), 0)
Example #22
0
def noise_reducer(fname_raw, raw=None, signals=[], noiseref=[], detrending=None,
                  tmin=None, tmax=None, reflp=None, refhp=None, refnotch=None,
                  exclude_artifacts=True, checkresults=True, return_raw=False,
                  complementary_signal=False, fnout=None, verbose=False):

    """Apply noise reduction to signal channels using reference channels.

    Parameters
    ----------
    fname_raw : (list of) rawfile names
    raw : mne Raw objects
        Allows passing of raw object as well.
    signals : list of string
              List of channels to compensate using noiseref.
              If empty use the meg signal channels.
    noiseref : list of string | str
              List of channels to use as noise reference.
              If empty use the magnetic reference channsls (default).
    signals and noiseref may contain regexp, which are resolved
    using mne.pick_channels_regexp(). All other channels are copied.
    tmin : lower latency bound for weight-calc [start of trace]
    tmax : upper latency bound for weight-calc [ end  of trace]
           Weights are calc'd for (tmin,tmax), but applied to entire data set
    refhp : high-pass frequency for reference signal filter [None]
    reflp :  low-pass frequency for reference signal filter [None]
            reflp < refhp: band-stop filter
            reflp > refhp: band-pass filter
            reflp is not None, refhp is None: low-pass filter
            reflp is None, refhp is not None: high-pass filter
    refnotch : (base) notch frequency for reference signal filter [None]
               use raw(ref)-notched(ref) as reference signal
    exclude_artifacts: filter signal-channels thru _is_good() [True]
                       (parameters are at present hard-coded!)
    return_raw : bool
        If return_raw is true, the raw object is returned and raw file
        is not written to disk. It is suggested that this option be used in cases
        where the noise_reducer is applied multiple times. [False]
    complementary_signal : replaced signal by traces that would be subtracted [False]
                           (can be useful for debugging)
    detrending: boolean to ctrl subtraction of linear trend from all magn. chans [False]
    checkresults : boolean to control internal checks and overall success [True]

    Outputfile
    ----------
    <wawa>,nr-raw.fif for input <wawa>-raw.fif

    Returns
    -------
    If return_raw is True, then mne.io.Raw instance is returned.

    Bugs
    ----
    - artifact checking is incomplete (and with arb. window of tstep=0.2s)
    - no accounting of channels used as signal/reference
    - non existing input file handled ungracefully
    """

    if type(complementary_signal) != bool:
        raise ValueError("Argument complementary_signal must be of type bool")

    # handle error if Raw object passed with file list
    if raw and isinstance(fname_raw, list):
        raise ValueError('List of file names cannot be combined with one Raw object')

    # handle error if return_raw is requested with file list
    if return_raw and isinstance(fname_raw, list):
        raise ValueError('List of file names cannot be combined return_raw.'
                         'Please pass one file at a time.')

    # handle error if Raw object is passed with detrending option
    #TODO include perform_detrending for Raw objects
    if raw and detrending:
        raise ValueError('Please perform detrending on the raw file directly. Cannot perform'
                         'detrending on the raw object')

    fnraw = get_files_from_list(fname_raw)

    # loop across all filenames
    for fname in fnraw:

        if verbose:
            print "########## Read raw data:"

        tc0 = time.clock()
        tw0 = time.time()

        if raw is None:
            if detrending:
                raw = perform_detrending(fname, save=False)
            else:
                raw = mne.io.Raw(fname, preload=True)
        else:
            # perform sanity check to make sure Raw object and file are same
            if os.path.basename(fname) != os.path.basename(raw.info['filename']):
                warnings.warn('The file name within the Raw object and provided'
                              'fname are not the same. Please check again.')

        tc1 = time.clock()
        tw1 = time.time()

        if verbose:
            print ">>> loading raw data took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tc0), (tw1 - tw0))

        # Time window selection
        # weights are calc'd based on [tmin,tmax], but applied to the entire data set.
        # tstep is used in artifact detection
        # tmin,tmax variables must not be changed here!
        if tmin is None:
            itmin = 0
        else:
            itmin = int(floor(tmin * raw.info['sfreq']))
        if tmax is None:
            itmax = raw.last_samp
        else:
            itmax = int(ceil(tmax * raw.info['sfreq']))

        if itmax - itmin < 2:
            raise ValueError("Time-window for noise compensation empty or too short")

        if verbose:
            print ">>> Set time-range to [%7.3f,%7.3f]" % \
                  (raw.times[itmin], raw.times[itmax])

        if signals is None or len(signals) == 0:
            sigpick = mne.pick_types(raw.info, meg='mag', eeg=False, stim=False,
                                     eog=False, exclude='bads')
        else:
            sigpick = channel_indices_from_list(raw.info['ch_names'][:], signals,
                                                raw.info.get('bads'))
        nsig = len(sigpick)
        if nsig == 0:
            raise ValueError("No channel selected for noise compensation")

        if noiseref is None or len(noiseref) == 0:
            # References are not limited to 4D ref-chans, but can be anything,
            # incl. ECG or powerline monitor.
            if verbose:
                print ">>> Using all refchans."
            refexclude = "bads"
            refpick = mne.pick_types(raw.info, ref_meg=True, meg=False, eeg=False,
                                     stim=False, eog=False, exclude='bads')
        else:
            refpick = channel_indices_from_list(raw.info['ch_names'][:], noiseref,
                                                raw.info.get('bads'))
        nref = len(refpick)
        if nref == 0:
            raise ValueError("No channel selected as noise reference")

        if verbose:
            print ">>> sigpick: %3d chans, refpick: %3d chans" % (nsig, nref)

        if reflp is None and refhp is None and refnotch is None:
            use_reffilter = False
            use_refantinotch = False
        else:
            use_reffilter = True
            if verbose:
                print "########## Filter reference channels:"

            use_refantinotch = False
            if refnotch is not None:
                if reflp is None and reflp is None:
                    use_refantinotch = True
                    freqlast = np.min([5.01 * refnotch, 0.5 * raw.info['sfreq']])
                    if verbose:
                        print ">>> notches at freq %.1f and harmonics below %.1f" % (refnotch, freqlast)
                else:
                    raise ValueError("Cannot specify notch- and high-/low-pass"
                                     "reference filter together")
            else:
                if verbose:
                    if reflp is not None:
                        print ">>>  low-pass with cutoff-freq %.1f" % reflp
                    if refhp is not None:
                        print ">>> high-pass with cutoff-freq %.1f" % refhp

            # Adapt followg drop-chans cmd to use 'all-but-refpick'
            droplist = [raw.info['ch_names'][k] for k in xrange(raw.info['nchan']) if not k in refpick]
            tct = time.clock()
            twt = time.time()
            fltref = raw.copy().drop_channels(droplist)
            if use_refantinotch:
                rawref = raw.copy().drop_channels(droplist)
                freqlast = np.min([5.01 * refnotch, 0.5 * raw.info['sfreq']])
                fltref.notch_filter(np.arange(refnotch, freqlast, refnotch),
                                    picks=np.array(xrange(nref)), method='iir')
                fltref._data = (rawref._data - fltref._data)
            else:
                fltref.filter(refhp, reflp, picks=np.array(xrange(nref)), method='iir')
            tc1 = time.clock()
            tw1 = time.time()
            if verbose:
                print ">>> filtering ref-chans  took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tct), (tw1 - twt))

        if verbose:
            print "########## Calculating sig-ref/ref-ref-channel covariances:"
        # Calculate sig-ref/ref-ref-channel covariance:
        # (there is no need to calc inter-signal-chan cov,
        #  but there seems to be no appropriat fct available)
        # Here we copy the idea from compute_raw_data_covariance()
        # and truncate it as appropriate.
        tct = time.clock()
        twt = time.time()
        # The following reject and infosig entries are only
        # used in _is_good-calls.
        # _is_good() from mne-0.9.git-py2.7.egg/mne/epochs.py seems to
        # ignore ref-channels (not covered by dict) and checks individual
        # data segments - artifacts across a buffer boundary are not found.
        reject = dict(grad=4000e-13, # T / m (gradiometers)
                      mag=4e-12,     # T (magnetometers)
                      eeg=40e-6,     # uV (EEG channels)
                      eog=250e-6)    # uV (EOG channels)

        infosig = copy.copy(raw.info)
        infosig['chs'] = [raw.info['chs'][k] for k in sigpick]
        infosig['ch_names'] = [raw.info['ch_names'][k] for k in sigpick]
        infosig['nchan'] = len(sigpick)
        idx_by_typesig = channel_indices_by_type(infosig)

        # Read data in chunks:
        tstep = 0.2
        itstep = int(ceil(tstep * raw.info['sfreq']))
        sigmean = 0
        refmean = 0
        sscovdata = 0
        srcovdata = 0
        rrcovdata = 0
        n_samples = 0

        for first in range(itmin, itmax, itstep):
            last = first + itstep
            if last >= itmax:
                last = itmax
            raw_segmentsig, times = raw[sigpick, first:last]
            if use_reffilter:
                raw_segmentref, times = fltref[:, first:last]
            else:
                raw_segmentref, times = raw[refpick, first:last]

            if not exclude_artifacts or \
               _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject, flat=None,
                        ignore_chs=raw.info['bads']):
                sigmean += raw_segmentsig.sum(axis=1)
                refmean += raw_segmentref.sum(axis=1)
                sscovdata += (raw_segmentsig * raw_segmentsig).sum(axis=1)
                srcovdata += np.dot(raw_segmentsig, raw_segmentref.T)
                rrcovdata += np.dot(raw_segmentref, raw_segmentref.T)
                n_samples += raw_segmentsig.shape[1]
            else:
                logger.info("Artefact detected in [%d, %d]" % (first, last))
        if n_samples <= 1:
            raise ValueError('Too few samples to calculate weights')
        sigmean /= n_samples
        refmean /= n_samples
        sscovdata -= n_samples * sigmean[:] * sigmean[:]
        sscovdata /= (n_samples - 1)
        srcovdata -= n_samples * sigmean[:, None] * refmean[None, :]
        srcovdata /= (n_samples - 1)
        rrcovdata -= n_samples * refmean[:, None] * refmean[None, :]
        rrcovdata /= (n_samples - 1)
        sscovinit = np.copy(sscovdata)
        if verbose:
            print ">>> Normalize srcov..."

        rrslope = copy.copy(rrcovdata)
        for iref in xrange(nref):
            dtmp = rrcovdata[iref, iref]
            if dtmp > TINY:
                srcovdata[:, iref] /= dtmp
                rrslope[:, iref] /= dtmp
            else:
                srcovdata[:, iref] = 0.
                rrslope[:, iref] = 0.

        if verbose:
            print ">>> Number of samples used : %d" % n_samples
            tc1 = time.clock()
            tw1 = time.time()
            print ">>> sigrefchn covar-calc took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tct), (tw1 - twt))

        if checkresults:
            if verbose:
                print "########## Calculated initial signal channel covariance:"
                # Calculate initial signal channel covariance:
                # (only used as quality measure)
                print ">>> initl rt(avg sig pwr) = %12.5e" % np.sqrt(np.mean(sscovdata))
                for i in xrange(5):
                    print ">>> initl signal-rms[%3d] = %12.5e" % (i, np.sqrt(sscovdata.flatten()[i]))
                print ">>>"

        U, s, V = np.linalg.svd(rrslope, full_matrices=True)
        if verbose:
            print ">>> singular values:"
            print s
            print ">>> Applying cutoff for smallest SVs:"

        dtmp = s.max() * SVD_RELCUTOFF
        s *= (abs(s) >= dtmp)
        sinv = [1. / s[k] if s[k] != 0. else 0. for k in xrange(nref)]
        if verbose:
            print ">>> singular values (after cutoff):"
            print s

        stat = np.allclose(rrslope, np.dot(U, np.dot(np.diag(s), V)))
        if verbose:
            print ">>> Testing svd-result: %s" % stat
            if not stat:
                print "    (Maybe due to SV-cutoff?)"

        # Solve for inverse coefficients:
        # Set RRinv.tr=U diag(sinv) V
        RRinv = np.transpose(np.dot(U, np.dot(np.diag(sinv), V)))
        if checkresults:
            stat = np.allclose(np.identity(nref), np.dot(RRinv, rrslope))
            if stat:
                if verbose:
                    print ">>> Testing RRinv-result (should be unit-matrix): ok"
            else:
                print ">>> Testing RRinv-result (should be unit-matrix): failed"
                print np.transpose(np.dot(RRinv, rrslope))
                print ">>>"

        if verbose:
            print "########## Calc weight matrix..."

        # weights-matrix will be somewhat larger than necessary,
        # (to simplify indexing in compensation loop):
        weights = np.zeros((raw._data.shape[0], nref))
        for isig in xrange(nsig):
            for iref in xrange(nref):
                weights[sigpick[isig],iref] = np.dot(srcovdata[isig,:], RRinv[:,iref])

        if verbose:
            print "########## Compensating signal channels:"
            if complementary_signal:
                print ">>> Caveat: REPLACING signal by compensation signal"

        tct = time.clock()
        twt = time.time()

        # Work on entire data stream:
        for isl in xrange(raw._data.shape[1]):
            slice = np.take(raw._data, [isl], axis=1)
            if use_reffilter:
                refslice = np.take(fltref._data, [isl], axis=1)
                refarr = refslice[:].flatten() - refmean
                # refarr = fltres[:,isl]-refmean
            else:
                refarr = slice[refpick].flatten() - refmean
            subrefarr = np.dot(weights[:], refarr)

            if not complementary_signal:
                raw._data[:, isl] -= subrefarr
            else:
                raw._data[:, isl] = subrefarr

            if (isl % 10000 == 0) and verbose:
                print "\rProcessed slice %6d" % isl

        if verbose:
            print "\nDone."
            tc1 = time.clock()
            tw1 = time.time()
            print ">>> compensation loop took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tct), (tw1 - twt))

        if checkresults:
            if verbose:
                print "########## Calculating final signal channel covariance:"
            # Calculate final signal channel covariance:
            # (only used as quality measure)
            tct = time.clock()
            twt = time.time()
            sigmean = 0
            sscovdata = 0
            n_samples = 0
            for first in range(itmin, itmax, itstep):
                last = first + itstep
                if last >= itmax:
                    last = itmax
                raw_segmentsig, times = raw[sigpick, first:last]
                # Artifacts found here will probably differ from pre-noisered artifacts!
                if not exclude_artifacts or \
                   _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject,
                            flat=None, ignore_chs=raw.info['bads']):
                    sigmean += raw_segmentsig.sum(axis=1)
                    sscovdata += (raw_segmentsig * raw_segmentsig).sum(axis=1)
                    n_samples += raw_segmentsig.shape[1]
            sigmean /= n_samples
            sscovdata -= n_samples * sigmean[:] * sigmean[:]
            sscovdata /= (n_samples - 1)
            if verbose:
                print ">>> no channel got worse: ", np.all(np.less_equal(sscovdata, sscovinit))
                print ">>> final rt(avg sig pwr) = %12.5e" % np.sqrt(np.mean(sscovdata))
                for i in xrange(5):
                    print ">>> final signal-rms[%3d] = %12.5e" % (i, np.sqrt(sscovdata.flatten()[i]))
                tc1 = time.clock()
                tw1 = time.time()
                print ">>> signal covar-calc took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tct), (tw1 - twt))
                print ">>>"

        if fnout is not None:
            fnoutloc = fnout
        else:
            fnoutloc = fname[:fname.rfind('-raw.fif')] + ',nr-raw.fif'

        if verbose:
            print ">>> Saving '%s'..." % fnoutloc

        if return_raw:
            return raw
        else:
            raw.save(fnoutloc, overwrite=True)

        tc1 = time.clock()
        tw1 = time.time()
        if verbose:
            print ">>> Total run took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tc0), (tw1 - tw0))
def test_noise_reducer():

    data_path = os.environ['SUBJECTS_DIR']
    subject = os.environ['SUBJECT']

    dname = data_path + '/' + 'empty_room_files' + '/109925_empty_room_file-raw.fif'
    subjects_dir = data_path + '/subjects'
    #
    checkresults = True
    exclart = False
    use_reffilter = True
    refflt_lpfreq = 52.
    refflt_hpfreq = 48.

    print "########## before of noisereducer call ##########"
    sigchanlist = ['MEG ..1', 'MEG ..3', 'MEG ..5', 'MEG ..7', 'MEG ..9']
    sigchanlist = None
    refchanlist = ['RFM 001', 'RFM 003', 'RFM 005', 'RFG ...']
    tmin = 15.
    noise_reducer(dname,
                  signals=sigchanlist,
                  noiseref=refchanlist,
                  tmin=tmin,
                  reflp=refflt_lpfreq,
                  refhp=refflt_hpfreq,
                  exclude_artifacts=exclart,
                  complementary_signal=True)
    print "########## behind of noisereducer call ##########"

    print "########## Read raw data:"
    tc0 = time.clock()
    tw0 = time.time()
    raw = mne.io.Raw(dname, preload=True)
    tc1 = time.clock()
    tw1 = time.time()
    print "loading raw data  took %.1f ms (%.2f s walltime)" % (1000. *
                                                                (tc1 - tc0),
                                                                (tw1 - tw0))

    # Time window selection
    # weights are calc'd based on [tmin,tmax], but applied to the entire data set.
    # tstep is used in artifact detection
    tmax = raw.index_as_time(raw.last_samp)[0]
    tstep = 0.2
    itmin = int(floor(tmin * raw.info['sfreq']))
    itmax = int(ceil(tmax * raw.info['sfreq']))
    itstep = int(ceil(tstep * raw.info['sfreq']))
    print ">>> Set time-range to [%7.3f,%7.3f]" % (tmin, tmax)

    if sigchanlist is None:
        sigpick = mne.pick_types(raw.info,
                                 meg='mag',
                                 eeg=False,
                                 stim=False,
                                 eog=False,
                                 exclude='bads')
    else:
        sigpick = channel_indices_from_list(raw.info['ch_names'][:],
                                            sigchanlist)
    nsig = len(sigpick)
    print "sigpick: %3d chans" % nsig
    if nsig == 0:
        raise ValueError("No channel selected for noise compensation")

    if refchanlist is None:
        # References are not limited to 4D ref-chans, but can be anything,
        # incl. ECG or powerline monitor.
        print ">>> Using all refchans."
        refexclude = "bads"
        refpick = mne.pick_types(raw.info,
                                 ref_meg=True,
                                 meg=False,
                                 eeg=False,
                                 stim=False,
                                 eog=False,
                                 exclude=refexclude)
    else:
        refpick = channel_indices_from_list(raw.info['ch_names'][:],
                                            refchanlist)
        print "refpick = '%s'" % refpick
    nref = len(refpick)
    print "refpick: %3d chans" % nref
    if nref == 0:
        raise ValueError("No channel selected as noise reference")

    print "########## Refchan geo data:"
    # This is just for info to locate special 4D-refs.
    for iref in refpick:
        print raw.info['chs'][iref]['ch_name'], raw.info['chs'][iref]['loc'][
            0:3]
    print ""

    if use_reffilter:
        print "########## Filter reference channels:"
        if refflt_lpfreq is not None:
            print " low-pass with cutoff-freq %.1f" % refflt_lpfreq
        if refflt_hpfreq is not None:
            print "high-pass with cutoff-freq %.1f" % refflt_hpfreq
        # Adapt followg drop-chans cmd to use 'all-but-refpick'
        droplist = [
            raw.info['ch_names'][k] for k in xrange(raw.info['nchan'])
            if not k in refpick
        ]
        fltref = raw.drop_channels(droplist, copy=True)
        tct = time.clock()
        twt = time.time()
        fltref.filter(refflt_hpfreq,
                      refflt_lpfreq,
                      picks=np.array(xrange(nref)),
                      method='iir')
        tc1 = time.clock()
        tw1 = time.time()
        print "filtering ref-chans  took %.1f ms (%.2f s walltime)" % (
            1000. * (tc1 - tct), (tw1 - twt))

    print "########## Calculating sig-ref/ref-ref-channel covariances:"
    # Calculate sig-ref/ref-ref-channel covariance:
    # (there is no need to calc inter-signal-chan cov,
    #  but there seems to be no appropriat fct available)
    # Here we copy the idea from compute_raw_data_covariance()
    # and truncate it as appropriate.
    tct = time.clock()
    twt = time.time()
    # The following reject and info{sig,ref} entries are only
    # used in _is_good-calls.
    # _is_good() from mne-0.9.git-py2.7.egg/mne/epochs.py seems to
    # ignore ref-channels (not covered by dict) and checks individual
    # data segments - artifacts across a buffer boundary are not found.
    reject = dict(
        grad=4000e-13,  # T / m (gradiometers)
        mag=4e-12,  # T (magnetometers)
        eeg=40e-6,  # uV (EEG channels)
        eog=250e-6)  # uV (EOG channels)

    infosig = copy.copy(raw.info)
    infosig['chs'] = [raw.info['chs'][k] for k in sigpick]
    infosig['ch_names'] = [raw.info['ch_names'][k] for k in sigpick]
    infosig['nchan'] = len(sigpick)
    idx_by_typesig = channel_indices_by_type(infosig)

    # inforef not good w/ filtering, but anyway useless
    inforef = copy.copy(raw.info)
    inforef['chs'] = [raw.info['chs'][k] for k in refpick]
    inforef['ch_names'] = [raw.info['ch_names'][k] for k in refpick]
    inforef['nchan'] = len(refpick)
    idx_by_typeref = channel_indices_by_type(inforef)

    # Read data in chunks:
    sigmean = 0
    refmean = 0
    sscovdata = 0
    srcovdata = 0
    rrcovdata = 0
    n_samples = 0
    for first in range(itmin, itmax, itstep):
        last = first + itstep
        if last >= itmax:
            last = itmax
        raw_segmentsig, times = raw[sigpick, first:last]
        if use_reffilter:
            raw_segmentref, times = fltref[:, first:last]
        else:
            raw_segmentref, times = raw[refpick, first:last]
        # if True:
        # if _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject, flat=None,
        #            ignore_chs=raw.info['bads']) and _is_good(raw_segmentref,
        #              inforef['ch_names'], idx_by_typeref, reject, flat=None,
        #                ignore_chs=raw.info['bads']):
        if not exclart or \
           _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject,
                    flat=None, ignore_chs=raw.info['bads']):
            sigmean += raw_segmentsig.sum(axis=1)
            refmean += raw_segmentref.sum(axis=1)
            sscovdata += (raw_segmentsig * raw_segmentsig).sum(axis=1)
            srcovdata += np.dot(raw_segmentsig, raw_segmentref.T)
            rrcovdata += np.dot(raw_segmentref, raw_segmentref.T)
            n_samples += raw_segmentsig.shape[1]
        else:
            logger.info("Artefact detected in [%d, %d]" % (first, last))

    #_check_n_samples(n_samples, len(picks))
    sigmean /= n_samples
    refmean /= n_samples
    sscovdata -= n_samples * sigmean[:] * sigmean[:]
    sscovdata /= (n_samples - 1)
    srcovdata -= n_samples * sigmean[:, None] * refmean[None, :]
    srcovdata /= (n_samples - 1)
    rrcovdata -= n_samples * refmean[:, None] * refmean[None, :]
    rrcovdata /= (n_samples - 1)
    sscovinit = sscovdata
    print "Normalize srcov..."
    rrslopedata = copy.copy(rrcovdata)
    for iref in xrange(nref):
        dtmp = rrcovdata[iref][iref]
        if dtmp > TINY:
            for isig in xrange(nsig):
                srcovdata[isig][iref] /= dtmp
            for jref in xrange(nref):
                rrslopedata[jref][iref] /= dtmp
        else:
            for isig in xrange(nsig):
                srcovdata[isig][iref] = 0.
            for jref in xrange(nref):
                rrslopedata[jref][iref] = 0.
    logger.info("Number of samples used : %d" % n_samples)
    tc1 = time.clock()
    tw1 = time.time()
    print "sigrefchn covar-calc took %.1f ms (%.2f s walltime)" % (1000. *
                                                                   (tc1 - tct),
                                                                   (tw1 - twt))

    print "########## Calculating sig-ref/ref-ref-channel covariances (robust):"
    # Calculate sig-ref/ref-ref-channel covariance:
    # (usg B.P.Welford, "Note on a method for calculating corrected sums
    #                   of squares and products", Technometrics4 (1962) 419-420)
    # (there is no need to calc inter-signal-chan cov,
    #  but there seems to be no appropriat fct available)
    # Here we copy the idea from compute_raw_data_covariance()
    # and truncate it as appropriate.
    tct = time.clock()
    twt = time.time()
    # The following reject and info{sig,ref} entries are only
    # used in _is_good-calls.
    # _is_good() from mne-0.9.git-py2.7.egg/mne/epochs.py seems to
    # ignore ref-channels (not covered by dict) and checks individual
    # data segments - artifacts across a buffer boundary are not found.
    reject = dict(
        grad=4000e-13,  # T / m (gradiometers)
        mag=4e-12,  # T (magnetometers)
        eeg=40e-6,  # uV (EEG channels)
        eog=250e-6)  # uV (EOG channels)

    infosig = copy.copy(raw.info)
    infosig['chs'] = [raw.info['chs'][k] for k in sigpick]
    infosig['ch_names'] = [raw.info['ch_names'][k] for k in sigpick]
    infosig['nchan'] = len(sigpick)
    idx_by_typesig = channel_indices_by_type(infosig)

    # inforef not good w/ filtering, but anyway useless
    inforef = copy.copy(raw.info)
    inforef['chs'] = [raw.info['chs'][k] for k in refpick]
    inforef['ch_names'] = [raw.info['ch_names'][k] for k in refpick]
    inforef['nchan'] = len(refpick)
    idx_by_typeref = channel_indices_by_type(inforef)

    # Read data in chunks:
    smean = np.zeros(nsig)
    smold = np.zeros(nsig)
    rmean = np.zeros(nref)
    rmold = np.zeros(nref)
    sscov = 0
    srcov = 0
    rrcov = np.zeros((nref, nref))
    srcov = np.zeros((nsig, nref))
    n_samples = 0
    for first in range(itmin, itmax, itstep):
        last = first + itstep
        if last >= itmax:
            last = itmax
        raw_segmentsig, times = raw[sigpick, first:last]
        if use_reffilter:
            raw_segmentref, times = fltref[:, first:last]
        else:
            raw_segmentref, times = raw[refpick, first:last]
        # if True:
        # if _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject, flat=None,
        #            ignore_chs=raw.info['bads']) and _is_good(raw_segmentref,
        #              inforef['ch_names'], idx_by_typeref, reject, flat=None,
        #                ignore_chs=raw.info['bads']):
        if not exclart or \
           _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject,
                    flat=None, ignore_chs=raw.info['bads']):
            for isl in xrange(raw_segmentsig.shape[1]):
                nsl = isl + n_samples + 1
                cnslm1dnsl = float((nsl - 1)) / float(nsl)
                sslsubmean = (raw_segmentsig[:, isl] - smold)
                rslsubmean = (raw_segmentref[:, isl] - rmold)
                smean = smold + sslsubmean / nsl
                rmean = rmold + rslsubmean / nsl
                sscov += sslsubmean * (raw_segmentsig[:, isl] - smean)
                srcov += cnslm1dnsl * np.dot(sslsubmean.reshape(
                    (nsig, 1)), rslsubmean.reshape((1, nref)))
                rrcov += cnslm1dnsl * np.dot(rslsubmean.reshape(
                    (nref, 1)), rslsubmean.reshape((1, nref)))
                smold = smean
                rmold = rmean
            n_samples += raw_segmentsig.shape[1]
        else:
            logger.info("Artefact detected in [%d, %d]" % (first, last))

    #_check_n_samples(n_samples, len(picks))
    sscov /= (n_samples - 1)
    srcov /= (n_samples - 1)
    rrcov /= (n_samples - 1)
    print "Normalize srcov..."
    rrslope = copy.copy(rrcov)
    for iref in xrange(nref):
        dtmp = rrcov[iref][iref]
        if dtmp > TINY:
            srcov[:, iref] /= dtmp
            rrslope[:, iref] /= dtmp
        else:
            srcov[:, iref] = 0.
            rrslope[:, iref] = 0.
    logger.info("Number of samples used : %d" % n_samples)
    print "Compare results with 'standard' values:"
    print "cmp(sigmean,smean):", np.allclose(smean, sigmean, atol=0.)
    print "cmp(refmean,rmean):", np.allclose(rmean, refmean, atol=0.)
    print "cmp(sscovdata,sscov):", np.allclose(sscov, sscovdata, atol=0.)
    print "cmp(srcovdata,srcov):", np.allclose(srcov, srcovdata, atol=0.)
    print "cmp(rrcovdata,rrcov):", np.allclose(rrcov, rrcovdata, atol=0.)
    tc1 = time.clock()
    tw1 = time.time()
    print "sigrefchn covar-calc took %.1f ms (%.2f s walltime)" % (1000. *
                                                                   (tc1 - tct),
                                                                   (tw1 - twt))

    if checkresults:
        print "########## Calculated initial signal channel covariance:"
        # Calculate initial signal channel covariance:
        # (only used as quality measure)
        print "initl rt(avg sig pwr) = %12.5e" % np.sqrt(np.mean(sscov))
        for i in xrange(5):
            print "initl signal-rms[%3d] = %12.5e" % (
                i, np.sqrt(sscov.flatten()[i]))
        print " "
    if nref < 6:
        print "rrslope-entries:"
        for i in xrange(nref):
            print rrslope[i][:]

    U, s, V = np.linalg.svd(rrslope, full_matrices=True)
    print s

    print "Applying cutoff for smallest SVs:"
    dtmp = s.max() * SVD_RELCUTOFF
    sinv = np.zeros(nref)
    for i in xrange(nref):
        if abs(s[i]) >= dtmp:
            sinv[i] = 1. / s[i]
        else:
            s[i] = 0.
    # s *= (abs(s)>=dtmp)
    # sinv = ???
    print s
    stat = np.allclose(rrslope, np.dot(U, np.dot(np.diag(s), V)))
    print ">>> Testing svd-result: %s" % stat
    if not stat:
        print "    (Maybe due to SV-cutoff?)"

    # Solve for inverse coefficients:
    print ">>> Setting RRinvtr=U diag(sinv) V"
    RRinvtr = np.zeros((nref, nref))
    RRinvtr = np.dot(U, np.dot(np.diag(sinv), V))
    if checkresults:
        # print ">>> RRinvtr-result:"
        # print RRinvtr
        stat = np.allclose(np.identity(nref),
                           np.dot(rrslope.transpose(), RRinvtr))
        if stat:
            print ">>> Testing RRinvtr-result (shld be unit-matrix): ok"
        else:
            print ">>> Testing RRinvtr-result (shld be unit-matrix): failed"
            print np.dot(rrslope.transpose(), RRinvtr)
            # np.less_equal(np.abs(np.dot(rrslope.transpose(),RRinvtr)-np.identity(nref)),0.01*np.ones((nref,nref)))
        print ""

    print "########## Calc weight matrix..."
    # weights-matrix will be somewhat larger than necessary,
    # (to simplify indexing in compensation loop):
    weights = np.zeros((raw._data.shape[0], nref))
    for isig in xrange(nsig):
        for iref in xrange(nref):
            weights[sigpick[isig]][iref] = np.dot(srcov[isig][:],
                                                  RRinvtr[iref][:])

    if np.allclose(np.zeros(weights.shape), np.abs(weights), atol=1.e-8):
        print ">>> all weights are small (<=1.e-8)."
    else:
        print ">>> largest weight %12.5e" % np.max(np.abs(weights))
        wlrg = np.where(np.abs(weights) >= 0.99 * np.max(np.abs(weights)))
        for iwlrg in xrange(len(wlrg[0])):
            print ">>> weights[%3d,%2d] = %12.5e" % \
                  (wlrg[0][iwlrg], wlrg[1][iwlrg], weights[wlrg[0][iwlrg], wlrg[1][iwlrg]])

    if nref < 5:
        print "weights-entries for first sigchans:"
        for i in xrange(5):
            print 'weights[sp(%2d)][r]=[' % i + ' '.join(
                [' %+10.7f' % val for val in weights[sigpick[i]][:]]) + ']'

    print "########## Compensating signal channels:"
    tct = time.clock()
    twt = time.time()
    # data,times = raw[:,raw.time_as_index(tmin)[0]:raw.time_as_index(tmax)[0]:]
    # Work on entire data stream:
    for isl in xrange(raw._data.shape[1]):
        slice = np.take(raw._data, [isl], axis=1)
        if use_reffilter:
            refslice = np.take(fltref._data, [isl], axis=1)
            refarr = refslice[:].flatten() - rmean
            # refarr = fltres[:,isl]-rmean
        else:
            refarr = slice[refpick].flatten() - rmean
        subrefarr = np.dot(weights[:], refarr)
        # data[:,isl] -= subrefarr   will not modify raw._data?
        raw._data[:, isl] -= subrefarr
        if isl % 10000 == 0:
            print "\rProcessed slice %6d" % isl
    print "\nDone."
    tc1 = time.clock()
    tw1 = time.time()
    print "compensation loop took %.1f ms (%.2f s walltime)" % (1000. *
                                                                (tc1 - tct),
                                                                (tw1 - twt))

    if checkresults:
        print "########## Calculating final signal channel covariance:"
        # Calculate final signal channel covariance:
        # (only used as quality measure)
        tct = time.clock()
        twt = time.time()
        sigmean = 0
        sscovdata = 0
        n_samples = 0
        for first in range(itmin, itmax, itstep):
            last = first + itstep
            if last >= itmax:
                last = itmax
            raw_segmentsig, times = raw[sigpick, first:last]
            # Artifacts found here will probably differ from pre-noisered artifacts!
            if not exclart or \
               _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject,
                        flat=None, ignore_chs=raw.info['bads']):
                sigmean += raw_segmentsig.sum(axis=1)
                sscovdata += (raw_segmentsig * raw_segmentsig).sum(axis=1)
                n_samples += raw_segmentsig.shape[1]
        sigmean /= n_samples
        sscovdata -= n_samples * sigmean[:] * sigmean[:]
        sscovdata /= (n_samples - 1)
        print ">>> no channel got worse: ", np.all(
            np.less_equal(sscovdata, sscovinit))
        print "final rt(avg sig pwr) = %12.5e" % np.sqrt(np.mean(sscovdata))
        for i in xrange(5):
            print "final signal-rms[%3d] = %12.5e" % (
                i, np.sqrt(sscovdata.flatten()[i]))
        tc1 = time.clock()
        tw1 = time.time()
        print "signal covar-calc took %.1f ms (%.2f s walltime)" % (
            1000. * (tc1 - tct), (tw1 - twt))
        print " "

    nrname = dname[:dname.rfind('-raw.fif')] + ',nold-raw.fif'
    print "Saving '%s'..." % nrname
    raw.save(nrname, overwrite=True)
    tc1 = time.clock()
    tw1 = time.time()
    print "Total run         took %.1f ms (%.2f s walltime)" % (1000. *
                                                                (tc1 - tc0),
                                                                (tw1 - tw0))
def noise_reducer_4raw_data(fname_raw,
                            raw=None,
                            signals=[],
                            noiseref=[],
                            detrending=None,
                            tmin=None,
                            tmax=None,
                            reflp=None,
                            refhp=None,
                            refnotch=None,
                            exclude_artifacts=True,
                            checkresults=True,
                            fif_extention="-raw.fif",
                            fif_postfix="nr",
                            reject={
                                'grad': 4000e-13,
                                'mag': 4e-12,
                                'eeg': 40e-6,
                                'eog': 250e-6
                            },
                            complementary_signal=False,
                            fnout=None,
                            verbose=False,
                            save=True):
    """Apply noise reduction to signal channels using reference channels.
        
       !!! ONLY ONE RAW Obj Interface Version FB !!!
           
    Parameters
    ----------
    fname_raw : rawfile name

    raw     : fif raw object

    signals : list of string
              List of channels to compensate using noiseref.
              If empty use the meg signal channels.
    noiseref : list of string | str
              List of channels to use as noise reference.
              If empty use the magnetic reference channsls (default).
    signals and noiseref may contain regexp, which are resolved
    using mne.pick_channels_regexp(). All other channels are copied.
    tmin : lower latency bound for weight-calc [start of trace]
    tmax : upper latency bound for weight-calc [ end  of trace]
           Weights are calc'd for (tmin,tmax), but applied to entire data set
    refhp : high-pass frequency for reference signal filter [None]
    reflp :  low-pass frequency for reference signal filter [None]
            reflp < refhp: band-stop filter
            reflp > refhp: band-pass filter
            reflp is not None, refhp is None: low-pass filter
            reflp is None, refhp is not None: high-pass filter
    refnotch : (base) notch frequency for reference signal filter [None]
               use raw(ref)-notched(ref) as reference signal
    exclude_artifacts: filter signal-channels thru _is_good() [True]
                       (parameters are at present hard-coded!)
    complementary_signal : replaced signal by traces that would be subtracted [False]
                           (can be useful for debugging)
    checkresults : boolean to control internal checks and overall success [True]

    reject =  dict for rejection threshold 
              units:
              grad:    T / m (gradiometers)
              mag:     T (magnetometers)
              eeg/eog: uV (EEG channels)
              default=>{'grad':4000e-13,'mag':4e-12,'eeg':40e-6,'eog':250e-6}
              
    save : save data to fif file

    Outputfile:
    -------
    <wawa>,nr-raw.fif for input <wawa>-raw.fif

    Returns
    -------
    TBD

    Bugs
    ----
    - artifact checking is incomplete (and with arb. window of tstep=0.2s)
    - no accounting of channels used as signal/reference
    - non existing input file handled ungracefully
    """

    tc0 = time.clock()
    tw0 = time.time()

    if type(complementary_signal) != bool:
        raise ValueError("Argument complementary_signal must be of type bool")

    raw, fname_raw = jumeg_base.get_raw_obj(fname_raw, raw=raw)

    if detrending:
        raw = perform_detrending(None, raw=raw, save=False)

    tc1 = time.clock()
    tw1 = time.time()

    if verbose:
        print ">>> loading raw data took %.1f ms (%.2f s walltime)" % (
            1000. * (tc1 - tc0), (tw1 - tw0))

    # Time window selection
    # weights are calc'd based on [tmin,tmax], but applied to the entire data set.
    # tstep is used in artifact detection
    # tmin,tmax variables must not be changed here!
    if tmin is None:
        itmin = 0
    else:
        itmin = int(floor(tmin * raw.info['sfreq']))
    if tmax is None:
        itmax = raw.last_samp
    else:
        itmax = int(ceil(tmax * raw.info['sfreq']))

    if itmax - itmin < 2:
        raise ValueError(
            "Time-window for noise compensation empty or too short")

    if verbose:
        print ">>> Set time-range to [%7.3f,%7.3f]" % \
              (raw.index_as_time(itmin)[0], raw.index_as_time(itmax)[0])

    if signals is None or len(signals) == 0:
        sigpick = jumeg_base.pick_meg_nobads(raw)
    else:
        sigpick = channel_indices_from_list(raw.info['ch_names'][:], signals,
                                            raw.info.get('bads'))
    nsig = len(sigpick)
    if nsig == 0:
        raise ValueError("No channel selected for noise compensation")

    if noiseref is None or len(noiseref) == 0:
        # References are not limited to 4D ref-chans, but can be anything,
        # incl. ECG or powerline monitor.
        if verbose:
            print ">>> Using all refchans."

        refexclude = "bads"
        refpick = jumeg_base.pick_ref_nobads(raw)
    else:
        refpick = channel_indices_from_list(raw.info['ch_names'][:], noiseref,
                                            raw.info.get('bads'))
    nref = len(refpick)
    if nref == 0:
        raise ValueError("No channel selected as noise reference")

    if verbose:
        print ">>> sigpick: %3d chans, refpick: %3d chans" % (nsig, nref)

    if reflp is None and refhp is None and refnotch is None:
        use_reffilter = False
        use_refantinotch = False
    else:
        use_reffilter = True
        if verbose:
            print "########## Filter reference channels:"

        use_refantinotch = False
        if refnotch is not None:
            if reflp is None and reflp is None:
                use_refantinotch = True
                freqlast = np.min([5.01 * refnotch, 0.5 * raw.info['sfreq']])
                if verbose:
                    print ">>> notches at freq %.1f and harmonics below %.1f" % (
                        refnotch, freqlast)
            else:
                raise ValueError("Cannot specify notch- and high-/low-pass"
                                 "reference filter together")
        else:
            if verbose:
                if reflp is not None:
                    print ">>>  low-pass with cutoff-freq %.1f" % reflp
                if refhp is not None:
                    print ">>> high-pass with cutoff-freq %.1f" % refhp

        # Adapt followg drop-chans cmd to use 'all-but-refpick'
        droplist = [
            raw.info['ch_names'][k] for k in xrange(raw.info['nchan'])
            if not k in refpick
        ]
        tct = time.clock()
        twt = time.time()
        fltref = raw.drop_channels(droplist, copy=True)
        if use_refantinotch:
            rawref = raw.drop_channels(droplist, copy=True)
            freqlast = np.min([5.01 * refnotch, 0.5 * raw.info['sfreq']])
            fltref.notch_filter(np.arange(refnotch, freqlast, refnotch),
                                picks=np.array(xrange(nref)),
                                method='iir')
            fltref._data = (rawref._data - fltref._data)
        else:
            fltref.filter(refhp,
                          reflp,
                          picks=np.array(xrange(nref)),
                          method='iir')
        tc1 = time.clock()
        tw1 = time.time()
        if verbose:
            print ">>> filtering ref-chans  took %.1f ms (%.2f s walltime)" % (
                1000. * (tc1 - tct), (tw1 - twt))

    if verbose:
        print "########## Calculating sig-ref/ref-ref-channel covariances:"
    # Calculate sig-ref/ref-ref-channel covariance:
    # (there is no need to calc inter-signal-chan cov,
    #  but there seems to be no appropriat fct available)
    # Here we copy the idea from compute_raw_data_covariance()
    # and truncate it as appropriate.
    tct = time.clock()
    twt = time.time()
    # The following reject and infosig entries are only
    # used in _is_good-calls.
    # _is_good() from mne-0.9.git-py2.7.egg/mne/epochs.py seems to
    # ignore ref-channels (not covered by dict) and checks individual
    # data segments - artifacts across a buffer boundary are not found.

    #--- !!! FB put to kwargs

    #reject = dict(grad=4000e-13, # T / m (gradiometers)
    #              mag=4e-12,     # T (magnetometers)
    #              eeg=40e-6,     # uV (EEG channels)
    #              eog=250e-6)    # uV (EOG channels)

    infosig = copy.copy(raw.info)
    infosig['chs'] = [raw.info['chs'][k] for k in sigpick]
    infosig['ch_names'] = [raw.info['ch_names'][k] for k in sigpick]
    infosig['nchan'] = len(sigpick)
    idx_by_typesig = channel_indices_by_type(infosig)

    # Read data in chunks:
    tstep = 0.2
    itstep = int(ceil(tstep * raw.info['sfreq']))
    sigmean = 0
    refmean = 0
    sscovdata = 0
    srcovdata = 0
    rrcovdata = 0
    n_samples = 0

    for first in range(itmin, itmax, itstep):
        last = first + itstep
        if last >= itmax:
            last = itmax
        raw_segmentsig, times = raw[sigpick, first:last]
        if use_reffilter:
            raw_segmentref, times = fltref[:, first:last]
        else:
            raw_segmentref, times = raw[refpick, first:last]

        if not exclude_artifacts or \
           _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject, flat=None,
                    ignore_chs=raw.info['bads']):
            sigmean += raw_segmentsig.sum(axis=1)
            refmean += raw_segmentref.sum(axis=1)
            sscovdata += (raw_segmentsig * raw_segmentsig).sum(axis=1)
            srcovdata += np.dot(raw_segmentsig, raw_segmentref.T)
            rrcovdata += np.dot(raw_segmentref, raw_segmentref.T)
            n_samples += raw_segmentsig.shape[1]
        else:
            logger.info("Artefact detected in [%d, %d]" % (first, last))
    if n_samples <= 1:
        raise ValueError('Too few samples to calculate weights')
    sigmean /= n_samples
    refmean /= n_samples
    sscovdata -= n_samples * sigmean[:] * sigmean[:]
    sscovdata /= (n_samples - 1)
    srcovdata -= n_samples * sigmean[:, None] * refmean[None, :]
    srcovdata /= (n_samples - 1)
    rrcovdata -= n_samples * refmean[:, None] * refmean[None, :]
    rrcovdata /= (n_samples - 1)
    sscovinit = np.copy(sscovdata)
    if verbose:
        print ">>> Normalize srcov..."

    rrslope = copy.copy(rrcovdata)
    for iref in xrange(nref):
        dtmp = rrcovdata[iref, iref]
        if dtmp > TINY:
            srcovdata[:, iref] /= dtmp
            rrslope[:, iref] /= dtmp
        else:
            srcovdata[:, iref] = 0.
            rrslope[:, iref] = 0.

    if verbose:
        print ">>> Number of samples used : %d" % n_samples
        tc1 = time.clock()
        tw1 = time.time()
        print ">>> sigrefchn covar-calc took %.1f ms (%.2f s walltime)" % (
            1000. * (tc1 - tct), (tw1 - twt))

    if checkresults:
        if verbose:
            print "########## Calculated initial signal channel covariance:"
            # Calculate initial signal channel covariance:
            # (only used as quality measure)
            print ">>> initl rt(avg sig pwr) = %12.5e" % np.sqrt(
                np.mean(sscovdata))
            for i in xrange(5):
                print ">>> initl signal-rms[%3d] = %12.5e" % (
                    i, np.sqrt(sscovdata.flatten()[i]))
            print ">>>"

    U, s, V = np.linalg.svd(rrslope, full_matrices=True)
    if verbose:
        print ">>> singular values:"
        print s
        print ">>> Applying cutoff for smallest SVs:"

    dtmp = s.max() * SVD_RELCUTOFF
    s *= (abs(s) >= dtmp)
    sinv = [1. / s[k] if s[k] != 0. else 0. for k in xrange(nref)]
    if verbose:
        print ">>> singular values (after cutoff):"
        print s

    stat = np.allclose(rrslope, np.dot(U, np.dot(np.diag(s), V)))
    if verbose:
        print ">>> Testing svd-result: %s" % stat
        if not stat:
            print "    (Maybe due to SV-cutoff?)"

    # Solve for inverse coefficients:
    # Set RRinv.tr=U diag(sinv) V
    RRinv = np.transpose(np.dot(U, np.dot(np.diag(sinv), V)))
    if checkresults:
        stat = np.allclose(np.identity(nref), np.dot(RRinv, rrslope))
        if stat:
            if verbose:
                print ">>> Testing RRinv-result (should be unit-matrix): ok"
        else:
            print ">>> Testing RRinv-result (should be unit-matrix): failed"
            print np.transpose(np.dot(RRinv, rrslope))
            print ">>>"

    if verbose:
        print "########## Calc weight matrix..."

    # weights-matrix will be somewhat larger than necessary,
    # (to simplify indexing in compensation loop):
    weights = np.zeros((raw._data.shape[0], nref))
    for isig in xrange(nsig):
        for iref in xrange(nref):
            weights[sigpick[isig], iref] = np.dot(srcovdata[isig, :],
                                                  RRinv[:, iref])

    if verbose:
        print "########## Compensating signal channels:"
        if complementary_signal:
            print ">>> Caveat: REPLACING signal by compensation signal"

    tct = time.clock()
    twt = time.time()

    # Work on entire data stream:
    for isl in xrange(raw._data.shape[1]):
        slice = np.take(raw._data, [isl], axis=1)
        if use_reffilter:
            refslice = np.take(fltref._data, [isl], axis=1)
            refarr = refslice[:].flatten() - refmean
            # refarr = fltres[:,isl]-refmean
        else:
            refarr = slice[refpick].flatten() - refmean
        subrefarr = np.dot(weights[:], refarr)

        if not complementary_signal:
            raw._data[:, isl] -= subrefarr
        else:
            raw._data[:, isl] = subrefarr

        if (isl % 10000 == 0) and verbose:
            print "\rProcessed slice %6d" % isl

    if verbose:
        print "\nDone."
        tc1 = time.clock()
        tw1 = time.time()
        print ">>> compensation loop took %.1f ms (%.2f s walltime)" % (
            1000. * (tc1 - tct), (tw1 - twt))

    if checkresults:
        if verbose:
            print "########## Calculating final signal channel covariance:"
        # Calculate final signal channel covariance:
        # (only used as quality measure)
        tct = time.clock()
        twt = time.time()
        sigmean = 0
        sscovdata = 0
        n_samples = 0
        for first in range(itmin, itmax, itstep):
            last = first + itstep
            if last >= itmax:
                last = itmax
            raw_segmentsig, times = raw[sigpick, first:last]
            # Artifacts found here will probably differ from pre-noisered artifacts!
            if not exclude_artifacts or \
               _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject,
                        flat=None, ignore_chs=raw.info['bads']):
                sigmean += raw_segmentsig.sum(axis=1)
                sscovdata += (raw_segmentsig * raw_segmentsig).sum(axis=1)
                n_samples += raw_segmentsig.shape[1]
        sigmean /= n_samples
        sscovdata -= n_samples * sigmean[:] * sigmean[:]
        sscovdata /= (n_samples - 1)

        if verbose:
            print ">>> no channel got worse: ", np.all(
                np.less_equal(sscovdata, sscovinit))
            print ">>> final rt(avg sig pwr) = %12.5e" % np.sqrt(
                np.mean(sscovdata))
            for i in xrange(5):
                print ">>> final signal-rms[%3d] = %12.5e" % (
                    i, np.sqrt(sscovdata.flatten()[i]))
            tc1 = time.clock()
            tw1 = time.time()
            print ">>> signal covar-calc took %.1f ms (%.2f s walltime)" % (
                1000. * (tc1 - tct), (tw1 - twt))
            print ">>>"

#--- fb update 21.07.2015
    fname_out = jumeg_base.get_fif_name(raw=raw,
                                        postfix=fif_postfix,
                                        extention=fif_extention)

    if save:
        jumeg_base.apply_save_mne_data(raw, fname=fname_out, overwrite=True)

    tc1 = time.clock()
    tw1 = time.time()
    if verbose:
        print ">>> Total run took %.1f ms (%.2f s walltime)" % (1000. *
                                                                (tc1 - tc0),
                                                                (tw1 - tw0))

    return raw, fname_out
Example #25
0
def test_plot_topomap_basic(monkeypatch):
    """Test basics of topomap plotting."""
    evoked = read_evokeds(evoked_fname, 'Left Auditory', baseline=(None, 0))
    res = 8
    fast_test = dict(res=res, contours=0, sensors=False, time_unit='s')
    fast_test_noscale = dict(res=res, contours=0, sensors=False)
    ev_bad = evoked.copy().pick_types(meg=False, eeg=True)
    ev_bad.pick_channels(ev_bad.ch_names[:2])
    plt_topomap = partial(ev_bad.plot_topomap, **fast_test)
    plt_topomap(times=ev_bad.times[:2] - 1e-6)  # auto, plots EEG
    pytest.raises(ValueError, plt_topomap, ch_type='mag')
    pytest.raises(ValueError, plt_topomap, times=[-100])  # bad time
    pytest.raises(ValueError, plt_topomap, times=[[0]])  # bad time

    evoked.plot_topomap([0.1],
                        ch_type='eeg',
                        scalings=1,
                        res=res,
                        contours=[-100, 0, 100],
                        time_unit='ms')

    # extrapolation to the edges of the convex hull or the head circle
    evoked.plot_topomap([0.1],
                        ch_type='eeg',
                        scalings=1,
                        res=res,
                        contours=[-100, 0, 100],
                        time_unit='ms',
                        extrapolate='local')
    evoked.plot_topomap([0.1],
                        ch_type='eeg',
                        scalings=1,
                        res=res,
                        contours=[-100, 0, 100],
                        time_unit='ms',
                        extrapolate='head')
    evoked.plot_topomap([0.1],
                        ch_type='eeg',
                        scalings=1,
                        res=res,
                        contours=[-100, 0, 100],
                        time_unit='ms',
                        extrapolate='head',
                        outlines='skirt')

    # extrapolation options when < 4 channels:
    temp_data = np.random.random(3)
    picks = channel_indices_by_type(evoked.info)['mag'][:3]
    info_sel = pick_info(evoked.info, picks)
    plot_topomap(temp_data, info_sel, extrapolate='local', res=res)
    plot_topomap(temp_data, info_sel, extrapolate='head', res=res)

    # make sure extrapolation works for 3 channels with border='mean'
    # (if extra points are placed incorrectly some of them have only
    #  other extra points as neighbours and border='mean' fails)
    plot_topomap(temp_data,
                 info_sel,
                 extrapolate='local',
                 border='mean',
                 res=res)

    # border=0 and border='mean':
    # ---------------------------
    ch_pos = np.array(
        sum(([[0, 0, r], [r, 0, 0], [-r, 0, 0], [0, -r, 0], [0, r, 0]]
             for r in np.linspace(0.2, 1.0, 5)), []))
    rng = np.random.RandomState(23)
    data = np.full(len(ch_pos), 5) + rng.randn(len(ch_pos))
    info = create_info(len(ch_pos), 250, 'eeg')
    ch_pos_dict = {name: pos for name, pos in zip(info['ch_names'], ch_pos)}
    dig = make_dig_montage(ch_pos_dict, coord_frame='head')
    info.set_montage(dig)

    # border=0
    ax, _ = plot_topomap(data, info, extrapolate='head', border=0, sphere=1)
    img_data = ax.get_array().data

    assert np.abs(img_data[31, 31] - data[0]) < 0.12
    assert np.abs(img_data[0, 0]) < 1.5

    # border='mean'
    ax, _ = plot_topomap(data,
                         info,
                         extrapolate='head',
                         border='mean',
                         sphere=1)
    img_data = ax.get_array().data

    assert np.abs(img_data[31, 31] - data[0]) < 0.12
    assert img_data[0, 0] > 5

    # error when not numeric or str:
    error_msg = 'border must be an instance of numeric or str'
    with pytest.raises(TypeError, match=error_msg):
        plot_topomap(data, info, extrapolate='head', border=[1, 2, 3])

    # error when str is not 'mean':
    error_msg = "The only allowed value is 'mean', but got 'fancy' instead."
    with pytest.raises(ValueError, match=error_msg):
        plot_topomap(data, info, extrapolate='head', border='fancy')

    # test channel placement when only 'grad' are picked:
    # ---------------------------------------------------
    info_grad = evoked.copy().pick('grad').info
    n_grads = len(info_grad['ch_names'])
    data = np.random.randn(n_grads)
    img, _ = plot_topomap(data, info_grad)

    # check that channels are scattered around x == 0
    pos = img.axes.collections[-1].get_offsets()
    prop_channels_on_the_right = (pos[:, 0] > 0).mean()
    assert prop_channels_on_the_right < 0.6

    # other:
    # ------
    plt_topomap = partial(evoked.plot_topomap, **fast_test)
    plt.close('all')
    axes = [plt.subplot(221), plt.subplot(222)]
    plt_topomap(axes=axes, colorbar=False)
    plt.close('all')
    plt_topomap(times=[-0.1, 0.2])
    plt.close('all')
    evoked_grad = evoked.copy().crop(0, 0).pick_types(meg='grad')
    mask = np.zeros((204, 1), bool)
    mask[[0, 3, 5, 6]] = True
    names = []

    def proc_names(x):
        names.append(x)
        return x[4:]

    evoked_grad.plot_topomap(ch_type='grad',
                             times=[0],
                             mask=mask,
                             show_names=proc_names,
                             **fast_test)
    assert_equal(sorted(names),
                 ['MEG 011x', 'MEG 012x', 'MEG 013x', 'MEG 014x'])
    mask = np.zeros_like(evoked.data, dtype=bool)
    mask[[1, 5], :] = True
    plt_topomap(ch_type='mag', outlines=None)
    times = [0.1]
    plt_topomap(times, ch_type='grad', mask=mask)
    plt_topomap(times, ch_type='planar1')
    plt_topomap(times, ch_type='planar2')
    plt_topomap(times,
                ch_type='grad',
                mask=mask,
                show_names=True,
                mask_params={'marker': 'x'})
    plt.close('all')
    with pytest.raises(ValueError, match='number of seconds; got -'):
        plt_topomap(times, ch_type='eeg', average=-1e3)
    with pytest.raises(TypeError, match='number of seconds; got type'):
        plt_topomap(times, ch_type='eeg', average='x')

    p = plt_topomap(times,
                    ch_type='grad',
                    image_interp='bilinear',
                    show_names=lambda x: x.replace('MEG', ''))
    subplot = [x for x in p.get_children() if 'Subplot' in str(type(x))]
    assert len(subplot) >= 1, [type(x) for x in p.get_children()]
    subplot = subplot[0]

    have_all = all('MEG' not in x.get_text() for x in subplot.get_children()
                   if isinstance(x, matplotlib.text.Text))
    assert have_all

    # Plot array
    for ch_type in ('mag', 'grad'):
        evoked_ = evoked.copy().pick_types(eeg=False, meg=ch_type)
        plot_topomap(evoked_.data[:, 0], evoked_.info, **fast_test_noscale)
    # fail with multiple channel types
    pytest.raises(ValueError, plot_topomap, evoked.data[0, :], evoked.info)

    # Test title
    def get_texts(p):
        return [
            x.get_text() for x in p.get_children()
            if isinstance(x, matplotlib.text.Text)
        ]

    p = plt_topomap(times, ch_type='eeg', average=0.01)
    assert_equal(len(get_texts(p)), 0)
    p = plt_topomap(times, ch_type='eeg', title='Custom')
    texts = get_texts(p)
    assert_equal(len(texts), 1)
    assert_equal(texts[0], 'Custom')
    plt.close('all')

    # delaunay triangulation warning
    plt_topomap(times, ch_type='mag')
    # projs have already been applied
    pytest.raises(RuntimeError,
                  plot_evoked_topomap,
                  evoked,
                  0.1,
                  'mag',
                  proj='interactive',
                  time_unit='s')

    # change to no-proj mode
    evoked = read_evokeds(evoked_fname,
                          'Left Auditory',
                          baseline=(None, 0),
                          proj=False)
    fig1 = evoked.plot_topomap('interactive',
                               'mag',
                               proj='interactive',
                               **fast_test)
    _fake_click(fig1, fig1.axes[1], (0.5, 0.5))  # click slider
    data_max = np.max(fig1.axes[0].images[0]._A)
    fig2 = plt.gcf()
    _fake_click(fig2, fig2.axes[0], (0.075, 0.775))  # toggle projector
    # make sure projector gets toggled
    assert (np.max(fig1.axes[0].images[0]._A) != data_max)

    with monkeypatch.context() as m:  # speed it up by not actually plotting
        m.setattr(topomap, '_plot_topomap', lambda *args, **kwargs:
                  (None, None, None))
        with pytest.warns(RuntimeWarning, match='More than 25 topomaps plots'):
            plot_evoked_topomap(evoked, [0.1] * 26, colorbar=False)

    pytest.raises(ValueError,
                  plot_evoked_topomap,
                  evoked, [-3e12, 15e6],
                  time_unit='s')

    for ch in evoked.info['chs']:
        if ch['coil_type'] == FIFF.FIFFV_COIL_EEG:
            ch['loc'].fill(0)

    # Remove extra digitization point, so EEG digitization points
    # correspond with the EEG electrodes
    del evoked.info['dig'][85]

    # Plot skirt
    evoked.plot_topomap(times, ch_type='eeg', outlines='skirt', **fast_test)

    # Pass custom outlines without patch
    eeg_picks = pick_types(evoked.info, meg=False, eeg=True)
    pos, outlines = _get_pos_outlines(evoked.info, eeg_picks, 0.1)
    evoked.plot_topomap(times, ch_type='eeg', outlines=outlines, **fast_test)
    plt.close('all')

    # Test interactive cmap
    fig = plot_evoked_topomap(evoked,
                              times=[0., 0.1],
                              ch_type='eeg',
                              cmap=('Reds', True),
                              title='title',
                              **fast_test)
    fig.canvas.key_press_event('up')
    fig.canvas.key_press_event(' ')
    fig.canvas.key_press_event('down')
    cbar = fig.get_axes()[0].CB  # Fake dragging with mouse.
    ax = cbar.cbar.ax
    _fake_click(fig, ax, (0.1, 0.1))
    _fake_click(fig, ax, (0.1, 0.2), kind='motion')
    _fake_click(fig, ax, (0.1, 0.3), kind='release')

    _fake_click(fig, ax, (0.1, 0.1), button=3)
    _fake_click(fig, ax, (0.1, 0.2), button=3, kind='motion')
    _fake_click(fig, ax, (0.1, 0.3), kind='release')

    fig.canvas.scroll_event(0.5, 0.5, -0.5)  # scroll down
    fig.canvas.scroll_event(0.5, 0.5, 0.5)  # scroll up

    plt.close('all')

    # Pass custom outlines with patch callable
    def patch():
        return Circle((0.5, 0.4687),
                      radius=.46,
                      clip_on=True,
                      transform=plt.gca().transAxes)

    outlines['patch'] = patch
    plot_evoked_topomap(evoked,
                        times,
                        ch_type='eeg',
                        outlines=outlines,
                        **fast_test)

    # Remove digitization points. Now topomap should fail
    evoked.info['dig'] = None
    pytest.raises(RuntimeError,
                  plot_evoked_topomap,
                  evoked,
                  times,
                  ch_type='eeg',
                  time_unit='s')
    plt.close('all')

    # Error for missing names
    n_channels = len(pos)
    data = np.ones(n_channels)
    pytest.raises(ValueError, plot_topomap, data, pos, show_names=True)

    # Test error messages for invalid pos parameter
    pos_1d = np.zeros(n_channels)
    pos_3d = np.zeros((n_channels, 2, 2))
    pytest.raises(ValueError, plot_topomap, data, pos_1d)
    pytest.raises(ValueError, plot_topomap, data, pos_3d)
    pytest.raises(ValueError, plot_topomap, data, pos[:3, :])

    pos_x = pos[:, :1]
    pos_xyz = np.c_[pos, np.zeros(n_channels)[:, np.newaxis]]
    pytest.raises(ValueError, plot_topomap, data, pos_x)
    pytest.raises(ValueError, plot_topomap, data, pos_xyz)

    # An #channels x 4 matrix should work though. In this case (x, y, width,
    # height) is assumed.
    pos_xywh = np.c_[pos, np.zeros((n_channels, 2))]
    plot_topomap(data, pos_xywh)
    plt.close('all')

    # Test peak finder
    axes = [plt.subplot(131), plt.subplot(132)]
    evoked.plot_topomap(times='peaks', axes=axes, **fast_test)
    plt.close('all')
    evoked.data = np.zeros(evoked.data.shape)
    evoked.data[50][1] = 1
    assert_array_equal(_find_peaks(evoked, 10), evoked.times[1])
    evoked.data[80][100] = 1
    assert_array_equal(_find_peaks(evoked, 10), evoked.times[[1, 100]])
    evoked.data[2][95] = 2
    assert_array_equal(_find_peaks(evoked, 10), evoked.times[[1, 95]])
    assert_array_equal(_find_peaks(evoked, 1), evoked.times[95])

    # Test excluding bads channels
    evoked_grad.info['bads'] += [evoked_grad.info['ch_names'][0]]
    orig_bads = evoked_grad.info['bads']
    evoked_grad.plot_topomap(ch_type='grad', times=[0], time_unit='ms')
    assert_array_equal(evoked_grad.info['bads'], orig_bads)
    plt.close('all')
Example #26
0
def get_rejection_threshold(epochs):
    """Compute global rejection thresholds.

    Parameters
    ----------
    epochs : mne.Epochs object
        The epochs from which to estimate the epochs dictionary

    Returns
    -------
    reject : dict
        The rejection dictionary with keys 'mag', 'grad', 'eeg', 'eog'
        and 'ecg'.

    Note
    ----
    Sensors marked as bad by user will be excluded when estimating the
    rejection dictionary.
    """
    reject = dict()
    X = epochs.get_data()
    picks = channel_indices_by_type(epochs.info)
    for ch_type in ['mag', 'grad', 'eeg', 'eog', 'ecg']:
        if ch_type not in epochs:
            continue
        if ch_type == 'ecg' and 'mag' not in epochs:
            continue
        if ch_type == 'eog' and not \
                ('mag' in epochs or 'grad' in epochs or 'eeg' in epochs):
            continue

        this_picks = [
            p for p in picks[ch_type]
            if epochs.info['ch_names'][p] not in epochs.info['bads']
        ]
        deltas = np.array([np.ptp(d, axis=1) for d in X[:, this_picks, :]])
        param_range = deltas.max(axis=1)
        print('Estimating rejection dictionary for %s with %d candidate'
              ' thresholds' % (ch_type, param_range.shape[0]))

        if ch_type == 'mag' or ch_type == 'ecg':
            this_epoch = epochs.copy().pick_types(meg='mag', eeg=False)
        elif ch_type == 'eeg':
            this_epoch = epochs.copy().pick_types(meg=False, eeg=True)
        elif ch_type == 'eog':
            # Cannot mix channel types in cv score
            if 'eeg' in epochs:
                this_epoch = epochs.copy().pick_types(meg=False, eeg=True)
            elif 'grad' in epochs:
                this_epoch = epochs.copy().pick_types(meg='grad', eeg=False)
            elif 'mag' in epochs:
                this_epoch = epochs.copy().pick_types(meg='mag', eeg=False)
        elif ch_type == 'grad':
            this_epoch = epochs.copy().pick_types(meg='grad', eeg=False)

        _, test_scores = validation_curve(GlobalAutoReject(),
                                          this_epoch,
                                          y=None,
                                          param_name="thresh",
                                          param_range=param_range,
                                          cv=5)

        test_scores = -test_scores.mean(axis=1)
        reject[ch_type] = param_range[np.argmin(test_scores)]
    return reject
Example #27
0
def preproc1epoch(eeg,
                  info,
                  projs=[],
                  SSP=True,
                  reject=None,
                  mne_reject=1,
                  reject_ch=None,
                  flat=None,
                  bad_channels=[],
                  opt_detrend=1):
    '''
    Preprocesses epoched EEG data.
    
    # Input
    - eeg: Epoched EEG data in the following format: (trials, time samples, channels).
    - info: predefined info containing channels etc.
    - projs: used if SSP=True. SSP projectors
    
    # Preprocessing
    - EpochsArray format in MNE (with initial baseline correction)
    - Bandpass filter (0-40Hz)
    - Resample to 100Hz
    - SSP (if True)
    - Reject bad channels
        - interpolate bad channels
    - Rereference to average
    - Baseline correction
    
    # Output
    - Epoched preprocessed EEG data in np array.
    
    '''

    n_samples = eeg.shape[0]
    n_channels = eeg.shape[1]
    eeg = np.reshape(eeg.T, (1, n_channels, n_samples))
    tmin = -0.1  # start baseline at

    # Temporal detrending:
    if opt_detrend == 1:
        eeg = detrend(eeg, axis=2, type='linear')

    epoch = mne.EpochsArray(eeg, info, tmin=tmin, baseline=None, verbose=False)

    # Drop list of channels known to be problematic:
    if reject_ch == True:
        bads = ['Fp1', 'Fp2', 'Fz', 'AF3', 'AF4', 'T7', 'T8', 'F7', 'F8']
        epoch.drop_channels(bads)

    # Lowpass
    epoch.filter(HP, LP, fir_design='firwin', phase=phase, verbose=False)

    # Downsample
    epoch.resample(100, npad='auto', verbose=False)

    # Apply baseline correction
    epoch.apply_baseline(baseline=(None, 0), verbose=False)

    # Apply SSP prejectors
    if SSP == True:
        # Apply projection to the epochs already defined
        epoch.add_proj(projs)
        epoch.apply_proj()

    if reject is not None:  # currently not used
        if mne_reject == 1:  # use mne method to reject+interpolate bad channels
            from mne.epochs import _is_good
            from mne.io.pick import channel_indices_by_type
            #reject=dict(eeg=100)
            idx_by_type = channel_indices_by_type(epoch.info)
            A, bad_channels = _is_good(epoch.get_data()[0],
                                       epoch.ch_names,
                                       channel_type_idx=idx_by_type,
                                       reject=reject,
                                       flat=flat,
                                       full_report=True)
            print(A)
            if A == False:
                epoch.info['bads'] = bad_channels
                epoch.interpolate_bads(reset_bads=True, verbose=False)
        else:  # bad_channels is predefined
            epoch.drop_channels(bad_channels)

    # Rereferencing
    epoch.set_eeg_reference(verbose=False)
    # Apply baseline after rereference
    epoch.apply_baseline(baseline=(None, 0), verbose=False)

    epoch = epoch.get_data()[0]
    return epoch
def preproc1epoch(eeg,
                  info,
                  projs=[],
                  SSP=True,
                  reject=None,
                  mne_reject=1,
                  reject_ch=None,
                  flat=None,
                  bad_channels=[],
                  opt_detrend=1,
                  HP=0,
                  LP=40,
                  phase='zero-double'):
    '''    
    Preprocesses EEG data epoch-wise. 
    
    # Arguments
        eeg: numPy array
            EEG epoch in the following format: [time samples, channels].
        
        info: MNE info structure. 
            Predefined info structure. Can be generated using createInfoMNE function.
            
        projs: list
            MNE SSP projector objects. Used if SSP = True. 
            
        SSP: boolean
            Whether to apply SSP projectors (artefact correction) to the EEG epoch.
            
        reject: boolean
            Whether to reject channels, either manually defined or based on MNE analysis.
            
        mne_reject: boolean
            Whether to use MNE rejection based on the built-in function: epochs._is_good. 
            
        reject_ch: boolean
            Whether to reject nine predefined channels (can be changed to any channels).
            
        flat: boolean
            Input for the MNE built-in function: epochs._is_good. See function documentation.
            
        bad_channels: list
            Input for the MNE built-in function: epochs._is_good. Manual rejection of channels. See function documentation.
            
        opt_detrend: boolean
            Whether to apply temporal EEG detrending (linear).
        
        HP: int
            High-pass filter cut-off, default 0 Hz.
        
        LP: int
            Low-pass filter cut-off, default 40 Hz.

        phase: string
            FIR filter phase (refer to MNE filtering function for options), default 'zero-double'.

    
    # Preprocessing steps - based on inputs 
    
        Linear temporal detrending
        
        Initial rejection of pre-defined channels 
        
        Bandpass filtering (currently 0-40 Hz, defined by variables: LP, HP, phase)
        
        Resampling to 100 Hz
        
        SSP artefact correction 
        
        Analysis and rejection of bad channels
            Interpolation of bad channels
            
        Average re-referencing
        
        Baseline correction
    
    # Returns
        epoch: NumPy array
            Preprocessed EEG epoch in NumPy array.
    
    '''

    n_samples = eeg.shape[0]
    n_channels = eeg.shape[1]
    eeg = np.reshape(eeg.T, (1, n_channels, n_samples))
    tmin = -0.1  # Baseline start, i.e. 100 ms before stimulus onset

    # Temporal detrending:
    if opt_detrend == 1:
        eeg = detrend(eeg, axis=2, type='linear')

    epoch = mne.EpochsArray(eeg, info, tmin=tmin, baseline=None, verbose=False)

    # Drop list of channels known to be problematic:
    if reject_ch == True:
        bads = ['Fp1', 'Fp2', 'Fz', 'AF3', 'AF4', 'T7', 'T8', 'F7', 'F8']
        epoch.drop_channels(bads)

    # Lowpass
    epoch.filter(HP, LP, fir_design='firwin', phase=phase, verbose=False)

    # Downsample
    epoch.resample(100, npad='auto', verbose=False)

    # Apply baseline correction
    epoch.apply_baseline(baseline=(None, 0), verbose=False)

    # Apply SSP projectors
    if SSP == True:
        epoch.add_proj(projs)
        epoch.apply_proj()

    if reject is not None:  # Rejection of channels, either manually defined or based on MNE analysis. Currently not used.
        if mne_reject == 1:  # Use MNE method to reject+interpolate bad channels
            from mne.epochs import _is_good
            from mne.io.pick import channel_indices_by_type
            # reject=dict(eeg=100)
            idx_by_type = channel_indices_by_type(epoch.info)
            A, bad_channels = _is_good(epoch.get_data()[0],
                                       epoch.ch_names,
                                       channel_type_idx=idx_by_type,
                                       reject=reject,
                                       flat=flat,
                                       full_report=True)
            print(A)
            if A == False:
                epoch.info['bads'] = bad_channels
                epoch.interpolate_bads(reset_bads=True, verbose=False)
        else:  # Predefined bad_channels
            epoch.drop_channels(bad_channels)

    # Re-referencing
    epoch.set_eeg_reference(verbose=False)

    # Apply baseline after rereference
    epoch.apply_baseline(baseline=(None, 0), verbose=False)

    epoch = epoch.get_data()[0]

    return epoch
Example #29
0
def noise_reducer(fname_raw, raw=None, signals=[], noiseref=[], detrending=None,
                  tmin=None, tmax=None, reflp=None, refhp=None, refnotch=None,
                  exclude_artifacts=True, checkresults=True, return_raw=False,
                  complementary_signal=False, fnout=None, verbose=False):
    """
    Apply noise reduction to signal channels using reference channels.

    Parameters
    ----------
    fname_raw : (list of) rawfile name(s)
    raw : mne Raw objects
        Allows passing of (preloaded) raw object in addition to fname_raw
        or solely (use fname_raw=None in this case).
    signals : list of string
              List of channels to compensate using noiseref.
              If empty use the meg signal channels.
    noiseref : list of string | str
              List of channels to use as noise reference.
              If empty use the magnetic reference channsls (default).
    signals and noiseref may contain regexp, which are resolved
    using mne.pick_channels_regexp(). All other channels are copied.
    tmin : lower latency bound for weight-calc [start of trace]
    tmax : upper latency bound for weight-calc [ end  of trace]
           Weights are calc'd for (tmin,tmax), but applied to entire data set
    refhp : high-pass frequency for reference signal filter [None]
    reflp :  low-pass frequency for reference signal filter [None]
            reflp < refhp: band-stop filter
            reflp > refhp: band-pass filter
            reflp is not None, refhp is None: low-pass filter
            reflp is None, refhp is not None: high-pass filter
    refnotch : (list of) notch frequencies for reference signal filter [None]
               use raw(ref)-notched(ref) as reference signal
    exclude_artifacts: filter signal-channels thru _is_good() [True]
                       (parameters are at present hard-coded!)
    return_raw : bool
        If return_raw is true, the raw object is returned and raw file
        is not written to disk unless fnout is explicitly specified.
        It is suggested that this option be used in cases where the
        noise_reducer is applied multiple times. [False]
    fnout : explicit specification for an output file name [None]
        Automatic filenames replace '-raw.fif' by ',nr-raw.fif'.
    complementary_signal : replaced signal by traces that would be
                           subtracted [False]
                           (can be useful for debugging)
    detrending: boolean to ctrl subtraction of linear trend from all
                magn. chans [False]
    checkresults : boolean to control internal checks and overall success
                   [True]

    Outputfile
    ----------
    <wawa>,nr-raw.fif for input <wawa>-raw.fif

    Returns
    -------
    If return_raw is True, then mne.io.Raw instance is returned.

    Bugs
    ----
    - artifact checking is incomplete (and with arb. window of tstep=0.2s)
    - no accounting of channels used as signal/reference
    - non existing input file handled ungracefully
    """

    if type(complementary_signal) != bool:
        raise ValueError("Argument complementary_signal must be of type bool")

    # handle error if Raw object passed with file list
    if raw and isinstance(fname_raw, list):
        raise ValueError('List of file names cannot be combined with'
                         'one Raw object')

    # handle error if return_raw is requested with file list
    if return_raw and isinstance(fname_raw, list):
        raise ValueError('List of file names cannot be combined return_raw.'
                         'Please pass one file at a time.')

    # handle error if Raw object is passed with detrending option
    # TODO include perform_detrending for Raw objects
    if raw and detrending:
        raise ValueError('Please perform detrending on the raw file directly.'
                         'Cannot perform detrending on the raw object')

    # Handle combinations of fname_raw and raw object:
    if fname_raw is not None:
        fnraw = get_files_from_list(fname_raw)
        have_input_file = True
    elif raw is not None:
        if 'filename' in raw.info:
            fnraw = [os.path.basename(raw.filenames[0])]
        else:
            fnraw = raw._filenames[0]
        warnings.warn('Setting file name from Raw object')
        have_input_file = False
        if fnout is None and not return_raw:
            raise ValueError('Refusing to waste resources without result')
    else:
        raise ValueError('Refusing Creatio ex nihilo')

    # loop across all filenames
    for fname in fnraw:

        if verbose:
            print("########## Read raw data:")

        tc0 = time.clock()
        tw0 = time.time()

        if raw is None:
            if detrending:
                raw = perform_detrending(fname, save=False)
            else:
                raw = mne.io.Raw(fname, preload=True)
        else:
            # perform sanity check to make sure Raw object and file are same
            if 'filename' in raw.info:
                fnintern = [os.path.basename(raw.filenames[0])]
            else:
                fnintern = raw._filenames[0]
            if os.path.basename(fname) != os.path.basename(fnintern):
                warnings.warn('The file name within the Raw object and provided\n   '
                              'fname are not the same. Please check again.')

        tc1 = time.clock()
        tw1 = time.time()

        if verbose:
            print(">>> loading raw data took {:.1f} ms ({:.2f} s walltime)".format((1000. * (tc1 - tc0)), (tw1 - tw0)))

        # Time window selection
        # weights are calc'd based on [tmin,tmax], but applied to the entire data set.
        # tstep is used in artifact detection
        # tmin,tmax variables must not be changed here!
        if tmin is None:
            itmin = 0
        else:
            itmin = int(floor(tmin * raw.info['sfreq']))
        if tmax is None:
            itmax = raw.last_samp - raw.first_samp
        else:
            itmax = int(ceil(tmax * raw.info['sfreq']))

        if itmax - itmin < 2:
            raise ValueError("Time-window for noise compensation empty or too short")

        if verbose:
            print(">>> Set time-range to [%7.3f,%7.3f]" % \
                  (raw.times[itmin], raw.times[itmax]))

        if signals is None or len(signals) == 0:
            sigpick = mne.pick_types(raw.info, meg='mag', eeg=False, stim=False,
                                     eog=False, exclude='bads')
        else:
            sigpick = channel_indices_from_list(raw.info['ch_names'][:], signals,
                                                raw.info.get('bads'))
        nsig = len(sigpick)
        if nsig == 0:
            raise ValueError("No channel selected for noise compensation")

        if noiseref is None or len(noiseref) == 0:
            # References are not limited to 4D ref-chans, but can be anything,
            # incl. ECG or powerline monitor.
            if verbose:
                print(">>> Using all refchans.")
            refexclude = "bads"
            refpick = mne.pick_types(raw.info, ref_meg=True, meg=False,
                                     eeg=False, stim=False,
                                     eog=False, exclude='bads')
        else:
            refpick = channel_indices_from_list(raw.info['ch_names'][:],
                                                noiseref, raw.info.get('bads'))
        nref = len(refpick)
        if nref == 0:
            raise ValueError("No channel selected as noise reference")

        if verbose:
            print(">>> sigpick: %3d chans, refpick: %3d chans" % (nsig, nref))
        badpick = np.intersect1d(sigpick, refpick, assume_unique=False)
        if len(badpick) > 0:
            raise Warning("Intersection of signal and reference channels not empty")

        if reflp is None and refhp is None and refnotch is None:
            use_reffilter = False
            use_refantinotch = False
        else:
            use_reffilter = True
            if verbose:
                print("########## Filter reference channels:")

            use_refantinotch = False
            if refnotch is not None:
                if reflp is not None or reflp is not None:
                    raise ValueError("Cannot specify notch- and high-/low-pass"
                                     "reference filter together")
                nyquist = (0.5 * raw.info['sfreq'])
                if isinstance(refnotch, list):
                    notchfrqs = refnotch
                else:
                    notchfrqs = [refnotch]
                notchfrqscln = []
                for nfrq in notchfrqs:
                    if not isinstance(nfrq, float) and not isinstance(nfrq, int):
                        raise ValueError("Illegal entry for notch-frequency (", nfrq, ")")
                    if nfrq >= nyquist:
                        warnings.warn('Ignoring notch frequency > 0.5*sample_rate=%.1fHz' % nyquist)
                    else:
                        notchfrqscln.append(nfrq)
                if len(notchfrqscln) == 0:
                    raise ValueError("Notch frequency list is (now) empty")
                use_refantinotch = True
                if verbose:
                    print(">>> notches at freq ", end=' ')
                    print(notchfrqscln)
            else:
                if verbose:
                    if reflp is not None:
                        print(">>>  low-pass with cutoff-freq %.1f" % reflp)
                    if refhp is not None:
                        print(">>> high-pass with cutoff-freq %.1f" % refhp)

            # Adapt followg drop-chans cmd to use 'all-but-refpick'
            droplist = [raw.info['ch_names'][k] for k in range(raw.info['nchan']) if not k in refpick]
            tct = time.clock()
            twt = time.time()
            fltref = raw.copy().drop_channels(droplist)
            if use_refantinotch:
                rawref = raw.copy().drop_channels(droplist)
                fltref.notch_filter(notchfrqscln, fir_design='firwin',
                                    fir_window='hann', phase='zero',
                                    picks=np.array(list(range(nref))),
                                    method='fir')
                fltref._data = (rawref._data - fltref._data)
            else:
                fltref.filter(refhp, reflp, fir_design='firwin',
                              fir_window='hann', phase='zero',
                              picks=np.array(list(range(nref))),
                              method='fir')
            tc1 = time.clock()
            tw1 = time.time()
            if verbose:
                print(">>> filtering ref-chans  took {:.1f} ms ({:.2f} s walltime)".format((1000. * (tc1 - tct)),
                                                                                           (tw1 - twt)))

        if verbose:
            print("########## Calculating sig-ref/ref-ref-channel covariances:")
        # Calculate sig-ref/ref-ref-channel covariance:
        # (there is no need to calc inter-signal-chan cov,
        #  but there seems to be no appropriat fct available)
        # Here we copy the idea from compute_raw_data_covariance()
        # and truncate it as appropriate.
        tct = time.clock()
        twt = time.time()
        # The following reject and infosig entries are only
        # used in _is_good-calls.
        # _is_good() from mne-0.9.git-py2.7.egg/mne/epochs.py seems to
        # ignore ref-channels (not covered by dict) and checks individual
        # data segments - artifacts across a buffer boundary are not found.
        reject = dict(grad=4000e-13,  # T / m (gradiometers)
                      mag=4e-12,  # T (magnetometers)
                      eeg=40e-6,  # uV (EEG channels)
                      eog=250e-6)  # uV (EOG channels)

        infosig = copy.copy(raw.info)
        infosig['chs'] = [raw.info['chs'][k] for k in sigpick]
        # the below fields are *NOT* (190103) updated automatically when 'chs' is updated
        infosig['ch_names'] = [raw.info['ch_names'][k] for k in sigpick]
        infosig['nchan'] = len(sigpick)
        idx_by_typesig = channel_indices_by_type(infosig)

        # Read data in chunks:
        tstep = 0.2
        itstep = int(ceil(tstep * raw.info['sfreq']))
        sigmean = 0
        refmean = 0
        sscovdata = 0
        srcovdata = 0
        rrcovdata = 0
        n_samples = 0

        for first in range(itmin, itmax, itstep):
            last = first + itstep
            if last >= itmax:
                last = itmax
            raw_segmentsig, times = raw[sigpick, first:last]
            if use_reffilter:
                raw_segmentref, times = fltref[:, first:last]
            else:
                raw_segmentref, times = raw[refpick, first:last]

            if not exclude_artifacts or \
                    _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject, flat=None,
                             ignore_chs=raw.info['bads']):
                sigmean += raw_segmentsig.sum(axis=1)
                refmean += raw_segmentref.sum(axis=1)
                sscovdata += (raw_segmentsig * raw_segmentsig).sum(axis=1)
                srcovdata += np.dot(raw_segmentsig, raw_segmentref.T)
                rrcovdata += np.dot(raw_segmentref, raw_segmentref.T)
                n_samples += raw_segmentsig.shape[1]
            else:
                logger.info("Artefact detected in [%d, %d]" % (first, last))
        if n_samples <= 1:
            raise ValueError('Too few samples to calculate weights')
        sigmean /= n_samples
        refmean /= n_samples
        sscovdata -= n_samples * sigmean[:] * sigmean[:]
        sscovdata /= (n_samples - 1)
        srcovdata -= n_samples * sigmean[:, None] * refmean[None, :]
        srcovdata /= (n_samples - 1)
        rrcovdata -= n_samples * refmean[:, None] * refmean[None, :]
        rrcovdata /= (n_samples - 1)
        sscovinit = np.copy(sscovdata)
        if verbose:
            print(">>> Normalize srcov...")

        rrslope = copy.copy(rrcovdata)
        for iref in range(nref):
            dtmp = rrcovdata[iref, iref]
            if dtmp > TINY:
                srcovdata[:, iref] /= dtmp
                rrslope[:, iref] /= dtmp
            else:
                srcovdata[:, iref] = 0.
                rrslope[:, iref] = 0.

        if verbose:
            print(">>> Number of samples used : %d" % n_samples)
            tc1 = time.clock()
            tw1 = time.time()
            print(">>> sigrefchn covar-calc took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tct), (tw1 - twt)))

        if checkresults:
            if verbose:
                print("########## Calculated initial signal channel covariance:")
                # Calculate initial signal channel covariance:
                # (only used as quality measure)
                print(">>> initl rt(avg sig pwr) = %12.5e" % np.sqrt(np.mean(sscovdata)))
                for i in range(min(5, nsig)):
                    print(">>> initl signal-rms[%3d] = %12.5e" % (i, np.sqrt(sscovdata.flatten()[i])))
                for i in range(max(0, nsig - 5), nsig):
                    print(">>> initl signal-rms[%3d] = %12.5e" % (i, np.sqrt(sscovdata.flatten()[i])))
                print(">>>")

        U, s, V = np.linalg.svd(rrslope, full_matrices=True)
        if verbose:
            print(">>> singular values:")
            print(s)
            print(">>> Applying cutoff for smallest SVs:")

        dtmp = s.max() * SVD_RELCUTOFF
        s *= (abs(s) >= dtmp)
        sinv = [1. / s[k] if s[k] != 0. else 0. for k in range(nref)]
        if verbose:
            print(">>> singular values (after cutoff):")
            print(s)

        stat = np.allclose(rrslope, np.dot(U, np.dot(np.diag(s), V)))
        if verbose:
            print(">>> Testing svd-result: %s" % stat)
            if not stat:
                print("    (Maybe due to SV-cutoff?)")

        # Solve for inverse coefficients:
        # Set RRinv.tr=U diag(sinv) V
        RRinv = np.transpose(np.dot(U, np.dot(np.diag(sinv), V)))
        if checkresults:
            stat = np.allclose(np.identity(nref), np.dot(RRinv, rrslope))
            if stat:
                if verbose:
                    print(">>> Testing RRinv-result (should be unit-matrix): ok")
            else:
                print(">>> Testing RRinv-result (should be unit-matrix): failed")
                print(np.transpose(np.dot(RRinv, rrslope)))
                print(">>>")

        if verbose:
            print("########## Calc weight matrix...")

        # weights-matrix will be somewhat larger than necessary,
        # (to simplify indexing in compensation loop):
        weights = np.zeros((raw._data.shape[0], nref))
        for isig in range(nsig):
            for iref in range(nref):
                weights[sigpick[isig], iref] = np.dot(srcovdata[isig, :], RRinv[:, iref])

        if verbose:
            print("########## Compensating signal channels:")
            if complementary_signal:
                print(">>> Caveat: REPLACING signal by compensation signal")

        tct = time.clock()
        twt = time.time()

        # Work on entire data stream:
        for isl in range(raw._data.shape[1]):
            slice = np.take(raw._data, [isl], axis=1)
            if use_reffilter:
                refslice = np.take(fltref._data, [isl], axis=1)
                refarr = refslice[:].flatten() - refmean
                # refarr = fltres[:,isl]-refmean
            else:
                refarr = slice[refpick].flatten() - refmean
            subrefarr = np.dot(weights[:], refarr)

            if not complementary_signal:
                raw._data[:, isl] -= subrefarr
            else:
                raw._data[:, isl] = subrefarr

            if (isl % 10000 == 0 or isl + 1 == raw._data.shape[1]) and verbose:
                print("\rProcessed slice %6d" % isl, end=" ")
                sys.stdout.flush()

        if verbose:
            print("\nDone.")
            tc1 = time.clock()
            tw1 = time.time()
            print(">>> compensation loop took {:.1f} ms ({:.2f} s walltime)".format((1000. * (tc1 - tct)), (tw1 - twt)))

        if checkresults:
            if verbose:
                print("########## Calculating final signal channel covariance:")
            # Calculate final signal channel covariance:
            # (only used as quality measure)
            tct = time.clock()
            twt = time.time()
            sigmean = 0
            sscovdata = 0
            n_samples = 0
            for first in range(itmin, itmax, itstep):
                last = first + itstep
                if last >= itmax:
                    last = itmax
                raw_segmentsig, times = raw[sigpick, first:last]
                # Artifacts found here will probably differ from pre-noisered artifacts!
                if not exclude_artifacts or \
                        _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject,
                                 flat=None, ignore_chs=raw.info['bads']):
                    sigmean += raw_segmentsig.sum(axis=1)
                    sscovdata += (raw_segmentsig * raw_segmentsig).sum(axis=1)
                    n_samples += raw_segmentsig.shape[1]
            if n_samples <= 1:
                raise ValueError('Too few samples to calculate final signal channel covariance')
            sigmean /= n_samples
            sscovdata -= n_samples * sigmean[:] * sigmean[:]
            sscovdata /= (n_samples - 1)
            if verbose:
                print(">>> no channel got worse: %s" % str(np.all(np.less_equal(sscovdata, sscovinit))))
                print(">>> final rt(avg sig pwr) = %12.5e" % np.sqrt(np.mean(sscovdata)))
                for i in range(min(5, nsig)):
                    print(">>> final signal-rms[%3d] = %12.5e" % (i, np.sqrt(sscovdata.flatten()[i])))
                # for i in range(min(5,nsig),max(0,nsig-5)):
                #    print(">>> final signal-rms[%3d] = %12.5e" % (i, np.sqrt(sscovdata.flatten()[i])))
                for i in range(max(0, nsig - 5), nsig):
                    print(">>> final signal-rms[%3d] = %12.5e" % (i, np.sqrt(sscovdata.flatten()[i])))
                tc1 = time.clock()
                tw1 = time.time()
                print(">>> signal covar-calc took {:.1f} ms ({:.2f} s walltime)".format((1000. * (tc1 - tct)),
                                                                                        (tw1 - twt)))
                print(">>>")

        if fnout is not None:
            fnoutloc = fnout
        elif return_raw:
            fnoutloc = None
        elif have_input_file:
            fnoutloc = fname[:fname.rfind('-raw.fif')] + ',nr-raw.fif'
        else:
            fnoutloc = None

        if fnoutloc is not None:
            if verbose:
                print(">>> Saving '%s'..." % fnoutloc)
            raw.save(fnoutloc, overwrite=True)

        tc1 = time.clock()
        tw1 = time.time()
        if verbose:
            print(">>> Total run took {:.1f} ms ({:.2f} s walltime)".format((1000. * (tc1 - tc0)), (tw1 - tw0)))

        if return_raw:
            if verbose:
                print(">>> Returning raw object...")
            return raw
Example #30
0
def test_plot_topomap():
    """Test topomap plotting."""
    # evoked
    res = 8
    fast_test = dict(res=res, contours=0, sensors=False, time_unit='s')
    fast_test_noscale = dict(res=res, contours=0, sensors=False)
    evoked = read_evokeds(evoked_fname, 'Left Auditory',
                          baseline=(None, 0))

    # Test animation
    _, anim = evoked.animate_topomap(ch_type='grad', times=[0, 0.1],
                                     butterfly=False, time_unit='s')
    anim._func(1)  # _animate has to be tested separately on 'Agg' backend.
    plt.close('all')

    ev_bad = evoked.copy().pick_types(meg=False, eeg=True)
    ev_bad.pick_channels(ev_bad.ch_names[:2])
    plt_topomap = partial(ev_bad.plot_topomap, **fast_test)
    plt_topomap(times=ev_bad.times[:2] - 1e-6)  # auto, plots EEG
    pytest.raises(ValueError, plt_topomap, ch_type='mag')
    pytest.raises(TypeError, plt_topomap, head_pos='foo')
    pytest.raises(KeyError, plt_topomap, head_pos=dict(foo='bar'))
    pytest.raises(ValueError, plt_topomap, head_pos=dict(center=0))
    pytest.raises(ValueError, plt_topomap, times=[-100])  # bad time
    pytest.raises(ValueError, plt_topomap, times=[[0]])  # bad time

    evoked.plot_topomap([0.1], ch_type='eeg', scalings=1, res=res,
                        contours=[-100, 0, 100], time_unit='ms')

    # extrapolation to the edges of the convex hull or the head circle
    evoked.plot_topomap([0.1], ch_type='eeg', scalings=1, res=res,
                        contours=[-100, 0, 100], time_unit='ms',
                        extrapolate='local')
    evoked.plot_topomap([0.1], ch_type='eeg', scalings=1, res=res,
                        contours=[-100, 0, 100], time_unit='ms',
                        extrapolate='head')
    evoked.plot_topomap([0.1], ch_type='eeg', scalings=1, res=res,
                        contours=[-100, 0, 100], time_unit='ms',
                        extrapolate='head', outlines='skirt')

    # extrapolation options when < 4 channels:
    temp_data = np.random.random(3)
    picks = channel_indices_by_type(evoked.info)['mag'][:3]
    info_sel = pick_info(evoked.info, picks)
    plot_topomap(temp_data, info_sel, extrapolate='local', res=res)
    plot_topomap(temp_data, info_sel, extrapolate='head', res=res)

    plt_topomap = partial(evoked.plot_topomap, **fast_test)
    plt_topomap(0.1, layout=layout, scalings=dict(mag=0.1))
    plt.close('all')
    axes = [plt.subplot(221), plt.subplot(222)]
    plt_topomap(axes=axes, colorbar=False)
    plt.close('all')
    plt_topomap(times=[-0.1, 0.2])
    plt.close('all')
    evoked_grad = evoked.copy().crop(0, 0).pick_types(meg='grad')
    mask = np.zeros((204, 1), bool)
    mask[[0, 3, 5, 6]] = True
    names = []

    def proc_names(x):
        names.append(x)
        return x[4:]

    evoked_grad.plot_topomap(ch_type='grad', times=[0], mask=mask,
                             show_names=proc_names, **fast_test)
    assert_equal(sorted(names),
                 ['MEG 011x', 'MEG 012x', 'MEG 013x', 'MEG 014x'])
    mask = np.zeros_like(evoked.data, dtype=bool)
    mask[[1, 5], :] = True
    plt_topomap(ch_type='mag', outlines=None)
    times = [0.1]
    plt_topomap(times, ch_type='grad', mask=mask)
    plt_topomap(times, ch_type='planar1')
    plt_topomap(times, ch_type='planar2')
    plt_topomap(times, ch_type='grad', mask=mask, show_names=True,
                mask_params={'marker': 'x'})
    plt.close('all')
    pytest.raises(ValueError, plt_topomap, times, ch_type='eeg', average=-1e3)
    pytest.raises(ValueError, plt_topomap, times, ch_type='eeg', average='x')

    p = plt_topomap(times, ch_type='grad', image_interp='bilinear',
                    show_names=lambda x: x.replace('MEG', ''))
    subplot = [x for x in p.get_children() if 'Subplot' in str(type(x))]
    assert len(subplot) >= 1, [type(x) for x in p.get_children()]
    subplot = subplot[0]
    assert (all('MEG' not in x.get_text()
                for x in subplot.get_children()
                if isinstance(x, matplotlib.text.Text)))

    # Plot array
    for ch_type in ('mag', 'grad'):
        evoked_ = evoked.copy().pick_types(eeg=False, meg=ch_type)
        plot_topomap(evoked_.data[:, 0], evoked_.info, **fast_test_noscale)
    # fail with multiple channel types
    pytest.raises(ValueError, plot_topomap, evoked.data[0, :], evoked.info)

    # Test title
    def get_texts(p):
        return [x.get_text() for x in p.get_children() if
                isinstance(x, matplotlib.text.Text)]

    p = plt_topomap(times, ch_type='eeg', average=0.01)
    assert_equal(len(get_texts(p)), 0)
    p = plt_topomap(times, ch_type='eeg', title='Custom')
    texts = get_texts(p)
    assert_equal(len(texts), 1)
    assert_equal(texts[0], 'Custom')
    plt.close('all')

    # delaunay triangulation warning
    plt_topomap(times, ch_type='mag', layout=None)
    # projs have already been applied
    pytest.raises(RuntimeError, plot_evoked_topomap, evoked, 0.1, 'mag',
                  proj='interactive', time_unit='s')

    # change to no-proj mode
    evoked = read_evokeds(evoked_fname, 'Left Auditory',
                          baseline=(None, 0), proj=False)
    fig1 = evoked.plot_topomap('interactive', 'mag', proj='interactive',
                               **fast_test)
    _fake_click(fig1, fig1.axes[1], (0.5, 0.5))  # click slider
    data_max = np.max(fig1.axes[0].images[0]._A)
    fig2 = plt.gcf()
    _fake_click(fig2, fig2.axes[0], (0.075, 0.775))  # toggle projector
    # make sure projector gets toggled
    assert (np.max(fig1.axes[0].images[0]._A) != data_max)

    pytest.raises(RuntimeError, plot_evoked_topomap, evoked,
                  np.repeat(.1, 50), time_unit='s')
    pytest.raises(ValueError, plot_evoked_topomap, evoked, [-3e12, 15e6],
                  time_unit='s')

    for ch in evoked.info['chs']:
        if ch['coil_type'] == FIFF.FIFFV_COIL_EEG:
            ch['loc'].fill(0)

    # Remove extra digitization point, so EEG digitization points
    # correspond with the EEG electrodes
    del evoked.info['dig'][85]

    pos = make_eeg_layout(evoked.info).pos[:, :2]
    pos, outlines = _check_outlines(pos, 'head')
    assert ('head' in outlines.keys())
    assert ('nose' in outlines.keys())
    assert ('ear_left' in outlines.keys())
    assert ('ear_right' in outlines.keys())
    assert ('autoshrink' in outlines.keys())
    assert (outlines['autoshrink'])
    assert ('clip_radius' in outlines.keys())
    assert_array_equal(outlines['clip_radius'], 0.5)

    pos, outlines = _check_outlines(pos, 'skirt')
    assert ('head' in outlines.keys())
    assert ('nose' in outlines.keys())
    assert ('ear_left' in outlines.keys())
    assert ('ear_right' in outlines.keys())
    assert ('autoshrink' in outlines.keys())
    assert (not outlines['autoshrink'])
    assert ('clip_radius' in outlines.keys())
    assert_array_equal(outlines['clip_radius'], 0.625)

    pos, outlines = _check_outlines(pos, 'skirt',
                                    head_pos={'scale': [1.2, 1.2]})
    assert_array_equal(outlines['clip_radius'], 0.75)

    # Plot skirt
    evoked.plot_topomap(times, ch_type='eeg', outlines='skirt', **fast_test)

    # Pass custom outlines without patch
    evoked.plot_topomap(times, ch_type='eeg', outlines=outlines, **fast_test)
    plt.close('all')

    # Test interactive cmap
    fig = plot_evoked_topomap(evoked, times=[0., 0.1], ch_type='eeg',
                              cmap=('Reds', True), title='title', **fast_test)
    fig.canvas.key_press_event('up')
    fig.canvas.key_press_event(' ')
    fig.canvas.key_press_event('down')
    cbar = fig.get_axes()[0].CB  # Fake dragging with mouse.
    ax = cbar.cbar.ax
    _fake_click(fig, ax, (0.1, 0.1))
    _fake_click(fig, ax, (0.1, 0.2), kind='motion')
    _fake_click(fig, ax, (0.1, 0.3), kind='release')

    _fake_click(fig, ax, (0.1, 0.1), button=3)
    _fake_click(fig, ax, (0.1, 0.2), button=3, kind='motion')
    _fake_click(fig, ax, (0.1, 0.3), kind='release')

    fig.canvas.scroll_event(0.5, 0.5, -0.5)  # scroll down
    fig.canvas.scroll_event(0.5, 0.5, 0.5)  # scroll up

    plt.close('all')

    # Pass custom outlines with patch callable
    def patch():
        return Circle((0.5, 0.4687), radius=.46,
                      clip_on=True, transform=plt.gca().transAxes)
    outlines['patch'] = patch
    plot_evoked_topomap(evoked, times, ch_type='eeg', outlines=outlines,
                        **fast_test)

    # Remove digitization points. Now topomap should fail
    evoked.info['dig'] = None
    pytest.raises(RuntimeError, plot_evoked_topomap, evoked,
                  times, ch_type='eeg', time_unit='s')
    plt.close('all')

    # Error for missing names
    n_channels = len(pos)
    data = np.ones(n_channels)
    pytest.raises(ValueError, plot_topomap, data, pos, show_names=True)

    # Test error messages for invalid pos parameter
    pos_1d = np.zeros(n_channels)
    pos_3d = np.zeros((n_channels, 2, 2))
    pytest.raises(ValueError, plot_topomap, data, pos_1d)
    pytest.raises(ValueError, plot_topomap, data, pos_3d)
    pytest.raises(ValueError, plot_topomap, data, pos[:3, :])

    pos_x = pos[:, :1]
    pos_xyz = np.c_[pos, np.zeros(n_channels)[:, np.newaxis]]
    pytest.raises(ValueError, plot_topomap, data, pos_x)
    pytest.raises(ValueError, plot_topomap, data, pos_xyz)

    # An #channels x 4 matrix should work though. In this case (x, y, width,
    # height) is assumed.
    pos_xywh = np.c_[pos, np.zeros((n_channels, 2))]
    plot_topomap(data, pos_xywh)
    plt.close('all')

    # Test peak finder
    axes = [plt.subplot(131), plt.subplot(132)]
    evoked.plot_topomap(times='peaks', axes=axes, **fast_test)
    plt.close('all')
    evoked.data = np.zeros(evoked.data.shape)
    evoked.data[50][1] = 1
    assert_array_equal(_find_peaks(evoked, 10), evoked.times[1])
    evoked.data[80][100] = 1
    assert_array_equal(_find_peaks(evoked, 10), evoked.times[[1, 100]])
    evoked.data[2][95] = 2
    assert_array_equal(_find_peaks(evoked, 10), evoked.times[[1, 95]])
    assert_array_equal(_find_peaks(evoked, 1), evoked.times[95])

    # Test excluding bads channels
    evoked_grad.info['bads'] += [evoked_grad.info['ch_names'][0]]
    orig_bads = evoked_grad.info['bads']
    evoked_grad.plot_topomap(ch_type='grad', times=[0], time_unit='ms')
    assert_array_equal(evoked_grad.info['bads'], orig_bads)
    plt.close('all')
Example #31
0
def test_plot_topomap():
    """Test topomap plotting."""
    # evoked
    res = 8
    fast_test = dict(res=res, contours=0, sensors=False, time_unit='s')
    fast_test_noscale = dict(res=res, contours=0, sensors=False)
    evoked = read_evokeds(evoked_fname, 'Left Auditory', baseline=(None, 0))

    # Test animation
    _, anim = evoked.animate_topomap(ch_type='grad',
                                     times=[0, 0.1],
                                     butterfly=False,
                                     time_unit='s')
    anim._func(1)  # _animate has to be tested separately on 'Agg' backend.
    plt.close('all')

    ev_bad = evoked.copy().pick_types(meg=False, eeg=True)
    ev_bad.pick_channels(ev_bad.ch_names[:2])
    plt_topomap = partial(ev_bad.plot_topomap, **fast_test)
    plt_topomap(times=ev_bad.times[:2] - 1e-6)  # auto, plots EEG
    pytest.raises(ValueError, plt_topomap, ch_type='mag')
    pytest.raises(TypeError, plt_topomap, head_pos='foo')
    pytest.raises(KeyError, plt_topomap, head_pos=dict(foo='bar'))
    pytest.raises(ValueError, plt_topomap, head_pos=dict(center=0))
    pytest.raises(ValueError, plt_topomap, times=[-100])  # bad time
    pytest.raises(ValueError, plt_topomap, times=[[0]])  # bad time

    evoked.plot_topomap([0.1],
                        ch_type='eeg',
                        scalings=1,
                        res=res,
                        contours=[-100, 0, 100],
                        time_unit='ms')

    # extrapolation to the edges of the convex hull or the head circle
    evoked.plot_topomap([0.1],
                        ch_type='eeg',
                        scalings=1,
                        res=res,
                        contours=[-100, 0, 100],
                        time_unit='ms',
                        extrapolate='local')
    evoked.plot_topomap([0.1],
                        ch_type='eeg',
                        scalings=1,
                        res=res,
                        contours=[-100, 0, 100],
                        time_unit='ms',
                        extrapolate='head')
    evoked.plot_topomap([0.1],
                        ch_type='eeg',
                        scalings=1,
                        res=res,
                        contours=[-100, 0, 100],
                        time_unit='ms',
                        extrapolate='head',
                        outlines='skirt')

    # extrapolation options when < 4 channels:
    temp_data = np.random.random(3)
    picks = channel_indices_by_type(evoked.info)['mag'][:3]
    info_sel = pick_info(evoked.info, picks)
    plot_topomap(temp_data, info_sel, extrapolate='local', res=res)
    plot_topomap(temp_data, info_sel, extrapolate='head', res=res)

    plt_topomap = partial(evoked.plot_topomap, **fast_test)
    plt_topomap(0.1, layout=layout, scalings=dict(mag=0.1))
    plt.close('all')
    axes = [plt.subplot(221), plt.subplot(222)]
    plt_topomap(axes=axes, colorbar=False)
    plt.close('all')
    plt_topomap(times=[-0.1, 0.2])
    plt.close('all')
    evoked_grad = evoked.copy().crop(0, 0).pick_types(meg='grad')
    mask = np.zeros((204, 1), bool)
    mask[[0, 3, 5, 6]] = True
    names = []

    def proc_names(x):
        names.append(x)
        return x[4:]

    evoked_grad.plot_topomap(ch_type='grad',
                             times=[0],
                             mask=mask,
                             show_names=proc_names,
                             **fast_test)
    assert_equal(sorted(names),
                 ['MEG 011x', 'MEG 012x', 'MEG 013x', 'MEG 014x'])
    mask = np.zeros_like(evoked.data, dtype=bool)
    mask[[1, 5], :] = True
    plt_topomap(ch_type='mag', outlines=None)
    times = [0.1]
    plt_topomap(times, ch_type='grad', mask=mask)
    plt_topomap(times, ch_type='planar1')
    plt_topomap(times, ch_type='planar2')
    plt_topomap(times,
                ch_type='grad',
                mask=mask,
                show_names=True,
                mask_params={'marker': 'x'})
    plt.close('all')
    pytest.raises(ValueError, plt_topomap, times, ch_type='eeg', average=-1e3)
    pytest.raises(ValueError, plt_topomap, times, ch_type='eeg', average='x')

    p = plt_topomap(times,
                    ch_type='grad',
                    image_interp='bilinear',
                    show_names=lambda x: x.replace('MEG', ''))
    subplot = [x for x in p.get_children() if 'Subplot' in str(type(x))]
    assert len(subplot) >= 1, [type(x) for x in p.get_children()]
    subplot = subplot[0]

    have_all = all('MEG' not in x.get_text() for x in subplot.get_children()
                   if isinstance(x, matplotlib.text.Text))
    assert have_all

    # Plot array
    for ch_type in ('mag', 'grad'):
        evoked_ = evoked.copy().pick_types(eeg=False, meg=ch_type)
        plot_topomap(evoked_.data[:, 0], evoked_.info, **fast_test_noscale)
    # fail with multiple channel types
    pytest.raises(ValueError, plot_topomap, evoked.data[0, :], evoked.info)

    # Test title
    def get_texts(p):
        return [
            x.get_text() for x in p.get_children()
            if isinstance(x, matplotlib.text.Text)
        ]

    p = plt_topomap(times, ch_type='eeg', average=0.01)
    assert_equal(len(get_texts(p)), 0)
    p = plt_topomap(times, ch_type='eeg', title='Custom')
    texts = get_texts(p)
    assert_equal(len(texts), 1)
    assert_equal(texts[0], 'Custom')
    plt.close('all')

    # delaunay triangulation warning
    plt_topomap(times, ch_type='mag', layout=None)
    # projs have already been applied
    pytest.raises(RuntimeError,
                  plot_evoked_topomap,
                  evoked,
                  0.1,
                  'mag',
                  proj='interactive',
                  time_unit='s')

    # change to no-proj mode
    evoked = read_evokeds(evoked_fname,
                          'Left Auditory',
                          baseline=(None, 0),
                          proj=False)
    fig1 = evoked.plot_topomap('interactive',
                               'mag',
                               proj='interactive',
                               **fast_test)
    _fake_click(fig1, fig1.axes[1], (0.5, 0.5))  # click slider
    data_max = np.max(fig1.axes[0].images[0]._A)
    fig2 = plt.gcf()
    _fake_click(fig2, fig2.axes[0], (0.075, 0.775))  # toggle projector
    # make sure projector gets toggled
    assert (np.max(fig1.axes[0].images[0]._A) != data_max)

    pytest.raises(RuntimeError,
                  plot_evoked_topomap,
                  evoked,
                  np.repeat(.1, 50),
                  time_unit='s')
    pytest.raises(ValueError,
                  plot_evoked_topomap,
                  evoked, [-3e12, 15e6],
                  time_unit='s')

    for ch in evoked.info['chs']:
        if ch['coil_type'] == FIFF.FIFFV_COIL_EEG:
            ch['loc'].fill(0)

    # Remove extra digitization point, so EEG digitization points
    # correspond with the EEG electrodes
    del evoked.info['dig'][85]

    pos = make_eeg_layout(evoked.info).pos[:, :2]
    pos, outlines = _check_outlines(pos, 'head')
    assert ('head' in outlines.keys())
    assert ('nose' in outlines.keys())
    assert ('ear_left' in outlines.keys())
    assert ('ear_right' in outlines.keys())
    assert ('autoshrink' in outlines.keys())
    assert (outlines['autoshrink'])
    assert ('clip_radius' in outlines.keys())
    assert_array_equal(outlines['clip_radius'], 0.5)

    pos, outlines = _check_outlines(pos, 'skirt')
    assert ('head' in outlines.keys())
    assert ('nose' in outlines.keys())
    assert ('ear_left' in outlines.keys())
    assert ('ear_right' in outlines.keys())
    assert ('autoshrink' in outlines.keys())
    assert (not outlines['autoshrink'])
    assert ('clip_radius' in outlines.keys())
    assert_array_equal(outlines['clip_radius'], 0.625)

    pos, outlines = _check_outlines(pos,
                                    'skirt',
                                    head_pos={'scale': [1.2, 1.2]})
    assert_array_equal(outlines['clip_radius'], 0.75)

    # Plot skirt
    evoked.plot_topomap(times, ch_type='eeg', outlines='skirt', **fast_test)

    # Pass custom outlines without patch
    evoked.plot_topomap(times, ch_type='eeg', outlines=outlines, **fast_test)
    plt.close('all')

    # Test interactive cmap
    fig = plot_evoked_topomap(evoked,
                              times=[0., 0.1],
                              ch_type='eeg',
                              cmap=('Reds', True),
                              title='title',
                              **fast_test)
    fig.canvas.key_press_event('up')
    fig.canvas.key_press_event(' ')
    fig.canvas.key_press_event('down')
    cbar = fig.get_axes()[0].CB  # Fake dragging with mouse.
    ax = cbar.cbar.ax
    _fake_click(fig, ax, (0.1, 0.1))
    _fake_click(fig, ax, (0.1, 0.2), kind='motion')
    _fake_click(fig, ax, (0.1, 0.3), kind='release')

    _fake_click(fig, ax, (0.1, 0.1), button=3)
    _fake_click(fig, ax, (0.1, 0.2), button=3, kind='motion')
    _fake_click(fig, ax, (0.1, 0.3), kind='release')

    fig.canvas.scroll_event(0.5, 0.5, -0.5)  # scroll down
    fig.canvas.scroll_event(0.5, 0.5, 0.5)  # scroll up

    plt.close('all')

    # Pass custom outlines with patch callable
    def patch():
        return Circle((0.5, 0.4687),
                      radius=.46,
                      clip_on=True,
                      transform=plt.gca().transAxes)

    outlines['patch'] = patch
    plot_evoked_topomap(evoked,
                        times,
                        ch_type='eeg',
                        outlines=outlines,
                        **fast_test)

    # Remove digitization points. Now topomap should fail
    evoked.info['dig'] = None
    pytest.raises(RuntimeError,
                  plot_evoked_topomap,
                  evoked,
                  times,
                  ch_type='eeg',
                  time_unit='s')
    plt.close('all')

    # Error for missing names
    n_channels = len(pos)
    data = np.ones(n_channels)
    pytest.raises(ValueError, plot_topomap, data, pos, show_names=True)

    # Test error messages for invalid pos parameter
    pos_1d = np.zeros(n_channels)
    pos_3d = np.zeros((n_channels, 2, 2))
    pytest.raises(ValueError, plot_topomap, data, pos_1d)
    pytest.raises(ValueError, plot_topomap, data, pos_3d)
    pytest.raises(ValueError, plot_topomap, data, pos[:3, :])

    pos_x = pos[:, :1]
    pos_xyz = np.c_[pos, np.zeros(n_channels)[:, np.newaxis]]
    pytest.raises(ValueError, plot_topomap, data, pos_x)
    pytest.raises(ValueError, plot_topomap, data, pos_xyz)

    # An #channels x 4 matrix should work though. In this case (x, y, width,
    # height) is assumed.
    pos_xywh = np.c_[pos, np.zeros((n_channels, 2))]
    plot_topomap(data, pos_xywh)
    plt.close('all')

    # Test peak finder
    axes = [plt.subplot(131), plt.subplot(132)]
    evoked.plot_topomap(times='peaks', axes=axes, **fast_test)
    plt.close('all')
    evoked.data = np.zeros(evoked.data.shape)
    evoked.data[50][1] = 1
    assert_array_equal(_find_peaks(evoked, 10), evoked.times[1])
    evoked.data[80][100] = 1
    assert_array_equal(_find_peaks(evoked, 10), evoked.times[[1, 100]])
    evoked.data[2][95] = 2
    assert_array_equal(_find_peaks(evoked, 10), evoked.times[[1, 95]])
    assert_array_equal(_find_peaks(evoked, 1), evoked.times[95])

    # Test excluding bads channels
    evoked_grad.info['bads'] += [evoked_grad.info['ch_names'][0]]
    orig_bads = evoked_grad.info['bads']
    evoked_grad.plot_topomap(ch_type='grad', times=[0], time_unit='ms')
    assert_array_equal(evoked_grad.info['bads'], orig_bads)
    plt.close('all')
Example #32
0
def preproc1epoch(eeg,info,projs=[],SSP=True,reject=None,mne_reject=1,reject_ch=None,flat=None,bad_channels=[],opt_detrend=1):

    '''
    Preprocesses epoched EEG data.
    
    # Input
    - eeg: numPy array. EEG epoch in the following format: (time samples, channels).
    - info: MNE info data structure. Predefined info containing channels etc. Can be generated using create_info_mne function.
    - projs: MNE SSP projector objects. Used if SSP = True. 
    - reject: bool. Whether to reject channels, either manually defined or based on MNE analysis.
    - reject_ch: bool. Whether to reject nine predefined channels.
    - mne_reject: bool. Whether to use MNE rejection based on epochs._is_good. 
    - flat: bool??. Input for MNE rejection
    - bad_channels: list. Manual rejection of channels.
    - opt_detrend: bool. Whether to apply temporal EEG detrending (linear).
    
    # Preprocessing steps - based on inputs 
    - Linear temporal detrending
    - Rejection of initial, predefined channels 
    - Bandpass filter (0-40Hz)
    - Resample to 100Hz
    - SSP correction 
    - Rejection of bad channels
        - Interpolation of bad channels
    - Rereference to average
    - Baseline correction
    
    # Output
    - Epoched preprocessed EEG data in numPy array.
    
    '''
    
    n_samples = eeg.shape[0]
    n_channels = eeg.shape[1]
    eeg = np.reshape(eeg.T,(1,n_channels,n_samples))
    tmin = -0.1 # Baseline start 
    
    # Temporal detrending:
    if opt_detrend == 1:
        eeg = detrend(eeg, axis=2, type='linear')
        
    epoch = mne.EpochsArray(eeg, info, tmin=tmin, baseline=None, verbose=False)
    
    # Drop list of channels known to be problematic:
    if reject_ch == True: 
        bads =  ['Fp1','Fp2','Fz','AF3','AF4','T7','T8','F7','F8']
        epoch.drop_channels(bads)
    
    # Lowpass
    epoch.filter(HP, LP, fir_design='firwin', phase=phase, verbose=False)
    
    # Downsample
    epoch.resample(100, npad='auto',verbose=False)
    
    # Apply baseline correction
    epoch.apply_baseline(baseline=(None,0),verbose=False)
    
    # Apply SSP projectors
    if SSP == True:
        # Apply projection to the epochs already defined
        epoch.add_proj(projs)
        epoch.apply_proj()
        
    if reject is not None: # Rejection of channels, either manually defined or based on MNE analysis. Currently not used.
        if mne_reject == 1: # Use MNE method to reject+interpolate bad channels
            from mne.epochs import _is_good
            from mne.io.pick import channel_indices_by_type    
            #reject=dict(eeg=100)
            idx_by_type = channel_indices_by_type(epoch.info)
            A,bad_channels = _is_good(epoch.get_data()[0], epoch.ch_names, channel_type_idx=idx_by_type,reject=reject, flat=flat, full_report=True)
            print(A)
            if A == False:
                epoch.info['bads']=bad_channels    
                epoch.interpolate_bads(reset_bads=True, verbose=False)
        else: # Predfined bad_channels 
            epoch.drop_channels(bad_channels)
    
    # Rereferencing
    epoch.set_eeg_reference(verbose=False)
    
    # Apply baseline after rereference
    epoch.apply_baseline(baseline=(None,0),verbose=False)
        
    epoch = epoch.get_data()[0]
    
    return epoch
Example #33
0
def test_noise_reducer():

    data_path = os.environ['SUBJECTS_DIR']
    subject   = os.environ['SUBJECT']

    dname = data_path + '/' + 'empty_room_files' + '/109925_empty_room_file-raw.fif'
    subjects_dir = data_path + '/subjects'
    #
    checkresults = True
    exclart = False
    use_reffilter = True
    refflt_lpfreq = 52.
    refflt_hpfreq = 48.

    print "########## before of noisereducer call ##########"
    sigchanlist = ['MEG ..1', 'MEG ..3', 'MEG ..5', 'MEG ..7', 'MEG ..9']
    sigchanlist = None
    refchanlist = ['RFM 001', 'RFM 003', 'RFM 005', 'RFG ...']
    tmin = 15.
    noise_reducer(dname, signals=sigchanlist, noiseref=refchanlist, tmin=tmin,
                  reflp=refflt_lpfreq, refhp=refflt_hpfreq,
                  exclude_artifacts=exclart, complementary_signal=True)
    print "########## behind of noisereducer call ##########"

    print "########## Read raw data:"
    tc0 = time.clock()
    tw0 = time.time()
    raw = mne.io.Raw(dname, preload=True)
    tc1 = time.clock()
    tw1 = time.time()
    print "loading raw data  took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tc0), (tw1 - tw0))

    # Time window selection
    # weights are calc'd based on [tmin,tmax], but applied to the entire data set.
    # tstep is used in artifact detection
    tmax = raw.times[raw.last_samp]
    tstep = 0.2
    itmin = int(floor(tmin * raw.info['sfreq']))
    itmax = int(ceil(tmax * raw.info['sfreq']))
    itstep = int(ceil(tstep * raw.info['sfreq']))
    print ">>> Set time-range to [%7.3f,%7.3f]" % (tmin, tmax)

    if sigchanlist is None:
        sigpick = mne.pick_types(raw.info, meg='mag', eeg=False, stim=False, eog=False, exclude='bads')
    else:
        sigpick = channel_indices_from_list(raw.info['ch_names'][:], sigchanlist)
    nsig = len(sigpick)
    print "sigpick: %3d chans" % nsig
    if nsig == 0:
        raise ValueError("No channel selected for noise compensation")

    if refchanlist is None:
        # References are not limited to 4D ref-chans, but can be anything,
        # incl. ECG or powerline monitor.
        print ">>> Using all refchans."
        refexclude = "bads"
        refpick = mne.pick_types(raw.info, ref_meg=True, meg=False, eeg=False,
                                 stim=False, eog=False, exclude=refexclude)
    else:
        refpick = channel_indices_from_list(raw.info['ch_names'][:], refchanlist)
        print "refpick = '%s'" % refpick
    nref = len(refpick)
    print "refpick: %3d chans" % nref
    if nref == 0:
        raise ValueError("No channel selected as noise reference")

    print "########## Refchan geo data:"
    # This is just for info to locate special 4D-refs.
    for iref in refpick:
        print raw.info['chs'][iref]['ch_name'], raw.info['chs'][iref]['loc'][0:3]
    print ""

    if use_reffilter:
        print "########## Filter reference channels:"
        if refflt_lpfreq is not None:
            print " low-pass with cutoff-freq %.1f" % refflt_lpfreq
        if refflt_hpfreq is not None:
            print "high-pass with cutoff-freq %.1f" % refflt_hpfreq
        # Adapt followg drop-chans cmd to use 'all-but-refpick'
        droplist = [raw.info['ch_names'][k] for k in xrange(raw.info['nchan']) if not k in refpick]
        fltref = raw.drop_channels(droplist, copy=True)
        tct = time.clock()
        twt = time.time()
        fltref.filter(refflt_hpfreq, refflt_lpfreq, picks=np.array(xrange(nref)), method='iir')
        tc1 = time.clock()
        tw1 = time.time()
        print "filtering ref-chans  took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tct), (tw1 - twt))

    print "########## Calculating sig-ref/ref-ref-channel covariances:"
    # Calculate sig-ref/ref-ref-channel covariance:
    # (there is no need to calc inter-signal-chan cov,
    #  but there seems to be no appropriat fct available)
    # Here we copy the idea from compute_raw_data_covariance()
    # and truncate it as appropriate.
    tct = time.clock()
    twt = time.time()
    # The following reject and info{sig,ref} entries are only
    # used in _is_good-calls.
    # _is_good() from mne-0.9.git-py2.7.egg/mne/epochs.py seems to
    # ignore ref-channels (not covered by dict) and checks individual
    # data segments - artifacts across a buffer boundary are not found.
    reject = dict(grad=4000e-13, # T / m (gradiometers)
                  mag=4e-12,     # T (magnetometers)
                  eeg=40e-6,     # uV (EEG channels)
                  eog=250e-6)    # uV (EOG channels)

    infosig = copy.copy(raw.info)
    infosig['chs'] = [raw.info['chs'][k] for k in sigpick]
    infosig['ch_names'] = [raw.info['ch_names'][k] for k in sigpick]
    infosig['nchan'] = len(sigpick)
    idx_by_typesig = channel_indices_by_type(infosig)

    # inforef not good w/ filtering, but anyway useless
    inforef = copy.copy(raw.info)
    inforef['chs'] = [raw.info['chs'][k] for k in refpick]
    inforef['ch_names'] = [raw.info['ch_names'][k] for k in refpick]
    inforef['nchan'] = len(refpick)
    idx_by_typeref = channel_indices_by_type(inforef)

    # Read data in chunks:
    sigmean = 0
    refmean = 0
    sscovdata = 0
    srcovdata = 0
    rrcovdata = 0
    n_samples = 0
    for first in range(itmin, itmax, itstep):
        last = first + itstep
        if last >= itmax:
            last = itmax
        raw_segmentsig, times = raw[sigpick, first:last]
        if use_reffilter:
            raw_segmentref, times = fltref[:, first:last]
        else:
            raw_segmentref, times = raw[refpick, first:last]
        # if True:
        # if _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject, flat=None,
        #            ignore_chs=raw.info['bads']) and _is_good(raw_segmentref,
        #              inforef['ch_names'], idx_by_typeref, reject, flat=None,
        #                ignore_chs=raw.info['bads']):
        if not exclart or \
           _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject,
                    flat=None, ignore_chs=raw.info['bads']):
            sigmean += raw_segmentsig.sum(axis=1)
            refmean += raw_segmentref.sum(axis=1)
            sscovdata += (raw_segmentsig * raw_segmentsig).sum(axis=1)
            srcovdata += np.dot(raw_segmentsig, raw_segmentref.T)
            rrcovdata += np.dot(raw_segmentref, raw_segmentref.T)
            n_samples += raw_segmentsig.shape[1]
        else:
            logger.info("Artefact detected in [%d, %d]" % (first, last))

    #_check_n_samples(n_samples, len(picks))
    sigmean /= n_samples
    refmean /= n_samples
    sscovdata -= n_samples * sigmean[:] * sigmean[:]
    sscovdata /= (n_samples - 1)
    srcovdata -= n_samples * sigmean[:, None] * refmean[None, :]
    srcovdata /= (n_samples - 1)
    rrcovdata -= n_samples * refmean[:, None] * refmean[None, :]
    rrcovdata /= (n_samples - 1)
    sscovinit = sscovdata
    print "Normalize srcov..."
    rrslopedata = copy.copy(rrcovdata)
    for iref in xrange(nref):
        dtmp = rrcovdata[iref][iref]
        if dtmp > TINY:
            for isig in xrange(nsig):
                srcovdata[isig][iref] /= dtmp
            for jref in xrange(nref):
                rrslopedata[jref][iref] /= dtmp
        else:
            for isig in xrange(nsig):
                srcovdata[isig][iref] = 0.
            for jref in xrange(nref):
                rrslopedata[jref][iref] = 0.
    logger.info("Number of samples used : %d" % n_samples)
    tc1 = time.clock()
    tw1 = time.time()
    print "sigrefchn covar-calc took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tct), (tw1 - twt))

    print "########## Calculating sig-ref/ref-ref-channel covariances (robust):"
    # Calculate sig-ref/ref-ref-channel covariance:
    # (usg B.P.Welford, "Note on a method for calculating corrected sums
    #                   of squares and products", Technometrics4 (1962) 419-420)
    # (there is no need to calc inter-signal-chan cov,
    #  but there seems to be no appropriat fct available)
    # Here we copy the idea from compute_raw_data_covariance()
    # and truncate it as appropriate.
    tct = time.clock()
    twt = time.time()
    # The following reject and info{sig,ref} entries are only
    # used in _is_good-calls.
    # _is_good() from mne-0.9.git-py2.7.egg/mne/epochs.py seems to
    # ignore ref-channels (not covered by dict) and checks individual
    # data segments - artifacts across a buffer boundary are not found.
    reject = dict(grad=4000e-13, # T / m (gradiometers)
                  mag=4e-12,     # T (magnetometers)
                  eeg=40e-6,     # uV (EEG channels)
                  eog=250e-6)    # uV (EOG channels)

    infosig = copy.copy(raw.info)
    infosig['chs'] = [raw.info['chs'][k] for k in sigpick]
    infosig['ch_names'] = [raw.info['ch_names'][k] for k in sigpick]
    infosig['nchan'] = len(sigpick)
    idx_by_typesig = channel_indices_by_type(infosig)

    # inforef not good w/ filtering, but anyway useless
    inforef = copy.copy(raw.info)
    inforef['chs'] = [raw.info['chs'][k] for k in refpick]
    inforef['ch_names'] = [raw.info['ch_names'][k] for k in refpick]
    inforef['nchan'] = len(refpick)
    idx_by_typeref = channel_indices_by_type(inforef)

    # Read data in chunks:
    smean = np.zeros(nsig)
    smold = np.zeros(nsig)
    rmean = np.zeros(nref)
    rmold = np.zeros(nref)
    sscov = 0
    srcov = 0
    rrcov = np.zeros((nref, nref))
    srcov = np.zeros((nsig, nref))
    n_samples = 0
    for first in range(itmin, itmax, itstep):
        last = first + itstep
        if last >= itmax:
            last = itmax
        raw_segmentsig, times = raw[sigpick, first:last]
        if use_reffilter:
            raw_segmentref, times = fltref[:, first:last]
        else:
            raw_segmentref, times = raw[refpick, first:last]
        # if True:
        # if _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject, flat=None,
        #            ignore_chs=raw.info['bads']) and _is_good(raw_segmentref,
        #              inforef['ch_names'], idx_by_typeref, reject, flat=None,
        #                ignore_chs=raw.info['bads']):
        if not exclart or \
           _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject,
                    flat=None, ignore_chs=raw.info['bads']):
            for isl in xrange(raw_segmentsig.shape[1]):
                nsl = isl + n_samples + 1
                cnslm1dnsl = float((nsl - 1)) / float(nsl)
                sslsubmean = (raw_segmentsig[:, isl] - smold)
                rslsubmean = (raw_segmentref[:, isl] - rmold)
                smean = smold + sslsubmean / nsl
                rmean = rmold + rslsubmean / nsl
                sscov += sslsubmean * (raw_segmentsig[:, isl] - smean)
                srcov += cnslm1dnsl * np.dot(sslsubmean.reshape((nsig, 1)), rslsubmean.reshape((1, nref)))
                rrcov += cnslm1dnsl * np.dot(rslsubmean.reshape((nref, 1)), rslsubmean.reshape((1, nref)))
                smold = smean
                rmold = rmean
            n_samples += raw_segmentsig.shape[1]
        else:
            logger.info("Artefact detected in [%d, %d]" % (first, last))

    #_check_n_samples(n_samples, len(picks))
    sscov /= (n_samples - 1)
    srcov /= (n_samples - 1)
    rrcov /= (n_samples - 1)
    print "Normalize srcov..."
    rrslope = copy.copy(rrcov)
    for iref in xrange(nref):
        dtmp = rrcov[iref][iref]
        if dtmp > TINY:
            srcov[:, iref] /= dtmp
            rrslope[:, iref] /= dtmp
        else:
            srcov[:, iref] = 0.
            rrslope[:, iref] = 0.
    logger.info("Number of samples used : %d" % n_samples)
    print "Compare results with 'standard' values:"
    print "cmp(sigmean,smean):", np.allclose(smean, sigmean, atol=0.)
    print "cmp(refmean,rmean):", np.allclose(rmean, refmean, atol=0.)
    print "cmp(sscovdata,sscov):", np.allclose(sscov, sscovdata, atol=0.)
    print "cmp(srcovdata,srcov):", np.allclose(srcov, srcovdata, atol=0.)
    print "cmp(rrcovdata,rrcov):", np.allclose(rrcov, rrcovdata, atol=0.)
    tc1 = time.clock()
    tw1 = time.time()
    print "sigrefchn covar-calc took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tct), (tw1 - twt))

    if checkresults:
        print "########## Calculated initial signal channel covariance:"
        # Calculate initial signal channel covariance:
        # (only used as quality measure)
        print "initl rt(avg sig pwr) = %12.5e" % np.sqrt(np.mean(sscov))
        for i in xrange(5):
            print "initl signal-rms[%3d] = %12.5e" % (i, np.sqrt(sscov.flatten()[i]))
        print " "
    if nref < 6:
        print "rrslope-entries:"
        for i in xrange(nref):
            print rrslope[i][:]

    U, s, V = np.linalg.svd(rrslope, full_matrices=True)
    print s

    print "Applying cutoff for smallest SVs:"
    dtmp = s.max() * SVD_RELCUTOFF
    sinv = np.zeros(nref)
    for i in xrange(nref):
        if abs(s[i]) >= dtmp:
            sinv[i] = 1. / s[i]
        else:
            s[i] = 0.
    # s *= (abs(s)>=dtmp)
    # sinv = ???
    print s
    stat = np.allclose(rrslope, np.dot(U, np.dot(np.diag(s), V)))
    print ">>> Testing svd-result: %s" % stat
    if not stat:
        print "    (Maybe due to SV-cutoff?)"

    # Solve for inverse coefficients:
    print ">>> Setting RRinvtr=U diag(sinv) V"
    RRinvtr = np.zeros((nref, nref))
    RRinvtr = np.dot(U, np.dot(np.diag(sinv), V))
    if checkresults:
        # print ">>> RRinvtr-result:"
        # print RRinvtr
        stat = np.allclose(np.identity(nref), np.dot(rrslope.transpose(), RRinvtr))
        if stat:
            print ">>> Testing RRinvtr-result (shld be unit-matrix): ok"
        else:
            print ">>> Testing RRinvtr-result (shld be unit-matrix): failed"
            print np.dot(rrslope.transpose(), RRinvtr)
            # np.less_equal(np.abs(np.dot(rrslope.transpose(),RRinvtr)-np.identity(nref)),0.01*np.ones((nref,nref)))
        print ""

    print "########## Calc weight matrix..."
    # weights-matrix will be somewhat larger than necessary,
    # (to simplify indexing in compensation loop):
    weights = np.zeros((raw._data.shape[0], nref))
    for isig in xrange(nsig):
        for iref in xrange(nref):
            weights[sigpick[isig]][iref] = np.dot(srcov[isig][:], RRinvtr[iref][:])

    if np.allclose(np.zeros(weights.shape), np.abs(weights), atol=1.e-8):
        print ">>> all weights are small (<=1.e-8)."
    else:
        print ">>> largest weight %12.5e" % np.max(np.abs(weights))
        wlrg = np.where(np.abs(weights) >= 0.99 * np.max(np.abs(weights)))
        for iwlrg in xrange(len(wlrg[0])):
            print ">>> weights[%3d,%2d] = %12.5e" % \
                  (wlrg[0][iwlrg], wlrg[1][iwlrg], weights[wlrg[0][iwlrg], wlrg[1][iwlrg]])

    if nref < 5:
        print "weights-entries for first sigchans:"
        for i in xrange(5):
            print 'weights[sp(%2d)][r]=[' % i + ' '.join([' %+10.7f' %
                             val for val in weights[sigpick[i]][:]]) + ']'

    print "########## Compensating signal channels:"
    tct = time.clock()
    twt = time.time()
    # data,times = raw[:,raw.time_as_index(tmin)[0]:raw.time_as_index(tmax)[0]:]
    # Work on entire data stream:
    for isl in xrange(raw._data.shape[1]):
        slice = np.take(raw._data, [isl], axis=1)
        if use_reffilter:
            refslice = np.take(fltref._data, [isl], axis=1)
            refarr = refslice[:].flatten() - rmean
            # refarr = fltres[:,isl]-rmean
        else:
            refarr = slice[refpick].flatten() - rmean
        subrefarr = np.dot(weights[:], refarr)
        # data[:,isl] -= subrefarr   will not modify raw._data?
        raw._data[:, isl] -= subrefarr
        if isl%10000 == 0:
            print "\rProcessed slice %6d" % isl
    print "\nDone."
    tc1 = time.clock()
    tw1 = time.time()
    print "compensation loop took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tct), (tw1 - twt))

    if checkresults:
        print "########## Calculating final signal channel covariance:"
        # Calculate final signal channel covariance:
        # (only used as quality measure)
        tct = time.clock()
        twt = time.time()
        sigmean = 0
        sscovdata = 0
        n_samples = 0
        for first in range(itmin, itmax, itstep):
            last = first + itstep
            if last >= itmax:
                last = itmax
            raw_segmentsig, times = raw[sigpick, first:last]
            # Artifacts found here will probably differ from pre-noisered artifacts!
            if not exclart or \
               _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject,
                        flat=None, ignore_chs=raw.info['bads']):
                sigmean += raw_segmentsig.sum(axis=1)
                sscovdata += (raw_segmentsig * raw_segmentsig).sum(axis=1)
                n_samples += raw_segmentsig.shape[1]
        sigmean /= n_samples
        sscovdata -= n_samples * sigmean[:] * sigmean[:]
        sscovdata /= (n_samples - 1)
        print ">>> no channel got worse: ", np.all(np.less_equal(sscovdata, sscovinit))
        print "final rt(avg sig pwr) = %12.5e" % np.sqrt(np.mean(sscovdata))
        for i in xrange(5):
            print "final signal-rms[%3d] = %12.5e" % (i, np.sqrt(sscovdata.flatten()[i]))
        tc1 = time.clock()
        tw1 = time.time()
        print "signal covar-calc took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tct), (tw1 - twt))
        print " "

    nrname = dname[:dname.rfind('-raw.fif')] + ',nold-raw.fif'
    print "Saving '%s'..." % nrname
    raw.save(nrname, overwrite=True)
    tc1 = time.clock()
    tw1 = time.time()
    print "Total run         took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tc0), (tw1 - tw0))
def preprocessEpoch(eeg, info, downsample, tmin, reject=None, mne_reject=1, reject_ch=None, flat=None, bad_channels=[],
                    opt_detrend=1, HP=0, LP=40, phase='zero-double'):
    n_samples = eeg.shape[0]
    n_channels = eeg.shape[1]
    eeg = np.reshape(eeg.T, (1, n_channels, n_samples))
    # Baseline start, i.e. 200 ms before stimulus onset

    # Temporal detrending:
    if opt_detrend == 1:
        eeg = detrend(eeg, axis=2, type='linear')

    epoch = mne.EpochsArray(eeg, info, tmin=tmin, baseline=None, verbose=False)

    # Drop list of channels known to be problematic:
    if reject_ch == True:
        # label of channels to remove
        bads = ['RAW_CQ', 'GYROX', 'GYROY', 'TIMESTAMP']
        badSet = set(bads)

        # list of all channel names
        allSet = set(epoch.ch_names)

        # find the intersection of all available channels and bad channels
        badSet = badSet.intersection(allSet)
        badSet = list(badSet)
        epoch.drop_channels(badSet)

    # Lowpass
    epoch.filter(HP, LP, fir_design='firwin', phase=phase, verbose=False)

    # Downsample
    epoch.resample(downsample, npad='auto', verbose=False)

    # Apply baseline correction
    epoch.apply_baseline(baseline=(None, 0), verbose=False)

    if reject is not None:  # Rejection of channels, either manually defined or based on MNE analysis. Currently not
        # used.
        if mne_reject == 1:  # Use MNE method to reject+interpolate bad channels
            from mne.epochs import _is_good
            from mne.io.pick import channel_indices_by_type
            # reject=dict(eeg=100)
            idx_by_type = channel_indices_by_type(epoch.info)
            A, bad_channels = _is_good(epoch.get_data()[0], epoch.ch_names, channel_type_idx=idx_by_type, reject=reject,
                                       flat=flat, full_report=True)
            print(A)
            if A == False:
                epoch.info['bads'] = bad_channels
                epoch.interpolate_bads(reset_bads=True, verbose=False)
        else:  # Predefined bad_channels
            epoch.drop_channels(bad_channels)

    # Re-referencing
    epoch.set_eeg_reference(verbose=False)

    # Apply baseline after re-reference
    epoch.apply_baseline(baseline=(None, 0), verbose=False)

    epoch = epoch.get_data()[0]

    return epoch