示例#1
0
def test_plot_raw_traces(raw, events, browser_backend):
    """Test plotting of raw data."""
    ismpl = browser_backend.name == 'matplotlib'
    with raw.info._unlock():
        raw.info['lowpass'] = 10.  # allow heavy decim during plotting
    fig = raw.plot(events=events, order=[1, 7, 5, 2, 3], n_channels=3,
                   group_by='original')
    assert hasattr(fig, 'mne')  # make sure fig.mne param object is present
    if ismpl:
        assert len(fig.axes) == 5

    # setup
    x = fig.mne.traces[0].get_xdata()[5]
    y = fig.mne.traces[0].get_ydata()[5]
    hscroll = fig.mne.ax_hscroll
    vscroll = fig.mne.ax_vscroll
    # test marking bad channels
    label = fig._get_ticklabels('y')[0]
    assert label not in fig.mne.info['bads']
    # click data to mark bad
    fig._fake_click((x, y), xform='data')
    assert label in fig.mne.info['bads']
    # click data to unmark bad
    fig._fake_click((x, y), xform='data')
    assert label not in fig.mne.info['bads']
    # click name to mark bad
    fig._click_ch_name(ch_index=0, button=1)
    assert label in fig.mne.info['bads']
    # test other kinds of clicks
    fig._fake_click((0.5, 0.98))  # click elsewhere (add vline)
    assert fig.mne.vline_visible is True
    fig._fake_click((0.5, 0.98), button=3)  # remove vline
    assert fig.mne.vline_visible is False
    fig._fake_click((0.5, 0.5), ax=hscroll)  # change time
    t_start = fig.mne.t_start
    fig._fake_click((0.5, 0.5), ax=hscroll)  # shouldn't change time this time
    assert round(t_start, 6) == round(fig.mne.t_start, 6)
    # test scrolling through channels
    labels = fig._get_ticklabels('y')
    assert labels == [raw.ch_names[1], raw.ch_names[7], raw.ch_names[5]]
    fig._fake_click((0.5, 0.05), ax=vscroll)  # change channels to end
    labels = fig._get_ticklabels('y')
    assert labels == [raw.ch_names[5], raw.ch_names[2], raw.ch_names[3]]
    for _ in (0, 0):
        # first click changes channels to mid; second time shouldn't change
        # This needs to be changed for Qt, because there scrollbars are
        # drawn differently (value of slider at lower end, not at middle)
        yclick = 0.5 if ismpl else 0.7
        fig._fake_click((0.5, yclick), ax=vscroll)
        labels = fig._get_ticklabels('y')
        assert labels == [raw.ch_names[7], raw.ch_names[5], raw.ch_names[2]]

    # test clicking a channel name in butterfly mode
    bads = fig.mne.info['bads'].copy()
    fig._fake_keypress('b')
    fig._click_ch_name(ch_index=0, button=1)  # should be no-op
    assert fig.mne.info['bads'] == bads        # unchanged
    fig._fake_keypress('b')

    # test starting up in zen mode
    fig = plot_raw(raw, show_scrollbars=False)
    # test order, title, & show_options kwargs
    with pytest.raises(ValueError, match='order should be array-like; got'):
        raw.plot(order='foo')
    with pytest.raises(TypeError, match='title must be None or a string, got'):
        raw.plot(title=1)
    raw.plot(show_options=True)
    browser_backend._close_all()

    # annotations outside data range
    annot = Annotations([10, 10 + raw.first_samp / raw.info['sfreq']],
                        [10, 10], ['test', 'test'], raw.info['meas_date'])
    with pytest.warns(RuntimeWarning, match='outside data range'):
        raw.set_annotations(annot)

    # Color setting
    with pytest.raises(KeyError, match='must be strictly positive, or -1'):
        raw.plot(event_color={0: 'r'})
    with pytest.raises(TypeError, match='event_color key must be an int, got'):
        raw.plot(event_color={'foo': 'r'})
    plot_raw(raw, events=events, event_color={-1: 'r', 998: 'b'})
示例#2
0
def test_anonymize(tmpdir):
    """Test that sensitive information can be anonymized."""
    pytest.raises(TypeError, anonymize_info, 'foo')

    # Fake some subject data
    raw = read_raw_fif(raw_fname)
    raw.set_annotations(
        Annotations(onset=[0, 1],
                    duration=[1, 1],
                    description='dummy',
                    orig_time=None))
    first_samp = raw.first_samp
    expected_onset = np.arange(2) + raw._first_time
    assert raw.first_samp == first_samp
    assert_allclose(raw.annotations.onset, expected_onset)

    # test mne.anonymize_info()
    events = read_events(event_name)
    epochs = Epochs(raw, events[:1], 2, 0., 0.1, baseline=None)
    _test_anonymize_info(raw.info.copy())
    _test_anonymize_info(epochs.info.copy())

    # test instance methods & I/O roundtrip
    for inst, keep_his in zip((raw, epochs), (True, False)):
        inst = inst.copy()

        subject_info = dict(his_id='Volunteer', sex=2, hand=1)
        inst.info['subject_info'] = subject_info
        inst.anonymize(keep_his=keep_his)

        si = inst.info['subject_info']
        if keep_his:
            assert si == subject_info
        else:
            assert si['his_id'] == '0'
            assert si['sex'] == 0
            assert 'hand' not in si

        # write to disk & read back
        inst_type = 'raw' if isinstance(inst, BaseRaw) else 'epo'
        fname = 'tmp_raw.fif' if inst_type == 'raw' else 'tmp_epo.fif'
        out_path = tmpdir.join(fname)
        inst.save(out_path, overwrite=True)
        if inst_type == 'raw':
            read_raw_fif(out_path)
        else:
            read_epochs(out_path)

    # test that annotations are correctly zeroed
    raw.anonymize()
    assert raw.first_samp == first_samp
    assert_allclose(raw.annotations.onset, expected_onset)
    assert raw.annotations.orig_time == raw.info['meas_date']
    stamp = _dt_to_stamp(raw.info['meas_date'])
    assert raw.annotations.orig_time == _stamp_to_dt(stamp)

    raw.info['meas_date'] = None
    with pytest.warns(RuntimeWarning, match='None'):
        raw.anonymize()
    assert raw.annotations.orig_time is None
    assert raw.first_samp == first_samp
    assert_allclose(raw.annotations.onset, expected_onset)
示例#3
0
def test_plot_ica_sources():
    """Test plotting of ICA panel."""
    raw = read_raw_fif(raw_fname).crop(0, 1).load_data()
    picks = _get_picks(raw)
    epochs = _get_epochs()
    raw.pick_channels([raw.ch_names[k] for k in picks])
    ica_picks = pick_types(raw.info,
                           meg=True,
                           eeg=False,
                           stim=False,
                           ecg=False,
                           eog=False,
                           exclude='bads')
    ica = ICA(n_components=2)
    ica.fit(raw, picks=ica_picks)
    ica.exclude = [1]
    fig = ica.plot_sources(raw)
    assert len(plt.get_fignums()) == 1
    # change which component is in ICA.exclude (click data trace to remove
    # current one; click name to add other one)
    fig.canvas.draw()
    x = fig.mne.traces[1].get_xdata()[5]
    y = fig.mne.traces[1].get_ydata()[5]
    _fake_click(fig, fig.mne.ax_main, (x, y), xform='data')  # exclude = []
    _click_ch_name(fig, ch_index=0, button=1)  # exclude = [0]
    fig.canvas.key_press_event(fig.mne.close_key)
    _close_event(fig)
    assert len(plt.get_fignums()) == 0
    assert_array_equal(ica.exclude, [0])
    # test when picks does not include ica.exclude.
    fig = ica.plot_sources(raw, picks=[1])
    assert len(plt.get_fignums()) == 1
    plt.close('all')

    # dtype can change int->np.int64 after load, test it explicitly
    ica.n_components_ = np.int64(ica.n_components_)

    # test clicks on y-label (need >2 secs for plot_properties() to work)
    long_raw = read_raw_fif(raw_fname).crop(0, 5).load_data()
    fig = ica.plot_sources(long_raw)
    assert len(plt.get_fignums()) == 1
    fig.canvas.draw()
    _fake_click(fig, fig.mne.ax_main, (-0.1, 0), xform='data', button=3)
    assert len(fig.mne.child_figs) == 1
    assert len(plt.get_fignums()) == 2
    # close child fig directly (workaround for mpl issue #18609)
    fig.mne.child_figs[0].canvas.key_press_event('escape')
    assert len(plt.get_fignums()) == 1
    fig.canvas.key_press_event(fig.mne.close_key)
    assert len(plt.get_fignums()) == 0
    del long_raw

    # test with annotations
    orig_annot = raw.annotations
    raw.set_annotations(Annotations([0.2], [0.1], 'Test'))
    fig = ica.plot_sources(raw)
    assert len(fig.mne.ax_main.collections) == 1
    assert len(fig.mne.ax_hscroll.collections) == 1
    raw.set_annotations(orig_annot)

    # test error handling
    raw.info['bads'] = ['MEG 0113']
    with pytest.raises(RuntimeError, match="Raw doesn't match fitted data"):
        ica.plot_sources(inst=raw)
    epochs.info['bads'] = ['MEG 0113']
    with pytest.raises(RuntimeError, match="Epochs don't match fitted data"):
        ica.plot_sources(inst=epochs)
    epochs.info['bads'] = []

    # test w/ epochs and evokeds
    ica.plot_sources(epochs)
    ica.plot_sources(epochs.average())
    evoked = epochs.average()
    fig = ica.plot_sources(evoked)
    # Test a click
    ax = fig.get_axes()[0]
    line = ax.lines[0]
    _fake_click(fig, ax, [line.get_xdata()[0], line.get_ydata()[0]], 'data')
    _fake_click(fig, ax, [ax.get_xlim()[0], ax.get_ylim()[1]], 'data')
    # plot with bad channels excluded
    ica.exclude = [0]
    ica.plot_sources(evoked)
    ica.labels_ = dict(eog=[0])
    ica.labels_['eog/0/crazy-channel'] = [0]
    ica.plot_sources(evoked)  # now with labels
    with pytest.raises(ValueError, match='must be of Raw or Epochs type'):
        ica.plot_sources('meeow')
示例#4
0
def test_basics():
    """Test annotation class."""
    raw = read_raw_fif(fif_fname)
    assert raw.annotations is not None  # XXX to be fixed in #5416
    assert len(raw.annotations.onset) == 0  # XXX to be fixed in #5416
    pytest.raises(IOError, read_annotations, fif_fname)
    onset = np.array(range(10))
    duration = np.ones(10)
    description = np.repeat('test', 10)
    dt = datetime.utcnow()
    meas_date = raw.info['meas_date']
    # Test time shifts.
    for orig_time in [None, dt, meas_date[0], meas_date]:
        annot = Annotations(onset, duration, description, orig_time)

    pytest.raises(ValueError, Annotations, onset, duration, description[:9])
    pytest.raises(ValueError, Annotations, [onset, 1], duration, description)
    pytest.raises(ValueError, Annotations, onset, [duration, 1], description)

    # Test combining annotations with concatenate_raws
    raw2 = raw.copy()
    delta = raw.times[-1] + 1. / raw.info['sfreq']
    orig_time = (meas_date[0] + meas_date[1] * 1e-6 + raw2._first_time)
    offset = orig_time - _handle_meas_date(raw2.info['meas_date'])
    annot = Annotations(onset, duration, description, orig_time)
    assert ' segments' in repr(annot)
    raw2.set_annotations(annot)
    assert_array_equal(raw2.annotations.onset, onset + offset)
    assert id(raw2.annotations) != id(annot)
    concatenate_raws([raw, raw2])
    assert_and_remove_boundary_annot(raw)
    assert_allclose(onset + offset + delta, raw.annotations.onset, rtol=1e-5)
    assert_array_equal(annot.duration, raw.annotations.duration)
    assert_array_equal(raw.annotations.description, np.repeat('test', 10))

    # Test combining with RawArray and orig_times
    data = np.random.randn(2, 1000) * 10e-12
    sfreq = 100.
    info = create_info(ch_names=['MEG1', 'MEG2'], ch_types=['grad'] * 2,
                       sfreq=sfreq)
    info['meas_date'] = (np.pi, 0)
    raws = []
    for first_samp in [12300, 100, 12]:
        raw = RawArray(data.copy(), info, first_samp=first_samp)
        ants = Annotations([1., 2.], [.5, .5], 'x', np.pi + first_samp / sfreq)
        raw.set_annotations(ants)
        raws.append(raw)
    raw = RawArray(data.copy(), info)
    raw.set_annotations(Annotations([1.], [.5], 'x', None))
    raws.append(raw)
    raw = concatenate_raws(raws, verbose='debug')
    assert_and_remove_boundary_annot(raw, 3)
    assert_array_equal(raw.annotations.onset, [124., 125., 134., 135.,
                                               144., 145., 154.])
    raw.annotations.delete(2)
    assert_array_equal(raw.annotations.onset, [124., 125., 135., 144.,
                                               145., 154.])
    raw.annotations.append(5, 1.5, 'y')
    assert_array_equal(raw.annotations.onset,
                       [5., 124., 125., 135., 144., 145., 154.])
    assert_array_equal(raw.annotations.duration,
                       [1.5, .5, .5, .5, .5, .5, .5])
    assert_array_equal(raw.annotations.description,
                       ['y', 'x', 'x', 'x', 'x', 'x', 'x'])

    # These three things should be equivalent
    expected_orig_time = (raw.info['meas_date'][0] +
                          raw.info['meas_date'][1] / 1000000)
    for empty_annot in (
            Annotations([], [], [], expected_orig_time),
            Annotations([], [], [], None),
            None):
        raw.set_annotations(empty_annot)
        assert isinstance(raw.annotations, Annotations)
        assert len(raw.annotations) == 0
        assert raw.annotations.orig_time == expected_orig_time
