Exemplo n.º 1
0
def test_check(tmp_path):
    """Test checking functions."""
    pytest.raises(ValueError, check_random_state, 'foo')
    pytest.raises(TypeError, _check_fname, 1)
    _check_fname(Path('./foo'))
    fname = tmp_path / 'foo'
    with open(fname, 'wb'):
        pass
    assert op.isfile(fname)
    _check_fname(fname, overwrite='read', must_exist=True)
    orig_perms = os.stat(fname).st_mode
    os.chmod(fname, 0)
    if not sys.platform.startswith('win'):
        with pytest.raises(PermissionError, match='read permissions'):
            _check_fname(fname, overwrite='read', must_exist=True)
    os.chmod(fname, orig_perms)
    os.remove(fname)
    assert not op.isfile(fname)
    pytest.raises(IOError, check_fname, 'foo', 'tets-dip.x', (), ('.fif', ))
    pytest.raises(ValueError, _check_subject, None, None)
    pytest.raises(TypeError, _check_subject, None, 1)
    pytest.raises(TypeError, _check_subject, 1, None)
    # smoke tests for permitted types
    check_random_state(None).choice(1)
    check_random_state(0).choice(1)
    check_random_state(np.random.RandomState(0)).choice(1)
    if check_version('numpy', '1.17'):
        check_random_state(np.random.default_rng(0)).choice(1)
Exemplo n.º 2
0
def test_check(tmpdir):
    """Test checking functions."""
    pytest.raises(ValueError, check_random_state, 'foo')
    pytest.raises(TypeError, _check_fname, 1)
    _check_fname(Path('./'))
    fname = str(tmpdir.join('foo'))
    with open(fname, 'wb'):
        pass
    assert op.isfile(fname)
    _check_fname(fname, overwrite='read', must_exist=True)
    orig_perms = os.stat(fname).st_mode
    os.chmod(fname, 0)
    if not sys.platform.startswith('win'):
        with pytest.raises(PermissionError, match='read permissions'):
            _check_fname(fname, overwrite='read', must_exist=True)
    os.chmod(fname, orig_perms)
    os.remove(fname)
    assert not op.isfile(fname)
    pytest.raises(IOError, check_fname, 'foo', 'tets-dip.x', (), ('.fif', ))
    pytest.raises(ValueError, _check_subject, None, None)
    pytest.raises(TypeError, _check_subject, None, 1)
    pytest.raises(TypeError, _check_subject, 1, None)
    # smoke tests for permitted types
    check_random_state(None).choice(1)
    check_random_state(0).choice(1)
    check_random_state(np.random.RandomState(0)).choice(1)
    if check_version('numpy', '1.17'):
        check_random_state(np.random.default_rng(0)).choice(1)

    # _meg.fif is a valid ending and should not raise an error
    new_fname = str(
        tmpdir.join(op.basename(fname_raw).replace('_raw.', '_meg.')))
    shutil.copyfile(fname_raw, new_fname)
    mne.io.read_raw_fif(new_fname)
Exemplo n.º 3
0
 def __init__(
     self,
     raw,
     params,
     ransac=True,
     channel_wise=False,
     max_chunk_size=None,
     random_state=None,
     matlab_strict=False,
 ):
     """Initialize the class."""
     raw.load_data()
     self.raw = raw.copy()
     self.ch_names = self.raw.ch_names
     self.raw.pick_types(eeg=True, eog=False, meg=False)
     self.ch_names_eeg = self.raw.ch_names
     self.EEG = self.raw.get_data()
     self.reference_channels = params["ref_chs"]
     self.rereferenced_channels = params["reref_chs"]
     self.sfreq = self.raw.info["sfreq"]
     self.ransac_settings = {
         "ransac": ransac,
         "channel_wise": channel_wise,
         "max_chunk_size": max_chunk_size,
     }
     self.random_state = check_random_state(random_state)
     self._extra_info = {}
     self.matlab_strict = matlab_strict
Exemplo n.º 4
0
    def get_ransac_pred(self,
                        chn_pos,
                        chn_pos_good,
                        good_chn_labs,
                        n_pred_chns,
                        data,
                        random_state=None):
        """Perform RANSAC prediction.

        Parameters
        ----------
        chn_pos : ndarray
            3-D coordinates of the electrode position
        chn_pos_good : ndarray
            3-D coordinates of all the channels not detected noisy so far
        good_chn_labs : array_like
            channel labels for the ch_pos_good channels
        n_pred_chns : int
            channel numbers used for interpolation for RANSAC
        data : ndarray
            2-D EEG data
        random_state : int | None | np.random.mtrand.RandomState
            If random_state is an int, it will be used as a seed for RandomState.
            If None, the seed will be obtained from the operating system
            (see RandomState for details). Default is None.

        Returns
        -------
        ransac_pred : ndarray
            Single RANSAC prediction

        Title: noisy
        Author: Stefan Appelhoff
        Date: 2018
        Availability: https://github.com/sappelhoff/pyprep/blob/master/pyprep/noisy.py
        """
        rng = check_random_state(random_state)

        # Pick a subset of clean channels for reconstruction
        reconstr_idx = rng.choice(np.arange(chn_pos_good.shape[0]),
                                  size=n_pred_chns,
                                  replace=False)

        # Get positions and according labels
        reconstr_labels = good_chn_labs[reconstr_idx]
        reconstr_pos = chn_pos_good[reconstr_idx, :]

        # Map the labels to their indices within the complete data
        # Do not use mne.pick_channels, because it will return a sorted list.
        reconstr_picks = [
            list(self.ch_names_new).index(chn_lab)
            for chn_lab in reconstr_labels
        ]

        # Interpolate
        interpol_mat = _make_interpolation_matrix(reconstr_pos, chn_pos)
        ransac_pred = np.matmul(interpol_mat, data[reconstr_picks, :])
        return ransac_pred
Exemplo n.º 5
0
    def __init__(self,
                 raw,
                 do_detrend=True,
                 random_state=None,
                 matlab_strict=False):
        # Make sure that we got an MNE object
        assert isinstance(raw, mne.io.BaseRaw)

        raw.load_data()
        self.raw_mne = raw.copy()
        self.raw_mne.pick_types(eeg=True)
        self.sample_rate = raw.info["sfreq"]
        if do_detrend:
            self.raw_mne._data = removeTrend(self.raw_mne.get_data(),
                                             self.sample_rate,
                                             matlab_strict=matlab_strict)
        self.matlab_strict = matlab_strict

        # Extra data for debugging
        self._extra_info = {
            "bad_by_deviation": {},
            "bad_by_hf_noise": {},
            "bad_by_correlation": {},
            "bad_by_dropout": {},
            "bad_by_ransac": {},
        }

        # random_state
        self.random_state = check_random_state(random_state)

        # The identified bad channels
        self.bad_by_nan = []
        self.bad_by_flat = []
        self.bad_by_deviation = []
        self.bad_by_hf_noise = []
        self.bad_by_correlation = []
        self.bad_by_SNR = []
        self.bad_by_dropout = []
        self.bad_by_ransac = []

        # Get original EEG channel names, channel count & samples
        ch_names = np.asarray(self.raw_mne.info["ch_names"])
        self.ch_names_original = ch_names
        self.n_chans_original = len(ch_names)
        self.n_samples = raw._data.shape[1]

        # Before anything else, flag bad-by-NaNs and bad-by-flats
        self.find_bad_by_nan_flat()
        bads_by_nan_flat = self.bad_by_nan + self.bad_by_flat

        # Make a subset of the data containing only usable EEG channels
        self.usable_idx = np.isin(ch_names, bads_by_nan_flat, invert=True)
        self.EEGData = self.raw_mne.get_data(picks=ch_names[self.usable_idx])
        self.EEGFiltered = None

        # Get usable EEG channel names & channel counts
        self.ch_names_new = np.asarray(ch_names[self.usable_idx])
        self.n_chans_new = len(self.ch_names_new)
Exemplo n.º 6
0
def add_atom(data, atom, low, high, random_state=None):
    rng = check_random_state(random_state)

    support = atom.shape[0]
    n_samples = data.shape[0]

    starts = rng.random_integers(low=low, high=high,
                                 size=(n_samples))
    for i in range(n_samples):
        start = starts[i]
        data[i, start: start + support] = atom
Exemplo n.º 7
0
    def __init__(
        self,
        raw,
        prep_params,
        montage,
        ransac=True,
        channel_wise=False,
        max_chunk_size=None,
        random_state=None,
        filter_kwargs=None,
        matlab_strict=False,
    ):
        """Initialize PREP class."""
        raw.load_data()
        self.raw_eeg = raw.copy()

        # split eeg and non eeg channels
        self.ch_names_all = raw.ch_names.copy()
        self.ch_types_all = raw.get_channel_types()
        self.ch_names_eeg = [
            self.ch_names_all[i] for i in range(len(self.ch_names_all))
            if self.ch_types_all[i] == "eeg"
        ]
        self.ch_names_non_eeg = list(
            set(self.ch_names_all) - set(self.ch_names_eeg))
        self.raw_eeg.pick_channels(self.ch_names_eeg)
        if self.ch_names_non_eeg == []:
            self.raw_non_eeg = None
        else:
            self.raw_non_eeg = raw.copy()
            self.raw_non_eeg.pick_channels(self.ch_names_non_eeg)

        self.raw_eeg.set_montage(montage)
        # raw_non_eeg may not be compatible with the montage
        # so it is not set for that object

        self.EEG_raw = self.raw_eeg.get_data()
        self.prep_params = prep_params
        if self.prep_params["ref_chs"] == "eeg":
            self.prep_params["ref_chs"] = self.ch_names_eeg
        if self.prep_params["reref_chs"] == "eeg":
            self.prep_params["reref_chs"] = self.ch_names_eeg
        if "max_iterations" not in prep_params.keys():
            self.prep_params["max_iterations"] = 4
        self.sfreq = self.raw_eeg.info["sfreq"]
        self.ransac_settings = {
            "ransac": ransac,
            "channel_wise": channel_wise,
            "max_chunk_size": max_chunk_size,
        }
        self.random_state = check_random_state(random_state)
        self.filter_kwargs = filter_kwargs
        self.matlab_strict = matlab_strict
Exemplo n.º 8
0
 def __init__(self, raw, params, ransac=True, random_state=None):
     """Initialize the class."""
     self.raw = raw.copy()
     self.ch_names = self.raw.ch_names
     self.raw.pick_types(eeg=True, eog=False, meg=False)
     self.ch_names_eeg = self.raw.ch_names
     self.EEG = self.raw.get_data() * 1e6
     self.reference_channels = params["ref_chs"]
     self.rereferenced_channels = params["reref_chs"]
     self.sfreq = self.raw.info["sfreq"]
     self.ransac = ransac
     self.random_state = check_random_state(random_state)
Exemplo n.º 9
0
    def __init__(self, raw, do_detrend=True, random_state=None):
        """Initialize the class.

        Parameters
        ----------
        raw : mne.io.Raw
            The MNE raw object.
        do_detrend : bool
            Whether or not to remove a trend from the data upon initializing the
            `NoisyChannels` object. Defaults to True.
        random_state : int | None | np.random.RandomState
            The random seed at which to initialize the class. If random_state
            is an int, it will be used as a seed for RandomState.
            If None, the seed will be obtained from the operating system
            (see RandomState for details). Default is None.

        """
        # Make sure that we got an MNE object
        assert isinstance(raw, mne.io.BaseRaw)

        self.raw_mne = raw.copy()
        self.sample_rate = raw.info["sfreq"]
        if do_detrend:
            self.raw_mne._data = removeTrend(
                self.raw_mne.get_data(), sample_rate=self.sample_rate
            )

        self.EEGData = self.raw_mne.get_data(picks="eeg")
        self.EEGData_beforeFilt = self.EEGData
        self.ch_names_original = np.asarray(raw.info["ch_names"])
        self.n_chans_original = len(self.ch_names_original)
        self.n_chans_new = self.n_chans_original
        self.signal_len = len(self.raw_mne.times)
        self.original_dimensions = np.shape(self.EEGData)
        self.new_dimensions = self.original_dimensions
        self.original_channels = np.arange(self.original_dimensions[0])
        self.new_channels = self.original_channels
        self.ch_names_new = self.ch_names_original
        self.channels_interpolate = self.original_channels

        # random_state
        self.random_state = check_random_state(random_state)

        # The identified bad channels
        self.bad_by_nan = []
        self.bad_by_flat = []
        self.bad_by_deviation = []
        self.bad_by_hf_noise = []
        self.bad_by_correlation = []
        self.bad_by_SNR = []
        self.bad_by_dropout = []
        self.bad_by_ransac = []
Exemplo n.º 10
0
def runautoreject(epochs,
                  fiffile,
                  senstype,
                  bads=[],
                  n_interpolates=np.array([1, 4, 32]),
                  consensus_percs=np.linspace(0, 1, 11)):

    check_random_state(42)

    raw = mne.io.read_raw_fif(fiffile, preload=True)
    raw.info['bads'] = list()
    raw.pick_types(meg=True)
    raw.info['projs'] = list()
    epochs.info = raw.info  #required since no channel infos

    del raw

    picks = mne.pick_types(epochs.info,
                           meg=senstype,
                           eeg=False,
                           stim=False,
                           eog=False,
                           include=[],
                           exclude=bads)

    epochs.verbose = False
    epochs.baseline = (None, 0)
    epochs.preload = True
    epochs.detrend = 0

    ar = AutoReject(n_interpolates,
                    consensus_percs,
                    picks=picks,
                    thresh_method='bayesian_optimization',
                    random_state=42,
                    verbose=False)

    epochs, reject_log = ar.fit_transform(epochs, return_log=True)
    return reject_log
