def test_annotate_amplitude_multiple_ch_types(meas_date, first_samp): """Test cases with several channel types.""" n_ch, n_times = 11, 1000 data = np.random.RandomState(0).randn(n_ch, n_times) assert not (np.diff(data, axis=-1) == 0).any() # nothing flat at first info = create_info(n_ch, 1000., ['eeg'] * 3 + ['mag'] * 2 + ['grad'] * 4 + ['eeg'] * 2) # from annotate_flat: test first_samp != for gh-6295 raw = RawArray(data, info, first_samp=first_samp) raw.info['bads'] = [raw.ch_names[-1]] raw.set_meas_date(meas_date) # -- 2 channel types both to annotate -- raw_ = raw.copy() raw_._data[1, 800:] = 0. raw_._data[5, :200] = np.arange(0, 200 * 10, 10) raw_._data[5, 200:] += raw_._data[5, 199] # add offset for next samples annots, bads = annotate_amplitude(raw_, peak=5, flat=0, bad_percent=50) assert len(annots) == 2 assert len(bads) == 0 # check annotation instance assert all(annot['description'] in ('BAD_flat', 'BAD_peak') for annot in annots) for annot in annots: start_idx = 0 if annot['description'] == 'BAD_peak' else 800 stop_idx = 199 if annot['description'] == 'BAD_peak' else -1 _check_annotation(raw_, annot, meas_date, first_samp, start_idx, stop_idx) # -- 2 channel types, one flat picked, one not picked -- raw_ = raw.copy() raw_._data[1, 800:] = 0. raw_._data[5, :200] = np.arange(0, 200 * 10, 10) raw_._data[5, 200:] += raw_._data[5, 199] # add offset for next samples annots, bads = annotate_amplitude(raw_, peak=5, flat=0, bad_percent=50, picks='eeg') assert len(annots) == 1 assert len(bads) == 0 # check annotation instance _check_annotation(raw_, annots[0], meas_date, first_samp, 800, -1) assert annots[0]['description'] == 'BAD_flat' # -- 2 channel types, one flat, one not picked, reverse -- raw_ = raw.copy() raw_._data[1, 800:] = 0. raw_._data[5, :200] = np.arange(0, 200 * 10, 10) raw_._data[5, 200:] += raw_._data[5, 199] # add offset for next samples annots, bads = annotate_amplitude(raw_, peak=5, flat=0, bad_percent=50, picks='grad') assert len(annots) == 1 assert len(bads) == 0 # check annotation instance _check_annotation(raw_, annots[0], meas_date, first_samp, 0, 199) assert annots[0]['description'] == 'BAD_peak'
def test_flat_bad_acq_skip(): """Test that acquisition skips are handled properly.""" # -- file with a couple of skip and flat channels -- raw = read_raw_fif(skip_fname, preload=True) annots, bads = annotate_amplitude(raw, flat=0) assert len(annots) == 0 assert bads == [ # MaxFilter finds the same 21 channels 'MEG%04d' % (int(num),) for num in '141 331 421 431 611 641 1011 1021 1031 1241 1421 ' '1741 1841 2011 2131 2141 2241 2531 2541 2611 2621'.split()] # -- overlap of flat segment with bad_acq_skip -- n_ch, n_times = 11, 1000 data = np.random.RandomState(0).randn(n_ch, n_times) assert not (np.diff(data, axis=-1) == 0).any() # nothing flat at first info = create_info(n_ch, 1000., 'eeg') raw = RawArray(data, info, first_samp=0) raw.info['bads'] = [raw.ch_names[-1]] bad_acq_skip = Annotations([0.5], [0.2], ['bad_acq_skip'], orig_time=None) raw.set_annotations(bad_acq_skip) # add flat channel overlapping with the left edge of bad_acq_skip raw_ = raw.copy() raw_._data[0, 400:600] = 0. annots, bads = annotate_amplitude(raw_, peak=None, flat=0, bad_percent=25) assert len(annots) == 1 assert len(bads) == 0 # check annotation instance assert annots[0]['description'] == 'BAD_flat' _check_annotation(raw_, annots[0], None, 0, 400, 499) # add flat channel overlapping with the right edge of bad_acq_skip raw_ = raw.copy() raw_._data[0, 600:800] = 0. annots, bads = annotate_amplitude(raw_, peak=None, flat=0, bad_percent=25) assert len(annots) == 1 assert len(bads) == 0 # check annotation instance assert annots[0]['description'] == 'BAD_flat' _check_annotation(raw_, annots[0], None, 0, 700, 799) # add flat channel overlapping entirely with bad_acq_skip raw_ = raw.copy() raw_._data[0, 200:800] = 0. annots, bads = annotate_amplitude(raw_, peak=None, flat=0, bad_percent=41) assert len(annots) == 2 assert len(bads) == 0 # check annotation instance annots = sorted(annots, key=lambda x: x['onset']) assert all(annot['description'] == 'BAD_flat' for annot in annots) _check_annotation(raw_, annots[0], None, 0, 200, 500) _check_annotation(raw_, annots[1], None, 0, 700, 799)
def test_annotate_amplitude_with_overlap(meas_date, first_samp): """Test cases with overlap between annotations.""" n_ch, n_times = 11, 1000 data = np.random.RandomState(0).randn(n_ch, n_times) assert not (np.diff(data, axis=-1) == 0).any() # nothing flat at first info = create_info(n_ch, 1000., 'eeg') # from annotate_flat: test first_samp != for gh-6295 raw = RawArray(data, info, first_samp=first_samp) raw.info['bads'] = [raw.ch_names[-1]] raw.set_meas_date(meas_date) # -- overlap between peak and flat -- raw_ = raw.copy() raw_._data[0, 800:] = 0. raw_._data[1, 700:900] = np.arange(0, 200 * 10, 10) raw_._data[1, 900:] += raw_._data[1, 899] # add offset for next samples annots, bads = annotate_amplitude(raw_, peak=5, flat=0, bad_percent=25) assert len(annots) == 2 assert len(bads) == 0 # check annotation instance assert all(annot['description'] in ('BAD_flat', 'BAD_peak') for annot in annots) for annot in annots: start_idx = 700 if annot['description'] == 'BAD_peak' else 800 stop_idx = 899 if annot['description'] == 'BAD_peak' else -1 _check_annotation(raw_, annot, meas_date, first_samp, start_idx, stop_idx) # -- overlap between peak and peak on same channel -- raw_ = raw.copy() raw_._data[0, 700:900] = np.arange(0, 200 * 10, 10) raw_._data[0, 800:] = np.arange(1000, 300 * 10, 10) annots, bads = annotate_amplitude(raw_, peak=5, flat=None, bad_percent=50) assert len(annots) == 1 assert len(bads) == 0 # check annotation instance assert annots[0]['description'] == 'BAD_peak' _check_annotation(raw_, annots[0], meas_date, first_samp, 700, -1) # -- overlap between flat and flat on different channel -- raw_ = raw.copy() raw_._data[0, 700:900] = 0. raw_._data[1, 800:] = 0. annots, bads = annotate_amplitude(raw_, peak=None, flat=0.01, bad_percent=50) assert len(annots) == 1 assert len(bads) == 0 # check annotation instance assert annots[0]['description'] == 'BAD_flat' _check_annotation(raw_, annots[0], meas_date, first_samp, 700, -1)
def _inspect_raw(*, bids_path, l_freq, h_freq, find_flat, show_annotations): """Raw data inspection.""" # Delay the import import matplotlib import matplotlib.pyplot as plt extra_params = dict() if bids_path.extension == '.fif': extra_params['allow_maxshield'] = True raw = read_raw_bids(bids_path, extra_params=extra_params, verbose='error') old_bads = raw.info['bads'].copy() old_annotations = raw.annotations.copy() if find_flat: raw.load_data() # Speeds up processing dramatically if _annotate_flat_func.__name__ == 'annotate_amplitude': flat_annot, flat_chans = annotate_amplitude(raw=raw, flat=0, min_duration=0.05, bad_percent=5) else: # pragma: no cover flat_annot, flat_chans = annotate_flat(raw=raw, min_duration=0.05) new_annot = raw.annotations + flat_annot raw.set_annotations(new_annot) raw.info['bads'] = list(set(raw.info['bads'] + flat_chans)) del new_annot, flat_annot else: flat_chans = [] show_options = bids_path.datatype == 'meg' fig = raw.plot(title=f'{bids_path.root.name}: {bids_path.basename}', highpass=l_freq, lowpass=h_freq, show_options=show_options, block=False, show=False, verbose='warning') # Add our own event handlers so that when the MNE Raw Browser is being # closed, our dialog box will pop up, asking whether to save changes. def _handle_close(event): mne_raw_fig = event.canvas.figure # Bads alterations are only transferred to `inst` once the figure is # closed; Annotation changes are immediately reflected in `inst` new_bads = mne_raw_fig.mne.info['bads'].copy() new_annotations = mne_raw_fig.mne.inst.annotations.copy() if not new_annotations: # Ensure it's not an empty list, but an empty set of Annotations. new_annotations = mne.Annotations( onset=[], duration=[], description=[], orig_time=mne_raw_fig.mne.info['meas_date']) _save_raw_if_changed(old_bads=old_bads, new_bads=new_bads, flat_chans=flat_chans, old_annotations=old_annotations, new_annotations=new_annotations, bids_path=bids_path) _global_vars['raw_fig'] = None def _keypress_callback(event): if event.key == _global_vars['mne_close_key']: _handle_close(event) fig.canvas.mpl_connect('close_event', _handle_close) fig.canvas.mpl_connect('key_press_event', _keypress_callback) if not show_annotations: # Remove annotations and kill `_toggle_annotation_fig` method, since # we cannot directly and easily remove the associated `a` keyboard # event callback. fig._clear_annotations() fig._toggle_annotation_fig = lambda: None # Ensure it's not an empty list, but an empty set of Annotations. old_annotations = mne.Annotations(onset=[], duration=[], description=[], orig_time=raw.info['meas_date']) if matplotlib.get_backend() != 'agg': plt.show(block=True) _global_vars['raw_fig'] = fig _global_vars['mne_close_key'] = fig.mne.close_key
def test_invalid_arguments(): """Test error messages raised by invalid arguments.""" n_ch, n_times = 2, 100 data = np.random.RandomState(0).randn(n_ch, n_times) info = create_info(n_ch, 100., 'eeg') raw = RawArray(data, info, first_samp=0) # negative floats PTP with pytest.raises(ValueError, match="Argument 'flat' should define a positive " "threshold. Provided: '-1'."): annotate_amplitude(raw, peak=None, flat=-1) with pytest.raises(ValueError, match="Argument 'peak' should define a positive " "threshold. Provided: '-1'."): annotate_amplitude(raw, peak=-1, flat=None) # negative PTP threshold for one channel type with pytest.raises(ValueError, match="Argument 'flat' should define positive " "thresholds. Provided for channel type " "'eog': '-1'."): annotate_amplitude(raw, peak=None, flat=dict(eeg=1, eog=-1)) with pytest.raises(ValueError, match="Argument 'peak' should define positive " "thresholds. Provided for channel type " "'eog': '-1'."): annotate_amplitude(raw, peak=dict(eeg=1, eog=-1), flat=None) # test both PTP set to None with pytest.raises(ValueError, match="At least one of the arguments 'peak' or 'flat' " "must not be None."): annotate_amplitude(raw, peak=None, flat=None) # bad_percent outside [0, 100] with pytest.raises(ValueError, match="Argument 'bad_percent' should define a " "percentage between 0% and 100%. Provided: " "-1.0%."): annotate_amplitude(raw, peak=dict(eeg=1), flat=None, bad_percent=-1) # min_duration negative with pytest.raises(ValueError, match="Argument 'min_duration' should define a " "positive duration in seconds. Provided: " "'-1.0' seconds."): annotate_amplitude(raw, peak=dict(eeg=1), flat=None, min_duration=-1) # min_duration equal to the raw duration with pytest.raises( ValueError, match=re.escape("Argument 'min_duration' should define a " "positive duration in seconds shorter than the " "raw duration (1.0 seconds). Provided: " "'1.0' seconds.") ): annotate_amplitude(raw, peak=dict(eeg=1), flat=None, min_duration=1.) # min_duration longer than the raw duration with pytest.raises( ValueError, match=re.escape("Argument 'min_duration' should define a " "positive duration in seconds shorter than the " "raw duration (1.0 seconds). Provided: " "'10.0' seconds.") ): annotate_amplitude(raw, peak=dict(eeg=1), flat=None, min_duration=10)
def test_annotate_amplitude(meas_date, first_samp): """Test automatic annotation for segments based on peak-to-peak value.""" n_ch, n_times = 11, 1000 data = np.random.RandomState(0).randn(n_ch, n_times) assert not (np.diff(data, axis=-1) == 0).any() # nothing flat at first info = create_info(n_ch, 1000., 'eeg') # from annotate_flat: test first_samp != for gh-6295 raw = RawArray(data, info, first_samp=first_samp) raw.info['bads'] = [raw.ch_names[-1]] raw.set_meas_date(meas_date) # -- test bad channels spatial marking -- for perc, dur in itertools.product((5, 99.9, 100.), (0.005, 0.95, 0.99)): kwargs = dict(bad_percent=perc, min_duration=dur) # test entire channel flat raw_ = raw.copy() raw_._data[0] = 0. annots, bads = annotate_amplitude(raw_, peak=None, flat=0., **kwargs) assert len(annots) == 0 assert bads == ['0'] # test multiple channels flat raw_ = raw.copy() raw_._data[0] = 0. raw_._data[2] = 0. annots, bads = annotate_amplitude(raw_, peak=None, flat=0., **kwargs) assert len(annots) == 0 assert bads == ['0', '2'] # test entire channel drifting raw_ = raw.copy() raw_._data[0] = np.arange(0, raw.times.size * 10, 10) annots, bads = annotate_amplitude(raw_, peak=5, flat=None, **kwargs) assert len(annots) == 0 assert bads == ['0'] # test multiple channels drifting raw_ = raw.copy() raw_._data[0] = np.arange(0, raw.times.size * 10, 10) raw_._data[2] = np.arange(0, raw.times.size * 10, 10) annots, bads = annotate_amplitude(raw_, peak=5, flat=None, **kwargs) assert len(annots) == 0 assert bads == ['0', '2'] # -- test bad channels temporal marking -- # flat channel for the 20% last points n_good_times = int(round(0.8 * n_times)) raw_ = raw.copy() raw_._data[0, n_good_times:] = 0. for perc in (5, 20): annots, bads = annotate_amplitude(raw_, peak=None, flat=0., bad_percent=perc) assert len(annots) == 0 assert bads == ['0'] annots, bads = annotate_amplitude(raw_, peak=None, flat=0., bad_percent=20.1) assert len(annots) == 1 assert len(bads) == 0 # check annotation instance assert annots[0]['description'] == 'BAD_flat' _check_annotation(raw_, annots[0], meas_date, first_samp, n_good_times, -1) # test multiple channels flat and multiple channels drift raw_ = raw.copy() raw_._data[0, 800:] = 0. raw_._data[1, 850:950] = 0. raw_._data[2, :200] = np.arange(0, 200 * 10, 10) raw_._data[2, 200:] += raw_._data[2, 199] # add offset for next samples raw_._data[3, 50:150] = np.arange(0, 100 * 10, 10) raw_._data[3, 150:] += raw_._data[3, 149] # add offset for next samples for perc in (5, 10): annots, bads = annotate_amplitude(raw_, peak=5, flat=0., bad_percent=perc) assert len(annots) == 0 assert bads == ['0', '1', '2', '3'] for perc in (10.1, 20): annots, bads = annotate_amplitude(raw_, peak=5, flat=0., bad_percent=perc) assert len(annots) == 2 assert bads == ['0', '2'] # check annotation instance assert all(annot['description'] in ('BAD_flat', 'BAD_peak') for annot in annots) for annot in annots: start_idx = 50 if annot['description'] == 'BAD_peak' else 850 stop_idx = 149 if annot['description'] == 'BAD_peak' else 949 _check_annotation(raw_, annot, meas_date, first_samp, start_idx, stop_idx) annots, bads = annotate_amplitude(raw_, peak=5, flat=0., bad_percent=20.1) assert len(annots) == 2 assert len(bads) == 0 # check annotation instance assert all(annot['description'] in ('BAD_flat', 'BAD_peak') for annot in annots) for annot in annots: start_idx = 0 if annot['description'] == 'BAD_peak' else 800 stop_idx = 199 if annot['description'] == 'BAD_peak' else -1 _check_annotation(raw_, annot, meas_date, first_samp, start_idx, stop_idx) # test flat on already marked bad channel raw_ = raw.copy() raw_._data[-1, :] = 0. # this channel is already in info['bads'] annots, bads = annotate_amplitude(raw_, peak=None, flat=0., bad_percent=5) assert len(annots) == 0 assert len(bads) == 0 # test drift on already marked bad channel raw_ = raw.copy() raw_._data[-1, :] = np.arange(0, raw.times.size * 10, 10) annots, bads = annotate_amplitude(raw_, peak=5, flat=None, bad_percent=5) assert len(annots) == 0 assert len(bads) == 0