示例#5
0
def test_plot_raw():
    """Test plotting of raw data."""
    import matplotlib.pyplot as plt
    raw = _get_raw()
    raw.info['lowpass'] = 10.  # allow heavy decim during plotting
    events = _get_events()
    plt.close('all')  # ensure all are closed
    with warnings.catch_warnings(record=True):
        fig = raw.plot(events=events,
                       show_options=True,
                       order=[1, 7, 3],
                       group_by='original')
        # test mouse clicks
        x = fig.get_axes()[0].lines[1].get_xdata().mean()
        y = fig.get_axes()[0].lines[1].get_ydata().mean()
        data_ax = fig.axes[0]

        _fake_click(fig, data_ax, [x, y], xform='data')  # mark a bad channel
        _fake_click(fig, data_ax, [x, y], xform='data')  # unmark a bad channel
        _fake_click(fig, data_ax, [0.5, 0.999])  # click elsewhere in 1st axes
        _fake_click(fig, data_ax, [-0.1, 0.9])  # click on y-label
        _fake_click(fig, fig.get_axes()[1], [0.5, 0.5])  # change time
        _fake_click(fig, fig.get_axes()[2], [0.5, 0.5])  # change channels
        _fake_click(fig, fig.get_axes()[3], [0.5, 0.5])  # open SSP window
        fig.canvas.button_press_event(1, 1, 1)  # outside any axes
        fig.canvas.scroll_event(0.5, 0.5, -0.5)  # scroll down
        fig.canvas.scroll_event(0.5, 0.5, 0.5)  # scroll up
        # sadly these fail when no renderer is used (i.e., when using Agg):
        # ssp_fig = set(plt.get_fignums()) - set([fig.number])
        # assert_equal(len(ssp_fig), 1)
        # ssp_fig = plt.figure(list(ssp_fig)[0])
        # ax = ssp_fig.get_axes()[0]  # only one axis is used
        # t = [c for c in ax.get_children() if isinstance(c,
        #      matplotlib.text.Text)]
        # pos = np.array(t[0].get_position()) + 0.01
        # _fake_click(ssp_fig, ssp_fig.get_axes()[0], pos, xform='data')  # off
        # _fake_click(ssp_fig, ssp_fig.get_axes()[0], pos, xform='data')  # on
        #  test keypresses
        for key in [
                'down', 'up', 'right', 'left', 'o', '-', '+', '=', 'pageup',
                'pagedown', 'home', 'end', '?', 'f11', 'escape'
        ]:
            fig.canvas.key_press_event(key)
        fig = plot_raw(raw, events=events, group_by='selection')
        for key in [
                'b', 'down', 'up', 'right', 'left', 'o', '-', '+', '=',
                'pageup', 'pagedown', 'home', 'end', '?', 'f11', 'b', 'escape'
        ]:
            fig.canvas.key_press_event(key)
        # Color setting
        assert_raises(KeyError, raw.plot, event_color={0: 'r'})
        assert_raises(TypeError, raw.plot, event_color={'foo': 'r'})
        annot = Annotations([10, 10 + raw.first_samp / raw.info['sfreq']],
                            [10, 10], ['test', 'test'], raw.info['meas_date'])
        raw.annotations = annot
        fig = plot_raw(raw, events=events, event_color={-1: 'r', 998: 'b'})
        plt.close('all')
        for group_by, order in zip(
            ['position', 'selection'],
            [np.arange(len(raw.ch_names))[::-3], [1, 2, 4, 6]]):
            fig = raw.plot(group_by=group_by, order=order)
            x = fig.get_axes()[0].lines[1].get_xdata()[10]
            y = fig.get_axes()[0].lines[1].get_ydata()[10]
            _fake_click(fig, data_ax, [x, y], xform='data')  # mark bad
            fig.canvas.key_press_event('down')  # change selection
            _fake_click(fig, fig.get_axes()[2], [0.5, 0.5])  # change channels
            sel_fig = plt.figure(1)
            topo_ax = sel_fig.axes[1]
            _fake_click(sel_fig, topo_ax, [-0.425, 0.20223853], xform='data')
            fig.canvas.key_press_event('down')
            fig.canvas.key_press_event('up')
            fig.canvas.scroll_event(0.5, 0.5, -1)  # scroll down
            fig.canvas.scroll_event(0.5, 0.5, 1)  # scroll up
            _fake_click(sel_fig, topo_ax, [-0.5, 0.], xform='data')
            _fake_click(sel_fig,
                        topo_ax, [0.5, 0.],
                        xform='data',
                        kind='motion')
            _fake_click(sel_fig,
                        topo_ax, [0.5, 0.5],
                        xform='data',
                        kind='motion')
            _fake_click(sel_fig,
                        topo_ax, [-0.5, 0.5],
                        xform='data',
                        kind='release')

            plt.close('all')
        # test if meas_date has only one element
        raw.info['meas_date'] = np.array([raw.info['meas_date'][0]],
                                         dtype=np.int32)
        raw.annotations = Annotations([1 + raw.first_samp / raw.info['sfreq']],
                                      [5], ['bad'])
        raw.plot(group_by='position', order=np.arange(8))
        for fig_num in plt.get_fignums():
            fig = plt.figure(fig_num)
            if hasattr(fig, 'radio'):  # Get access to selection fig.
                break
        for key in ['down', 'up', 'escape']:
            fig.canvas.key_press_event(key)
        plt.close('all')
示例#6
0
def test_plot_raw():
    """Test plotting of raw data."""
    raw = _get_raw()
    raw.info['lowpass'] = 10.  # allow heavy decim during plotting
    events = _get_events()
    plt.close('all')  # ensure all are closed
    assert len(plt.get_fignums()) == 0
    fig = raw.plot(events=events, order=[1, 7, 3], group_by='original')
    assert len(plt.get_fignums()) == 1

    # make sure fig._mne_params is present
    assert isinstance(fig._mne_params, dict)

    # test mouse clicks
    x = fig.get_axes()[0].lines[1].get_xdata().mean()
    y = fig.get_axes()[0].lines[1].get_ydata().mean()
    data_ax = fig.axes[0]

    _fake_click(fig, data_ax, [x, y], xform='data')  # mark a bad channel
    _fake_click(fig, data_ax, [x, y], xform='data')  # unmark a bad channel
    _fake_click(fig, data_ax, [0.5, 0.999])  # click elsewhere in 1st axes
    _fake_click(fig, data_ax, [-0.1, 0.9])  # click on y-label
    _fake_click(fig, fig.get_axes()[1], [0.5, 0.5])  # change time
    _fake_click(fig, fig.get_axes()[2], [0.5, 0.5])  # change channels
    assert len(plt.get_fignums()) == 1
    # open SSP window
    _fake_click(fig, fig.get_axes()[-1], [0.5, 0.5])
    _fake_click(fig, fig.get_axes()[-1], [0.5, 0.5], kind='release')
    assert len(plt.get_fignums()) == 2
    ssp_fig = plt.figure(plt.get_fignums()[-1])
    fig.canvas.button_press_event(1, 1, 1)  # outside any axes
    fig.canvas.scroll_event(0.5, 0.5, -0.5)  # scroll down
    fig.canvas.scroll_event(0.5, 0.5, 0.5)  # scroll up

    ax = ssp_fig.get_axes()[0]  # only one axis is used
    assert _proj_status(ax) == [True] * 3
    t = [c for c in ax.get_children() if isinstance(c, matplotlib.text.Text)]
    pos = np.array(t[0].get_position()) + 0.01
    _fake_click(ssp_fig, ssp_fig.get_axes()[0], pos, xform='data')  # off
    assert _proj_status(ax) == [False, True, True]
    _fake_click(ssp_fig, ssp_fig.get_axes()[0], pos, xform='data')  # on
    assert _proj_status(ax) == [True] * 3
    _fake_click(ssp_fig, ssp_fig.get_axes()[1], [0.5, 0.5])  # all off
    _fake_click(ssp_fig, ssp_fig.get_axes()[1], [0.5, 0.5], kind='release')
    assert _proj_status(ax) == [False] * 3
    assert fig._mne_params['projector'] is None  # actually off
    _fake_click(ssp_fig, ssp_fig.get_axes()[1], [0.5, 0.5])  # all on
    _fake_click(ssp_fig, ssp_fig.get_axes()[1], [0.5, 0.5], kind='release')
    assert fig._mne_params['projector'] is not None  # on
    assert _proj_status(ax) == [True] * 3

    # test keypresses
    # test for group_by='original'
    for key in ['down', 'up', 'right', 'left', 'o', '-', '+', '=', 'd', 'd',
                'pageup', 'pagedown', 'home', 'end', '?', 'f11', 'z',
                'escape']:
        fig.canvas.key_press_event(key)

    # test for group_by='selection'
    fig = plot_raw(raw, events=events, group_by='selection')
    for key in ['b', 'down', 'up', 'right', 'left', 'o', '-', '+', '=', 'd',
                'd', 'pageup', 'pagedown', 'home', 'end', '?', 'f11', 'b', 'z',
                'escape']:
        fig.canvas.key_press_event(key)

    # test zen mode
    fig = plot_raw(raw, show_scrollbars=False)

    # Color setting
    pytest.raises(KeyError, raw.plot, event_color={0: 'r'})
    pytest.raises(TypeError, raw.plot, event_color={'foo': 'r'})
    annot = Annotations([10, 10 + raw.first_samp / raw.info['sfreq']],
                        [10, 10], ['test', 'test'], raw.info['meas_date'])
    with pytest.warns(RuntimeWarning, match='outside data range'):
        raw.set_annotations(annot)
    fig = plot_raw(raw, events=events, event_color={-1: 'r', 998: 'b'})
    plt.close('all')
    for group_by, order in zip(['position', 'selection'],
                               [np.arange(len(raw.ch_names))[::-3],
                                [1, 2, 4, 6]]):
        with pytest.warns(None):  # sometimes projection
            fig = raw.plot(group_by=group_by, order=order)
        x = fig.get_axes()[0].lines[1].get_xdata()[10]
        y = fig.get_axes()[0].lines[1].get_ydata()[10]
        with pytest.warns(None):  # old mpl (at least 2.0) can warn
            _fake_click(fig, data_ax, [x, y], xform='data')  # mark bad
        fig.canvas.key_press_event('down')  # change selection
        _fake_click(fig, fig.get_axes()[2], [0.5, 0.5])  # change channels
        sel_fig = plt.figure(1)
        topo_ax = sel_fig.axes[1]
        _fake_click(sel_fig, topo_ax, [-0.425, 0.20223853],
                    xform='data')
        fig.canvas.key_press_event('down')
        fig.canvas.key_press_event('up')
        fig.canvas.scroll_event(0.5, 0.5, -1)  # scroll down
        fig.canvas.scroll_event(0.5, 0.5, 1)  # scroll up
        _fake_click(sel_fig, topo_ax, [-0.5, 0.], xform='data')
        _fake_click(sel_fig, topo_ax, [0.5, 0.], xform='data',
                    kind='motion')
        _fake_click(sel_fig, topo_ax, [0.5, 0.5], xform='data',
                    kind='motion')
        _fake_click(sel_fig, topo_ax, [-0.5, 0.5], xform='data',
                    kind='release')

        plt.close('all')
    # test if meas_date is off
    raw.set_meas_date(_dt_to_stamp(raw.info['meas_date'])[0])
    annot = Annotations([1 + raw.first_samp / raw.info['sfreq']],
                        [5], ['bad'])
    with pytest.warns(RuntimeWarning, match='outside data range'):
        raw.set_annotations(annot)
    with pytest.warns(None):  # sometimes projection
        raw.plot(group_by='position', order=np.arange(8))
    for fig_num in plt.get_fignums():
        fig = plt.figure(fig_num)
        if hasattr(fig, 'radio'):  # Get access to selection fig.
            break
    for key in ['down', 'up', 'escape']:
        fig.canvas.key_press_event(key)

    raw._data[:] = np.nan
    # this should (at least) not die, the output should pretty clearly show
    # that there is a problem so probably okay to just plot something blank
    with pytest.warns(None):
        raw.plot(scalings='auto')

    plt.close('all')
示例#7
0
def test_crop():
    """Test cropping with annotations."""
    raw = read_raw_fif(fif_fname)
    events = mne.find_events(raw)
    onset = events[events[:, 2] == 1, 0] / raw.info['sfreq']
    duration = np.full_like(onset, 0.5)
    description = ['bad %d' % k for k in range(len(onset))]
    annot = mne.Annotations(onset, duration, description,
                            orig_time=raw.info['meas_date'])
    raw.set_annotations(annot)

    split_time = raw.times[-1] / 2. + 2.
    split_idx = len(onset) // 2 + 1
    raw_cropped_left = raw.copy().crop(0., split_time - 1. / raw.info['sfreq'])
    assert_array_equal(raw_cropped_left.annotations.description,
                       raw.annotations.description[:split_idx])
    assert_allclose(raw_cropped_left.annotations.duration,
                    raw.annotations.duration[:split_idx])
    assert_allclose(raw_cropped_left.annotations.onset,
                    raw.annotations.onset[:split_idx])
    raw_cropped_right = raw.copy().crop(split_time, None)
    assert_array_equal(raw_cropped_right.annotations.description,
                       raw.annotations.description[split_idx:])
    assert_allclose(raw_cropped_right.annotations.duration,
                    raw.annotations.duration[split_idx:])
    assert_allclose(raw_cropped_right.annotations.onset,
                    raw.annotations.onset[split_idx:])
    raw_concat = mne.concatenate_raws([raw_cropped_left, raw_cropped_right],
                                      verbose='debug')
    assert_allclose(raw_concat.times, raw.times)
    assert_allclose(raw_concat[:][0], raw[:][0], atol=1e-20)
    assert_and_remove_boundary_annot(raw_concat)
    # Ensure we annotations survive round-trip crop->concat
    assert_array_equal(raw_concat.annotations.description,
                       raw.annotations.description)
    for attr in ('onset', 'duration'):
        assert_allclose(getattr(raw_concat.annotations, attr),
                        getattr(raw.annotations, attr),
                        err_msg='Failed for %s:' % (attr,))

    raw.set_annotations(None)  # undo

    # Test concatenating annotations with and without orig_time.
    raw2 = raw.copy()
    raw.set_annotations(Annotations([45.], [3], 'test', raw.info['meas_date']))
    raw2.set_annotations(Annotations([2.], [3], 'BAD', None))
    expected_onset = [45., 2. + raw._last_time]
    raw = concatenate_raws([raw, raw2])
    assert_and_remove_boundary_annot(raw)
    assert_array_almost_equal(raw.annotations.onset, expected_onset, decimal=2)

    # Test IO
    tempdir = _TempDir()
    fname = op.join(tempdir, 'test-annot.fif')
    raw.annotations.save(fname)
    annot_read = read_annotations(fname)
    for attr in ('onset', 'duration', 'orig_time'):
        assert_allclose(getattr(annot_read, attr),
                        getattr(raw.annotations, attr))
    assert_array_equal(annot_read.description, raw.annotations.description)
    annot = Annotations((), (), ())
    annot.save(fname)
    pytest.raises(IOError, read_annotations, fif_fname)  # none in old raw
    annot = read_annotations(fname)
    assert isinstance(annot, Annotations)
    assert len(annot) == 0
    # Test that empty annotations can be saved with an object
    fname = op.join(tempdir, 'test_raw.fif')
    raw.set_annotations(annot)
    raw.save(fname)
    raw_read = read_raw_fif(fname)
    assert isinstance(raw_read.annotations, Annotations)
    assert len(raw_read.annotations) == 0
    raw.set_annotations(None)
    raw.save(fname, overwrite=True)
    raw_read = read_raw_fif(fname)
    assert raw_read.annotations is not None  # XXX to be fixed in #5416
    assert len(raw_read.annotations.onset) == 0  # XXX to be fixed in #5416