Exemplo n.º 11
0
def test_check():
    """Test checking functions."""
    pytest.raises(ValueError, check_random_state, 'foo')
    pytest.raises(TypeError, _check_fname, 1)
    pytest.raises(IOError, check_fname, 'foo', 'tets-dip.x', (), ('.fif', ))
    pytest.raises(ValueError, _check_subject, None, None)
    pytest.raises(TypeError, _check_subject, None, 1)
    pytest.raises(TypeError, _check_subject, 1, None)
    # smoke tests for permitted types
    check_random_state(None).choice(1)
    check_random_state(0).choice(1)
    check_random_state(np.random.RandomState(0)).choice(1)
    if check_version('numpy', '1.17'):
        check_random_state(np.random.default_rng(0)).choice(1)

    # _meg.fif is a valid ending and should not raise an error
    sh.copyfile(fname_raw, fname_raw.replace('_raw.', '_meg.'))
    mne.io.read_raw_fif(fname_raw.replace('_raw.', '_meg.'))
Exemplo n.º 12
0
    def get_ransac_pred(self, chn_pos, chn_pos_good, good_chn_labs,
                        n_pred_chns, data):
        """Perform RANSAC prediction.

        Parameters
        ----------
        chn_pos : ndarray
            3-D coordinates of the electrode position
        chn_pos_good : ndarray
            3-D coordinates of all the channels not detected noisy so far
        good_chn_labs : array_like
            channel labels for the ch_pos_good channels
        n_pred_chns : int
            channel numbers used for interpolation for RANSAC
        data : ndarray
            2-D EEG data

        Returns
        -------
        ransac_pred : ndarray
            Single RANSAC prediction

        """
        rng = check_random_state(self.random_state)

        # Pick a subset of clean channels for reconstruction
        reconstr_idx = rng.choice(np.arange(chn_pos_good.shape[0]),
                                  size=n_pred_chns,
                                  replace=False)

        # Get positions and according labels
        reconstr_labels = good_chn_labs[reconstr_idx]
        reconstr_pos = chn_pos_good[reconstr_idx, :]

        # Map the labels to their indices within the complete data
        # Do not use mne.pick_channels, because it will return a sorted list.
        reconstr_picks = [
            list(self.ch_names_new).index(chn_lab)
            for chn_lab in reconstr_labels
        ]

        # Interpolate
        interpol_mat = _make_interpolation_matrix(reconstr_pos, chn_pos)
        ransac_pred = np.matmul(interpol_mat, data[reconstr_picks, :])
        return ransac_pred
Exemplo n.º 13
0
    def fit(self, epochs):
        self.picks = _handle_picks(info=epochs.info, picks=self.picks)
        _check_data(epochs,
                    picks=self.picks,
                    ch_constraint='single_channel_type',
                    verbose=self.verbose)
        self.ch_type = _get_channel_type(epochs, self.picks)
        n_epochs = len(epochs)

        n_jobs = check_n_jobs(self.n_jobs)
        parallel = Parallel(n_jobs, verbose=10)
        my_iterator = delayed(_iterate_epochs)
        if self.verbose is not False and self.n_jobs > 1:
            print('Iterating epochs ...')
        verbose = False if self.n_jobs > 1 else self.verbose
        rng = check_random_state(self.random_state)
        base_random_state = rng.randint(np.iinfo(np.int16).max)
        self.ch_subsets_ = [
            self._get_random_subsets(epochs.info,
                                     base_random_state + random_state)
            for random_state in np.arange(0, n_epochs, n_jobs)
        ]
        epoch_idxs = np.array_split(np.arange(n_epochs), n_jobs)
        corrs = parallel(
            my_iterator(self, epochs, idxs, chs, verbose)
            for idxs, chs in zip(epoch_idxs, self.ch_subsets_))
        self.corr_ = np.concatenate(corrs)
        if self.verbose is not False and self.n_jobs > 1:
            print('[Done]')

        # compute how many windows is a sensor RANSAC-bad
        self.bad_log = np.zeros_like(self.corr_)
        self.bad_log[self.corr_ < self.min_corr] = 1
        bad_log = self.bad_log.sum(axis=0)

        bad_idx = np.where(bad_log > self.unbroken_time * n_epochs)[0]
        if len(bad_idx) > 0:
            self.bad_chs_ = [
                epochs.info['ch_names'][self.picks[p]] for p in bad_idx
            ]
        else:
            self.bad_chs_ = []
        return self
Exemplo n.º 14
0
    def __init__(
        self,
        raw,
        prep_params,
        montage,
        ransac=True,
        random_state=None,
        filter_kwargs=None,
    ):
        """Initialize PREP class."""
        self.raw_eeg = raw.copy()

        # split eeg and non eeg channels
        self.ch_names_all = raw.ch_names.copy()
        self.ch_types_all = raw.get_channel_types()
        self.ch_names_eeg = [
            self.ch_names_all[i]
            for i in range(len(self.ch_names_all))
            if self.ch_types_all[i] == "eeg"
        ]
        self.ch_names_non_eeg = list(set(self.ch_names_all) - set(self.ch_names_eeg))
        self.raw_eeg.pick_channels(self.ch_names_eeg)
        if self.ch_names_non_eeg == []:
            self.raw_non_eeg = None
        else:
            self.raw_non_eeg = raw.copy()
            self.raw_non_eeg.pick_channels(self.ch_names_non_eeg)

        self.raw_eeg.set_montage(montage)
        # raw_non_eeg may not be compatible with the montage
        # so it is not set for that object

        self.EEG_raw = self.raw_eeg.get_data() * 1e6
        self.prep_params = prep_params
        if self.prep_params["ref_chs"] == "eeg":
            self.prep_params["ref_chs"] = self.ch_names_eeg
        if self.prep_params["reref_chs"] == "eeg":
            self.prep_params["reref_chs"] = self.ch_names_eeg
        self.sfreq = self.raw_eeg.info["sfreq"]
        self.ransac = ransac
        self.random_state = check_random_state(random_state)
        self.filter_kwargs = filter_kwargs
Exemplo n.º 15
0
def init_signal(parcels,
                raw_fname,
                fwd_fname,
                subject,
                n_sources_max=3,
                random_state=None,
                signal_type='eeg'):
    '''
    '''
    # randomly choose how many parcels will be activated between 1 and
    # n_sources_max and which index at the parcel
    rng = check_random_state(random_state)

    n_parcels = rng.randint(n_sources_max, size=1)[0] + 1
    to_activate = []
    parcels_selected = []

    # do this so that the same label is not selected twice
    deck = list(rng.permutation(len(parcels)))
    # deck_rh = list(rng.permutation(len(parcels_rh)))
    for idx in range(n_parcels):
        parcel_selected = deck.pop()
        parcel_used = parcels[parcel_selected]
        l1_source = parcels[parcel_selected].copy()
        l1_source.vertices = [rng.choice(parcel_used.vertices)]

        to_activate.append(l1_source)
        parcels_selected.append(parcel_used)

    # activate selected parcels
    events, _, raw = generate_signal(raw_fname,
                                     fwd_fname,
                                     subject,
                                     parcels=to_activate,
                                     signal_type=signal_type,
                                     random_state=rng)

    evoked = mne.Epochs(raw, events, tmax=0.3).average()
    data = evoked.data[:, np.argmax((evoked.data**2).sum(axis=0))]

    names_parcels_selected = [parcel.name for parcel in parcels_selected]
    return data, names_parcels_selected, to_activate
Exemplo n.º 16
0
    def _get_random_subsets(self, info):
        """ Get random channels"""
        # have to set the seed here
        rng = check_random_state(self.random_state)
        n_channels = len(info['ch_names'])

        # number of channels to interpolate from
        n_samples = int(np.round(self.min_channels * n_channels))

        # get picks for resamples
        picks = []
        for idx in range(self.n_resample):
            pick = rng.permutation(n_channels)[:n_samples].copy()
            picks.append(pick)

        # get channel subsets as lists
        ch_subsets = []
        for pick in picks:
            ch_subsets.append([info['ch_names'][p] for p in pick])

        return ch_subsets
Exemplo n.º 17
0
def mock_data(n_samples=100, support=20, noise_level=0.1, random_state=None):
    rng = check_random_state(random_state)

    n_times = 300
    data = np.zeros((n_samples, n_times))

    print('Computing morlet')
    # morl = signal.morlet(support).real
    # tri = np.hstack((np.linspace(0, 1, support), np.zeros(support)))
    tri = np.hstack((np.linspace(0, 1, support),
                     np.linspace(0, 1, support)[::-1]))
    tri -= np.mean(tri)

    # normalize data
    tri = tri / np.linalg.norm(tri)
    print('[Done]')

    square_u = np.ones((support))
    # square_d = -2 * np.ones((support / 2))
    square_d = -np.ones((support))
    square = np.hstack((square_u, square_d))

    # normalize data
    square = np.hstack((square[::2], square[::2]))
    square = square / np.linalg.norm(square)

    low = 0
    high = n_times // 2 - tri.shape[0]
    add_atom(data, atom=tri, low=low, high=high, random_state=random_state)
    low = n_times // 2
    high = n_times - square.shape[0]
    add_atom(data, atom=square, low=low, high=high, random_state=random_state)

    # add some random noise
    data = data + noise_level * rng.rand(*data.shape)
    print('[Done]')

    sfreq = 1

    return data, sfreq
Exemplo n.º 18
0
    def _get_random_subsets(self, info):
        """ Get random channels"""
        # have to set the seed here
        rng = check_random_state(self.random_state)
        picked_info = mne.io.pick.pick_info(info, self.picks)
        n_channels = len(picked_info['ch_names'])

        # number of channels to interpolate from
        n_samples = int(np.round(self.min_channels * n_channels))

        # get picks for resamples
        picks = []
        for idx in range(self.n_resample):
            pick = rng.permutation(n_channels)[:n_samples].copy()
            picks.append(pick)

        # get channel subsets as lists
        ch_subsets = []
        for pick in picks:
            ch_subsets.append([picked_info['ch_names'][p] for p in pick])

        return ch_subsets
Exemplo n.º 19
0
from mne.utils import check_random_state  # noqa
from mne.datasets import sample  # noqa

###############################################################################
# Now, we can import the class required for rejecting and repairing bad
# epochs. :func:`autoreject.compute_thresholds` is a callable which must be
# provided to the :class:`autoreject.LocalAutoRejectCV` class for computing
# the channel-level thresholds.

from autoreject import (LocalAutoRejectCV, compute_thresholds,
                        set_matplotlib_defaults)  # noqa

###############################################################################
# Let us now read in the raw `fif` file for MNE sample dataset.

check_random_state(42)

data_path = sample.data_path()
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
raw = mne.io.read_raw_fif(raw_fname, preload=True)

###############################################################################
# We can then read in the events

event_fname = data_path + ('/MEG/sample/sample_audvis_filt-0-40_raw-'
                           'eve.fif')
event_id = {'Auditory/Left': 1, 'Auditory/Right': 2}
tmin, tmax = -0.2, 0.5

events = mne.read_events(event_fname)
Exemplo n.º 20
0
def select_source_in_label(src, label, random_state=None, location='random',
                           subject=None, subjects_dir=None, surf='sphere'):
    """Select source positions using a label

    Parameters
    ----------
    src : list of dict
        The source space
    label : Label
        the label (read with mne.read_label)
    random_state : None | int | np.random.RandomState
        To specify the random generator state.
    location : str
        The label location to choose. Can be 'random' (default) or 'center'
        to use :func:`mne.Label.center_of_mass` (restricting to vertices
        both in the label and in the source space). Note that for 'center'
        mode the label values are used as weights.

        .. versionadded:: 0.13

    subject : string | None
        The subject the label is defined for.
        Only used with ``location='center'``.

        .. versionadded:: 0.13

    subjects_dir : str, or None
        Path to the SUBJECTS_DIR. If None, the path is obtained by using
        the environment variable SUBJECTS_DIR.
        Only used with ``location='center'``.

        .. versionadded:: 0.13

    surf : str
        The surface to use for Euclidean distance center of mass
        finding. The default here is "sphere", which finds the center
        of mass on the spherical surface to help avoid potential issues
        with cortical folding.

        .. versionadded:: 0.13

    Returns
    -------
    lh_vertno : list
        selected source coefficients on the left hemisphere
    rh_vertno : list
        selected source coefficients on the right hemisphere
    """
    lh_vertno = list()
    rh_vertno = list()
    if not isinstance(location, string_types) or \
            location not in ('random', 'center'):
        raise ValueError('location must be "random" or "center", got %s'
                         % (location,))

    rng = check_random_state(random_state)
    if label.hemi == 'lh':
        vertno = lh_vertno
        hemi_idx = 0
    else:
        vertno = rh_vertno
        hemi_idx = 1
    src_sel = np.intersect1d(src[hemi_idx]['vertno'], label.vertices)
    if location == 'random':
        idx = src_sel[rng.randint(0, len(src_sel), 1)[0]]
    else:  # 'center'
        idx = label.center_of_mass(
            subject, restrict_vertices=src_sel, subjects_dir=subjects_dir,
            surf=surf)
    vertno.append(idx)
    return lh_vertno, rh_vertno
Exemplo n.º 21
0
def simulate_training_data(src, n_dipoles, times, data_fun=lambda t: 1e-7 * np.sin(20 * np.pi * t), labels=None,
                           random_state=None, location='random', subject=None, subjects_dir=None, surf='sphere'):

    """Generate sparse (n_dipoles) sources time courses from data_fun

    This function has a pre-determined training procedure that generates ''n_dipoles' vertices in the whole cortex
    or one single vertex (randomly in or in the center of) each label if ''labels i not None''. It uses ''data_fun''
    to generate waveforms for each vertex.

    Parameters
    ----------
    src : instance of SourceSpaces
        The source space.
    n_dipoles : int
        Number of dipoles to simulate.
    times : array
        Time array
    data_fun : callable
        Function to generate the waveforms. The default is a 100 nAm, 10 Hz
        sinusoid as ``1e-7 * np.sin(20 * pi * t)``. The function should take
        as input the array of time samples in seconds and return an array of
        the same length containing the time courses.
    labels : None | list of Labels
        The labels. The default is None, otherwise its size must be n_dipoles.
    random_state : None | int | np.random.RandomState
        To specify the random generator state.
    location : str
        The label location to choose. Can be 'random' (default) or 'center'
        to use :func:`mne.Label.center_of_mass`. Note that for 'center'
        mode the label values are used as weights.

        .. versionadded:: 0.13

    subject : string | None
        The subject the label is defined for.
        Only used with ``location='center'``.

        .. versionadded:: 0.13

    subjects_dir : str, or None
        Path to the SUBJECTS_DIR. If None, the path is obtained by using
        the environment variable SUBJECTS_DIR.
        Only used with ``location='center'``.

        .. versionadded:: 0.13

    surf : str
        The surface to use for Euclidean distance center of mass
        finding. The default here is "sphere", which finds the center
        of mass on the spherical surface to help avoid potential issues
        with cortical folding.

        .. versionadded:: 0.13

    Returns
    -------
    stc : SourceEstimate
        The generated source time courses.
    """
    rng = check_random_state(random_state)
    src = _ensure_src(src, verbose=False)
    subject_src = src[0].get('subject_his_id')
    if subject is None:
        subject = subject_src
    elif subject_src is not None and subject != subject_src:
        raise ValueError('subject argument (%s) did not match the source '
                         'space subject_his_id (%s)' % (subject, subject_src))
    data = np.zeros((n_dipoles, len(times)))
    for i_dip in range(n_dipoles):
        data[i_dip, :] = data_fun(times)

    if labels is None:
        # can be vol or surface source space
        offsets = np.linspace(0, n_dipoles, len(src) + 1).astype(int)
        n_dipoles_ss = np.diff(offsets)
        # don't use .choice b/c not on old numpy
        vs = [s['vertno'][np.sort(rng.permutation(np.arange(s['nuse']))[:n])]
              for n, s in zip(n_dipoles_ss, src)]
        datas = data
    else:
        if n_dipoles != len(labels):
            warn('The number of labels is different from the number of '
                 'dipoles. %s dipole(s) will be generated.'
                 % min(n_dipoles, len(labels)))
        labels = labels[:n_dipoles] if n_dipoles < len(labels) else labels

        vertno = [[], []]
        lh_data = [np.empty((0, data.shape[1]))]
        rh_data = [np.empty((0, data.shape[1]))]
        for i, label in enumerate(labels):
            lh_vertno, rh_vertno = select_source_in_label(
                src, label, rng, location, subject, subjects_dir, surf)
            vertno[0] += lh_vertno
            vertno[1] += rh_vertno
            if len(lh_vertno) != 0:
                lh_data.append(data[i][np.newaxis])
            elif len(rh_vertno) != 0:
                rh_data.append(data[i][np.newaxis])
            else:
                raise ValueError('No vertno found.')
        vs = [np.array(v) for v in vertno]
        datas = [np.concatenate(d) for d in [lh_data, rh_data]]
        # need to sort each hemi by vertex number
        for ii in range(2):
            order = np.argsort(vs[ii])
            vs[ii] = vs[ii][order]
            if len(order) > 0:  # fix for old numpy
                datas[ii] = datas[ii][order]
        datas = np.concatenate(datas)

    tmin, tstep = times[0], np.diff(times[:2])[0]
    assert datas.shape == data.shape
    cls = SourceEstimate if len(vs) == 2 else VolSourceEstimate
    stc = cls(datas, vertices=vs, tmin=tmin, tstep=tstep, subject=subject)
    return stc
Exemplo n.º 22
0
def simulate_movement(raw,
                      pos,
                      stc,
                      trans,
                      src,
                      bem,
                      cov='simple',
                      mindist=1.0,
                      interp='linear',
                      random_state=None,
                      n_jobs=1,
                      verbose=None):
    """Simulate raw data with head movements

    Parameters
    ----------
    raw : instance of Raw
        The raw instance to use. The measurement info, including the
        head positions, will be used to simulate data.
    pos : str | dict | None
        Name of the position estimates file. Should be in the format of
        the files produced by maxfilter-produced. If dict, keys should
        be the time points and entries should be 4x3 ``dev_head_t``
        matrices. If None, the original head position (from
        ``raw.info['dev_head_t']``) will be used.
    stc : instance of SourceEstimate
        The source estimate to use to simulate data. Must have the same
        sample rate as the raw data.
    trans : dict | str
        Either a transformation filename (usually made using mne_analyze)
        or an info dict (usually opened using read_trans()).
        If string, an ending of `.fif` or `.fif.gz` will be assumed to
        be in FIF format, any other ending will be assumed to be a text
        file with a 4x4 transformation matrix (like the `--trans` MNE-C
        option).
    src : str | instance of SourceSpaces
        If string, should be a source space filename. Can also be an
        instance of loaded or generated SourceSpaces.
    bem : str
        Filename of the BEM (e.g., "sample-5120-5120-5120-bem-sol.fif").
    cov : instance of Covariance | 'simple' | None
        The sensor covariance matrix used to generate noise. If None,
        no noise will be added. If 'simple', a basic (diagonal) ad-hoc
        noise covariance will be used.
    mindist : float
        Minimum distance between sources and the inner skull boundary
        to use during forward calculation.
    interp : str
        Either 'linear' or 'zero', the type of forward-solution
        interpolation to use between provided time points.
    random_state : None | int | np.random.RandomState
        To specify the random generator state.
    n_jobs : int
        Number of jobs to use.
    verbose : bool, str, int, or None
        If not None, override default verbose level (see mne.verbose).

    Returns
    -------
    raw : instance of Raw
        The simulated raw file.

    Notes
    -----
    Events coded with the number of the forward solution used will be placed
    in the raw files in the trigger channel STI101 at the t=0 times of the
    SourceEstimates.

    The resulting SNR will be determined by the structure of the noise
    covariance, and the amplitudes of the SourceEstimate. Note that this
    will vary as a function of position.
    """
    if isinstance(raw, string_types):
        with warnings.catch_warnings(record=True):
            raw = Raw(raw, allow_maxshield=True, preload=True, verbose=False)
    else:
        raw = raw.copy()

    if not isinstance(stc, _BaseSourceEstimate):
        raise TypeError('stc must be a SourceEstimate')
    if not np.allclose(raw.info['sfreq'], 1. / stc.tstep):
        raise ValueError('stc and raw must have same sample rate')
    rng = check_random_state(random_state)
    if interp not in ('linear', 'zero'):
        raise ValueError('interp must be "linear" or "zero"')

    if pos is None:  # use pos from file
        dev_head_ts = [raw.info['dev_head_t']] * 2
        offsets = np.array([0, raw.n_times])
        interp = 'zero'
    else:
        if isinstance(pos, string_types):
            pos = get_chpi_positions(pos, verbose=False)
        if isinstance(pos, tuple):  # can be an already-loaded pos file
            transs, rots, ts = pos
            ts -= raw.first_samp / raw.info['sfreq']  # MF files need reref
            dev_head_ts = [
                np.r_[np.c_[r, t[:, np.newaxis]], [[0, 0, 0, 1]]]
                for r, t in zip(rots, transs)
            ]
            del transs, rots
        elif isinstance(pos, dict):
            ts = np.array(list(pos.keys()), float)
            ts.sort()
            dev_head_ts = [pos[float(tt)] for tt in ts]
        else:
            raise TypeError('unknown pos type %s' % type(pos))
        if not (ts >= 0).all():  # pathological if not
            raise RuntimeError('Cannot have t < 0 in transform file')
        tend = raw.times[-1]
        assert not (ts < 0).any()
        assert not (ts > tend).any()
        if ts[0] > 0:
            ts = np.r_[[0.], ts]
            dev_head_ts.insert(0, raw.info['dev_head_t']['trans'])
        dev_head_ts = [{
            'trans': d,
            'to': raw.info['dev_head_t']['to'],
            'from': raw.info['dev_head_t']['from']
        } for d in dev_head_ts]
        if ts[-1] < tend:
            dev_head_ts.append(dev_head_ts[-1])
            ts = np.r_[ts, [tend]]
        offsets = raw.time_as_index(ts)
        offsets[-1] = raw.n_times  # fix for roundoff error
        assert offsets[-2] != offsets[-1]
        del ts
    if isinstance(cov, string_types):
        assert cov == 'simple'
        cov = make_ad_hoc_cov(raw.info, verbose=False)
    assert np.array_equal(offsets, np.unique(offsets))
    assert len(offsets) == len(dev_head_ts)
    approx_events = int(
        (raw.n_times / raw.info['sfreq']) / (stc.times[-1] - stc.times[0]))
    logger.info('Provided parameters will provide approximately %s event%s' %
                (approx_events, '' if approx_events == 1 else 's'))

    # get HPI freqs and reorder
    hpi_freqs = np.array(
        [x['custom_ref'][0] for x in raw.info['hpi_meas'][0]['hpi_coils']])
    n_freqs = len(hpi_freqs)
    order = [x['number'] - 1 for x in raw.info['hpi_meas'][0]['hpi_coils']]
    assert np.array_equal(np.unique(order), np.arange(n_freqs))
    hpi_freqs = hpi_freqs[order]
    hpi_order = raw.info['hpi_results'][0]['order'] - 1
    assert np.array_equal(np.unique(hpi_order), np.arange(n_freqs))
    hpi_freqs = hpi_freqs[hpi_order]

    # extract necessary info
    picks = pick_types(raw.info, meg=True, eeg=True)  # for simulation
    meg_picks = pick_types(raw.info, meg=True, eeg=False)  # for CHPI
    fwd_info = pick_info(raw.info, picks)
    fwd_info['projs'] = []
    logger.info('Setting up raw data simulation using %s head position%s' %
                (len(dev_head_ts), 's' if len(dev_head_ts) != 1 else ''))
    raw.preload_data(verbose=False)

    if isinstance(stc, VolSourceEstimate):
        verts = [stc.vertices]
    else:
        verts = stc.vertices
    src = _restrict_source_space_to(src, verts)

    # figure out our cHPI, ECG, and EOG dipoles
    dig = raw.info['dig']
    assert all([
        d['coord_frame'] == FIFF.FIFFV_COORD_HEAD for d in dig
        if d['kind'] == FIFF.FIFFV_POINT_HPI
    ])
    chpi_rrs = [d['r'] for d in dig if d['kind'] == FIFF.FIFFV_POINT_HPI]
    R, r0 = fit_sphere_to_headshape(raw.info, verbose=False)[:2]
    R /= 1000.
    r0 /= 1000.
    ecg_rr = np.array([[-R, 0, -3 * R]])
    eog_rr = [
        d['r'] for d in raw.info['dig']
        if d['ident'] == FIFF.FIFFV_POINT_NASION
    ][0]
    eog_rr = eog_rr - r0
    eog_rr = (eog_rr / np.sqrt(np.sum(eog_rr * eog_rr)) * 0.98 *
              R)[np.newaxis, :]
    eog_rr += r0
    eog_bem = make_sphere_model(r0,
                                head_radius=R,
                                relative_radii=(0.99, 1.),
                                sigmas=(0.33, 0.33),
                                verbose=False)
    # let's oscillate between resting (17 bpm) and reading (4.5 bpm) rate
    # http://www.ncbi.nlm.nih.gov/pubmed/9399231
    blink_rate = np.cos(2 * np.pi * 1. / 60. * raw.times)
    blink_rate *= 12.5 / 60.
    blink_rate += 4.5 / 60.
    blink_data = rng.rand(raw.n_times) < blink_rate / raw.info['sfreq']
    blink_data = blink_data * (rng.rand(raw.n_times) + 0.5)  # vary amplitudes
    blink_kernel = np.hanning(int(0.25 * raw.info['sfreq']))
    eog_data = np.convolve(blink_data, blink_kernel, 'same')[np.newaxis, :]
    eog_data += rng.randn(eog_data.shape[1]) * 0.05
    eog_data *= 100e-6
    del blink_data,

    max_beats = int(np.ceil(raw.times[-1] * 70. / 60.))
    cardiac_idx = np.cumsum(
        rng.uniform(60. / 70., 60. / 50., max_beats) *
        raw.info['sfreq']).astype(int)
    cardiac_idx = cardiac_idx[cardiac_idx < raw.n_times]
    cardiac_data = np.zeros(raw.n_times)
    cardiac_data[cardiac_idx] = 1
    cardiac_kernel = np.concatenate([
        2 * np.hanning(int(0.04 * raw.info['sfreq'])),
        -0.3 * np.hanning(int(0.05 * raw.info['sfreq'])),
        0.2 * np.hanning(int(0.26 * raw.info['sfreq']))
    ],
                                    axis=-1)
    ecg_data = np.convolve(cardiac_data, cardiac_kernel, 'same')[np.newaxis, :]
    ecg_data += rng.randn(ecg_data.shape[1]) * 0.05
    ecg_data *= 3e-4
    del cardiac_data

    # Add to data file, then rescale for simulation
    for data, scale, exg_ch in zip([eog_data, ecg_data], [1e-3, 5e-4],
                                   ['EOG062', 'ECG063']):
        ch = pick_channels(raw.ch_names, [exg_ch])
        if len(ch) == 1:
            raw._data[ch[0], :] = data
        data *= scale

    evoked = EvokedArray(np.zeros((len(picks), len(stc.times))),
                         fwd_info,
                         stc.tmin,
                         verbose=False)
    stc_event_idx = np.argmin(np.abs(stc.times))
    event_ch = pick_channels(raw.info['ch_names'], ['STI101'])[0]
    used = np.zeros(raw.n_times, bool)
    stc_indices = np.arange(raw.n_times) % len(stc.times)
    raw._data[event_ch, ].fill(0)
    hpi_mag = 25e-9
    last_fwd = last_fwd_chpi = last_fwd_eog = last_fwd_ecg = src_sel = None
    for fi, (fwd, fwd_eog, fwd_ecg, fwd_chpi) in \
        enumerate(_make_forward_solutions(
            fwd_info, trans, src, bem, eog_bem, dev_head_ts, mindist,
            chpi_rrs, eog_rr, ecg_rr, n_jobs)):
        # must be fixed orientation
        fwd = convert_forward_solution(fwd,
                                       surf_ori=True,
                                       force_fixed=True,
                                       verbose=False)
        # just use one arbitrary direction
        fwd_eog = fwd_eog['sol']['data'][:, ::3]
        fwd_ecg = fwd_ecg['sol']['data'][:, ::3]
        fwd_chpi = fwd_chpi[:, ::3]

        if src_sel is None:
            src_sel = _stc_src_sel(fwd['src'], stc)
            if isinstance(stc, VolSourceEstimate):
                verts = [stc.vertices]
            else:
                verts = stc.vertices
            diff_ = sum([len(v) for v in verts]) - len(src_sel)
            if diff_ != 0:
                warnings.warn(
                    '%s STC vertices omitted due to fwd calculation' %
                    (diff_, ))
        if last_fwd is None:
            last_fwd, last_fwd_eog, last_fwd_ecg, last_fwd_chpi = \
                fwd, fwd_eog, fwd_ecg, fwd_chpi
            continue
        n_time = offsets[fi] - offsets[fi - 1]

        time_slice = slice(offsets[fi - 1], offsets[fi])
        assert not used[time_slice].any()
        stc_idxs = stc_indices[time_slice]
        event_idxs = np.where(stc_idxs == stc_event_idx)[0] + offsets[fi - 1]
        used[time_slice] = True
        logger.info('  Simulating data for %0.3f-%0.3f sec with %s event%s' %
                    (tuple(offsets[fi - 1:fi + 1] / raw.info['sfreq']) +
                     (len(event_idxs), '' if len(event_idxs) == 1 else 's')))

        # simulate brain data
        stc_data = stc.data[:, stc_idxs][src_sel]
        data = _interp(last_fwd['sol']['data'], fwd['sol']['data'], stc_data,
                       interp)
        simulated = EvokedArray(data, evoked.info, 0)
        if cov is not None:
            noise = generate_noise_evoked(simulated, cov, [1, -1, 0.2], rng)
            simulated.data += noise.data
        assert simulated.data.shape[0] == len(picks)
        assert simulated.data.shape[1] == len(stc_idxs)
        raw._data[picks, time_slice] = simulated.data

        # add ECG, EOG, and CHPI traces
        raw._data[picks, time_slice] += \
            _interp(last_fwd_eog, fwd_eog, eog_data[:, time_slice], interp)
        raw._data[meg_picks, time_slice] += \
            _interp(last_fwd_ecg, fwd_ecg, ecg_data[:, time_slice], interp)
        this_t = np.arange(offsets[fi - 1], offsets[fi]) / raw.info['sfreq']
        sinusoids = np.zeros((n_freqs, n_time))
        for fi, freq in enumerate(hpi_freqs):
            sinusoids[fi] = 2 * np.pi * freq * this_t
            sinusoids[fi] = hpi_mag * np.sin(sinusoids[fi])
        raw._data[meg_picks, time_slice] += \
            _interp(last_fwd_chpi, fwd_chpi, sinusoids, interp)

        # add events
        raw._data[event_ch, event_idxs] = fi

        # prepare for next iteration
        last_fwd, last_fwd_eog, last_fwd_ecg, last_fwd_chpi = \
            fwd, fwd_eog, fwd_ecg, fwd_chpi
    assert used.all()
    logger.info('Done')
    return raw