示例#8
0
def test_plot_raw_traces(raw, events, browse_backend):
    """Test plotting of raw data."""
    raw.info['lowpass'] = 10.  # allow heavy decim during plotting
    fig = raw.plot(events=events,
                   order=[1, 7, 5, 2, 3],
                   n_channels=3,
                   group_by='original')
    assert hasattr(fig, 'mne')  # make sure fig.mne param object is present
    assert len(fig.axes) == 5

    # setup
    x = fig.mne.traces[0].get_xdata()[5]
    y = fig.mne.traces[0].get_ydata()[5]
    data_ax = fig.mne.ax_main
    # ToDo: The interaction with scrollbars will be different in pyqtgraph.
    hscroll = fig.mne.ax_hscroll
    vscroll = fig.mne.ax_vscroll
    # test marking bad channels
    label = fig.mne.ax_main.get_yticklabels()[0].get_text()
    assert label not in fig.mne.info['bads']
    # click data to mark bad
    fig._fake_click((x, y), xform='data')
    assert label in fig.mne.info['bads']
    # click data to unmark bad
    fig._fake_click((x, y), xform='data')
    assert label not in fig.mne.info['bads']
    # click name to mark bad
    fig._click_ch_name(ch_index=0, button=1)
    assert label in fig.mne.info['bads']
    # test other kinds of clicks
    fig._fake_click((0.5, 0.999))  # click elsewhere (add vline)
    fig._fake_click((0.5, 0.999), button=3)  # remove vline
    fig._fake_click((0.5, 0.5), ax=hscroll)  # change time
    fig._fake_click((0.5, 0.5), ax=hscroll)  # shouldn't change time this time
    # test scrolling through channels
    labels = [label.get_text() for label in data_ax.get_yticklabels()]
    assert labels == [raw.ch_names[1], raw.ch_names[7], raw.ch_names[5]]
    fig._fake_click((0.5, 0.01), ax=vscroll)  # change channels to end
    labels = [label.get_text() for label in data_ax.get_yticklabels()]
    assert labels == [raw.ch_names[5], raw.ch_names[2], raw.ch_names[3]]
    for _ in (0, 0):
        # first click changes channels to mid; second time shouldn't change
        fig._fake_click((0.5, 0.5), ax=vscroll)
        labels = [label.get_text() for label in data_ax.get_yticklabels()]
        assert labels == [raw.ch_names[7], raw.ch_names[5], raw.ch_names[2]]
        assert browse_backend._get_n_figs() == 1

    # test clicking a channel name in butterfly mode
    bads = fig.mne.info['bads'].copy()
    fig._fake_keypress('b')
    fig._click_ch_name(ch_index=0, button=1)  # should be no-op
    assert fig.mne.info['bads'] == bads  # unchanged
    fig._fake_keypress('b')

    # test starting up in zen mode
    fig = plot_raw(raw, show_scrollbars=False)
    # test order, title, & show_options kwargs
    with pytest.raises(ValueError, match='order should be array-like; got'):
        raw.plot(order='foo')
    with pytest.raises(TypeError, match='title must be None or a string, got'):
        raw.plot(title=1)
    raw.plot(show_options=True)
    browse_backend._close_all()

    # annotations outside data range
    annot = Annotations([10, 10 + raw.first_samp / raw.info['sfreq']],
                        [10, 10], ['test', 'test'], raw.info['meas_date'])
    with pytest.warns(RuntimeWarning, match='outside data range'):
        raw.set_annotations(annot)

    # Color setting
    with pytest.raises(KeyError, match='must be strictly positive, or -1'):
        raw.plot(event_color={0: 'r'})
    with pytest.raises(TypeError, match='event_color key must be an int, got'):
        raw.plot(event_color={'foo': 'r'})
    fig = plot_raw(raw, events=events, event_color={-1: 'r', 998: 'b'})
    browse_backend._close_all()
示例#9
0
def test_plot_raw_psd(raw, raw_orig):
    """Test plotting of raw psds."""
    raw_unchanged = raw.copy()
    # normal mode
    fig = raw.plot_psd(average=False)
    fig.canvas.resize_event()
    # specific mode
    picks = pick_types(raw.info, meg='mag', eeg=False)[:4]
    raw.plot_psd(tmax=None,
                 picks=picks,
                 area_mode='range',
                 average=False,
                 spatial_colors=True)
    raw.plot_psd(tmax=20.,
                 color='yellow',
                 dB=False,
                 line_alpha=0.4,
                 n_overlap=0.1,
                 average=False)
    plt.close('all')
    # one axes supplied
    ax = plt.axes()
    raw.plot_psd(tmax=None, picks=picks, ax=ax, average=True)
    plt.close('all')
    # two axes supplied
    _, axs = plt.subplots(2)
    raw.plot_psd(tmax=None, ax=axs, average=True)
    plt.close('all')
    # need 2, got 1
    ax = plt.axes()
    with pytest.raises(ValueError, match='of length 2, while the length is 1'):
        raw.plot_psd(ax=ax, average=True)
    plt.close('all')
    # topo psd
    ax = plt.subplot()
    raw.plot_psd_topo(axes=ax)
    plt.close('all')
    # with channel information not available
    for idx in range(len(raw.info['chs'])):
        raw.info['chs'][idx]['loc'] = np.zeros(12)
    with pytest.warns(RuntimeWarning, match='locations not available'):
        raw.plot_psd(spatial_colors=True, average=False)
    # with a flat channel
    raw[5, :] = 0
    for dB, estimate in itertools.product((True, False),
                                          ('power', 'amplitude')):
        with pytest.warns(UserWarning, match='[Infinite|Zero]'):
            fig = raw.plot_psd(average=True, dB=dB, estimate=estimate)
        # check grad axes
        title = fig.axes[0].get_title()
        ylabel = fig.axes[0].get_ylabel()
        ends_dB = ylabel.endswith('mathrm{(dB)}$')
        unit = '(fT/cm)²/Hz' if estimate == 'power' else r'fT/cm/\sqrt{Hz}'
        assert title == 'Gradiometers', title
        assert unit in ylabel, ylabel
        if dB:
            assert ends_dB, ylabel
        else:
            assert not ends_dB, ylabel
        # check mag axes
        title = fig.axes[1].get_title()
        ylabel = fig.axes[1].get_ylabel()
        unit = 'fT²/Hz' if estimate == 'power' else r'fT/\sqrt{Hz}'
        assert title == 'Magnetometers', title
        assert unit in ylabel, ylabel
    # test reject_by_annotation
    raw = raw_unchanged
    raw.set_annotations(Annotations([1, 5], [3, 3], ['test', 'test']))
    raw.plot_psd(reject_by_annotation=True)
    raw.plot_psd(reject_by_annotation=False)
    plt.close('all')

    # test fmax value checking
    with pytest.raises(ValueError, match='must not exceed ½ the sampling'):
        raw.plot_psd(fmax=50000)

    # test xscale value checking
    with pytest.raises(ValueError, match="Invalid value for the 'xscale'"):
        raw.plot_psd(xscale='blah')

    # gh-5046
    raw = raw_orig.crop(0, 1)
    picks = pick_types(raw.info, meg=True)
    raw.plot_psd(picks=picks, average=False)
    raw.plot_psd(picks=picks, average=True)
    plt.close('all')
    raw.set_channel_types(
        {
            'MEG 0113': 'hbo',
            'MEG 0112': 'hbr',
            'MEG 0122': 'fnirs_cw_amplitude',
            'MEG 0123': 'fnirs_od'
        },
        verbose='error')
    fig = raw.plot_psd()
    assert len(fig.axes) == 10
    plt.close('all')

    # gh-7631
    data = 1e-3 * np.random.rand(2, 100)
    info = create_info(['CH1', 'CH2'], 100)
    raw = RawArray(data, info)
    picks = pick_types(raw.info, misc=True)
    raw.plot_psd(picks=picks, spatial_colors=False)
    plt.close('all')
示例#10
0
def simulate_nirs_raw(sfreq=3.,
                      amplitude=1.,
                      annot_desc='A',
                      sig_dur=300.,
                      stim_dur=5.,
                      isi_min=15.,
                      isi_max=45.,
                      ch_name='Simulated',
                      hrf_model='glover'):
    """
    Create simulated fNIRS data.

    The returned data is of type `hbo`.
    One or more conditions can be simulated.
    To simulate multiple conditions pass in a description and amplitude
    for each
    `amplitude=[0., 2., 4.], annot_desc=['Control', 'Cond_A', 'Cond_B']`.

    Parameters
    ----------
    sfreq : Number
        The sample rate.
    amplitude : Number, Array of numbers
        The amplitude of the signal to simulate in uM.
        Pass in an array to simulate multiple conditions.
    annot_desc : str, Array of str
        The name of the annotations for simulated amplitudes.
        Pass in an array to simulate multiple conditions,
        must be the same length as amplitude.
    sig_dur : Number
        The length of the boxcar signal to generate in seconds that will
        be convolved with the HRF.
    stim_dur : Number, Array of numbers
        The length of the stimulus to generate in seconds.
    isi_min : Number
        The minimum duration of the inter stimulus interval in seconds.
    isi_max : Number
        The maximum duration of the inter stimulus interval in seconds.
    ch_name : str
        Channel name to be used in returned raw instance.
    hrf_model : str
        Specifies the hemodynamic response function. See nilearn docs.

    Returns
    -------
    raw : instance of Raw
        The generated raw instance.
    """
    from nilearn.glm.first_level import make_first_level_design_matrix
    from pandas import DataFrame

    if type(amplitude) is not list:
        amplitude = [amplitude]
    if type(annot_desc) is not list:
        annot_desc = [annot_desc]
    if type(stim_dur) is not list:
        stim_dur = [stim_dur]

    frame_times = np.arange(sig_dur * sfreq) / sfreq

    assert len(amplitude) == len(annot_desc), "Same number of amplitudes as " \
                                              "annotations required."
    assert len(amplitude) == len(stim_dur), "Same number of amplitudes as " \
                                            "durations required."

    onset = 0.
    onsets = []
    conditions = []
    durations = []
    while onset < sig_dur - 60:
        c_idx = np.random.randint(0, len(amplitude))
        onset += np.random.uniform(isi_min, isi_max) + stim_dur[c_idx]
        onsets.append(onset)
        conditions.append(annot_desc[c_idx])
        durations.append(stim_dur[c_idx])

    events = DataFrame({
        'trial_type': conditions,
        'onset': onsets,
        'duration': durations
    })

    dm = make_first_level_design_matrix(frame_times,
                                        events,
                                        hrf_model=hrf_model,
                                        drift_model='polynomial',
                                        drift_order=0)
    dm = dm.drop(columns='constant')

    annotations = Annotations(onsets, durations, conditions)

    info = create_info(ch_names=[ch_name], sfreq=sfreq, ch_types=['hbo'])

    for idx, annot in enumerate(annot_desc):
        if annot in dm.columns:
            dm[annot] *= amplitude[idx]

    a = np.sum(dm.to_numpy(), axis=1) * 1.e-6
    a = a.reshape(-1, 1).T

    raw = RawArray(a, info, verbose=False)
    raw.set_annotations(annotations)

    return raw
                     output='onset',
                     min_duration=0.001)

###############################################################################
# 4) Extract events from the status channel and save them as file annotations
# events to data frame
events = pd.DataFrame(events, columns=['onset', 'duration', 'description'])
# onset to seconds
events['onset_in_s'] = events['onset'] / sfreq
# sort by onset
events = events.sort_values(by=['onset_in_s'])
# only keep relevant events
events = events.loc[(events['description'] <= 245)]

# crate annotations object
annotations = Annotations(events['onset_in_s'], events['duration'],
                          events['description'])
# apply to raw data
raw.set_annotations(annotations)

###############################################################################
# 5) Export to bids
# file name compliant with bids
bids_path = BIDSPath(subject=str(subject).rjust(3, '0'),
                     task=task_name,
                     root=fname.data_dir)

# save in bids format
write_raw_bids(raw, bids_path, overwrite=True)

###############################################################################
# 7) Plot the data for report
示例#12
0
                     output='onset',
                     min_duration=0.001)

# 6) Extract events from the status channel and save them as file annotations
# events to data frame
events = pd.DataFrame(events, columns=['onset', 'duration', 'description'])
# onset to seconds
events['onset_in_s'] = events['onset'] / raw.info['sfreq']
# sort by onset
events = events.sort_values(by=['onset_in_s'])
# only keep relevant events
events = events.loc[(events['description'] <= 245)]

# crate annotations object
annotations = Annotations(events['onset_in_s'],
                          events['duration'],
                          events['description'],
                          orig_time=raw.info['meas_date'])
# apply to raw data
raw.set_annotations(annotations)

###############################################################################
# 5) Export to bids
# file name compliant with bids
bids_path = BIDSPath(subject=str(subject).rjust(3, '0'),
                     task=task_name,
                     root=fname.data_dir)

# save in bids format
write_raw_bids(raw, bids_path, overwrite=True)

###############################################################################
    duration = np.repeat(2, len(onsets))
    description = np.repeat('Bad', len(onsets))

    # get annotations in data
    artifacts = np.array((onsets, duration, description)).T
    # to pandas data frame
    artifacts = pd.DataFrame(artifacts, columns=annot_infos)
    # annotations from data
    annotations = pd.DataFrame(raw_copy.annotations)
    annotations = annotations[annot_infos]

    # merge artifacts and previous annotations
    artifacts = artifacts.append(annotations, ignore_index=True)

    # create new annotation info
    annotations = Annotations(artifacts['onset'], artifacts['duration'],
                              artifacts['description'])
    # apply to raw data
    raw.set_annotations(annotations)

# save total annotated time
total_time = sum(duration)
# save frequency of annotation per channel
frequency_of_annotation = {
    x: annotated_channels.count(x) * 2
    for x in annotated_channels
}