Exemplo n.º 23
0
def texture_ERB(n_freqs=20, n_coh=None, rho=1., seq=('inc', 'nb', 'inc', 'nb'),
                fs=24414.0625, dur=1., SAM_freq=7., random_state=None,
                freq_lims=(200, 8000), verbose=True):
    """Create ERB texture stimulus

    Parameters
    ----------
    n_freqs : int
        Number of tones in mixture (default 20).
    n_coh : int | None
        Number of tones to be temporally coherent. Default (None) is
        ``int(np.round(n_freqs * 0.8))``.
    rho : float
        Correlation between the envelopes of grouped tones (default is 1.0).
    seq : list
        Sequence of incoherent ('inc'), coherent noise envelope ('nb'), and
        SAM ('sam') mixtures. Default is ``('inc', 'nb', 'inc', 'nb')``.
    fs : float
        Sampling rate in Hz.
    dur : float
        Duration (in seconds) of each token in seq (default is 1.0).
    SAM_freq : float
        The SAM frequency to use.
    random_state : None | int | np.random.RandomState
        The random generator state used for band selection and noise
        envelope generation.
    freq_lims : tuple
        The lower and upper frequency limits (default is (200, 8000)).
    verbose : bool
        If True, print the resulting ERB spacing.

    Returns
    -------
    x : ndarray, shape (n_samples,)
        The stimulus, where ``n_samples = len(seq) * (fs * dur)``
        (approximately).

    Notes
    -----
    This function requires MNE.
    """
    from mne.time_frequency.multitaper import dpss_windows
    from mne.utils import check_random_state
    if not isinstance(seq, (list, tuple, np.ndarray)):
        raise TypeError('seq must be list, tuple, or ndarray, got %s'
                        % type(seq))
    known_seqs = ('inc', 'nb', 'sam')
    for si, s in enumerate(seq):
        if s not in known_seqs:
            raise ValueError('all entries in seq must be one of %s, got '
                             'seq[%s]=%s' % (known_seqs, si, s))
    fs = float(fs)
    rng = check_random_state(random_state)
    n_coh = int(np.round(n_freqs * 0.8)) if n_coh is None else n_coh
    rise = 0.002
    t = np.arange(int(round(dur * fs))) / fs

    f_min, f_max = freq_lims
    n_ERBs = _cams(f_max) - _cams(f_min)
    del f_max
    spacing_ERBs = n_ERBs / float(n_freqs - 1)
    if verbose:
        print('This stim will have successive tones separated by %2.2f ERBs'
              % spacing_ERBs)
    if spacing_ERBs < 1.0:
        warnings.warn('The spacing between tones is LESS THAN 1 ERB!')

    # Make a filter whose impulse response is purely positive (to avoid phase
    # jumps) so that the filtered envelope is purely positive. Use a DPSS
    # window to minimize sidebands. For a bandwidth of bw, to get the shortest
    # filterlength, we need to restrict time-bandwidth product to a minimum.
    # Thus we need a length*bw = 2 => length = 2/bw (second). Hence filter
    # coefficients are calculated as follows:
    b = dpss_windows(int(np.floor(2 * fs / 100.)), 1., 1)[0][0]
    b -= b[0]
    b /= b.sum()

    # Incoherent
    envrate = 14
    bw = 20
    incoh = 0.
    for k in range(n_freqs):
        f = _inv_cams(_cams(f_min) + spacing_ERBs * k)
        env = _make_narrow_noise(bw, envrate, dur, fs, rise, rng)
        env[env < 0] = 0
        env = np.convolve(b, env)[:len(t)]
        incoh += _scale_sound(window_edges(
            env * np.sin(2 * np.pi * f * t), fs, rise, window='dpss'))
    incoh /= rms(incoh)

    # Coherent (noise band)
    stims = dict(inc=0., nb=0., sam=0.)
    group = np.sort(rng.permutation(np.arange(n_freqs))[:n_coh])
    for kind in known_seqs:
        if kind == 'nb':  # noise band
            env_coh = _make_narrow_noise(bw, envrate, dur, fs, rise, rng)
        else:  # 'nb' or 'inc'
            env_coh = 0.5 + np.sin(2 * np.pi * SAM_freq * t) / 2.
            env_coh = window_edges(env_coh, fs, rise, window='dpss')
        env_coh[env_coh < 0] = 0
        env_coh = np.convolve(b, env_coh)[:len(t)]
        if kind == 'inc':
            use_group = []  # no coherent ones
        else:  # 'nb' or 'sam'
            use_group = group
        for k in range(n_freqs):
            f = _inv_cams(_cams(f_min) + spacing_ERBs * k)
            env_inc = _make_narrow_noise(bw, envrate, dur, fs, rise, rng)
            env_inc[env_inc < 0] = 0.
            env_inc = np.convolve(b, env_inc)[:len(t)]
            if k in use_group:
                env = np.sqrt(rho) * env_coh + np.sqrt(1 - rho ** 2) * env_inc
            else:
                env = env_inc
            stims[kind] += _scale_sound(window_edges(
                env * np.sin(2 * np.pi * f * t), fs, rise, window='dpss'))
        stims[kind] /= rms(stims[kind])
    stim = np.concatenate([stims[s] for s in seq])
    stim = 0.01 * stim / rms(stim)
    return stim
Exemplo n.º 24
0
def simulate_movement(raw, pos, stc, trans, src, bem, cov='simple',
                      mindist=1.0, interp='linear', random_state=None,
                      n_jobs=1, verbose=None):
    """Simulate raw data with head movements

    Parameters
    ----------
    raw : instance of Raw
        The raw instance to use. The measurement info, including the
        head positions, will be used to simulate data.
    pos : str | dict | None
        Name of the position estimates file. Should be in the format of
        the files produced by maxfilter-produced. If dict, keys should
        be the time points and entries should be 4x3 ``dev_head_t``
        matrices. If None, the original head position (from
        ``raw.info['dev_head_t']``) will be used.
    stc : instance of SourceEstimate
        The source estimate to use to simulate data. Must have the same
        sample rate as the raw data.
    trans : dict | str
        Either a transformation filename (usually made using mne_analyze)
        or an info dict (usually opened using read_trans()).
        If string, an ending of `.fif` or `.fif.gz` will be assumed to
        be in FIF format, any other ending will be assumed to be a text
        file with a 4x4 transformation matrix (like the `--trans` MNE-C
        option).
    src : str | instance of SourceSpaces
        If string, should be a source space filename. Can also be an
        instance of loaded or generated SourceSpaces.
    bem : str
        Filename of the BEM (e.g., "sample-5120-5120-5120-bem-sol.fif").
    cov : instance of Covariance | 'simple' | None
        The sensor covariance matrix used to generate noise. If None,
        no noise will be added. If 'simple', a basic (diagonal) ad-hoc
        noise covariance will be used.
    mindist : float
        Minimum distance between sources and the inner skull boundary
        to use during forward calculation.
    interp : str
        Either 'linear' or 'zero', the type of forward-solution
        interpolation to use between provided time points.
    random_state : None | int | np.random.RandomState
        To specify the random generator state.
    n_jobs : int
        Number of jobs to use.
    verbose : bool, str, int, or None
        If not None, override default verbose level (see mne.verbose).

    Returns
    -------
    raw : instance of Raw
        The simulated raw file.

    Notes
    -----
    Events coded with the number of the forward solution used will be placed
    in the raw files in the trigger channel STI101 at the t=0 times of the
    SourceEstimates.

    The resulting SNR will be determined by the structure of the noise
    covariance, and the amplitudes of the SourceEstimate. Note that this
    will vary as a function of position.
    """
    if isinstance(raw, string_types):
        with warnings.catch_warnings(record=True):
            raw = Raw(raw, allow_maxshield=True, preload=True, verbose=False)
    else:
        raw = raw.copy()

    if not isinstance(stc, _BaseSourceEstimate):
        raise TypeError('stc must be a SourceEstimate')
    if not np.allclose(raw.info['sfreq'], 1. / stc.tstep):
        raise ValueError('stc and raw must have same sample rate')
    rng = check_random_state(random_state)
    if interp not in ('linear', 'zero'):
        raise ValueError('interp must be "linear" or "zero"')

    if pos is None:  # use pos from file
        dev_head_ts = [raw.info['dev_head_t']] * 2
        offsets = np.array([0, raw.n_times])
        interp = 'zero'
    else:
        if isinstance(pos, string_types):
            pos = get_chpi_positions(pos, verbose=False)
        if isinstance(pos, tuple):  # can be an already-loaded pos file
            transs, rots, ts = pos
            ts -= raw.first_samp / raw.info['sfreq']  # MF files need reref
            dev_head_ts = [np.r_[np.c_[r, t[:, np.newaxis]], [[0, 0, 0, 1]]]
                           for r, t in zip(rots, transs)]
            del transs, rots
        elif isinstance(pos, dict):
            ts = np.array(list(pos.keys()), float)
            ts.sort()
            dev_head_ts = [pos[float(tt)] for tt in ts]
        else:
            raise TypeError('unknown pos type %s' % type(pos))
        if not (ts >= 0).all():  # pathological if not
            raise RuntimeError('Cannot have t < 0 in transform file')
        tend = raw.times[-1]
        assert not (ts < 0).any()
        assert not (ts > tend).any()
        if ts[0] > 0:
            ts = np.r_[[0.], ts]
            dev_head_ts.insert(0, raw.info['dev_head_t']['trans'])
        dev_head_ts = [{'trans': d, 'to': raw.info['dev_head_t']['to'],
                        'from': raw.info['dev_head_t']['from']}
                       for d in dev_head_ts]
        if ts[-1] < tend:
            dev_head_ts.append(dev_head_ts[-1])
            ts = np.r_[ts, [tend]]
        offsets = raw.time_as_index(ts)
        offsets[-1] = raw.n_times  # fix for roundoff error
        assert offsets[-2] != offsets[-1]
        del ts
    if isinstance(cov, string_types):
        assert cov == 'simple'
        cov = make_ad_hoc_cov(raw.info, verbose=False)
    assert np.array_equal(offsets, np.unique(offsets))
    assert len(offsets) == len(dev_head_ts)
    approx_events = int((raw.n_times / raw.info['sfreq']) /
                        (stc.times[-1] - stc.times[0]))
    logger.info('Provided parameters will provide approximately %s event%s'
                % (approx_events, '' if approx_events == 1 else 's'))

    # get HPI freqs and reorder
    hpi_freqs = np.array([x['custom_ref'][0]
                          for x in raw.info['hpi_meas'][0]['hpi_coils']])
    n_freqs = len(hpi_freqs)
    order = [x['number'] - 1 for x in raw.info['hpi_meas'][0]['hpi_coils']]
    assert np.array_equal(np.unique(order), np.arange(n_freqs))
    hpi_freqs = hpi_freqs[order]
    hpi_order = raw.info['hpi_results'][0]['order'] - 1
    assert np.array_equal(np.unique(hpi_order), np.arange(n_freqs))
    hpi_freqs = hpi_freqs[hpi_order]

    # extract necessary info
    picks = pick_types(raw.info, meg=True, eeg=True)  # for simulation
    meg_picks = pick_types(raw.info, meg=True, eeg=False)  # for CHPI
    fwd_info = pick_info(raw.info, picks)
    fwd_info['projs'] = []
    logger.info('Setting up raw data simulation using %s head position%s'
                % (len(dev_head_ts), 's' if len(dev_head_ts) != 1 else ''))
    raw.preload_data(verbose=False)

    if isinstance(stc, VolSourceEstimate):
        verts = [stc.vertices]
    else:
        verts = stc.vertices
    src = _restrict_source_space_to(src, verts)

    # figure out our cHPI, ECG, and EOG dipoles
    dig = raw.info['dig']
    assert all([d['coord_frame'] == FIFF.FIFFV_COORD_HEAD
                for d in dig if d['kind'] == FIFF.FIFFV_POINT_HPI])
    chpi_rrs = [d['r'] for d in dig if d['kind'] == FIFF.FIFFV_POINT_HPI]
    R, r0 = fit_sphere_to_headshape(raw.info, verbose=False)[:2]
    R /= 1000.
    r0 /= 1000.
    ecg_rr = np.array([[-R, 0, -3 * R]])
    eog_rr = [d['r'] for d in raw.info['dig']
              if d['ident'] == FIFF.FIFFV_POINT_NASION][0]
    eog_rr = eog_rr - r0
    eog_rr = (eog_rr / np.sqrt(np.sum(eog_rr * eog_rr)) *
              0.98 * R)[np.newaxis, :]
    eog_rr += r0
    eog_bem = make_sphere_model(r0, head_radius=R, relative_radii=(0.99, 1.),
                                sigmas=(0.33, 0.33), verbose=False)
    # let's oscillate between resting (17 bpm) and reading (4.5 bpm) rate
    # http://www.ncbi.nlm.nih.gov/pubmed/9399231
    blink_rate = np.cos(2 * np.pi * 1. / 60. * raw.times)
    blink_rate *= 12.5 / 60.
    blink_rate += 4.5 / 60.
    blink_data = rng.rand(raw.n_times) < blink_rate / raw.info['sfreq']
    blink_data = blink_data * (rng.rand(raw.n_times) + 0.5)  # vary amplitudes
    blink_kernel = np.hanning(int(0.25 * raw.info['sfreq']))
    eog_data = np.convolve(blink_data, blink_kernel, 'same')[np.newaxis, :]
    eog_data += rng.randn(eog_data.shape[1]) * 0.05
    eog_data *= 100e-6
    del blink_data,

    max_beats = int(np.ceil(raw.times[-1] * 70. / 60.))
    cardiac_idx = np.cumsum(rng.uniform(60. / 70., 60. / 50., max_beats) *
                            raw.info['sfreq']).astype(int)
    cardiac_idx = cardiac_idx[cardiac_idx < raw.n_times]
    cardiac_data = np.zeros(raw.n_times)
    cardiac_data[cardiac_idx] = 1
    cardiac_kernel = np.concatenate([
        2 * np.hanning(int(0.04 * raw.info['sfreq'])),
        -0.3 * np.hanning(int(0.05 * raw.info['sfreq'])),
        0.2 * np.hanning(int(0.26 * raw.info['sfreq']))], axis=-1)
    ecg_data = np.convolve(cardiac_data, cardiac_kernel, 'same')[np.newaxis, :]
    ecg_data += rng.randn(ecg_data.shape[1]) * 0.05
    ecg_data *= 3e-4
    del cardiac_data

    # Add to data file, then rescale for simulation
    for data, scale, exg_ch in zip([eog_data, ecg_data],
                                   [1e-3, 5e-4],
                                   ['EOG062', 'ECG063']):
        ch = pick_channels(raw.ch_names, [exg_ch])
        if len(ch) == 1:
            raw._data[ch[0], :] = data
        data *= scale

    evoked = EvokedArray(np.zeros((len(picks), len(stc.times))), fwd_info,
                         stc.tmin, verbose=False)
    stc_event_idx = np.argmin(np.abs(stc.times))
    event_ch = pick_channels(raw.info['ch_names'], ['STI101'])[0]
    used = np.zeros(raw.n_times, bool)
    stc_indices = np.arange(raw.n_times) % len(stc.times)
    raw._data[event_ch, ].fill(0)
    hpi_mag = 25e-9
    last_fwd = last_fwd_chpi = last_fwd_eog = last_fwd_ecg = src_sel = None
    for fi, (fwd, fwd_eog, fwd_ecg, fwd_chpi) in \
        enumerate(_make_forward_solutions(
            fwd_info, trans, src, bem, eog_bem, dev_head_ts, mindist,
            chpi_rrs, eog_rr, ecg_rr, n_jobs)):
        # must be fixed orientation
        fwd = convert_forward_solution(fwd, surf_ori=True,
                                       force_fixed=True, verbose=False)
        # just use one arbitrary direction
        fwd_eog = fwd_eog['sol']['data'][:, ::3]
        fwd_ecg = fwd_ecg['sol']['data'][:, ::3]
        fwd_chpi = fwd_chpi[:, ::3]

        if src_sel is None:
            src_sel = _stc_src_sel(fwd['src'], stc)
            if isinstance(stc, VolSourceEstimate):
                verts = [stc.vertices]
            else:
                verts = stc.vertices
            diff_ = sum([len(v) for v in verts]) - len(src_sel)
            if diff_ != 0:
                warnings.warn('%s STC vertices omitted due to fwd calculation'
                              % (diff_,))
        if last_fwd is None:
            last_fwd, last_fwd_eog, last_fwd_ecg, last_fwd_chpi = \
                fwd, fwd_eog, fwd_ecg, fwd_chpi
            continue
        n_time = offsets[fi] - offsets[fi-1]

        time_slice = slice(offsets[fi-1], offsets[fi])
        assert not used[time_slice].any()
        stc_idxs = stc_indices[time_slice]
        event_idxs = np.where(stc_idxs == stc_event_idx)[0] + offsets[fi-1]
        used[time_slice] = True
        logger.info('  Simulating data for %0.3f-%0.3f sec with %s event%s'
                    % (tuple(offsets[fi-1:fi+1] / raw.info['sfreq']) +
                       (len(event_idxs), '' if len(event_idxs) == 1 else 's')))

        # simulate brain data
        stc_data = stc.data[:, stc_idxs][src_sel]
        data = _interp(last_fwd['sol']['data'], fwd['sol']['data'],
                       stc_data, interp)
        simulated = EvokedArray(data, evoked.info, 0)
        if cov is not None:
            noise = generate_noise_evoked(simulated, cov, [1, -1, 0.2], rng)
            simulated.data += noise.data
        assert simulated.data.shape[0] == len(picks)
        assert simulated.data.shape[1] == len(stc_idxs)
        raw._data[picks, time_slice] = simulated.data

        # add ECG, EOG, and CHPI traces
        raw._data[picks, time_slice] += \
            _interp(last_fwd_eog, fwd_eog, eog_data[:, time_slice], interp)
        raw._data[meg_picks, time_slice] += \
            _interp(last_fwd_ecg, fwd_ecg, ecg_data[:, time_slice], interp)
        this_t = np.arange(offsets[fi-1], offsets[fi]) / raw.info['sfreq']
        sinusoids = np.zeros((n_freqs, n_time))
        for fi, freq in enumerate(hpi_freqs):
            sinusoids[fi] = 2 * np.pi * freq * this_t
            sinusoids[fi] = hpi_mag * np.sin(sinusoids[fi])
        raw._data[meg_picks, time_slice] += \
            _interp(last_fwd_chpi, fwd_chpi, sinusoids, interp)

        # add events
        raw._data[event_ch, event_idxs] = fi

        # prepare for next iteration
        last_fwd, last_fwd_eog, last_fwd_ecg, last_fwd_chpi = \
            fwd, fwd_eog, fwd_ecg, fwd_chpi
    assert used.all()
    logger.info('Done')
    return raw