# create plot with clean data
plot_artefacts = raw.plot(scalings=dict(eeg=50e-6, eog=50e-6),
                          n_channels=len(raw.info['ch_names']),
                          title='Robust reference applied to Sub-%s' % subject,
示例#14
0
def test_plot_raw():
    """Test plotting of raw data."""
    import matplotlib.pyplot as plt
    raw = _get_raw()
    events = _get_events()
    plt.close('all')  # ensure all are closed
    with warnings.catch_warnings(record=True):
        fig = raw.plot(events=events, show_options=True)
        # test mouse clicks
        x = fig.get_axes()[0].lines[1].get_xdata().mean()
        y = fig.get_axes()[0].lines[1].get_ydata().mean()
        data_ax = fig.get_axes()[0]
        _fake_click(fig, data_ax, [x, y], xform='data')  # mark a bad channel
        _fake_click(fig, data_ax, [x, y], xform='data')  # unmark a bad channel
        _fake_click(fig, data_ax, [0.5, 0.999])  # click elsewhere in 1st axes
        _fake_click(fig, data_ax, [-0.1, 0.9])  # click on y-label
        _fake_click(fig, fig.get_axes()[1], [0.5, 0.5])  # change time
        _fake_click(fig, fig.get_axes()[2], [0.5, 0.5])  # change channels
        _fake_click(fig, fig.get_axes()[3], [0.5, 0.5])  # open SSP window
        fig.canvas.button_press_event(1, 1, 1)  # outside any axes
        fig.canvas.scroll_event(0.5, 0.5, -0.5)  # scroll down
        fig.canvas.scroll_event(0.5, 0.5, 0.5)  # scroll up
        # sadly these fail when no renderer is used (i.e., when using Agg):
        # ssp_fig = set(plt.get_fignums()) - set([fig.number])
        # assert_equal(len(ssp_fig), 1)
        # ssp_fig = plt.figure(list(ssp_fig)[0])
        # ax = ssp_fig.get_axes()[0]  # only one axis is used
        # t = [c for c in ax.get_children() if isinstance(c,
        #      matplotlib.text.Text)]
        # pos = np.array(t[0].get_position()) + 0.01
        # _fake_click(ssp_fig, ssp_fig.get_axes()[0], pos, xform='data')  # off
        # _fake_click(ssp_fig, ssp_fig.get_axes()[0], pos, xform='data')  # on
        #  test keypresses
        fig.canvas.key_press_event('escape')
        fig.canvas.key_press_event('down')
        fig.canvas.key_press_event('up')
        fig.canvas.key_press_event('right')
        fig.canvas.key_press_event('left')
        fig.canvas.key_press_event('o')
        fig.canvas.key_press_event('-')
        fig.canvas.key_press_event('+')
        fig.canvas.key_press_event('=')
        fig.canvas.key_press_event('pageup')
        fig.canvas.key_press_event('pagedown')
        fig.canvas.key_press_event('home')
        fig.canvas.key_press_event('end')
        fig.canvas.key_press_event('?')
        fig.canvas.key_press_event('f11')
        fig.canvas.key_press_event('escape')
        # Color setting
        assert_raises(KeyError, raw.plot, event_color={0: 'r'})
        assert_raises(TypeError, raw.plot, event_color={'foo': 'r'})
        annot = Annotations([10, 10 + raw.first_samp / raw.info['sfreq']],
                            [10, 10], ['test', 'test'], raw.info['meas_date'])
        raw.annotations = annot
        fig = plot_raw(raw, events=events, event_color={-1: 'r', 998: 'b'})
        plt.close('all')
        for order in [
                'position', 'selection',
                range(len(raw.ch_names))[::-1], [1, 2, 4, 6]
        ]:
            fig = raw.plot(order=order)
            x = fig.get_axes()[0].lines[1].get_xdata()[10]
            y = fig.get_axes()[0].lines[1].get_ydata()[10]
            _fake_click(fig, data_ax, [x, y], xform='data')  # mark bad
            fig.canvas.key_press_event('down')  # change selection
            _fake_click(fig, fig.get_axes()[2], [0.5, 0.5])  # change channels
            if order == 'position':  # test clicking topo to change selection
                sel_fig = plt.figure(1)
                topo_ax = sel_fig.axes[1]
                _fake_click(sel_fig,
                            topo_ax, [-0.425, 0.20223853],
                            xform='data')
                fig.canvas.key_press_event('down')
                fig.canvas.key_press_event('up')
                fig.canvas.scroll_event(0.5, 0.5, -1)  # scroll down
                fig.canvas.scroll_event(0.5, 0.5, 1)  # scroll up
            plt.close('all')
def test_ica_full_data_recovery(method):
    """Test recovery of full data when no source is rejected."""
    # Most basic recovery
    _skip_check_picard(method)
    raw = read_raw_fif(raw_fname).crop(0.5, stop).load_data()
    events = read_events(event_name)
    picks = pick_types(raw.info,
                       meg=True,
                       stim=False,
                       ecg=False,
                       eog=False,
                       exclude='bads')[:10]
    with pytest.warns(RuntimeWarning, match='projection'):
        epochs = Epochs(raw,
                        events[:4],
                        event_id,
                        tmin,
                        tmax,
                        picks=picks,
                        baseline=(None, 0),
                        preload=True)
    evoked = epochs.average()
    n_channels = 5
    data = raw._data[:n_channels].copy()
    data_epochs = epochs.get_data()
    data_evoked = evoked.data
    raw.set_annotations(Annotations([0.5], [0.5], ['BAD']))
    methods = [method]
    for method in methods:
        stuff = [(2, n_channels, True), (2, n_channels // 2, False)]
        for n_components, n_pca_components, ok in stuff:
            ica = ICA(n_components=n_components,
                      max_pca_components=n_pca_components,
                      n_pca_components=n_pca_components,
                      method=method,
                      max_iter=1)
            with pytest.warns(None):  # sometimes warns
                ica.fit(raw, picks=list(range(n_channels)))
            raw2 = ica.apply(raw.copy(), exclude=[])
            if ok:
                assert_allclose(data[:n_channels],
                                raw2._data[:n_channels],
                                rtol=1e-10,
                                atol=1e-15)
            else:
                diff = np.abs(data[:n_channels] - raw2._data[:n_channels])
                assert (np.max(diff) > 1e-14)

            ica = ICA(n_components=n_components,
                      method=method,
                      max_pca_components=n_pca_components,
                      n_pca_components=n_pca_components)
            with pytest.warns(None):  # sometimes warns
                ica.fit(epochs, picks=list(range(n_channels)))
            epochs2 = ica.apply(epochs.copy(), exclude=[])
            data2 = epochs2.get_data()[:, :n_channels]
            if ok:
                assert_allclose(data_epochs[:, :n_channels],
                                data2,
                                rtol=1e-10,
                                atol=1e-15)
            else:
                diff = np.abs(data_epochs[:, :n_channels] - data2)
                assert (np.max(diff) > 1e-14)

            evoked2 = ica.apply(evoked.copy(), exclude=[])
            data2 = evoked2.data[:n_channels]
            if ok:
                assert_allclose(data_evoked[:n_channels],
                                data2,
                                rtol=1e-10,
                                atol=1e-15)
            else:
                diff = np.abs(evoked.data[:n_channels] - data2)
                assert (np.max(diff) > 1e-14)
    pytest.raises(ValueError, ICA, method='pizza-decomposision')
    description = np.repeat('Bad', len(onsets))

    # get annotations in data
    artifacts = np.array((onsets, duration, description)).T
    # to pandas data frame
    artifacts = pd.DataFrame(artifacts, columns=annot_infos)
    # annotations from data
    annotations = pd.DataFrame(raw_copy.annotations)
    annotations = annotations[annot_infos]

    # merge artifacts and previous annotations
    artifacts = artifacts.append(annotations, ignore_index=True)

    # create new annotation info
    annotations = Annotations(artifacts['onset'],
                              artifacts['duration'],
                              artifacts['description'],
                              orig_time=raw_copy.annotations.orig_time)
    # apply to raw data
    raw.set_annotations(annotations)

# save total annotated time
total_time = sum(duration)
# save frequency of annotation per channel
frequency_of_annotation = {
    x: annotated_channels.count(x) * 2
    for x in annotated_channels
}

# create plot with clean data
plot_artefacts = raw.plot(scalings=dict(eeg=50e-6, eog=50e-6),
                          n_channels=len(raw.info['ch_names']),
示例#17
0
def test_find_events():
    """Test find events in raw file."""
    events = read_events(fname)
    raw = read_raw_fif(raw_fname, preload=True)
    # let's test the defaulting behavior while we're at it
    extra_ends = ['', '_1']
    orig_envs = [os.getenv('MNE_STIM_CHANNEL%s' % s) for s in extra_ends]
    os.environ['MNE_STIM_CHANNEL'] = 'STI 014'
    if 'MNE_STIM_CHANNEL_1' in os.environ:
        del os.environ['MNE_STIM_CHANNEL_1']
    events2 = find_events(raw)
    assert_array_almost_equal(events, events2)
    # now test with mask
    events11 = find_events(raw, mask=3, mask_type='not_and')
    with pytest.warns(RuntimeWarning, match='events masked'):
        events22 = read_events(fname, mask=3, mask_type='not_and')
    assert_array_equal(events11, events22)

    # Reset some data for ease of comparison
    raw._first_samps[0] = 0
    with raw.info._unlock():
        raw.info['sfreq'] = 1000

    stim_channel = 'STI 014'
    stim_channel_idx = pick_channels(raw.info['ch_names'],
                                     include=[stim_channel])

    # test digital masking
    raw._data[stim_channel_idx, :5] = np.arange(5)
    raw._data[stim_channel_idx, 5:] = 0
    # 1 == '0b1', 2 == '0b10', 3 == '0b11', 4 == '0b100'

    pytest.raises(TypeError, find_events, raw, mask="0", mask_type='and')
    pytest.raises(ValueError, find_events, raw, mask=0, mask_type='blah')
    # testing mask_type. default = 'not_and'
    assert_array_equal(find_events(raw, shortest_event=1, mask=1,
                                   mask_type='not_and'),
                       [[2, 0, 2], [4, 2, 4]])
    assert_array_equal(find_events(raw, shortest_event=1, mask=2,
                                   mask_type='not_and'),
                       [[1, 0, 1], [3, 0, 1], [4, 1, 4]])
    assert_array_equal(find_events(raw, shortest_event=1, mask=3,
                                   mask_type='not_and'),
                       [[4, 0, 4]])
    assert_array_equal(find_events(raw, shortest_event=1, mask=4,
                                   mask_type='not_and'),
                       [[1, 0, 1], [2, 1, 2], [3, 2, 3]])
    # testing with mask_type = 'and'
    assert_array_equal(find_events(raw, shortest_event=1, mask=1,
                                   mask_type='and'),
                       [[1, 0, 1], [3, 0, 1]])
    assert_array_equal(find_events(raw, shortest_event=1, mask=2,
                                   mask_type='and'),
                       [[2, 0, 2]])
    assert_array_equal(find_events(raw, shortest_event=1, mask=3,
                                   mask_type='and'),
                       [[1, 0, 1], [2, 1, 2], [3, 2, 3]])
    assert_array_equal(find_events(raw, shortest_event=1, mask=4,
                                   mask_type='and'),
                       [[4, 0, 4]])

    # test empty events channel
    raw._data[stim_channel_idx, :] = 0
    assert_array_equal(find_events(raw), np.empty((0, 3), dtype='int32'))

    raw._data[stim_channel_idx, :4] = 1
    assert_array_equal(find_events(raw), np.empty((0, 3), dtype='int32'))

    raw._data[stim_channel_idx, -1:] = 9
    assert_array_equal(find_events(raw), [[14399, 0, 9]])

    # Test that we can handle consecutive events with no gap
    raw._data[stim_channel_idx, 10:20] = 5
    raw._data[stim_channel_idx, 20:30] = 6
    raw._data[stim_channel_idx, 30:32] = 5
    raw._data[stim_channel_idx, 40] = 6

    assert_array_equal(find_events(raw, consecutive=False),
                       [[10, 0, 5],
                        [40, 0, 6],
                        [14399, 0, 9]])
    assert_array_equal(find_events(raw, consecutive=True),
                       [[10, 0, 5],
                        [20, 5, 6],
                        [30, 6, 5],
                        [40, 0, 6],
                        [14399, 0, 9]])
    assert_array_equal(find_events(raw),
                       [[10, 0, 5],
                        [20, 5, 6],
                        [40, 0, 6],
                        [14399, 0, 9]])
    assert_array_equal(find_events(raw, output='offset', consecutive=False),
                       [[31, 0, 5],
                        [40, 0, 6],
                        [14399, 0, 9]])
    assert_array_equal(find_events(raw, output='offset', consecutive=True),
                       [[19, 6, 5],
                        [29, 5, 6],
                        [31, 0, 5],
                        [40, 0, 6],
                        [14399, 0, 9]])
    pytest.raises(ValueError, find_events, raw, output='step',
                  consecutive=True)
    assert_array_equal(find_events(raw, output='step', consecutive=True,
                                   shortest_event=1),
                       [[10, 0, 5],
                        [20, 5, 6],
                        [30, 6, 5],
                        [32, 5, 0],
                        [40, 0, 6],
                        [41, 6, 0],
                        [14399, 0, 9],
                        [14400, 9, 0]])
    assert_array_equal(find_events(raw, output='offset'),
                       [[19, 6, 5],
                        [31, 0, 6],
                        [40, 0, 6],
                        [14399, 0, 9]])
    assert_array_equal(find_events(raw, consecutive=False, min_duration=0.002),
                       [[10, 0, 5]])
    assert_array_equal(find_events(raw, consecutive=True, min_duration=0.002),
                       [[10, 0, 5],
                        [20, 5, 6],
                        [30, 6, 5]])
    assert_array_equal(find_events(raw, output='offset', consecutive=False,
                                   min_duration=0.002),
                       [[31, 0, 5]])
    assert_array_equal(find_events(raw, output='offset', consecutive=True,
                                   min_duration=0.002),
                       [[19, 6, 5],
                        [29, 5, 6],
                        [31, 0, 5]])
    assert_array_equal(find_events(raw, consecutive=True, min_duration=0.003),
                       [[10, 0, 5],
                        [20, 5, 6]])

    # test find_stim_steps merge parameter
    raw._data[stim_channel_idx, :] = 0
    raw._data[stim_channel_idx, 0] = 1
    raw._data[stim_channel_idx, 10] = 4
    raw._data[stim_channel_idx, 11:20] = 5
    assert_array_equal(find_stim_steps(raw, pad_start=0, merge=0,
                                       stim_channel=stim_channel),
                       [[0, 0, 1],
                        [1, 1, 0],
                        [10, 0, 4],
                        [11, 4, 5],
                        [20, 5, 0]])
    assert_array_equal(find_stim_steps(raw, merge=-1,
                                       stim_channel=stim_channel),
                       [[1, 1, 0],
                        [10, 0, 5],
                        [20, 5, 0]])
    assert_array_equal(find_stim_steps(raw, merge=1,
                                       stim_channel=stim_channel),
                       [[1, 1, 0],
                        [11, 0, 5],
                        [20, 5, 0]])

    # put back the env vars we trampled on
    for s, o in zip(extra_ends, orig_envs):
        if o is not None:
            os.environ['MNE_STIM_CHANNEL%s' % s] = o

    # Test with list of stim channels
    raw._data[stim_channel_idx, 1:101] = np.zeros(100)
    raw._data[stim_channel_idx, 10:11] = 1
    raw._data[stim_channel_idx, 30:31] = 3
    stim_channel2 = 'STI 015'
    stim_channel2_idx = pick_channels(raw.info['ch_names'],
                                      include=[stim_channel2])
    raw._data[stim_channel2_idx, :] = 0
    raw._data[stim_channel2_idx, :100] = raw._data[stim_channel_idx, 5:105]
    events1 = find_events(raw, stim_channel='STI 014')
    events2 = events1.copy()
    events2[:, 0] -= 5
    events = find_events(raw, stim_channel=['STI 014', stim_channel2])
    assert_array_equal(events[::2], events2)
    assert_array_equal(events[1::2], events1)

    # test initial_event argument
    info = create_info(['MYSTI'], 1000, 'stim')
    data = np.zeros((1, 1000))
    raw = RawArray(data, info)
    data[0, :10] = 100
    data[0, 30:40] = 200
    assert_array_equal(find_events(raw, 'MYSTI'), [[30, 0, 200]])
    assert_array_equal(find_events(raw, 'MYSTI', initial_event=True),
                       [[0, 0, 100], [30, 0, 200]])

    # test error message for raw without stim channels
    raw = read_raw_fif(raw_fname, preload=True)
    raw.pick_types(meg=True, stim=False)
    # raw does not have annotations
    with pytest.raises(ValueError, match="'stim_channel'"):
        find_events(raw)
    # if raw has annotations, we show a different error message
    raw.set_annotations(Annotations(0, 2, "test"))
    with pytest.raises(ValueError, match="mne.events_from_annotations"):
        find_events(raw)
示例#18
0
def test_plot_ica_sources(raw_orig, mpl_backend):
    """Test plotting of ICA panel."""
    raw = raw_orig.copy().crop(0, 1)
    picks = _get_picks(raw)
    epochs = _get_epochs()
    raw.pick_channels([raw.ch_names[k] for k in picks])
    ica_picks = pick_types(raw.info,
                           meg=True,
                           eeg=False,
                           stim=False,
                           ecg=False,
                           eog=False,
                           exclude='bads')
    ica = ICA(n_components=2)
    ica.fit(raw, picks=ica_picks)
    ica.exclude = [1]
    fig = ica.plot_sources(raw)
    assert mpl_backend._get_n_figs() == 1
    # change which component is in ICA.exclude (click data trace to remove
    # current one; click name to add other one)
    fig._redraw()
    # ToDo: This will be different methods in pyqtgraph
    x = fig.mne.traces[1].get_xdata()[5]
    y = fig.mne.traces[1].get_ydata()[5]
    fig._fake_click((x, y), xform='data')  # exclude = []
    _click_ch_name(fig, ch_index=0, button=1)  # exclude = [0]
    fig._fake_keypress(fig.mne.close_key)
    fig._close_event()
    assert mpl_backend._get_n_figs() == 0
    assert_array_equal(ica.exclude, [0])
    # test when picks does not include ica.exclude.
    fig = ica.plot_sources(raw, picks=[1])
    assert len(plt.get_fignums()) == 1
    mpl_backend._close_all()

    # dtype can change int->np.int64 after load, test it explicitly
    ica.n_components_ = np.int64(ica.n_components_)

    # test clicks on y-label (need >2 secs for plot_properties() to work)
    long_raw = raw_orig.crop(0, 5)
    fig = ica.plot_sources(long_raw)
    assert len(plt.get_fignums()) == 1
    fig._redraw()
    _click_ch_name(fig, ch_index=0, button=3)
    assert len(fig.mne.child_figs) == 1
    assert len(plt.get_fignums()) == 2
    # close child fig directly (workaround for mpl issue #18609)
    fig._fake_keypress('escape', fig=fig.mne.child_figs[0])
    assert len(plt.get_fignums()) == 1
    fig._fake_keypress(fig.mne.close_key)
    assert len(plt.get_fignums()) == 0
    del long_raw

    # test with annotations
    orig_annot = raw.annotations
    raw.set_annotations(Annotations([0.2], [0.1], 'Test'))
    fig = ica.plot_sources(raw)
    assert len(fig.mne.ax_main.collections) == 1
    assert len(fig.mne.ax_hscroll.collections) == 1
    raw.set_annotations(orig_annot)

    # test error handling
    raw_ = raw.copy().load_data()
    raw_.drop_channels('MEG 0113')
    with pytest.raises(RuntimeError, match="Raw doesn't match fitted data"), \
         pytest.warns(RuntimeWarning, match='could not be picked'):
        ica.plot_sources(inst=raw_)
    epochs_ = epochs.copy().load_data()
    epochs_.drop_channels('MEG 0113')
    with pytest.raises(RuntimeError, match="Epochs don't match fitted data"), \
         pytest.warns(RuntimeWarning, match='could not be picked'):
        ica.plot_sources(inst=epochs_)
    del raw_
    del epochs_

    # test w/ epochs and evokeds
    ica.plot_sources(epochs)
    ica.plot_sources(epochs.average())
    evoked = epochs.average()
    fig = ica.plot_sources(evoked)
    # Test a click
    ax = fig.get_axes()[0]
    line = ax.lines[0]
    _fake_click(fig, ax, [line.get_xdata()[0], line.get_ydata()[0]], 'data')
    _fake_click(fig, ax, [ax.get_xlim()[0], ax.get_ylim()[1]], 'data')

    # plot with bad channels excluded
    ica.exclude = [0]
    ica.plot_sources(evoked)

    # pretend find_bads_eog() yielded some results
    ica.labels_ = {'eog': [0], 'eog/0/crazy-channel': [0]}
    ica.plot_sources(evoked)  # now with labels

    # pass an invalid inst
    with pytest.raises(ValueError, match='must be of Raw or Epochs type'):
        ica.plot_sources('meeow')
示例#19
0
def test_plot_raw_psd():
    """Test plotting of raw psds."""
    raw = _get_raw()
    # normal mode
    raw.plot_psd(average=False)
    # specific mode
    picks = pick_types(raw.info, meg='mag', eeg=False)[:4]
    raw.plot_psd(tmax=None, picks=picks, area_mode='range', average=False,
                 spatial_colors=True)
    raw.plot_psd(tmax=20., color='yellow', dB=False, line_alpha=0.4,
                 n_overlap=0.1, average=False)
    plt.close('all')
    ax = plt.axes()
    # if ax is supplied:
    pytest.raises(ValueError, raw.plot_psd, ax=ax, average=True)
    raw.plot_psd(tmax=None, picks=picks, ax=ax, average=True)
    plt.close('all')
    ax = plt.axes()
    with pytest.raises(ValueError, match='2 axes must be supplied, got 1'):
        raw.plot_psd(ax=ax, average=True)
    plt.close('all')
    ax = plt.subplots(2)[1]
    raw.plot_psd(tmax=None, ax=ax, average=True)
    plt.close('all')
    # topo psd
    ax = plt.subplot()
    raw.plot_psd_topo(axes=ax)
    plt.close('all')
    # with channel information not available
    for idx in range(len(raw.info['chs'])):
        raw.info['chs'][idx]['loc'] = np.zeros(12)
    with pytest.warns(RuntimeWarning, match='locations not available'):
        raw.plot_psd(spatial_colors=True, average=False)
    # with a flat channel
    raw[5, :] = 0
    for dB, estimate in itertools.product((True, False),
                                          ('power', 'amplitude')):
        with pytest.warns(UserWarning, match='[Infinite|Zero]'):
            fig = raw.plot_psd(average=True, dB=dB, estimate=estimate)
        ylabel = fig.axes[1].get_ylabel()
        ends_dB = ylabel.endswith('mathrm{(dB)}$')
        if dB:
            assert ends_dB, ylabel
        else:
            assert not ends_dB, ylabel
        if estimate == 'amplitude':
            assert r'fT/cm/\sqrt{Hz}' in ylabel, ylabel
        else:
            assert estimate == 'power'
            assert '(fT/cm)²/Hz' in ylabel, ylabel
        ylabel = fig.axes[0].get_ylabel()
        if estimate == 'amplitude':
            assert r'fT/\sqrt{Hz}' in ylabel
        else:
            assert 'fT²/Hz' in ylabel
    # test reject_by_annotation
    raw = _get_raw()
    raw.set_annotations(Annotations([1, 5], [3, 3], ['test', 'test']))
    raw.plot_psd(reject_by_annotation=True)
    raw.plot_psd(reject_by_annotation=False)
    plt.close('all')

    # test fmax value checking
    with pytest.raises(ValueError, match='not exceed one half the sampling'):
        raw.plot_psd(fmax=50000)

    # test xscale value checking
    with pytest.raises(ValueError, match="Invalid value for the 'xscale'"):
        raw.plot_psd(xscale='blah')

    # gh-5046
    raw = read_raw_fif(raw_fname, preload=True).crop(0, 1)
    picks = pick_types(raw.info)
    raw.plot_psd(picks=picks, average=False)
    raw.plot_psd(picks=picks, average=True)
    plt.close('all')
    raw.set_channel_types({'MEG 0113': 'hbo', 'MEG 0112': 'hbr',
                           'MEG 0122': 'fnirs_raw', 'MEG 0123': 'fnirs_od'},
                          verbose='error')
    fig = raw.plot_psd()
    assert len(fig.axes) == 10
    plt.close('all')

    # gh-7631
    data = 1e-3 * np.random.rand(2, 100)
    info = create_info(['CH1', 'CH2'], 100)
    raw = RawArray(data, info)
    picks = pick_types(raw.info, misc=True)
    raw.plot_psd(picks=picks, spatial_colors=False)
    plt.close('all')
示例#20
0
def test_plot_raw_psd():
    """Test plotting of raw psds."""
    raw = _get_raw()
    # normal mode
    raw.plot_psd(average=False)
    # specific mode
    picks = pick_types(raw.info, meg='mag', eeg=False)[:4]
    raw.plot_psd(tmax=None,
                 picks=picks,
                 area_mode='range',
                 average=False,
                 spatial_colors=True)
    raw.plot_psd(tmax=20.,
                 color='yellow',
                 dB=False,
                 line_alpha=0.4,
                 n_overlap=0.1,
                 average=False)
    plt.close('all')
    ax = plt.axes()
    # if ax is supplied:
    pytest.raises(ValueError, raw.plot_psd, ax=ax, average=True)
    raw.plot_psd(tmax=None, picks=picks, ax=ax, average=True)
    plt.close('all')
    ax = plt.axes()
    pytest.raises(ValueError, raw.plot_psd, ax=ax, average=True)
    plt.close('all')
    ax = plt.subplots(2)[1]
    raw.plot_psd(tmax=None, ax=ax, average=True)
    plt.close('all')
    # topo psd
    ax = plt.subplot()
    raw.plot_psd_topo(axes=ax)
    plt.close('all')
    # with channel information not available
    for idx in range(len(raw.info['chs'])):
        raw.info['chs'][idx]['loc'] = np.zeros(12)
    with pytest.warns(RuntimeWarning, match='locations not available'):
        raw.plot_psd(spatial_colors=True, average=False)
    # with a flat channel
    raw[5, :] = 0
    for dB, estimate in itertools.product((True, False),
                                          ('power', 'amplitude')):
        with pytest.warns(UserWarning, match='[Infinite|Zero]'):
            fig = raw.plot_psd(average=True, dB=dB, estimate=estimate)
        ylabel = fig.axes[1].get_ylabel()
        ends_dB = ylabel.endswith('mathrm{(dB)}$')
        if dB:
            assert ends_dB, ylabel
        else:
            assert not ends_dB, ylabel
        if estimate == 'amplitude':
            assert r'fT/cm/\sqrt{Hz}' in ylabel, ylabel
        else:
            assert estimate == 'power'
            assert '(fT/cm)²/Hz' in ylabel, ylabel
        ylabel = fig.axes[0].get_ylabel()
        if estimate == 'amplitude':
            assert r'fT/\sqrt{Hz}' in ylabel
        else:
            assert 'fT²/Hz' in ylabel
    # test reject_by_annotation
    raw = _get_raw()
    raw.set_annotations(Annotations([1, 5], [3, 3], ['test', 'test']))
    raw.plot_psd(reject_by_annotation=True)
    raw.plot_psd(reject_by_annotation=False)

    # gh-5046
    raw = read_raw_fif(raw_fname, preload=True).crop(0, 1)
    picks = pick_types(raw.info)
    raw.plot_psd(picks=picks, average=False)
    raw.plot_psd(picks=picks, average=True)
    plt.close('all')
示例#21
0
def test_annotation_filtering():
    """Test that annotations work properly with filtering."""
    # Create data with just a DC component
    data = np.ones((1, 1000))
    info = create_info(1, 1000., 'eeg')
    raws = [RawArray(data * (ii + 1), info) for ii in range(4)]
    kwargs_pass = dict(l_freq=None, h_freq=50., fir_design='firwin')
    kwargs_stop = dict(l_freq=50., h_freq=None, fir_design='firwin')
    # lowpass filter, which should not modify the data
    raws_pass = [raw.copy().filter(**kwargs_pass) for raw in raws]
    # highpass filter, which should zero it out
    raws_stop = [raw.copy().filter(**kwargs_stop) for raw in raws]
    # concat the original and the filtered segments
    raws_concat = concatenate_raws([raw.copy() for raw in raws])
    raws_zero = raws_concat.copy().apply_function(lambda x: x * 0)
    raws_pass_concat = concatenate_raws(raws_pass)
    raws_stop_concat = concatenate_raws(raws_stop)
    # make sure we did something reasonable with our individual-file filtering
    assert_allclose(raws_concat[0][0], raws_pass_concat[0][0], atol=1e-14)
    assert_allclose(raws_zero[0][0], raws_stop_concat[0][0], atol=1e-14)
    # ensure that our Annotations cut up the filtering properly
    raws_concat_pass = raws_concat.copy().filter(skip_by_annotation='edge',
                                                 **kwargs_pass)
    assert_allclose(raws_concat[0][0], raws_concat_pass[0][0], atol=1e-14)
    raws_concat_stop = raws_concat.copy().filter(skip_by_annotation='edge',
                                                 **kwargs_stop)
    assert_allclose(raws_zero[0][0], raws_concat_stop[0][0], atol=1e-14)
    # one last test: let's cut out a section entirely:
    # here the 1-3 second window should be skipped
    raw = raws_concat.copy()
    raw.annotations.append(1., 2., 'foo')
    with catch_logging() as log:
        raw.filter(l_freq=50., h_freq=None, fir_design='firwin',
                   skip_by_annotation='foo', verbose='info')
    log = log.getvalue()
    assert '2 contiguous segments' in log
    raw.annotations.append(2., 1., 'foo')  # shouldn't change anything
    with catch_logging() as log:
        raw.filter(l_freq=50., h_freq=None, fir_design='firwin',
                   skip_by_annotation='foo', verbose='info')
    log = log.getvalue()
    assert '2 contiguous segments' in log
    # our filter will zero out anything not skipped:
    mask = np.concatenate((np.zeros(1000), np.ones(2000), np.zeros(1000)))
    expected_data = raws_concat[0][0][0] * mask
    assert_allclose(raw[0][0][0], expected_data, atol=1e-14)

    # Let's try another one
    raw = raws[0].copy()
    raw.set_annotations(Annotations([0.], [0.5], ['BAD_ACQ_SKIP']))
    my_data, times = raw.get_data(reject_by_annotation='omit',
                                  return_times=True)
    assert_allclose(times, raw.times[500:])
    assert my_data.shape == (1, 500)
    raw_filt = raw.copy().filter(skip_by_annotation='bad_acq_skip',
                                 **kwargs_stop)
    expected = data.copy()
    expected[:, 500:] = 0
    assert_allclose(raw_filt[:][0], expected, atol=1e-14)

    raw = raws[0].copy()
    raw.set_annotations(Annotations([0.5], [0.5], ['BAD_ACQ_SKIP']))
    my_data, times = raw.get_data(reject_by_annotation='omit',
                                  return_times=True)
    assert_allclose(times, raw.times[:500])
    assert my_data.shape == (1, 500)
    raw_filt = raw.copy().filter(skip_by_annotation='bad_acq_skip',
                                 **kwargs_stop)
    expected = data.copy()
    expected[:, :500] = 0
    assert_allclose(raw_filt[:][0], expected, atol=1e-14)
示例#22
0
def test_plot_ica_sources():
    """Test plotting of ICA panel."""
    raw = read_raw_fif(raw_fname).crop(0, 1).load_data()
    picks = _get_picks(raw)
    epochs = _get_epochs()
    raw.pick_channels([raw.ch_names[k] for k in picks])
    ica_picks = pick_types(raw.info, meg=True, eeg=False, stim=False,
                           ecg=False, eog=False, exclude='bads')
    ica = ICA(n_components=2, max_pca_components=3, n_pca_components=3)
    ica.fit(raw, picks=ica_picks)
    ica.exclude = [1]
    fig = ica.plot_sources(raw)
    fig.canvas.key_press_event('escape')
    # Sadly close_event isn't called on Agg backend and the test always passes.
    assert_array_equal(ica.exclude, [1])
    plt.close('all')

    # dtype can change int->np.int after load, test it explicitly
    ica.n_components_ = np.int64(ica.n_components_)
    fig = ica.plot_sources(raw)
    # also test mouse clicks
    data_ax = fig.axes[0]
    assert len(plt.get_fignums()) == 1
    _fake_click(fig, data_ax, [-0.1, 0.9])  # click on y-label
    assert len(plt.get_fignums()) == 2
    ica.exclude = [1]
    ica.plot_sources(raw)

    # test with annotations
    orig_annot = raw.annotations
    raw.set_annotations(Annotations([0.2], [0.1], 'Test'))
    fig = ica.plot_sources(raw)
    assert len(fig.axes[0].collections) == 1
    assert len(fig.axes[1].collections) == 1
    raw.set_annotations(orig_annot)

    raw.info['bads'] = ['MEG 0113']
    with pytest.raises(RuntimeError, match="Raw doesn't match fitted data"):
        ica.plot_sources(inst=raw)
    ica.plot_sources(epochs)
    epochs.info['bads'] = ['MEG 0113']
    with pytest.raises(RuntimeError, match="Epochs don't match fitted data"):
        ica.plot_sources(inst=epochs)
    epochs.info['bads'] = []
    ica.plot_sources(epochs.average())
    evoked = epochs.average()
    fig = ica.plot_sources(evoked)
    # Test a click
    ax = fig.get_axes()[0]
    line = ax.lines[0]
    _fake_click(fig, ax,
                [line.get_xdata()[0], line.get_ydata()[0]], 'data')
    _fake_click(fig, ax,
                [ax.get_xlim()[0], ax.get_ylim()[1]], 'data')
    # plot with bad channels excluded
    ica.exclude = [0]
    ica.plot_sources(evoked)
    ica.labels_ = dict(eog=[0])
    ica.labels_['eog/0/crazy-channel'] = [0]
    ica.plot_sources(evoked)  # now with labels
    with pytest.raises(ValueError, match='must be of Raw or Epochs type'):
        ica.plot_sources('meeow')
    plt.close('all')
示例#23
0
def test_events_from_annot_in_raw_objects():
    """Test basic functionality of events_fron_annot for raw objects."""
    raw = read_raw_fif(fif_fname)
    events = mne.find_events(raw)
    event_id = {
        'Auditory/Left': 1,
        'Auditory/Right': 2,
        'Visual/Left': 3,
        'Visual/Right': 4,
        'Visual/Smiley': 32,
        'Motor/Button': 5
    }
    event_map = {v: k for k, v in event_id.items()}
    annot = Annotations(onset=raw.times[events[:, 0] - raw.first_samp],
                        duration=np.zeros(len(events)),
                        description=[event_map[vv] for vv in events[:, 2]],
                        orig_time=None)
    raw.set_annotations(annot)

    events2, event_id2 = \
        events_from_annotations(raw, event_id=event_id, regexp=None)
    assert_array_equal(events, events2)
    assert_equal(event_id, event_id2)

    events3, event_id3 = \
        events_from_annotations(raw, event_id=None, regexp=None)

    assert_array_equal(events[:, 0], events3[:, 0])
    assert set(event_id.keys()) == set(event_id3.keys())

    first = np.unique(events3[:, 2])
    second = np.arange(1, len(event_id) + 1, 1).astype(first.dtype)
    assert_array_equal(first, second)

    first = np.unique(list(event_id3.values()))
    second = np.arange(1, len(event_id) + 1, 1).astype(first.dtype)
    assert_array_equal(first, second)

    events4, event_id4 =\
        events_from_annotations(raw, event_id=None, regexp='.*Left')

    expected_event_id4 = {k: v for k, v in event_id.items() if 'Left' in k}
    assert_equal(event_id4.keys(), expected_event_id4.keys())

    expected_events4 = events[(events[:, 2] == 1) | (events[:, 2] == 3)]
    assert_array_equal(expected_events4[:, 0], events4[:, 0])

    events5, event_id5 = \
        events_from_annotations(raw, event_id=event_id, regexp='.*Left')

    expected_event_id5 = {k: v for k, v in event_id.items() if 'Left' in k}
    assert_equal(event_id5, expected_event_id5)

    expected_events5 = events[(events[:, 2] == 1) | (events[:, 2] == 3)]
    assert_array_equal(expected_events5, events5)

    with pytest.raises(ValueError, match='not find any of the events'):
        events_from_annotations(raw, regexp='not_there')

    raw.set_annotations(None)
    events7, _ = events_from_annotations(raw)
    assert_array_equal(events7, np.empty((0, 3), dtype=int))
def test_events_from_annot_in_raw_objects():
    """Test basic functionality of events_fron_annot for raw objects."""
    raw = read_raw_fif(fif_fname)
    events = mne.find_events(raw)
    event_id = {
        'Auditory/Left': 1,
        'Auditory/Right': 2,
        'Visual/Left': 3,
        'Visual/Right': 4,
        'Visual/Smiley': 32,
        'Motor/Button': 5
    }
    event_map = {v: k for k, v in event_id.items()}
    annot = Annotations(onset=raw.times[events[:, 0] - raw.first_samp],
                        duration=np.zeros(len(events)),
                        description=[event_map[vv] for vv in events[:, 2]],
                        orig_time=None)
    raw.set_annotations(annot)

    events2, event_id2 = \
        events_from_annotations(raw, event_id=event_id, regexp=None)
    assert_array_equal(events, events2)
    assert_equal(event_id, event_id2)

    events3, event_id3 = \
        events_from_annotations(raw, event_id=None, regexp=None)

    assert_array_equal(events[:, 0], events3[:, 0])
    assert set(event_id.keys()) == set(event_id3.keys())

    # ensure that these actually got sorted properly
    expected_event_id = {
        desc: idx + 1
        for idx, desc in enumerate(sorted(event_id.keys()))
    }
    assert event_id3 == expected_event_id

    first = np.unique(events3[:, 2])
    second = np.arange(1, len(event_id) + 1, 1).astype(first.dtype)
    assert_array_equal(first, second)

    first = np.unique(list(event_id3.values()))
    second = np.arange(1, len(event_id) + 1, 1).astype(first.dtype)
    assert_array_equal(first, second)

    events4, event_id4 =\
        events_from_annotations(raw, event_id=None, regexp='.*Left')

    expected_event_id4 = {k: v for k, v in event_id.items() if 'Left' in k}
    assert_equal(event_id4.keys(), expected_event_id4.keys())

    expected_events4 = events[(events[:, 2] == 1) | (events[:, 2] == 3)]
    assert_array_equal(expected_events4[:, 0], events4[:, 0])

    events5, event_id5 = \
        events_from_annotations(raw, event_id=event_id, regexp='.*Left')

    expected_event_id5 = {k: v for k, v in event_id.items() if 'Left' in k}
    assert_equal(event_id5, expected_event_id5)

    expected_events5 = events[(events[:, 2] == 1) | (events[:, 2] == 3)]
    assert_array_equal(expected_events5, events5)

    with pytest.raises(ValueError, match='not find any of the events'):
        events_from_annotations(raw, regexp='not_there')

    with pytest.raises(ValueError, match='Invalid input event_id'):
        events_from_annotations(raw, event_id='wrong')

    # concat does not introduce BAD or EDGE
    raw_concat = concatenate_raws([raw.copy(), raw.copy()])
    _, event_id = events_from_annotations(raw_concat)
    assert isinstance(event_id, dict)
    assert len(event_id) > 0
    for kind in ('BAD', 'EDGE'):
        assert '%s boundary' % kind in raw_concat.annotations.description
        for key in event_id.keys():
            assert kind not in key

    # remove all events
    raw.set_annotations(None)
    events7, _ = events_from_annotations(raw)
    assert_array_equal(events7, np.empty((0, 3), dtype=int))
示例#25
0
def to_mne_eeg(eegstream=None,
               line_freq=None,
               filenames=None,
               nasion=None,
               lpa=None,
               rpa=None):
    '''Convert recordings to MNE format.

    Args:
        eegstream : array
            EEG streams previously imported.
        line_freq : int
            Powerline frequency (50 or 60).
        filenames : array
            Full path of XDF files. Used for recording identification.
        nasion : array, shape(3,)
            Position of the nasion fiducial point.
            If specified, the array must have the same lenght of eegstream.
            Format for every recording (X, Y, Z) in meters: [0,0,0]
        lpa : array, shape(3,)
            Position of the left periauricular fiducial point.
            If specified, the array must have the same lenght of eegstream.
            Format for every recording (X, Y, Z) in meters: [0,0,0]
        rpa : array, shape(3,)
            Position of the right periauricular fiducial point.
            If specified, the array must have the same lenght of eegstream.
            Format for every recording (X, Y, Z) in meters: [0,0,0]
    Returns:
        Array of MNE RawArray instances with the recordings specified in eegstream.
    Raises:
        ValueError: if no stream is specified in eegstream or powerline frequency is not 50 or 60.
    See also:
        read_raw_xdf
        read_raw_xdf_dir
    '''
    if eegstream is None:
        raise (ValueError('Enter parameter array of EEG recordings.'))

    if line_freq is None or line_freq not in [50, 60]:
        raise (ValueError(
            'Enter the powerline frequency of your region (50 Hz or 60 Hz).'))

    eegstream = [eegstream] if not isinstance(eegstream, list) else eegstream

    raweeg = []

    # Get the names of the channels
    ch_names = [
        eegstream[0]['info']['desc'][0]['channels'][0]['channel'][i]['label']
        [0] for i in range(len(eegstream[0]['time_series'][0]))
    ]

    # Define sensor coordinates
    sensor_coord = [[-0.0856192, -0.0465147, -0.0457070],
                    [-0.0548397, 0.0685722, -0.0105900],
                    [0.0557433, 0.0696568, -0.0107550],
                    [0.0861618, -0.0470353, -0.0458690]]

    for index, stream in enumerate(eegstream):
        # Get channels position
        dig_montage = channels.make_dig_montage(
            ch_pos=dict(zip(ch_names, sensor_coord)),
            nasion=nasion[index] if nasion is not None else None,
            lpa=lpa[index] if lpa is not None else None,
            rpa=rpa[index] if rpa is not None else None,
            coord_frame='head')
        # Create raw info for processing
        info = create_info(ch_names=dig_montage.ch_names,
                           sfreq=float(stream['info']['nominal_srate'][0]),
                           ch_types='eeg')
        # Add channels position to info
        info.set_montage(dig_montage)
        # Convert data from microvolts to volts
        conv_data = stream["time_series"] * 1e-6
        # Reorder channels
        ord_data = [[sublist[1][item] for item in [1, 2, 3, 0]]
                    for sublist in enumerate(conv_data)]
        # Create raw data for mne
        raw = io.RawArray(np.array(ord_data).T, info)
        # Get the information of each stream
        stream_info = stream['info']['name'][0][:9] + ' ' + (
            filenames[index] if filenames is not None else '')
        # Print the information of each stream
        print('\nInfo: ' + str(index) + ' ' + stream_info + '\n')
        # Add the powerline frequency of each stream
        raw.info['line_freq'] = line_freq
        # Create annotation to store the name of the device and the filenames
        annotations = Annotations(0, 0, stream_info)
        # Add the annotations to raw data
        raw.set_annotations(annotations)
        # Add the RawArray object to list of eeg recordings
        raweeg.append(raw)

    return raweeg
示例#26
0
def test_plot_raw_traces(raw):
    """Test plotting of raw data."""
    raw.info['lowpass'] = 10.  # allow heavy decim during plotting
    events = _get_events()
    plt.close('all')  # ensure all are closed
    fig = raw.plot(events=events, order=[1, 7, 5, 2, 3], n_channels=3,
                   group_by='original')
    assert hasattr(fig, 'mne')  # make sure fig.mne param object is present
    assert len(fig.axes) == 5

    # setup
    x = fig.mne.traces[0].get_xdata()[5]
    y = fig.mne.traces[0].get_ydata()[5]
    data_ax = fig.mne.ax_main
    hscroll = fig.mne.ax_hscroll
    vscroll = fig.mne.ax_vscroll
    # test marking bad channels
    label = fig.mne.ax_main.get_yticklabels()[0].get_text()
    assert label not in fig.mne.info['bads']
    _fake_click(fig, data_ax, [x, y], xform='data')  # click data to mark bad
    assert label in fig.mne.info['bads']
    _fake_click(fig, data_ax, [x, y], xform='data')  # click data to unmark bad
    assert label not in fig.mne.info['bads']
    _click_ch_name(fig, ch_index=0, button=1)        # click name to mark bad
    assert label in fig.mne.info['bads']
    # test other kinds of clicks
    _fake_click(fig, data_ax, [0.5, 0.999])  # click elsewhere (add vline)
    _fake_click(fig, data_ax, [0.5, 0.999], button=3)  # remove vline
    _fake_click(fig, hscroll, [0.5, 0.5])  # change time
    _fake_click(fig, hscroll, [0.5, 0.5])  # shouldn't change time this time
    # test scrolling through channels
    labels = [label.get_text() for label in data_ax.get_yticklabels()]
    assert labels == [raw.ch_names[1], raw.ch_names[7], raw.ch_names[5]]
    _fake_click(fig, vscroll, [0.5, 0.01])  # change channels to end
    labels = [label.get_text() for label in data_ax.get_yticklabels()]
    assert labels == [raw.ch_names[5], raw.ch_names[2], raw.ch_names[3]]
    for _ in (0, 0):
        # first click changes channels to mid; second time shouldn't change
        _fake_click(fig, vscroll, [0.5, 0.5])
        labels = [label.get_text() for label in data_ax.get_yticklabels()]
        assert labels == [raw.ch_names[7], raw.ch_names[5], raw.ch_names[2]]
        assert len(plt.get_fignums()) == 1

    # test clicking a channel name in butterfly mode
    bads = fig.mne.info['bads'].copy()
    fig.canvas.key_press_event('b')
    _click_ch_name(fig, ch_index=0, button=1)  # should be no-op
    assert fig.mne.info['bads'] == bads        # unchanged
    fig.canvas.key_press_event('b')

    # test starting up in zen mode
    fig = plot_raw(raw, show_scrollbars=False)
    # test order, title, & show_options kwargs
    with pytest.raises(ValueError, match='order should be array-like; got'):
        raw.plot(order='foo')
    with pytest.raises(TypeError, match='title must be None or a string, got'):
        raw.plot(title=1)
    raw.plot(show_options=True)
    plt.close('all')

    # Color setting
    with pytest.raises(KeyError, match='must be strictly positive, or -1'):
        raw.plot(event_color={0: 'r'})
    with pytest.raises(TypeError, match='event_color key must be an int, got'):
        raw.plot(event_color={'foo': 'r'})
    annot = Annotations([10, 10 + raw.first_samp / raw.info['sfreq']],
                        [10, 10], ['test', 'test'], raw.info['meas_date'])
    with pytest.warns(RuntimeWarning, match='outside data range'):
        raw.set_annotations(annot)
    fig = plot_raw(raw, events=events, event_color={-1: 'r', 998: 'b'})
    plt.close('all')
    for group_by, order in zip(['position', 'selection'],
                               [np.arange(len(raw.ch_names))[::-3],
                                [1, 2, 4, 6]]):
        fig = raw.plot(group_by=group_by, order=order)
        x = fig.get_axes()[0].lines[1].get_xdata()[10]
        y = fig.get_axes()[0].lines[1].get_ydata()[10]
        _fake_click(fig, data_ax, [x, y], xform='data')  # mark bad
        fig.canvas.key_press_event('down')  # change selection
        _fake_click(fig, fig.get_axes()[2], [0.5, 0.5])  # change channels
        sel_fig = plt.figure(1)
        topo_ax = sel_fig.axes[1]
        _fake_click(sel_fig, topo_ax, [-0.425, 0.20223853],
                    xform='data')
        fig.canvas.key_press_event('down')
        fig.canvas.key_press_event('up')
        fig.canvas.scroll_event(0.5, 0.5, -1)  # scroll down
        fig.canvas.scroll_event(0.5, 0.5, 1)  # scroll up
        _fake_click(sel_fig, topo_ax, [-0.5, 0.], xform='data')
        _fake_click(sel_fig, topo_ax, [0.5, 0.], xform='data',
                    kind='motion')
        _fake_click(sel_fig, topo_ax, [0.5, 0.5], xform='data',
                    kind='motion')
        _fake_click(sel_fig, topo_ax, [-0.5, 0.5], xform='data',
                    kind='release')

        plt.close('all')
    # test if meas_date is off
    raw.set_meas_date(_dt_to_stamp(raw.info['meas_date'])[0])
    annot = Annotations([1 + raw.first_samp / raw.info['sfreq']],
                        [5], ['bad'])
    with pytest.warns(RuntimeWarning, match='outside data range'):
        raw.set_annotations(annot)
    with pytest.warns(None):  # sometimes projection
        raw.plot(group_by='position', order=np.arange(8))
    for fig_num in plt.get_fignums():
        fig = plt.figure(fig_num)
        if hasattr(fig, 'radio'):  # Get access to selection fig.
            break
    for key in ['down', 'up', 'escape']:
        fig.canvas.key_press_event(key)

    raw._data[:] = np.nan
    # this should (at least) not die, the output should pretty clearly show
    # that there is a problem so probably okay to just plot something blank
    with pytest.warns(None):
        raw.plot(scalings='auto')

    plt.close('all')
示例#27
0
def test_plot_ica_properties():
    """Test plotting of ICA properties."""
    raw = _get_raw(preload=True).crop(0, 5)
    raw.add_proj([], remove_existing=True)
    events = make_fixed_length_events(raw)
    picks = _get_picks(raw)[:6]
    pick_names = [raw.ch_names[k] for k in picks]
    raw.pick_channels(pick_names)
    reject = dict(grad=4000e-13, mag=4e-12)

    epochs = Epochs(raw,
                    events[:3],
                    event_id,
                    tmin,
                    tmax,
                    baseline=(None, 0),
                    preload=True)

    ica = ICA(noise_cov=read_cov(cov_fname),
              n_components=2,
              max_iter=1,
              random_state=0)
    with pytest.warns(RuntimeWarning, match='projection'):
        ica.fit(raw)

    # test _create_properties_layout
    fig, ax = _create_properties_layout()
    assert_equal(len(ax), 5)
    with pytest.raises(ValueError, match='specify both fig and figsize'):
        _create_properties_layout(figsize=(2, 2), fig=fig)

    topoargs = dict(topomap_args={'res': 4, 'contours': 0, "sensors": False})
    with catch_logging() as log:
        ica.plot_properties(raw, picks=0, verbose='debug', **topoargs)
    log = log.getvalue()
    assert raw.ch_names[0] == 'MEG 0113'
    assert 'Interpolation mode local to mean' in log, log
    ica.plot_properties(epochs, picks=1, dB=False, plot_std=1.5, **topoargs)
    ica.plot_properties(epochs,
                        picks=1,
                        image_args={'sigma': 1.5},
                        topomap_args={
                            'res': 4,
                            'colorbar': True
                        },
                        psd_args={'fmax': 65.},
                        plot_std=False,
                        figsize=[4.5, 4.5],
                        reject=reject)
    plt.close('all')

    with pytest.raises(TypeError, match='must be an instance'):
        ica.plot_properties(epochs, dB=list('abc'))
    with pytest.raises(TypeError, match='must be an instance'):
        ica.plot_properties(ica)
    with pytest.raises(TypeError, match='must be an instance'):
        ica.plot_properties([0.2])
    with pytest.raises(TypeError, match='must be an instance'):
        plot_ica_properties(epochs, epochs)
    with pytest.raises(TypeError, match='must be an instance'):
        ica.plot_properties(epochs, psd_args='not dict')
    with pytest.raises(TypeError, match='must be an instance'):
        ica.plot_properties(epochs, plot_std=[])

    fig, ax = plt.subplots(2, 3)
    ax = ax.ravel()[:-1]
    ica.plot_properties(epochs, picks=1, axes=ax, **topoargs)
    pytest.raises(TypeError,
                  plot_ica_properties,
                  epochs,
                  ica,
                  picks=[0, 1],
                  axes=ax)
    pytest.raises(ValueError, ica.plot_properties, epochs, axes='not axes')
    plt.close('all')

    # Test merging grads.
    pick_names = raw.ch_names[:15:2] + raw.ch_names[1:15:2]
    raw = _get_raw(preload=True).pick_channels(pick_names).crop(0, 5)
    raw.info.normalize_proj()
    ica = ICA(random_state=0, max_iter=1)
    with pytest.warns(UserWarning, match='did not converge'):
        ica.fit(raw)
    ica.plot_properties(raw)
    plt.close('all')

    # Test handling of zeros
    ica = ICA(random_state=0, max_iter=1)
    epochs.pick_channels(pick_names)
    with pytest.warns(UserWarning, match='did not converge'):
        ica.fit(epochs)
    epochs._data[0] = 0
    with pytest.warns(None):  # Usually UserWarning: Infinite value .* for epo
        ica.plot_properties(epochs, **topoargs)
    plt.close('all')

    # Test Raw with annotations
    annot = Annotations(onset=[1], duration=[1], description=['BAD'])
    raw_annot = _get_raw(preload=True).set_annotations(annot).crop(0, 8)
    raw_annot.pick(np.arange(10))
    raw_annot.del_proj()

    with pytest.warns(UserWarning, match='did not converge'):
        ica.fit(raw_annot)
    # drop bad data segments
    fig = ica.plot_properties(raw_annot, picks=[0, 1], **topoargs)
    assert_equal(len(fig), 2)
    # don't drop
    ica.plot_properties(raw_annot, reject_by_annotation=False, **topoargs)
def test_ica_additional(method):
    """Test additional ICA functionality."""
    _skip_check_picard(method)

    import matplotlib.pyplot as plt
    tempdir = _TempDir()
    stop2 = 500
    raw = read_raw_fif(raw_fname).crop(1.5, stop).load_data()
    raw.del_proj()  # avoid warnings
    raw.set_annotations(Annotations([0.5], [0.5], ['BAD']))
    # XXX This breaks the tests :(
    # raw.info['bads'] = [raw.ch_names[1]]
    test_cov = read_cov(test_cov_name)
    events = read_events(event_name)
    picks = pick_types(raw.info,
                       meg=True,
                       stim=False,
                       ecg=False,
                       eog=False,
                       exclude='bads')[1::2]
    epochs = Epochs(raw,
                    events,
                    None,
                    tmin,
                    tmax,
                    picks=picks,
                    baseline=(None, 0),
                    preload=True,
                    proj=False)
    epochs.decimate(3, verbose='error')
    assert len(epochs) == 4

    # test if n_components=None works
    ica = ICA(n_components=None,
              max_pca_components=None,
              n_pca_components=None,
              random_state=0,
              method=method,
              max_iter=1)
    with pytest.warns(UserWarning, match='did not converge'):
        ica.fit(epochs)
    # for testing eog functionality
    picks2 = np.concatenate([picks, pick_types(raw.info, False, eog=True)])
    epochs_eog = Epochs(raw,
                        events[:4],
                        event_id,
                        tmin,
                        tmax,
                        picks=picks2,
                        baseline=(None, 0),
                        preload=True)
    del picks2

    test_cov2 = test_cov.copy()
    ica = ICA(noise_cov=test_cov2,
              n_components=3,
              max_pca_components=4,
              n_pca_components=4,
              method=method)
    assert (ica.info is None)
    with pytest.warns(RuntimeWarning, match='normalize_proj'):
        ica.fit(raw, picks[:5])
    assert (isinstance(ica.info, Info))
    assert (ica.n_components_ < 5)

    ica = ICA(n_components=3,
              max_pca_components=4,
              method=method,
              n_pca_components=4,
              random_state=0)
    pytest.raises(RuntimeError, ica.save, '')

    ica.fit(raw, picks=[1, 2, 3, 4, 5], start=start, stop=stop2)

    # check passing a ch_name to find_bads_ecg
    with pytest.warns(RuntimeWarning, match='longer'):
        _, scores_1 = ica.find_bads_ecg(raw)
        _, scores_2 = ica.find_bads_ecg(raw, raw.ch_names[1])
    assert scores_1[0] != scores_2[0]

    # test corrmap
    ica2 = ica.copy()
    ica3 = ica.copy()
    corrmap([ica, ica2], (0, 0),
            threshold='auto',
            label='blinks',
            plot=True,
            ch_type="mag")
    corrmap([ica, ica2], (0, 0), threshold=2, plot=False, show=False)
    assert (ica.labels_["blinks"] == ica2.labels_["blinks"])
    assert (0 in ica.labels_["blinks"])
    # test retrieval of component maps as arrays
    components = ica.get_components()
    template = components[:, 0]
    EvokedArray(components, ica.info, tmin=0.).plot_topomap([0], time_unit='s')

    corrmap([ica, ica3],
            template,
            threshold='auto',
            label='blinks',
            plot=True,
            ch_type="mag")
    assert (ica2.labels_["blinks"] == ica3.labels_["blinks"])

    plt.close('all')

    # test warnings on bad filenames
    ica_badname = op.join(op.dirname(tempdir), 'test-bad-name.fif.gz')
    with pytest.warns(RuntimeWarning, match='-ica.fif'):
        ica.save(ica_badname)
    with pytest.warns(RuntimeWarning, match='-ica.fif'):
        read_ica(ica_badname)

    # test decim
    ica = ICA(n_components=3,
              max_pca_components=4,
              n_pca_components=4,
              method=method,
              max_iter=1)
    raw_ = raw.copy()
    for _ in range(3):
        raw_.append(raw_)
    n_samples = raw_._data.shape[1]
    with pytest.warns(UserWarning, match='did not converge'):
        ica.fit(raw, picks=picks[:5], decim=3)
    assert raw_._data.shape[1] == n_samples

    # test expl var
    ica = ICA(n_components=1.0,
              max_pca_components=4,
              n_pca_components=4,
              method=method,
              max_iter=1)
    with pytest.warns(UserWarning, match='did not converge'):
        ica.fit(raw, picks=None, decim=3)
    assert (ica.n_components_ == 4)
    ica_var = _ica_explained_variance(ica, raw, normalize=True)
    assert (np.all(ica_var[:-1] >= ica_var[1:]))

    # test ica sorting
    ica.exclude = [0]
    ica.labels_ = dict(blink=[0], think=[1])
    ica_sorted = _sort_components(ica, [3, 2, 1, 0], copy=True)
    assert_equal(ica_sorted.exclude, [3])
    assert_equal(ica_sorted.labels_, dict(blink=[3], think=[2]))

    # epochs extraction from raw fit
    pytest.raises(RuntimeError, ica.get_sources, epochs)
    # test reading and writing
    test_ica_fname = op.join(op.dirname(tempdir), 'test-ica.fif')
    for cov in (None, test_cov):
        ica = ICA(noise_cov=cov,
                  n_components=2,
                  max_pca_components=4,
                  n_pca_components=4,
                  method=method,
                  max_iter=1)
        with pytest.warns(None):  # ICA does not converge
            ica.fit(raw, picks=picks[:10], start=start, stop=stop2)
        sources = ica.get_sources(epochs).get_data()
        assert (ica.mixing_matrix_.shape == (2, 2))
        assert (ica.unmixing_matrix_.shape == (2, 2))
        assert (ica.pca_components_.shape == (4, 10))
        assert (sources.shape[1] == ica.n_components_)

        for exclude in [[], [0], np.array([1, 2, 3])]:
            ica.exclude = exclude
            ica.labels_ = {'foo': [0]}
            ica.save(test_ica_fname)
            ica_read = read_ica(test_ica_fname)
            assert (list(ica.exclude) == ica_read.exclude)
            assert_equal(ica.labels_, ica_read.labels_)
            ica.apply(raw)
            ica.exclude = []
            ica.apply(raw, exclude=[1])
            assert (ica.exclude == [])

            ica.exclude = [0, 1]
            ica.apply(raw, exclude=[1])
            assert (ica.exclude == [0, 1])

            ica_raw = ica.get_sources(raw)
            assert (ica.exclude == [
                ica_raw.ch_names.index(e) for e in ica_raw.info['bads']
            ])

        # test filtering
        d1 = ica_raw._data[0].copy()
        ica_raw.filter(4, 20, fir_design='firwin2')
        assert_equal(ica_raw.info['lowpass'], 20.)
        assert_equal(ica_raw.info['highpass'], 4.)
        assert ((d1 != ica_raw._data[0]).any())
        d1 = ica_raw._data[0].copy()
        ica_raw.notch_filter([10], trans_bandwidth=10, fir_design='firwin')
        assert ((d1 != ica_raw._data[0]).any())

        ica.n_pca_components = 2
        ica.method = 'fake'
        ica.save(test_ica_fname)
        ica_read = read_ica(test_ica_fname)
        assert (ica.n_pca_components == ica_read.n_pca_components)
        assert_equal(ica.method, ica_read.method)
        assert_equal(ica.labels_, ica_read.labels_)

        # check type consistency
        attrs = ('mixing_matrix_ unmixing_matrix_ pca_components_ '
                 'pca_explained_variance_ pre_whitener_')

        def f(x, y):
            return getattr(x, y).dtype

        for attr in attrs.split():
            assert_equal(f(ica_read, attr), f(ica, attr))

        ica.n_pca_components = 4
        ica_read.n_pca_components = 4

        ica.exclude = []
        ica.save(test_ica_fname)
        ica_read = read_ica(test_ica_fname)
        for attr in [
                'mixing_matrix_', 'unmixing_matrix_', 'pca_components_',
                'pca_mean_', 'pca_explained_variance_', 'pre_whitener_'
        ]:
            assert_array_almost_equal(getattr(ica, attr),
                                      getattr(ica_read, attr))

        assert (ica.ch_names == ica_read.ch_names)
        assert (isinstance(ica_read.info, Info))

        sources = ica.get_sources(raw)[:, :][0]
        sources2 = ica_read.get_sources(raw)[:, :][0]
        assert_array_almost_equal(sources, sources2)

        _raw1 = ica.apply(raw, exclude=[1])
        _raw2 = ica_read.apply(raw, exclude=[1])
        assert_array_almost_equal(_raw1[:, :][0], _raw2[:, :][0])

    os.remove(test_ica_fname)
    # check score funcs
    for name, func in get_score_funcs().items():
        if name in score_funcs_unsuited:
            continue
        scores = ica.score_sources(raw,
                                   target='EOG 061',
                                   score_func=func,
                                   start=0,
                                   stop=10)
        assert (ica.n_components_ == len(scores))

    # check univariate stats
    scores = ica.score_sources(raw, score_func=stats.skew)
    # check exception handling
    pytest.raises(ValueError, ica.score_sources, raw, target=np.arange(1))

    params = []
    params += [(None, -1, slice(2), [0, 1])]  # variance, kurtosis params
    params += [(None, 'MEG 1531')]  # ECG / EOG channel params
    for idx, ch_name in product(*params):
        ica.detect_artifacts(raw,
                             start_find=0,
                             stop_find=50,
                             ecg_ch=ch_name,
                             eog_ch=ch_name,
                             skew_criterion=idx,
                             var_criterion=idx,
                             kurt_criterion=idx)

    evoked = epochs.average()
    evoked_data = evoked.data.copy()
    raw_data = raw[:][0].copy()
    epochs_data = epochs.get_data().copy()

    with pytest.warns(RuntimeWarning, match='longer'):
        idx, scores = ica.find_bads_ecg(raw, method='ctps')
    assert_equal(len(scores), ica.n_components_)
    with pytest.warns(RuntimeWarning, match='longer'):
        idx, scores = ica.find_bads_ecg(raw, method='correlation')
    assert_equal(len(scores), ica.n_components_)

    with pytest.warns(RuntimeWarning, match='longer'):
        idx, scores = ica.find_bads_eog(raw)
    assert_equal(len(scores), ica.n_components_)

    idx, scores = ica.find_bads_ecg(epochs, method='ctps')

    assert_equal(len(scores), ica.n_components_)
    pytest.raises(ValueError,
                  ica.find_bads_ecg,
                  epochs.average(),
                  method='ctps')
    pytest.raises(ValueError, ica.find_bads_ecg, raw, method='crazy-coupling')

    with pytest.warns(RuntimeWarning, match='longer'):
        idx, scores = ica.find_bads_eog(raw)
    assert_equal(len(scores), ica.n_components_)

    raw.info['chs'][raw.ch_names.index('EOG 061') - 1]['kind'] = 202
    with pytest.warns(RuntimeWarning, match='longer'):
        idx, scores = ica.find_bads_eog(raw)
    assert (isinstance(scores, list))
    assert_equal(len(scores[0]), ica.n_components_)

    idx, scores = ica.find_bads_eog(evoked, ch_name='MEG 1441')
    assert_equal(len(scores), ica.n_components_)

    idx, scores = ica.find_bads_ecg(evoked, method='correlation')
    assert_equal(len(scores), ica.n_components_)

    assert_array_equal(raw_data, raw[:][0])
    assert_array_equal(epochs_data, epochs.get_data())
    assert_array_equal(evoked_data, evoked.data)

    # check score funcs
    for name, func in get_score_funcs().items():
        if name in score_funcs_unsuited:
            continue
        scores = ica.score_sources(epochs_eog,
                                   target='EOG 061',
                                   score_func=func)
        assert (ica.n_components_ == len(scores))

    # check univariate stats
    scores = ica.score_sources(epochs, score_func=stats.skew)

    # check exception handling
    pytest.raises(ValueError, ica.score_sources, epochs, target=np.arange(1))

    # ecg functionality
    ecg_scores = ica.score_sources(raw,
                                   target='MEG 1531',
                                   score_func='pearsonr')

    with pytest.warns(RuntimeWarning, match='longer'):
        ecg_events = ica_find_ecg_events(raw,
                                         sources[np.abs(ecg_scores).argmax()])
    assert (ecg_events.ndim == 2)

    # eog functionality
    eog_scores = ica.score_sources(raw,
                                   target='EOG 061',
                                   score_func='pearsonr')
    with pytest.warns(RuntimeWarning, match='longer'):
        eog_events = ica_find_eog_events(raw,
                                         sources[np.abs(eog_scores).argmax()])
    assert (eog_events.ndim == 2)

    # Test ica fiff export
    ica_raw = ica.get_sources(raw, start=0, stop=100)
    assert (ica_raw.last_samp - ica_raw.first_samp == 100)
    assert_equal(len(ica_raw._filenames), 1)  # API consistency
    ica_chans = [ch for ch in ica_raw.ch_names if 'ICA' in ch]
    assert (ica.n_components_ == len(ica_chans))
    test_ica_fname = op.join(op.abspath(op.curdir), 'test-ica_raw.fif')
    ica.n_components = np.int32(ica.n_components)
    ica_raw.save(test_ica_fname, overwrite=True)
    ica_raw2 = read_raw_fif(test_ica_fname, preload=True)
    assert_allclose(ica_raw._data, ica_raw2._data, rtol=1e-5, atol=1e-4)
    ica_raw2.close()
    os.remove(test_ica_fname)

    # Test ica epochs export
    ica_epochs = ica.get_sources(epochs)
    assert (ica_epochs.events.shape == epochs.events.shape)
    ica_chans = [ch for ch in ica_epochs.ch_names if 'ICA' in ch]
    assert (ica.n_components_ == len(ica_chans))
    assert (ica.n_components_ == ica_epochs.get_data().shape[1])
    assert (ica_epochs._raw is None)
    assert (ica_epochs.preload is True)

    # test float n pca components
    ica.pca_explained_variance_ = np.array([0.2] * 5)
    ica.n_components_ = 0
    for ncomps, expected in [[0.3, 1], [0.9, 4], [1, 1]]:
        ncomps_ = ica._check_n_pca_components(ncomps)
        assert (ncomps_ == expected)

    ica = ICA(method=method)
    with pytest.warns(None):  # sometimes does not converge
        ica.fit(raw, picks=picks[:5])
    with pytest.warns(RuntimeWarning, match='longer'):
        ica.find_bads_ecg(raw)
    ica.find_bads_eog(epochs, ch_name='MEG 0121')
    assert_array_equal(raw_data, raw[:][0])

    raw.drop_channels(['MEG 0122'])
    pytest.raises(RuntimeError, ica.find_bads_eog, raw)
    with pytest.warns(RuntimeWarning, match='longer'):
        pytest.raises(RuntimeError, ica.find_bads_ecg, raw)
示例#29
0
def test_plot_ica_properties():
    """Test plotting of ICA properties."""
    res = 8
    raw = _get_raw(preload=True)
    raw.add_proj([], remove_existing=True)
    events = _get_events()
    picks = _get_picks(raw)[:6]
    pick_names = [raw.ch_names[k] for k in picks]
    raw.pick_channels(pick_names)
    reject = dict(grad=4000e-13, mag=4e-12)

    epochs = Epochs(raw,
                    events[:10],
                    event_id,
                    tmin,
                    tmax,
                    baseline=(None, 0),
                    preload=True)

    ica = ICA(noise_cov=read_cov(cov_fname),
              n_components=2,
              max_iter=1,
              max_pca_components=2,
              n_pca_components=2,
              random_state=0)
    with pytest.warns(RuntimeWarning, match='projection'):
        ica.fit(raw)

    # test _create_properties_layout
    fig, ax = _create_properties_layout()
    assert_equal(len(ax), 5)

    topoargs = dict(topomap_args={'res': res, 'contours': 0, "sensors": False})
    ica.plot_properties(raw, picks=0, **topoargs)
    ica.plot_properties(epochs, picks=1, dB=False, plot_std=1.5, **topoargs)
    ica.plot_properties(epochs,
                        picks=1,
                        image_args={'sigma': 1.5},
                        topomap_args={
                            'res': 10,
                            'colorbar': True
                        },
                        psd_args={'fmax': 65.},
                        plot_std=False,
                        figsize=[4.5, 4.5],
                        reject=reject)
    plt.close('all')

    pytest.raises(TypeError, ica.plot_properties, epochs, dB=list('abc'))
    pytest.raises(TypeError, ica.plot_properties, ica)
    pytest.raises(TypeError, ica.plot_properties, [0.2])
    pytest.raises(TypeError, plot_ica_properties, epochs, epochs)
    pytest.raises(TypeError, ica.plot_properties, epochs, psd_args='not dict')
    pytest.raises(ValueError, ica.plot_properties, epochs, plot_std=[])

    fig, ax = plt.subplots(2, 3)
    ax = ax.ravel()[:-1]
    ica.plot_properties(epochs, picks=1, axes=ax, **topoargs)
    fig = ica.plot_properties(raw, picks=[0, 1], **topoargs)
    assert_equal(len(fig), 2)
    pytest.raises(TypeError,
                  plot_ica_properties,
                  epochs,
                  ica,
                  picks=[0, 1],
                  axes=ax)
    pytest.raises(ValueError, ica.plot_properties, epochs, axes='not axes')
    plt.close('all')

    # Test merging grads.
    pick_names = raw.ch_names[:15:2] + raw.ch_names[1:15:2]
    raw = _get_raw(preload=True).pick_channels(pick_names)
    raw.info.normalize_proj()
    ica = ICA(random_state=0, max_iter=1)
    with pytest.warns(UserWarning, match='did not converge'):
        ica.fit(raw)
    ica.plot_properties(raw)
    plt.close('all')

    # Test handling of zeros
    raw._data[:] = 0
    with pytest.warns(None):  # Usually UserWarning: Infinite value .* for epo
        ica.plot_properties(raw)
    ica = ICA(random_state=0, max_iter=1)
    epochs.pick_channels(pick_names)
    with pytest.warns(UserWarning, match='did not converge'):
        ica.fit(epochs)
    epochs._data[0] = 0
    with pytest.warns(None):  # Usually UserWarning: Infinite value .* for epo
        ica.plot_properties(epochs)
    plt.close('all')

    # Test Raw with annotations
    annot = Annotations(onset=[1], duration=[1], description=['BAD'])
    raw_annot = _get_raw(preload=True).set_annotations(annot)

    with pytest.warns(UserWarning, match='did not converge'):
        ica.fit(raw_annot)
    # drop bad data segments
    ica.plot_properties(raw_annot)
    # don't drop
    ica.plot_properties(raw_annot, reject_by_annotation=False)
    # fitting with bad data
    with pytest.warns(UserWarning, match='did not converge'):
        ica.fit(raw_annot, reject_by_annotation=False)
    # drop bad data when plotting
    ica.plot_properties(raw_annot)
    # don't drop bad data when plotting
    ica.plot_properties(raw_annot, reject_by_annotation=False)
    plt.close('all')
示例#30
0
def test_annotate_breaks(meas_date):
    """Test annotate_breaks."""
    raw = read_raw_fif(raw_fname, allow_maxshield='yes')
    if meas_date is None:
        raw.set_meas_date(None)

    annots = Annotations(onset=[12, 15, 16, 20, 21],
                         duration=[1, 1, 1, 2, 0.5],
                         description=['test'],
                         orig_time=raw.info['meas_date'])

    if raw.info['meas_date'] is None:
        annots.onset -= raw.first_time

    raw.set_annotations(annots)

    min_break_duration = 0.5
    t_start_after_previous = 0.1
    t_stop_before_next = 0.1

    expected_onsets = np.array(
        [
            raw.first_time,
            13 + t_start_after_previous,
            17 + t_start_after_previous,
            22 + t_start_after_previous
        ]
    )

    if raw.info['meas_date'] is None:
        expected_onsets -= raw.first_time

    expected_durations = np.array(
        [
            12 - raw.first_time - t_stop_before_next,
            15 - 13 - t_start_after_previous - t_stop_before_next,
            20 - 17 - t_start_after_previous - t_stop_before_next,
            raw._last_time - 22 - t_start_after_previous
        ]
    )

    break_annots = annotate_break(
        raw=raw,
        min_break_duration=min_break_duration,
        t_start_after_previous=t_start_after_previous,
        t_stop_before_next=t_stop_before_next
    )

    assert break_annots.orig_time == raw.info["meas_date"]
    assert_allclose(break_annots.onset, expected_onsets)
    assert_allclose(break_annots.duration, expected_durations)
    assert all(description == 'BAD_break'
               for description in break_annots.description)

    # try setting the annotations, this should not omit anything
    raw.set_annotations(break_annots)
    current_annotations = raw.annotations
    if raw.info['meas_date'] is None:
        current_annotations.onset -= raw.first_time
    raw.set_annotations(current_annotations + break_annots)

    # reset before next test
    raw.set_annotations(annots)

    # `ignore` parameter should be respected
    raw.annotations.description[0] = 'BAD_'
    break_annots = annotate_break(
        raw=raw,
        min_break_duration=min_break_duration,
        t_start_after_previous=t_start_after_previous,
        t_stop_before_next=t_stop_before_next
    )

    assert_allclose(break_annots.onset,
                    expected_onsets[[True, False, True, True]])
    assert_allclose(
        break_annots.duration,
        [15 - raw.first_time - t_stop_before_next] +
        list(expected_durations[2:])
    )

    # try setting the annotations, this should not omit anything
    raw.set_annotations(break_annots)
    current_annotations = raw.annotations
    if raw.info['meas_date'] is None:
        current_annotations.onset -= raw.first_time
    raw.set_annotations(current_annotations + break_annots)

    # Restore annotations for next test
    raw.set_annotations(annots)
    raw.annotations.description[0] = 'test'

    # Test with events
    events, _ = events_from_annotations(raw=raw)
    raw.set_annotations(None)

    expected_onsets = np.array(
        [
            raw.first_time,
            12 + t_start_after_previous,
            15 + t_start_after_previous,
            16 + t_start_after_previous,
            20 + t_start_after_previous,
            21 + t_start_after_previous
        ]
    )

    expected_durations = np.array(
        [
            12 - raw.first_time - t_stop_before_next,
            15 - 12 - t_start_after_previous - t_stop_before_next,
            16 - 15 - t_start_after_previous - t_stop_before_next,
            20 - 16 - t_start_after_previous - t_stop_before_next,
            21 - 20 - t_start_after_previous - t_stop_before_next,
            raw._last_time - 21 - t_start_after_previous
        ]
    )

    break_annots = annotate_break(
        raw=raw,
        events=events,
        min_break_duration=min_break_duration,
        t_start_after_previous=t_start_after_previous,
        t_stop_before_next=t_stop_before_next
    )

    if raw.info['meas_date'] is None:
        expected_onsets -= raw.first_time

    assert_allclose(break_annots.onset, expected_onsets)
    assert_allclose(break_annots.duration, expected_durations)

    # try setting the annotations, this should not omit anything
    raw.set_annotations(break_annots)
    current_annotations = raw.annotations
    if raw.info['meas_date'] is None:
        current_annotations.onset -= raw.first_time
    raw.set_annotations(current_annotations + break_annots)

    # reset before next test
    raw.set_annotations(annots)

    # Not finding any break periods
    break_annots = annotate_break(
        raw=raw,
        events=events,
        min_break_duration=1000,
    )

    assert len(break_annots) == 0

    # Implausible parameters (would produce break annot of duration < 0)
    with pytest.raises(ValueError, match='must be greater than 0'):
        annotate_break(
            raw=raw,
            min_break_duration=5,
            t_start_after_previous=5,
            t_stop_before_next=5
        )

    # Empty events array
    with pytest.raises(ValueError, match='events array must not be empty'):
        annotate_break(raw=raw, events=np.array([]))

    # Invalid `ignore` value
    with pytest.raises(TypeError, match='must be an instance of str'):
        annotate_break(raw=raw, ignore=('foo', 1))

    # No annotations to work with
    raw.set_annotations(None)
    with pytest.raises(ValueError, match='Could not find.*annotations'):
        annotate_break(raw=raw)