Exemplo n.º 25
0
def texture_ERB(n_freqs=20,
                n_coh=None,
                rho=1.,
                seq=('inc', 'nb', 'inc', 'nb'),
                fs=24414.0625,
                dur=1.,
                SAM_freq=7.,
                random_state=None,
                freq_lims=(200, 8000),
                verbose=True):
    """Create ERB texture stimulus

    Parameters
    ----------
    n_freqs : int
        Number of tones in mixture (default 20).
    n_coh : int | None
        Number of tones to be temporally coherent. Default (None) is
        ``int(np.round(n_freqs * 0.8))``.
    rho : float
        Correlation between the envelopes of grouped tones (default is 1.0).
    seq : list
        Sequence of incoherent ('inc'), coherent noise envelope ('nb'), and
        SAM ('sam') mixtures. Default is ``('inc', 'nb', 'inc', 'nb')``.
    fs : float
        Sampling rate in Hz.
    dur : float
        Duration (in seconds) of each token in seq (default is 1.0).
    SAM_freq : float
        The SAM frequency to use.
    random_state : None | int | np.random.RandomState
        The random generator state used for band selection and noise
        envelope generation.
    freq_lims : tuple
        The lower and upper frequency limits (default is (200, 8000)).
    verbose : bool
        If True, print the resulting ERB spacing.

    Returns
    -------
    x : ndarray, shape (n_samples,)
        The stimulus, where ``n_samples = len(seq) * (fs * dur)``
        (approximately).

    Notes
    -----
    This function requires MNE.
    """
    from mne.time_frequency.multitaper import dpss_windows
    from mne.utils import check_random_state
    if not isinstance(seq, (list, tuple, np.ndarray)):
        raise TypeError('seq must be list, tuple, or ndarray, got %s' %
                        type(seq))
    known_seqs = ('inc', 'nb', 'sam')
    for si, s in enumerate(seq):
        if s not in known_seqs:
            raise ValueError('all entries in seq must be one of %s, got '
                             'seq[%s]=%s' % (known_seqs, si, s))
    fs = float(fs)
    rng = check_random_state(random_state)
    n_coh = int(np.round(n_freqs * 0.8)) if n_coh is None else n_coh
    rise = 0.002
    t = np.arange(int(round(dur * fs))) / fs

    f_min, f_max = freq_lims
    n_ERBs = _cams(f_max) - _cams(f_min)
    del f_max
    spacing_ERBs = n_ERBs / float(n_freqs - 1)
    if verbose:
        print('This stim will have successive tones separated by %2.2f ERBs' %
              spacing_ERBs)
    if spacing_ERBs < 1.0:
        warnings.warn('The spacing between tones is LESS THAN 1 ERB!')

    # Make a filter whose impulse response is purely positive (to avoid phase
    # jumps) so that the filtered envelope is purely positive. Use a DPSS
    # window to minimize sidebands. For a bandwidth of bw, to get the shortest
    # filterlength, we need to restrict time-bandwidth product to a minimum.
    # Thus we need a length*bw = 2 => length = 2/bw (second). Hence filter
    # coefficients are calculated as follows:
    b = dpss_windows(int(np.floor(2 * fs / 100.)), 1., 1)[0][0]
    b -= b[0]
    b /= b.sum()

    # Incoherent
    envrate = 14
    bw = 20
    incoh = 0.
    for k in range(n_freqs):
        f = _inv_cams(_cams(f_min) + spacing_ERBs * k)
        env = _make_narrow_noise(bw, envrate, dur, fs, rise, rng)
        env[env < 0] = 0
        env = np.convolve(b, env)[:len(t)]
        incoh += _scale_sound(
            window_edges(env * np.sin(2 * np.pi * f * t),
                         fs,
                         rise,
                         window='dpss'))
    incoh /= rms(incoh)

    # Coherent (noise band)
    stims = dict(inc=0., nb=0., sam=0.)
    group = np.sort(rng.permutation(np.arange(n_freqs))[:n_coh])
    for kind in known_seqs:
        if kind == 'nb':  # noise band
            env_coh = _make_narrow_noise(bw, envrate, dur, fs, rise, rng)
        else:  # 'nb' or 'inc'
            env_coh = 0.5 + np.sin(2 * np.pi * SAM_freq * t) / 2.
            env_coh = window_edges(env_coh, fs, rise, window='dpss')
        env_coh[env_coh < 0] = 0
        env_coh = np.convolve(b, env_coh)[:len(t)]
        if kind == 'inc':
            use_group = []  # no coherent ones
        else:  # 'nb' or 'sam'
            use_group = group
        for k in range(n_freqs):
            f = _inv_cams(_cams(f_min) + spacing_ERBs * k)
            env_inc = _make_narrow_noise(bw, envrate, dur, fs, rise, rng)
            env_inc[env_inc < 0] = 0.
            env_inc = np.convolve(b, env_inc)[:len(t)]
            if k in use_group:
                env = np.sqrt(rho) * env_coh + np.sqrt(1 - rho**2) * env_inc
            else:
                env = env_inc
            stims[kind] += _scale_sound(
                window_edges(env * np.sin(2 * np.pi * f * t),
                             fs,
                             rise,
                             window='dpss'))
        stims[kind] /= rms(stims[kind])
    stim = np.concatenate([stims[s] for s in seq])
    stim = 0.01 * stim / rms(stim)
    return stim
def learn_conv_sparse_coder(b, size_kernel, max_it, tol,
                            known_d=None,
                            beta=np.float64(1.0),
                            random_state=None):
    """
    Main function to solve the convolutional sparse coding
    Parameters for this function
    - b               : the signal dataset with size (num_signals, length)
    - size_kernel     : the size of each kernel (num_kernels, length)
    - beta            : the trade-off between sparsity and reconstruction error (as mentioned in the paper)
    - max_it          : the maximum iterations of the outer loop
    - tol             : the minimal difference in filters and codes after each iteration to continue
    - known_d         : the predefined filters (if possible)

    Important variables used in the code:
    - u_D, u_Z        : pair of proximal values for d-step and z-step
    - d_D, d_Z        : pair of Lagrange multipliers in the ADMM algo for d-step and z-step
    - v_D, v_Z        : pair of initial value pairs (Zd, d) for d-step and (Dz, z) for z-step
    """
    rng = check_random_state(random_state)

    k = size_kernel[0]
    n = b.shape[0]

    psf_radius = int(np.floor(size_kernel[1] / 2))

    size_x = [n, b.shape[1] + 2 * psf_radius]
    size_z = [n, k, size_x[1]]
    size_k_full = [k, size_x[1]]

    # M is MtM, Mtb is Mtx, the matrix M is zero-padded in 2*psf_radius rows
    # and cols
    M = np.pad(np.ones(b.shape, dtype=real_type),
               ((0, 0), (psf_radius, psf_radius)),
               mode='constant', constant_values=0)
    Mtb = np.pad(b, ((0, 0), (psf_radius, psf_radius)),
                 mode='constant', constant_values=0)

    """Penalty parameters, including the calculation of augmented Lagrange multipliers"""
    lambda_residual = np.float64(1.0)
    lambda_prior = np.float64(1.0)
    lambdas = [lambda_residual, lambda_prior]
    gamma_heuristic = 60 * lambda_prior * 1 / np.amax(b)
    gammas_D = [gamma_heuristic / 5000, gamma_heuristic]
    gammas_Z = [gamma_heuristic / 500, gamma_heuristic]
    rho = gammas_D[1] / gammas_D[0]

    """Initialize variables for the d-step"""
    if known_d is None:
        varsize_D = [size_x, size_k_full]
        xi_D = [np.zeros(varsize_D[0], dtype=real_type),
                np.zeros(varsize_D[1], dtype=real_type)]

        xi_D_hat = [np.zeros(varsize_D[0], dtype=imaginary_type),
                    np.zeros(varsize_D[1], dtype=imaginary_type)]

        u_D = [np.zeros(varsize_D[0], dtype=real_type),
               np.zeros(varsize_D[1], dtype=real_type)]

        # Lagrange multipliers
        d_D = [np.zeros(varsize_D[0], dtype=real_type),
               np.zeros(varsize_D[1], dtype=real_type)]

        v_D = [np.zeros(varsize_D[0], dtype=real_type),
               np.zeros(varsize_D[1], dtype=real_type)]

        d = rng.normal(size=size_kernel)
    else:
        d = known_d

    # Initial the filters and its fft after being rolled to fit the frequency
    d = np.pad(d, ((0, 0),
                   (0, size_x[1] - size_kernel[1])),
               mode='constant', constant_values=0)
    d = np.roll(d, -int(psf_radius), axis=1)
    d_hat = fft(d)

    """Initialize variables for the z-step"""
    varsize_Z = [size_x, size_z]
    xi_Z = [np.zeros(varsize_Z[0], dtype=real_type),
            np.zeros(varsize_Z[1], dtype=real_type)]

    xi_Z_hat = [np.zeros(varsize_Z[0], dtype=imaginary_type),
                np.zeros(varsize_Z[1], dtype=imaginary_type)]

    u_Z = [np.zeros(varsize_Z[0], dtype=real_type),
           np.zeros(varsize_Z[1], dtype=real_type)]

    # Lagrange multipliers
    d_Z = [np.zeros(varsize_Z[0], dtype=real_type),
           np.zeros(varsize_Z[1], dtype=real_type)]

    v_Z = [np.zeros(varsize_Z[0], dtype=real_type),
           np.zeros(varsize_Z[1], dtype=real_type)]

    # Initial the codes and its fft
    z = rng.normal(size=size_z)
    z_hat = fft(z)

    """Initial objective function (usually very large)"""
    obj_val = obj_func(z_hat, d_hat, b,
                       lambda_residual, lambda_prior,
                       psf_radius, size_z, size_x)

    # Back-and-forth local iteration for d and z
    if known_d is None:
        max_it_d = 10
    else:
        max_it_d = 0

    max_it_z = 10

    obj_val_filter = obj_val
    obj_val_z = obj_val

    list_obj_val = np.zeros(max_it)
    list_obj_val_filter = np.zeros(max_it)
    list_obj_val_z = np.zeros(max_it)

    """Start the main algorithm"""
    for i in range(max_it):

        """D-STEP"""
        if known_d is None:
            # Precompute what is necessary for later
            [zhat_mat, zhat_inv_mat] = precompute_D_step(z_hat, size_z, rho)
            d_old = d

            for i_d in range(max_it_d):

               # Compute v = [Zd, d]
                d_hat_dot_z_hat = np.multiply(d_hat, z_hat)
                v_D[0] = np.real(
                    ifft(np.sum(d_hat_dot_z_hat, axis=1).reshape(size_x)))
                v_D[1] = d

                # Compute proximal updates
                u = v_D[0] - d_D[0]
                theta = lambdas[0] / gammas_D[0]
                u_D[0] = np.divide((Mtb + 1.0 / theta * u),
                                   (M + 1.0 / theta * np.ones(size_x)))

                u = v_D[1] - d_D[1]
                u_D[1] = KernelConstraintProj(u, size_k_full, psf_radius)

                # Update Langrange multipliers
                d_D[0] = d_D[0] + (u_D[0] - v_D[0])
                d_D[1] = d_D[1] + (u_D[1] - v_D[1])

                # Compute new xi=u+d and transform to fft
                xi_D[0] = u_D[0] + d_D[0]
                xi_D[1] = u_D[1] + d_D[1]
                xi_D_hat[0] = fft(xi_D[0])
                xi_D_hat[1] = fft(xi_D[1])

                # Solve convolutional inverse
                d_hat = solve_conv_term_D(
                    zhat_mat, zhat_inv_mat, xi_D_hat, rho, size_z)
                d = np.real(ifft(d_hat))

                if (i_d == max_it_d - 1):
                    obj_val = obj_func(z_hat, d_hat, b,
                                       lambda_residual, lambda_prior,
                                       psf_radius, size_z, size_x)

                    print('--> Obj %3.3f' % obj_val)

                obj_val_filter = obj_val

                # Debug progress
                d_diff = d - d_old
                d_comp = d

                if (i_d == max_it_d - 1):
                    obj_val = obj_func(z_hat, d_hat, b,
                                       lambda_residual, lambda_prior,
                                       psf_radius, size_z, size_x)
                    print('Iter D %d, Obj %3.3f, Diff %5.5f' %
                          (i, obj_val, linalg.norm(d_diff) / linalg.norm(d_comp)))

        """Z-STEP"""
        # Precompute what is necessary for later
        [dhat_flat, dhatTdhat_flat] = precompute_Z_step(d_hat, size_x)
        dhatT_flat = np.ma.conjugate(dhat_flat.T)

        z_old = z

        for i_z in range(max_it_z):

            # Compute v = [Dz,z]
            d_hat_dot_z_hat = np.multiply(d_hat, z_hat)
            v_Z[0] = np.real(
                ifft(np.sum(d_hat_dot_z_hat, axis=1).reshape(size_x)))
            v_Z[1] = z

            # Compute proximal updates
            u = v_Z[0] - d_Z[0]
            theta = lambdas[0] / gammas_Z[0]
            u_Z[0] = np.divide((Mtb + 1.0 / theta * u),
                               (M + 1.0 / theta * np.ones(size_x)))

            u = v_Z[1] - d_Z[1]
            theta = lambdas[1] / gammas_Z[1] * np.ones(u.shape)
            u_Z[1] = np.multiply(np.maximum(
                0, 1 - np.divide(theta, np.abs(u))), u)

            # Update Langrange multipliers
            d_Z[0] = d_Z[0] + (u_Z[0] - v_Z[0])
            d_Z[1] = d_Z[1] + (u_Z[1] - v_Z[1])

            # Compute new xi=u+d and transform to fft
            xi_Z[0] = u_Z[0] + d_Z[0]
            xi_Z[1] = u_Z[1] + d_Z[1]

            xi_Z_hat[0] = fft(xi_Z[0])
            xi_Z_hat[1] = fft(xi_Z[1])

            # Solve convolutional inverse
            z_hat = solve_conv_term_Z(
                dhatT_flat, dhatTdhat_flat, xi_Z_hat, gammas_Z, size_z)
            z = np.real(ifft(z_hat))

            if (i_z == max_it_z - 1):
                obj_val = obj_func(z_hat, d_hat, b,
                                   lambda_residual, lambda_prior,
                                   psf_radius, size_z, size_x)

                print('--> Obj %3.3f' % obj_val)

        obj_val_z = obj_val

        list_obj_val[i] = obj_val
        list_obj_val_filter[i] = obj_val_filter
        list_obj_val_z[i] = obj_val_z

        # Debug progress
        z_diff = z - z_old
        z_comp = z

        print('Iter Z %d, Obj %3.3f, Diff %5.5f' %
              (i, obj_val, linalg.norm(z_diff) / linalg.norm(z_comp)))

        # Termination
        if (linalg.norm(z_diff) / linalg.norm(z_comp) < tol and
                linalg.norm(d_diff) / linalg.norm(d_comp) < tol):
            break

    """Final estimate"""
    z_res = z

    d_res = d
    d_res = np.roll(d_res, psf_radius, axis=1)
    d_res = d_res[:, 0:psf_radius * 2 + 1]

    fft_dot = np.multiply(d_hat, z_hat)
    Dz = np.real(ifft(np.sum(fft_dot, axis=1).reshape(size_x)))

    obj_val = obj_func(z_hat, d_hat, b,
                       lambda_residual, lambda_prior,
                       psf_radius, size_z, size_x)
    print('Final objective function %f' % (obj_val))

    reconstr_err = reconstruction_err(z_hat, d_hat, b, psf_radius, size_x)
    print('Final reconstruction error %f' % reconstr_err)

    return [d_res, z_res, Dz, list_obj_val, list_obj_val_filter, list_obj_val_z, reconstr_err]
Exemplo n.º 27
0
def find_bad_by_ransac(
    data,
    sample_rate,
    signal_len,
    complete_chn_labs,
    chn_pos,
    exclude,
    n_samples=50,
    fraction_good=0.25,
    corr_thresh=0.75,
    fraction_bad=0.4,
    corr_window_secs=5.0,
    channel_wise=False,
    random_state=None,
):
    """Detect channels that are not predicted well by other channels.

    Here, a ransac approach (see [1]_, and a short discussion in [2]_) is
    adopted to predict a "clean EEG" dataset. After identifying clean EEG
    channels through the other methods, the clean EEG dataset is
    constructed by repeatedly sampling a small subset of clean EEG channels
    and interpolation the complete data. The median of all those
    repetitions forms the clean EEG dataset. In a second step, the original
    and the ransac predicted data are correlated and channels, which do not
    correlate well with themselves across the two datasets are considered
    `bad_by_ransac`.

    Parameters
    ----------
    data : np.ndarray
        2-D EEG data, should be detrended.
    sample_rate : float
        sample rate of the EEG data
    signal_len : float
        number of total samples in the signal (the length of the signal).
    complete_chn_labs : array_like
        labels of the channels in data in the same order
    chn_pos : np.ndarray
        3-D coordinates of all the channels in the order of data
    exclude : list
        labels of the channels to ignore in the ransac. In example bad channels
        from other methods.
    n_samples : int
        Number of samples used for computation of ransac.
    fraction_good : float
        Fraction of channels used for robust reconstruction of the signal.
        This needs to be in the range [0, 1], where obviously neither 0
        nor 1 would make sense.
    corr_thresh : float
        The minimum correlation threshold that should be attained within a
        data window.
    fraction_bad : float
        If this percentage of all data windows in which the correlation
        threshold was not surpassed is exceeded, classify a
        channel as `bad_by_ransac`.
    corr_window_secs : float
        Size of the correlation window in seconds.
    channel_wise : bool
        If True the ransac will be done 1 channel at a time, if false
        it will be done as fast as possible (more channels at a time).

    Returns
    -------
    bad_by_ransac : list
        List of channels labels marked bad by ransac.
    channel_correlations : np.ndarray
        Array of shape (windows,channels) holding the correlations of
        the channels to their predicted ransac value in each of the windows.

    References
    ----------
    .. [1] Fischler, M.A., Bolles, R.C. (1981). Random rample consensus: A
        Paradigm for Model Fitting with Applications to Image Analysis and
        Automated Cartography. Communications of the ACM, 24, 381-395
    .. [2] Jas, M., Engemann, D.A., Bekhti, Y., Raimondo, F., Gramfort, A.
        (2017). Autoreject: Automated Artifact Rejection for MEG and EEG
        Data. NeuroImage, 159, 417-429
    """
    # First, check that the argument types are valid
    if type(n_samples) != int:
        err = "Argument 'n_samples' must be an int (got {0})"
        raise TypeError(err.format(type(n_samples).__name__))

    # Get all channel positions and the position subset of "clean channels"
    # Exclude should be the bad channels from other methods
    # That is, identify all bad channels by other means
    good_idx = mne.pick_channels(list(complete_chn_labs),
                                 include=[],
                                 exclude=exclude)
    good_chn_labs = complete_chn_labs[good_idx]
    n_chans_good = good_idx.shape[0]
    chn_pos_good = chn_pos[good_idx, :]

    # Check if we have enough remaining channels
    # after exclusion of bad channels
    n_pred_chns = int(np.ceil(fraction_good * n_chans_good))

    if n_pred_chns <= 3:
        raise IOError("Too few channels available to reliably perform"
                      " ransac. Perhaps, too many channels have failed"
                      " quality tests.")

    # Before running, make sure we have enough memory when using the
    # smallest possible chunk size
    verify_free_ram(data, n_samples, 1)

    # Generate random channel picks for each RANSAC sample
    random_ch_picks = []
    good_chans = np.arange(chn_pos_good.shape[0])
    rng = check_random_state(random_state)
    for i in range(n_samples):
        # Pick a random subset of clean channels to use for interpolation
        picks = rng.choice(good_chans, size=n_pred_chns, replace=False)
        random_ch_picks.append(picks)

    # Correlation windows setup
    correlation_frames = corr_window_secs * sample_rate
    correlation_window = np.arange(correlation_frames)
    n = correlation_window.shape[0]
    correlation_offsets = np.arange(0, (signal_len - correlation_frames),
                                    correlation_frames)
    w_correlation = correlation_offsets.shape[0]

    # Preallocate
    n_chans_complete = len(complete_chn_labs)
    channel_correlations = np.ones((w_correlation, n_chans_complete))
    # Notice self.EEGData.shape[0] = self.n_chans_new
    # Is now data.shape[0] = n_chans_complete
    # They came from the same drop of channels

    print("Executing RANSAC\nThis may take a while, so be patient...")

    # Calculate smallest chunk size for each possible chunk count
    chunk_sizes = []
    chunk_count = 0
    for i in range(1, n_chans_complete + 1):
        n_chunks = int(np.ceil(n_chans_complete / i))
        if n_chunks != chunk_count:
            chunk_count = n_chunks
            chunk_sizes.append(i)

    chunk_size = chunk_sizes.pop()
    mem_error = True
    job = list(range(n_chans_complete))

    if channel_wise:
        chunk_size = 1
    while mem_error:
        try:
            channel_chunks = split_list(job, chunk_size)
            total_chunks = len(channel_chunks)
            current = 1
            for chunk in channel_chunks:
                channel_correlations[:, chunk] = _ransac_correlations(
                    chunk,
                    random_ch_picks,
                    chn_pos,
                    chn_pos_good,
                    good_chn_labs,
                    complete_chn_labs,
                    data,
                    n_samples,
                    n,
                    w_correlation,
                )
                if chunk == channel_chunks[0]:
                    # If it gets here, it means it is the optimal
                    print("Finding optimal chunk size :", chunk_size)
                    print("Total # of chunks:", total_chunks)
                    print("Current chunk:", end=" ", flush=True)

                print(current, end=" ", flush=True)
                current = current + 1

            mem_error = False  # All chunks processed, hurray!
            del current
        except MemoryError:
            if len(chunk_sizes):
                chunk_size = chunk_sizes.pop()
            else:  # pragma: no cover
                raise MemoryError(
                    "Not even doing 1 channel at a time the data fits in ram..."
                    "You could downsample the data or reduce the number of requ"
                    "ested samples.")

    # Thresholding
    thresholded_correlations = channel_correlations < corr_thresh
    frac_bad_corr_windows = np.mean(thresholded_correlations, axis=0)

    # find the corresponding channel names and return
    bad_ransac_channels_idx = np.argwhere(frac_bad_corr_windows > fraction_bad)
    bad_ransac_channels_name = complete_chn_labs[
        bad_ransac_channels_idx.astype(int)]
    bad_by_ransac = [i[0] for i in bad_ransac_channels_name]
    print("\nRANSAC done!")

    return bad_by_ransac, channel_correlations
Exemplo n.º 28
0
def learn_conv_sparse_coder(b,
                            size_kernel,
                            max_it,
                            tol,
                            lambda_prior=1.0,
                            lambda_residual=1.0,
                            random_state=None,
                            ds_init=None,
                            feasible_evaluation=True,
                            stopping_pobj=None,
                            verbose=1,
                            max_it_d=10,
                            max_it_z=10):
    """
    Main function to solve the convolutional sparse coding.

    Parameters
    ----------
    - b               : the signal dataset with size (num_signals, length)
    - size_kernel     : the size of each kernel (num_kernels, length)
    - max_it          : the maximum iterations of the outer loop
    - tol             : the minimal difference in filters and codes after each
                        iteration to continue

    Important variables used in the code:
    - u_D, u_Z        : pair of proximal values for d-step and z-step
    - d_D, d_Z        : pair of Lagrange multipliers in the ADMM algo for
                        d-step and z-step
    - v_D, v_Z        : pair of initial value pairs (Zd, d) for d-step and
                        (Dz, z) for z-step
    """
    rng = check_random_state(random_state)

    k = size_kernel[0]
    n = b.shape[0]

    psf_radius = int(np.floor(size_kernel[1] / 2))

    size_x = [n, b.shape[1] + 2 * psf_radius]
    size_z = [n, k, size_x[1]]
    size_k_full = [k, size_x[1]]

    # M is MtM, Mtb is Mtx, the matrix M is zero-padded in 2*psf_radius rows
    # and cols
    M = np.pad(np.ones(b.shape, dtype=real_type),
               ((0, 0), (psf_radius, psf_radius)),
               mode='constant',
               constant_values=0)
    Mtb = np.pad(b, ((0, 0), (psf_radius, psf_radius)),
                 mode='constant',
                 constant_values=0)
    """Penalty parameters, including the calculation of augmented
       Lagrange multipliers"""
    lambdas = [lambda_residual, lambda_prior]
    gamma_heuristic = 60 * lambda_prior * 1 / np.amax(b)
    gammas_D = [gamma_heuristic / 5000, gamma_heuristic]
    gammas_Z = [gamma_heuristic / 500, gamma_heuristic]
    rho = gammas_D[1] / gammas_D[0]
    """Initialize variables for the d-step"""
    varsize_D = [size_x, size_k_full]
    xi_D = [
        np.zeros(varsize_D[0], dtype=real_type),
        np.zeros(varsize_D[1], dtype=real_type)
    ]

    xi_D_hat = [
        np.zeros(varsize_D[0], dtype=imaginary_type),
        np.zeros(varsize_D[1], dtype=imaginary_type)
    ]

    u_D = [
        np.zeros(varsize_D[0], dtype=real_type),
        np.zeros(varsize_D[1], dtype=real_type)
    ]

    # Lagrange multipliers
    d_D = [
        np.zeros(varsize_D[0], dtype=real_type),
        np.zeros(varsize_D[1], dtype=real_type)
    ]

    v_D = [
        np.zeros(varsize_D[0], dtype=real_type),
        np.zeros(varsize_D[1], dtype=real_type)
    ]

    # d = rng.normal(size=size_kernel)
    if ds_init is None:
        d = rng.randn(*size_kernel)
    else:
        d = ds_init.copy()
    d_norm = np.linalg.norm(d, axis=1)
    d /= d_norm[:, None]

    # Initial the filters and its fft after being rolled to fit the frequency
    d = np.pad(d, ((0, 0), (0, size_x[1] - size_kernel[1])),
               mode='constant',
               constant_values=0)
    d = np.roll(d, -int(psf_radius), axis=1)
    d_hat = fft(d)

    # Initialize variables for the z-step
    varsize_Z = [size_x, size_z]
    xi_Z = [
        np.zeros(varsize_Z[0], dtype=real_type),
        np.zeros(varsize_Z[1], dtype=real_type)
    ]

    xi_z_hat = [
        np.zeros(varsize_Z[0], dtype=imaginary_type),
        np.zeros(varsize_Z[1], dtype=imaginary_type)
    ]

    u_Z = [
        np.zeros(varsize_Z[0], dtype=real_type),
        np.zeros(varsize_Z[1], dtype=real_type)
    ]

    # Lagrange multipliers
    d_Z = [
        np.zeros(varsize_Z[0], dtype=real_type),
        np.zeros(varsize_Z[1], dtype=real_type)
    ]

    v_Z = [
        np.zeros(varsize_Z[0], dtype=real_type),
        np.zeros(varsize_Z[1], dtype=real_type)
    ]

    # Initial the codes and its fft
    # z = rng.normal(size=size_z)
    z = np.zeros(size_z)
    z_hat = fft(z)
    """Initial objective function (usually very large)"""
    # obj_val = obj_func(z_hat, d_hat, b,
    #                    lambda_residual, lambda_prior,
    #                    psf_radius, size_z, size_x)
    obj_val = obj_func_2(z, d, b, lambda_prior, psf_radius,
                         feasible_evaluation)

    if verbose > 0:
        print('Init, Obj %3.3f' % (obj_val, ))

    times = list()
    times.append(0.)
    list_obj_val = list()
    list_obj_val.append(obj_val)
    """Start the main algorithm"""
    for i in range(max_it):

        start = time.time()
        z, z_hat = update_z(z, z_hat, d_hat, u_Z, v_Z, d_Z, lambdas, gammas_Z,
                            Mtb, M, size_x, size_z, xi_Z, xi_z_hat, b,
                            lambda_prior, lambda_residual, psf_radius, verbose,
                            max_it_z)
        times.append(time.time() - start)

        # obj_val = obj_func(z_hat, d_hat, b,
        #                    lambda_residual, lambda_prior,
        #                    psf_radius, size_z, size_x)
        obj_val = obj_func_2(z, d, b, lambda_prior, psf_radius,
                             feasible_evaluation)

        if verbose > 0:
            print('Iter Z %d/%d, Obj %3.3f' % (i, max_it, obj_val))

        start = time.time()
        d, d_hat = update_d(z_hat, d_hat, size_z, size_x, rho, d, v_D, d_D,
                            lambdas, gammas_D, Mtb, u_D, M, size_k_full,
                            psf_radius, xi_D, xi_D_hat, verbose, max_it_d)
        times.append(time.time() - start)

        # obj_val = obj_func(z_hat, d_hat, b,
        #                    lambda_residual, lambda_prior,
        #                    psf_radius, size_z, size_x)
        obj_val = obj_func_2(z, d, b, lambda_prior, psf_radius,
                             feasible_evaluation)

        if verbose > 0:
            print('Iter D %d/%d, Obj %3.3f' % (i, max_it, obj_val))

        list_obj_val.append(obj_val)

        # Debug progress
        # z_comp = z

        # Termination
        # if (linalg.norm(z_diff) / linalg.norm(z_comp) < tol and
        #         linalg.norm(d_diff) / linalg.norm(d_comp) < tol):
        #     break
        if stopping_pobj is not None and obj_val < stopping_pobj:
            break
    """Final estimate"""
    z_res = z

    d_res = d
    d_res = np.roll(d_res, psf_radius, axis=1)
    d_res = d_res[:, 0:psf_radius * 2 + 1]

    Dz = np.real(ifft(np.einsum('ijk,jk->ik', z_hat, d_hat)))

    # obj_val = obj_func(z_hat, d_hat, b,
    #                    lambda_residual, lambda_prior,
    #                    psf_radius, size_z, size_x)
    # if verbose > 0:
    #     print('Final objective function %f' % obj_val)
    #
    # reconstr_err = reconstruction_err(z_hat, d_hat, b, psf_radius, size_x)
    # if verbose > 0:
    #     print('Final reconstruction error %f' % reconstr_err)

    return d_res, z_res, Dz, np.array(list_obj_val), times
Exemplo n.º 29
0
from mne.utils import check_random_state  # noqa
from mne.datasets import sample  # noqa

###############################################################################
# Now, we can import the class required for rejecting and repairing bad
# epochs. :func:`autoreject.compute_thresholds` is a callable which must be
# provided to the :class:`autoreject.LocalAutoRejectCV` class for computing
# the channel-level thresholds.

from autoreject import (LocalAutoRejectCV, compute_thresholds,
                        set_matplotlib_defaults)  # noqa

###############################################################################
# Let us now read in the raw `fif` file for MNE sample dataset.

check_random_state(42)

data_path = sample.data_path()
raw_fname = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
raw = mne.io.read_raw_fif(raw_fname, preload=True)

###############################################################################
# We can then read in the events

event_fname = data_path + ('/MEG/sample/sample_audvis_filt-0-40_raw-'
                           'eve.fif')
event_id = {'Auditory/Left': 1, 'Auditory/Right': 2}
tmin, tmax = -0.2, 0.5

events = mne.read_events(event_fname)
Exemplo n.º 30
0
def find_bad_by_ransac(
    data,
    sample_rate,
    complete_chn_labs,
    chn_pos,
    exclude,
    n_samples=50,
    sample_prop=0.25,
    corr_thresh=0.75,
    frac_bad=0.4,
    corr_window_secs=5.0,
    channel_wise=False,
    max_chunk_size=None,
    random_state=None,
    matlab_strict=False,
):
    """Detect channels that are not predicted well by other channels.

    Here, a RANSAC approach (see [1]_, and a short discussion in [2]_) is
    adopted to predict a "clean EEG" dataset. After identifying clean EEG
    channels through the other methods, the clean EEG dataset is
    constructed by repeatedly sampling a small subset of clean EEG channels
    and interpolation the complete data. The median of all those
    repetitions forms the clean EEG dataset. In a second step, the original
    and the RANSAC-predicted data are correlated and channels, which do not
    correlate well with themselves across the two datasets are considered
    `bad_by_ransac`.

    Parameters
    ----------
    data : np.ndarray
        A 2-D array of detrended EEG data, with bad-by-flat and bad-by-NaN
        channels removed.
    sample_rate : float
        The sample rate (in Hz) of the EEG data.
    complete_chn_labs : array_like
        Labels for all channels in `data`, in the same order as they appear
        in `data`.
    chn_pos : np.ndarray
        3-D electrode coordinates for all channels in `data`, in the same order
        as they appear in `data`.
    exclude : list
        Labels of channels to exclude as signal predictors during RANSAC
        (i.e., channels already flagged as bad by metrics other than HF noise).
    n_samples : int, optional
        Number of random channel samples to use for RANSAC. Defaults to ``50``.
    sample_prop : float, optional
        Proportion of total channels to use for signal prediction per RANSAC
        sample. This needs to be in the range [0, 1], where 0 would mean no
        channels would be used and 1 would mean all channels would be used
        (neither of which would be useful values). Defaults to ``0.25`` (e.g.,
        16 channels per sample for a 64-channel dataset).
    corr_thresh : float, optional
        The minimum predicted vs. actual signal correlation for a channel to
        be considered good within a given RANSAC window. Defaults to ``0.75``.
    frac_bad : float, optional
        The minimum fraction of bad (i.e., below-threshold) RANSAC windows for a
        channel to be considered bad-by-RANSAC. Defaults to ``0.4``.
    corr_window_secs : float, optional
        The duration (in seconds) of each RANSAC correlation window. Defaults to
        5 seconds.
    channel_wise : bool, optional
        Whether RANSAC should predict signals for chunks of channels over the
        entire signal length ("channel-wise RANSAC", see `max_chunk_size`
        parameter). If ``False``, RANSAC will instead predict signals for all
        channels at once but over a number of smaller time windows instead of
        over the entire signal length ("window-wise RANSAC"). Channel-wise
        RANSAC generally has higher RAM demands than window-wise RANSAC
        (especially if `max_chunk_size` is ``None``), but can be faster on
        systems with lots of RAM to spare. Defaults to ``False``.
    max_chunk_size : {int, None}, optional
        The maximum number of channels to predict at once during channel-wise
        RANSAC. If ``None``, RANSAC will use the largest chunk size that will
        fit into the available RAM, which may slow down other programs on the
        host system. If using window-wise RANSAC (the default), this parameter
        has no effect. Defaults to ``None``.
    random_state : {int, None, np.random.RandomState}, optional
        The random seed with which to generate random samples of channels during
        RANSAC. If random_state is an int, it will be used as a seed for RandomState.
        If ``None``, the seed will be obtained from the operating system
        (see RandomState for details). Defaults to ``None``.
    matlab_strict : bool, optional
        Whether or not RANSAC should strictly follow MATLAB PREP's internal
        math, ignoring any improvements made in PyPREP over the original code
        (see :ref:`matlab-diffs` for more details). Defaults to ``False``.

    Returns
    -------
    bad_by_ransac : list
        List containing the labels of all channels flagged as bad by RANSAC.
    channel_correlations : np.ndarray
        Array of shape (windows, channels) containing the correlations of
        the channels with their predicted RANSAC values for each window.

    References
    ----------
    .. [1] Fischler, M.A., Bolles, R.C. (1981). Random sample consensus: A
        Paradigm for Model Fitting with Applications to Image Analysis and
        Automated Cartography. Communications of the ACM, 24, 381-395
    .. [2] Jas, M., Engemann, D.A., Bekhti, Y., Raimondo, F., Gramfort, A.
        (2017). Autoreject: Automated Artifact Rejection for MEG and EEG
        Data. NeuroImage, 159, 417-429

    """
    # First, check that the argument types are valid
    if type(n_samples) != int:
        err = "Argument 'n_samples' must be an int (got {0})"
        raise TypeError(err.format(type(n_samples).__name__))

    # Get all channel positions and the position subset of "clean channels"
    # Exclude should be the bad channels from other methods
    # That is, identify all bad channels by other means
    good_idx = mne.pick_channels(list(complete_chn_labs), include=[], exclude=exclude)
    n_chans_good = good_idx.shape[0]
    chn_pos_good = chn_pos[good_idx, :]

    # Check if we have enough remaining channels
    # after exclusion of bad channels
    n_chans = data.shape[0]
    n_pred_chns = int(np.around(sample_prop * n_chans))

    if n_pred_chns <= 3:
        sample_pct = int(sample_prop * 100)
        e = "Too few channels in the original data to reliably perform RANSAC "
        e += "(minimum {0} for a sample size of {1}%)."
        raise IOError(e.format(int(np.floor(4.0 / sample_prop)), sample_pct))
    elif n_chans_good < (n_pred_chns + 1):
        e = "Too many noisy channels in the data to reliably perform RANSAC "
        e += "(only {0} good channels remaining, need at least {1})."
        raise IOError(e.format(n_chans_good, n_pred_chns + 1))

    # Before running, make sure we have enough memory when using the
    # smallest possible chunk size
    if channel_wise:
        _verify_free_ram(data, n_samples, 1)
    else:
        window_size = int(sample_rate * corr_window_secs)
        _verify_free_ram(data[:, :window_size], n_samples, n_chans_good)

    # Generate random channel picks for each RANSAC sample
    random_ch_picks = []
    good_chans = np.arange(chn_pos_good.shape[0])
    rng = check_random_state(random_state)
    for i in range(n_samples):
        # Pick a random subset of clean channels to use for interpolation
        picks = _get_random_subset(good_chans, n_pred_chns, rng)
        random_ch_picks.append(picks)

    # Generate interpolation matrix for each RANSAC sample
    interp_mats = _make_interpolation_matrices(random_ch_picks, chn_pos_good)

    # Calculate the size (in frames) and count of correlation windows
    correlation_frames = corr_window_secs * sample_rate
    signal_frames = data.shape[1]
    correlation_offsets = np.arange(
        0, (signal_frames - correlation_frames), correlation_frames
    )
    win_size = int(correlation_frames)
    win_count = correlation_offsets.shape[0]

    # Preallocate RANSAC correlation matrix
    n_chans_complete = len(complete_chn_labs)
    channel_correlations = np.ones((win_count, n_chans_complete))
    # Notice self.EEGData.shape[0] = self.n_chans_new
    # Is now data.shape[0] = n_chans_complete
    # They came from the same drop of channels

    logger.info("Executing RANSAC\nThis may take a while, so be patient...")

    # If enabled, run window-wise RANSAC
    if not channel_wise:
        # Get correlations between actual vs predicted signals for each RANSAC window
        channel_correlations[:, good_idx] = _ransac_by_window(
            data[good_idx, :], interp_mats, win_size, win_count, matlab_strict
        )

    # Calculate smallest chunk size for each possible chunk count
    chunk_sizes = []
    chunk_count = 0
    for i in range(1, n_chans_good + 1):
        n_chunks = int(np.ceil(n_chans_good / i))
        if n_chunks != chunk_count:
            chunk_count = n_chunks
            if not max_chunk_size or i <= max_chunk_size:
                chunk_sizes.append(i)

    chunk_size = chunk_sizes.pop()
    mem_error = True
    job = list(range(n_chans_good))

    # If not using window-wise RANSAC, do channel-wise RANSAC
    while mem_error and channel_wise:
        try:
            channel_chunks = _split_list(job, chunk_size)
            total_chunks = len(channel_chunks)
            current = 1
            for chunk in channel_chunks:
                interp_mats_for_chunk = [mat[chunk, :] for mat in interp_mats]
                channel_correlations[:, good_idx[chunk]] = _ransac_by_channel(
                    data[good_idx, :],
                    interp_mats_for_chunk,
                    win_size,
                    win_count,
                    chunk,
                    random_ch_picks,
                    matlab_strict,
                )
                if chunk == channel_chunks[0]:
                    # If it gets here, it means it is the optimal
                    logger.info("Finding optimal chunk size : %s", chunk_size)
                    logger.info("Total # of chunks: %s", total_chunks)
                    logger.info("Current chunk:")

                logger.info(current)
                current = current + 1

            mem_error = False  # All chunks processed, hurray!
            del current
        except MemoryError:
            if len(chunk_sizes):
                chunk_size = chunk_sizes.pop()
            else:  # pragma: no cover
                raise MemoryError(
                    "Not even doing 1 channel at a time the data fits in ram..."
                    "You could downsample the data or reduce the number of requ"
                    "ested samples."
                )

    # Calculate fractions of bad RANSAC windows for each channel
    thresholded_correlations = channel_correlations < corr_thresh
    frac_bad_corr_windows = np.mean(thresholded_correlations, axis=0)

    # find the corresponding channel names and return
    bad_ransac_channels_idx = np.argwhere(frac_bad_corr_windows > frac_bad)
    bad_ransac_channels_name = complete_chn_labs[bad_ransac_channels_idx.astype(int)]
    bad_by_ransac = [i[0] for i in bad_ransac_channels_name]
    logger.info("\nRANSAC done!")

    return bad_by_ransac, channel_correlations