Ejemplo n.º 1
0
def _interpolate_bads_eeg_epochs(epochs, bad_channels_by_epoch=None):
    """Interpolate bad channels per epoch

    Parameters
    ----------
    inst : mne.io.Raw, mne.Epochs or mne.Evoked
        The data to interpolate. Must be preloaded.
    bad_channels_by_epoch : list of list of str
        Bad channel names specified for each epoch. For example, for an Epochs
        instance containing 3 epochs: ``[['F1'], [], ['F3', 'FZ']]``
    """
    if len(bad_channels_by_epoch) != len(epochs):
        raise ValueError("Unequal length of epochs (%i) and "
                         "bad_channels_by_epoch (%i)" %
                         (len(epochs), len(bad_channels_by_epoch)))

    interp_cache = {}
    for i, bad_channels in enumerate(bad_channels_by_epoch):
        if not bad_channels:
            continue

        # find interpolation matrix
        key = tuple(sorted(bad_channels))
        if key in interp_cache:
            goods_idx, bads_idx, interpolation = interp_cache[key]
        else:
            goods_idx, bads_idx, interpolation = interp_cache[key] \
                                = _make_interpolator(epochs, key)

        # apply interpolation
        logger.info('Interpolating %i sensors on epoch %i', bads_idx.sum(), i)
        epochs._data[i, bads_idx, :] = np.dot(interpolation,
                                              epochs._data[i, goods_idx, :])
Ejemplo n.º 2
0
def _raw_to_epochs_array(x, sfreq, events, tmin, tmax):
    """Aux function to create epochs from a 2D array"""
    if events.ndim != 1:
        raise ValueError('events must be 1D')
    if events.dtype != int:
        raise ValueError('events must be of dtype int')

    # Check that events won't be cut off
    n_times = x.shape[-1]
    min_ix = 0 - sfreq * tmin
    max_ix = n_times - sfreq * tmax
    msk_keep = np.logical_and(events > min_ix, events < max_ix)

    if not all(msk_keep):
        logger.info('Some event windows extend beyond data limits,'
                    ' and will be cut off...')
        events = events[msk_keep]

    # Pull events from the raw data
    epochs = []
    for ix in events:
        ix_min, ix_max = [ix + int(i_tlim * sfreq)
                          for i_tlim in [tmin, tmax]]
        epochs.append(x[np.newaxis, :, ix_min:ix_max])
    epochs = np.concatenate(epochs, axis=0)
    times = np.arange(epochs.shape[-1]) / float(sfreq) + tmin
    return epochs, times, msk_keep
Ejemplo n.º 3
0
def _save_annotations(*, annotations, bids_path, verbose):
    # Attach the new Annotations to our raw data so we can easily convert them
    # to events, which will be stored in the *_events.tsv sidecar.
    extra_params = dict()
    if bids_path.extension == '.fif':
        extra_params['allow_maxshield'] = True

    raw = read_raw_bids(bids_path=bids_path,
                        extra_params=extra_params,
                        verbose='warning')
    raw.set_annotations(annotations)
    events, durs, descrs = _read_events(events_data=None,
                                        event_id=None,
                                        raw=raw,
                                        verbose=False)

    # Write sidecar – or remove it if no events are left.
    events_tsv_fname = (bids_path.copy().update(suffix='events',
                                                extension='.tsv').fpath)

    if len(events) > 0:
        _events_tsv(events=events,
                    durations=durs,
                    raw=raw,
                    fname=events_tsv_fname,
                    trial_type=descrs,
                    overwrite=True,
                    verbose=verbose)
    elif events_tsv_fname.exists():
        logger.info(f'No events remaining after interactive inspection, '
                    f'removing {events_tsv_fname.name}')
        events_tsv_fname.unlink()
Ejemplo n.º 4
0
def _handle_coordsystem_reading(coordsystem_fpath, datatype):
    """Read associated coordsystem.json.

    Handle reading the coordinate frame and coordinate unit
    of each electrode.
    """
    # open coordinate system sidecar json
    with open(coordsystem_fpath, 'r', encoding='utf-8-sig') as fin:
        coordsystem_json = json.load(fin)

    if datatype == 'meg':
        coord_frame = coordsystem_json['MEGCoordinateSystem']
        coord_unit = coordsystem_json['MEGCoordinateUnits']
        coord_frame_desc = coordsystem_json.get('MEGCoordinateDescription',
                                                None)
    elif datatype == 'eeg':
        coord_frame = coordsystem_json['EEGCoordinateSystem']
        coord_unit = coordsystem_json['EEGCoordinateUnits']
        coord_frame_desc = coordsystem_json.get('EEGCoordinateDescription',
                                                None)
    elif datatype == 'ieeg':
        coord_frame = coordsystem_json['iEEGCoordinateSystem']
        coord_unit = coordsystem_json['iEEGCoordinateUnits']
        coord_frame_desc = coordsystem_json.get('iEEGCoordinateDescription',
                                                None)

    logger.info(f"Reading in coordinate system frame {coord_frame}: "
                f"{coord_frame_desc}.")

    return coord_frame, coord_unit
Ejemplo n.º 5
0
def _create_folds(X, y, n_folds=None):
    """Split the observations in X into stratified folds."""
    if y is None:
        # No folding
        return X[np.newaxis, ...]

    y = np.asarray(y)
    if len(y) != len(X):
        raise ValueError(f'The length of y ({len(y)}) does not match the '
                         f'number of items ({len(X)}).')

    y_one_hot = _convert_to_one_hot(y)
    n_items = y_one_hot.shape[1]

    if n_folds is None:
        # Set n_folds to maximum value
        n_folds = len(X) // n_items
        logger.info(f'Automatic dermination of folds: {n_folds}' +
                    ' (no cross-validation)' if n_folds == 1 else '')

    if n_folds == 1:
        # Making one fold is easy
        folds = [_compute_item_means(X, y_one_hot)]
    else:
        folds = []
        for _, fold in StratifiedKFold(n_folds).split(X, y):
            folds.append(_compute_item_means(X, y_one_hot, fold))
    return np.array(folds)
Ejemplo n.º 6
0
def chop_raw_data(raw, start_time=60.0, stop_time=360.0, save=True):
    '''
    This function extracts specified duration of raw data
    and writes it into a fif file.
    Five mins of data will be extracted by default.

    Parameters
    ----------

    raw: Raw object or raw file name as a string.
    start_time: Time to extract data from in seconds. Default is 60.0 seconds.
    stop_time: Time up to which data is to be extracted. Default is 360.0 seconds.
    save: bool, If True the raw file is written to disk.

    '''
    if isinstance(raw, str):
        print 'Raw file name provided, loading raw object...'
        raw = mne.io.Raw(raw, preload=True)
    # Check if data is longer than required chop duration.
    if (raw.n_times / (raw.info['sfreq'])) < (stop_time + start_time):
        logger.info("The data is not long enough for file %s.") % (raw.info['filename'])
        return
    # Obtain indexes for start and stop times.
    assert start_time < stop_time, "Start time is greater than stop time."
    start_idx = raw.time_as_index(start_time)
    stop_idx = raw.time_as_index(stop_time)
    data, times = raw[:, start_idx:stop_idx]
    raw._data,raw._times = data, times
    dur = int((stop_time - start_time) / 60)
    if save:
        #raw.save(raw.info['filename'].split('/')[-1].split('.')[0] + '_' + str(dur) + 'm-raw.fif')
        raw.save(raw.info['filename'].split('-raw.fif')[0] + ',' + str(dur) + 'm-raw.fif')
    raw.close()
    return
Ejemplo n.º 7
0
def chop_raw_data(raw, start_time=60.0, stop_time=360.0):
    '''
    This function extracts specified duration of raw data
    and write it into a fif file.
    Five mins of data will be extracted by default.

    Parameters
    ----------

    raw: Raw object.
    start_time: Time to extract data from in seconds. Default is 60.0 seconds.
    stop_time: Time up to which data is to be extracted. Default is 360.0 seconds.

    '''
    # Check if data is longer than required chop duration.
    if (raw.n_times / (raw.info['sfreq'])) < (stop_time + 60.0):
        logger.info("The data is not long enough.")
        return
    # Obtain indexes for start and stop times.
    assert start_time < stop_time, "Start time is greater than stop time."
    start_idx = raw.time_as_index(start_time)
    stop_idx = raw.time_as_index(stop_time)
    data, times = raw[:, start_idx:stop_idx]
    raw._data, raw._times = data, times
    dur = int((stop_time - start_time) / 60)
    raw.save(raw.info['filename'].split('/')[-1].split('.')[0] + '_' +
             str(dur) + 'm.fif')
    # For the moment, simply warn.
    logger.warning('The file name is not saved in standard form.')
    return
Ejemplo n.º 8
0
    def save(self, fname, overwrite=False):
        if not isinstance(fname, Path):
            fname = Path(fname)
        self._save_info(fname, overwrite=overwrite)
        save_vars = self._get_save_vars(
            exclude=['data_', 'parent', 'tmin', 'tmax'])

        parent_name = self.parent._get_title()

        save_vars['parent_name_'] = parent_name

        has_parent = False

        with h5py.File(fname) as h5fid:
            if parent_name in h5fid:
                has_parent = True
                logger.info('Parent already present in HDF5 file, '
                            'will not be overwritten')

        if not has_parent:
            logger.info('Writing numerator to HDF5 file')
            self.parent.save(fname, overwrite=overwrite)

        write_hdf5(
            fname, save_vars, overwrite=overwrite,
            title=self._get_title(), slash='replace')
Ejemplo n.º 9
0
def chop_raw_data(raw, start_time=60.0, stop_time=360.0, save=True, return_chop=False):
    '''
    This function extracts specified duration of raw data
    and writes it into a fif file.
    Five mins of data will be extracted by default.

    Parameters
    ----------

    raw: Raw object or raw file name as a string.
    start_time: Time to extract data from in seconds. Default is 60.0 seconds.
    stop_time: Time up to which data is to be extracted. Default is 360.0 seconds.
    save: bool, If True the raw file is written to disk. (default: True)
    return_chop: bool, Return the chopped raw object. (default: False)

    '''
    if isinstance(raw, str):
        print 'Raw file name provided, loading raw object...'
        raw = mne.io.Raw(raw, preload=True)
    # Check if data is longer than required chop duration.
    if (raw.n_times / (raw.info['sfreq'])) < (stop_time + start_time):
        logger.info("The data is not long enough for file %s.") % (raw.filenames[0])
        return
    # Obtain indexes for start and stop times.
    assert start_time < stop_time, "Start time is greater than stop time."
    crop = raw.copy().crop(tmin=start_time, tmax=stop_time)
    dur = int((stop_time - start_time) / 60)
    if save:
        crop.save(crop.filenames[0].split('-raw.fif')[0] + ',' + str(dur) + 'm-raw.fif')
    raw.close()
    if return_chop:
         return crop
    else:
        crop.close()
        return
Ejemplo n.º 10
0
def create_epochs(raw: Raw) -> Epochs:
    """
    Create non-overlapping segments from Raw data with a fixed duration.
    Note that temporal filtering should be done before creating the epochs.
    The duration of epochs is defined in the configuration file (config.py).
    Parameters
    ----------
    raw: the continuous data to be segmented into non-overlapping epochs

    Returns
    -------
    Epochs instance
    """

    epoch_duration_in_seconds = settings["epochs"]["duration"]
    logger.info("Creating epochs from continuous data ...")
    events = make_fixed_length_events(raw,
                                      id=1,
                                      first_samp=True,
                                      duration=epoch_duration_in_seconds)

    epochs = Epochs(
        raw=raw,
        events=events,
        picks="all",
        event_id=list(np.unique(events[..., 2])),
        baseline=None,
        tmin=0.0,
        tmax=epoch_duration_in_seconds,  # - (1 / raw.info['sfreq']
        preload=False,
    )

    return epochs
Ejemplo n.º 11
0
def _rectify_resolution_matrix(resmat):
    """
    Ensure resolution matrix is square matrix.

    If resmat is not a square matrix, it is assumed that the inverse operator
    had free or loose orientation constraint, i.e. multiple values per source
    location. The Euclidean length for values at each location is computed to
    make resmat a square matrix.
    """
    shape = resmat.shape
    if not shape[0] == shape[1]:
        if shape[0] < shape[1]:
            raise ValueError('Number of target sources (%d) cannot be lower '
                             'than number of input sources (%d)' % shape[0],
                             shape[1])

        if np.mod(shape[0], shape[1]):  # if ratio not integer
            raise ValueError('Number of target sources (%d) must be a '
                             'multiple of the number of input sources (%d)'
                             % shape[0], shape[1])

        ns = shape[0] // shape[1]  # number of source components per vertex

        # Combine rows of resolution matrix
        resmatl = [np.sqrt((resmat[ns * i:ns * (i + 1), :]**2).sum(axis=0))
                   for i in np.arange(0, shape[1], dtype=int)]

        resmat = np.array(resmatl)

        logger.info('Rectified resolution matrix from (%d, %d) to (%d, %d).' %
                    (shape[0], shape[1], resmat.shape[0], resmat.shape[1]))

    return resmat
Ejemplo n.º 12
0
def chop_raw_data(raw, start_time=60.0, stop_time=360.0):
    ''' 
    This function extracts specified duration of raw data 
    and write it into a fif file.
    Five mins of data will be extracted by default.

    Parameters
    ----------

    raw: Raw object. 
    start_time: Time to extract data from in seconds. Default is 60.0 seconds. 
    stop_time: Time up to which data is to be extracted. Default is 360.0 seconds.

    '''
    # Check if data is longer than required chop duration.
    if (raw.n_times / (raw.info['sfreq'])) < (stop_time + 60.0):
        logger.info("The data is not long enough.")
        return
    # Obtain indexes for start and stop times.
    assert start_time < stop_time, "Start time is greater than stop time."
    start_idx = raw.time_as_index(start_time)
    stop_idx = raw.time_as_index(stop_time)
    data, times = raw[:, start_idx:stop_idx]
    raw._data,raw._times = data, times
    dur = int((stop_time - start_time) / 60)
    raw.save(raw.info['filename'].split('/')[-1].split('.')[0]+'_'+str(dur)+'m.fif')
    # For the moment, simply warn.
    logger.warning('The file name is not saved in standard form.')
    return
Ejemplo n.º 13
0
def _handle_electrodes_reading(electrodes_fname, coord_frame, coord_unit):
    """Read associated electrodes.tsv and populate raw.

    Handle xyz coordinates and coordinate frame of each channel.
    """
    logger.info('Reading electrode '
                'coords from {}.'.format(electrodes_fname))
    electrodes_dict = _from_tsv(electrodes_fname)
    ch_names_tsv = electrodes_dict['name']

    def _float_or_nan(val):
        if val == "n/a":
            return np.nan
        else:
            return float(val)

    # convert coordinates to float and create list of tuples
    electrodes_dict['x'] = [_float_or_nan(x) for x in electrodes_dict['x']]
    electrodes_dict['y'] = [_float_or_nan(x) for x in electrodes_dict['y']]
    electrodes_dict['z'] = [_float_or_nan(x) for x in electrodes_dict['z']]
    ch_names_raw = [
        x for i, x in enumerate(ch_names_tsv)
        if electrodes_dict['x'][i] != "n/a"
    ]
    ch_locs = np.c_[electrodes_dict['x'], electrodes_dict['y'],
                    electrodes_dict['z']]

    # convert coordinates to meters
    ch_locs = _scale_coord_to_meters(ch_locs, coord_unit)

    # create mne.DigMontage
    ch_pos = dict(zip(ch_names_raw, ch_locs))
    montage = mne.channels.make_dig_montage(ch_pos=ch_pos,
                                            coord_frame=coord_frame)
    return montage
Ejemplo n.º 14
0
 def __init__(self,
              tmin=None,
              tmax=None,
              fmin=None,
              fmax=None,
              method_params=None,
              n_jobs='auto',
              comment='default'):
     BaseMarkerSandbox.__init__(self, tmin=None, tmax=None, comment=comment)
     if method_params is None:
         method_params = {}
     if fmax is None:
         fmax = np.inf
     self.fmin = fmin
     self.fmax = fmax
     self.method_params = method_params
     if n_jobs == 'auto':
         try:
             import multiprocessing as mp
             import mkl
             n_jobs = int(mp.cpu_count() / mkl.get_max_threads())
             logger.info('Autodetected number of jobs {}'.format(n_jobs))
         except Exception:
             logger.info('Cannot autodetect number of jobs')
             n_jobs = 1
     self.n_jobs = n_jobs
Ejemplo n.º 15
0
def faster_bad_components(ica, epochs, thres=3, use_metrics=None):
    """Implements the third step of the FASTER algorithm.
    
    This function attempts to automatically mark bad ICA components by
    performing outlier detection.
    Parameters
    ----------
    ica : Instance of ICA
        The ICA operator, already fitted to the supplied Epochs object.
    epochs : Instance of Epochs
        The untransformed epochs to analyze.
    thres : float
        The threshold value, in standard deviations, to apply. A component
        crossing this threshold value is marked as bad. Defaults to 3.
    use_metrics : list of str
        List of metrics to use. Can be any combination of:
            'eog_correlation', 'kurtosis', 'power_gradient', 'hurst',
            'median_gradient'
        Defaults to all of them.
    Returns
    -------
    bads : list of int
        The indices of the bad components.
    See also
    --------
    ICA.find_bads_ecg
    ICA.find_bads_eog
    """
    source_data = ica.get_sources(epochs).get_data().transpose(1, 0, 2)
    source_data = source_data.reshape(source_data.shape[0], -1)

    metrics = {
        'eog_correlation':
        lambda x: x.find_bads_eog(epochs)[1],
        'kurtosis':
        lambda x: kurtosis(np.dot(x.mixing_matrix_.T, x.
                                  pca_components_[:x.n_components_]),
                           axis=1),
        'power_gradient':
        lambda x: _power_gradient(x, source_data),
        'hurst':
        lambda x: hurst(source_data),
        'median_gradient':
        lambda x: np.median(np.abs(np.diff(source_data)), axis=1),
        'line_noise':
        lambda x: _freqs_power(source_data, epochs.info['sfreq'], [50, 60]),
    }

    if use_metrics is None:
        use_metrics = metrics.keys()

    bads = []
    for m in use_metrics:
        scores = np.atleast_2d(metrics[m](ica))
        for s in scores:
            b = find_outliers(s, thres)
            logger.info('Bad by %s:\n\t%s' % (m, b))
            bads.append(b)

    return np.unique(np.concatenate(bads)).tolist()
Ejemplo n.º 16
0
def _raw_to_epochs_array(x, sfreq, events, tmin, tmax):
    """Aux function to create epochs from a 2D array"""
    if events.ndim != 1:
        raise ValueError('events must be 1D')
    if events.dtype != int:
        raise ValueError('events must be of dtype int')

    # Check that events won't be cut off
    n_times = x.shape[-1]
    min_ix = 0 - sfreq * tmin
    max_ix = n_times - sfreq * tmax
    msk_keep = np.logical_and(events > min_ix, events < max_ix)

    if not all(msk_keep):
        logger.info('Some event windows extend beyond data limits,'
                    ' and will be cut off...')
        events = events[msk_keep]

    # Pull events from the raw data
    epochs = []
    for ix in events:
        ix_min, ix_max = [ix + int(i_tlim * sfreq) for i_tlim in [tmin, tmax]]
        epochs.append(x[np.newaxis, :, ix_min:ix_max])
    epochs = np.concatenate(epochs, axis=0)
    times = np.arange(epochs.shape[-1]) / float(sfreq) + tmin
    return epochs, times, msk_keep
Ejemplo n.º 17
0
def cv_decode_sliding(X,
                      y,
                      clf=None,
                      cv=None,
                      class_weight=None,
                      scoring='roc_auc',
                      random_state=None,
                      picks=None,
                      n_jobs=-1):
    all_scores = []
    if not isinstance(random_state, list):
        random_state = [random_state]

    for t_random in random_state:
        logger.info('Using random state {}'.format(t_random))
        clf, cv = _check_clf(clf, cv, class_weight, t_random)

        se = mne.decoding.SlidingEstimator(clf, scoring=scoring)
        if picks is not None:
            logger.info('Picking channels')
            X = X[:, picks, :]
        scores = mne.decoding.cross_val_multiscore(se,
                                                   X,
                                                   y,
                                                   cv=cv,
                                                   n_jobs=n_jobs)
        all_scores.append(scores)
    return np.concatenate(all_scores, axis=0)
Ejemplo n.º 18
0
def decode_generalization(X_train,
                          y_train,
                          X_test,
                          y_test,
                          clf=None,
                          class_weight=None,
                          scoring='roc_auc',
                          random_state=None,
                          picks=None,
                          n_jobs=-1):

    clf, _ = _check_clf(clf, None, class_weight, random_state)

    ge = mne.decoding.GeneralizingEstimator(clf,
                                            scoring=scoring,
                                            n_jobs=n_jobs)

    if picks is not None:
        logger.info('Picking channels')
        X_train = X_train[:, picks, :]
        X_test = X_test[:, picks, :]

    ge.fit(X_train, y_train)
    scores = ge.score(X_test, y_test)

    return scores[None, :]
Ejemplo n.º 19
0
def _interpolate_bads_eeg_epochs(epochs, bad_channels_by_epoch=None):
    """Interpolate bad channels per epoch

    Parameters
    ----------
    inst : mne.io.Raw, mne.Epochs or mne.Evoked
        The data to interpolate. Must be preloaded.
    bad_channels_by_epoch : list of list of str
        Bad channel names specified for each epoch. For example, for an Epochs
        instance containing 3 epochs: ``[['F1'], [], ['F3', 'FZ']]``
    """
    if len(bad_channels_by_epoch) != len(epochs):
        raise ValueError("Unequal length of epochs (%i) and "
                         "bad_channels_by_epoch (%i)"
                         % (len(epochs), len(bad_channels_by_epoch)))

    interp_cache = {}
    for i, bad_channels in enumerate(bad_channels_by_epoch):
        if not bad_channels:
            continue

        # find interpolation matrix
        key = tuple(sorted(bad_channels))
        if key in interp_cache:
            goods_idx, bads_idx, interpolation = interp_cache[key]
        else:
            goods_idx, bads_idx, interpolation = interp_cache[key] \
                                = _make_interpolator(epochs, key)

        # apply interpolation
        logger.info('Interpolating %i sensors on epoch %i', bads_idx.sum(), i)
        epochs._data[i, bads_idx, :] = np.dot(interpolation,
                                              epochs._data[i, goods_idx, :])
Ejemplo n.º 20
0
def cluster_threshold(con,
                      src,
                      min_size=20,
                      max_spread=0.013,
                      method='single',
                      verbose=None):
    """Threshold connectivity using clustering.

    First, connections are grouped into "bundles". A bundle is a group of
    connections which start and end points are close together. Then, only
    bundles with a sufficient amount of connections are retained.

    Parameters
    ----------
    con : instance of Connectivity
        Connectivity to threshold.
    src : instance of SourceSpace
        The source space for which the connectivity is defined.
    min_size : int
        Minimum amount of connections that a bundle must contain in order to be
        accepted.
    max_spread : float
        Maximum amount the position (in metres) of the start and end points
        of the connections may vary in order for them to be considered part of
        the same "bundle". Defaults to 0.013.
    method : str
        Linkage method for fclusterdata. Defaults to 'single'. See
        documentation for ``scipy.cluster.hierarchy.fclusterdata`` for for more
        information.
    verbose : bool | str | int | None
        If not ``None``, override default verbose level
        (see :func:`mne.verbose` and :ref:`Logging documentation <tut_logging>`
        for more).

    Returns
    -------
    thresholded_connectivity : instance of Connectivity
        Instance of connectivity with the thresholded data.
    """
    grid_points = np.vstack([s['rr'][v] for s, v in zip(src, con.vertices)])
    X = np.hstack([grid_points[inds] for inds in con.pairs])
    clust_no = fclusterdata(X, max_spread, criterion='distance', method=method)

    # Remove clusters that do not pass the threshold
    clusters, counts = np.unique(clust_no, return_counts=True)
    big_clusters = clusters[counts >= min_size]
    logger.info('Found %d bundles, of which %d are of sufficient size.' %
                (len(clusters), len(big_clusters)))

    # Restrict the connections to only those found in the big bundles
    mask = np.in1d(clust_no, big_clusters)
    data = con.data[mask]
    pairs = [p[mask] for p in con.pairs]

    return VertexConnectivity(data=data,
                              pairs=pairs,
                              vertices=con.vertices,
                              vertex_degree=con.source_degree,
                              subject=con.subject)
Ejemplo n.º 21
0
def faster_bad_channels_in_epochs(epochs,
                                  picks=None,
                                  thres=3,
                                  use_metrics=None):
    """Implements the fourth step of the FASTER algorithm.
    
    This function attempts to automatically mark bad channels in each epochs by
    performing outlier detection.
    Parameters
    ----------
    epochs : Instance of Epochs
        The epochs to analyze.
    picks : list of int | None
        Channels to operate on. Defaults to EEG channels.
    thres : float
        The threshold value, in standard deviations, to apply. An epoch
        crossing this threshold value is marked as bad. Defaults to 3.
    use_metrics : list of str
        List of metrics to use. Can be any combination of:
            'amplitude', 'variance', 'deviation', 'median_gradient'
        Defaults to all of them.
    Returns
    -------
    bads : list of lists of int
        For each epoch, the indices of the bad channels.
    """

    metrics = {
        'amplitude': lambda x: np.ptp(x, axis=2),
        'deviation': lambda x: _deviation(x),
        'variance': lambda x: np.var(x, axis=2),
        'median_gradient': lambda x: np.median(np.abs(np.diff(x)), axis=2),
        'line_noise':
        lambda x: _freqs_power(x, epochs.info['sfreq'], [50, 60]),
    }

    if picks is None:
        picks = mne.pick_types(epochs.info,
                               meg=False,
                               eeg=True,
                               exclude='bads')
    if use_metrics is None:
        use_metrics = metrics.keys()

    data = epochs.get_data()[:, picks, :]

    bads = [[] for i in range(len(epochs))]
    for m in use_metrics:
        s_epochs = metrics[m](data)
        for i, s in enumerate(s_epochs):
            b = [epochs.ch_names[picks[j]] for j in find_outliers(s, thres)]
            logger.info('Epoch %d, Bad by %s:\n\t%s' % (i, m, b))
            bads[i].append(b)

    for i, b in enumerate(bads):
        if len(b) > 0:
            bads[i] = np.unique(np.concatenate(b)).tolist()

    return bads
Ejemplo n.º 22
0
def _find_bad_channels(epochs, picks, use_metrics, thresh, max_iter):
    """Implements the first step of the FASTER algorithm.

    This function attempts to automatically mark bad EEG channels by performing
    outlier detection. It operated on epoched data, to make sure only relevant
    data is analyzed.

    Additional Parameters
    ---------------------
    use_metrics : list of str
        List of metrics to use. Can be any combination of:
            'variance', 'correlation', 'hurst', 'kurtosis', 'line_noise'
        Defaults to all of them.
    thresh : float
        The threshold value, in standard deviations, to apply. A channel
        crossing this threshold value is marked as bad. Defaults to 3.
    max_iter : int
        The maximum number of iterations performed during outlier detection
        (defaults to 1, as in the original FASTER paper).
    """
    from scipy.stats import kurtosis
    metrics = {
        'variance':
        lambda x: np.var(x, axis=1),
        'correlation':
        lambda x: np.mean(np.ma.masked_array(np.corrcoef(x),
                                             np.identity(len(x), dtype=bool)),
                          axis=0),
        'hurst':
        lambda x: _hurst(x),
        'kurtosis':
        lambda x: kurtosis(x, axis=1),
        'line_noise':
        lambda x: _freqs_power(x, epochs.info['sfreq'], [50, 60]),
    }

    if use_metrics is None:
        use_metrics = metrics.keys()

    # Concatenate epochs in time
    data = epochs.get_data()[:, picks]
    data = data.transpose(1, 0, 2).reshape(data.shape[1], -1)

    # Find bad channels
    bads = defaultdict(list)
    info = pick_info(epochs.info, picks, copy=True)
    for ch_type, chs in _picks_by_type(info):
        logger.info('Bad channel detection on %s channels:' % ch_type.upper())
        for metric in use_metrics:
            scores = metrics[metric](data[chs])
            bad_channels = [
                epochs.ch_names[picks[chs[i]]]
                for i in find_outliers(scores, thresh, max_iter)
            ]
            logger.info('\tBad by %s: %s' % (metric, bad_channels))
            bads[metric].append(bad_channels)

    bads = dict((k, np.concatenate(v).tolist()) for k, v in bads.items())
    return bads
Ejemplo n.º 23
0
 def _iter_spatial(self):
     """Generate spatial searchlight patches only."""
     logger.info('Creating spatial searchlight patches')
     for series in self.sel_series:
         patch = list(self.patch_template)  # Copy the template
         series_i = np.flatnonzero(self.dist[series] < self.spatial_radius)
         patch[self.series_dim] = series_i
         yield tuple(patch)
Ejemplo n.º 24
0
 def _iter_temporal(self):
     """Generate temporal searchlight patches only."""
     logger.info('Creating temporal searchlight patches')
     for sample in self.time_centers:
         patch = list(self.patch_template)  # Copy the template
         patch[self.samples_dim] = slice(sample - self.temporal_radius,
                                         sample + self.temporal_radius + 1)
         yield tuple(patch)
Ejemplo n.º 25
0
def _write_tsv(fname, dictionary, overwrite=False, verbose=None):
    """Write an ordered dictionary to a .tsv file."""
    if op.exists(fname) and not overwrite:
        raise FileExistsError(f'"{fname}" already exists. '
                              'Please set overwrite to True.')
    _to_tsv(dictionary, fname)

    logger.info(f"Writing '{fname}'...")
Ejemplo n.º 26
0
 def _iter_spatial(self):
     """Generate spatial searchlight patches only."""
     logger.info('Creating spatial searchlight patches')
     patch = list(self.patch_template)  # Copy the template
     for series in self.sel_series:
         spat_ind = _get_in_radius(self.dist, series, self.spatial_radius)
         patch[self.series_dim] = spat_ind
         yield tuple(patch)
Ejemplo n.º 27
0
def faster_bad_channels(epochs, picks=None, thres=3, use_metrics=None):
    """Implements the first step of the FASTER algorithm.
    
    This function attempts to automatically mark bad EEG channels by performing
    outlier detection. It operated on epoched data, to make sure only relevant
    data is analyzed.
    Parameters
    ----------
    epochs : Instance of Epochs
        The epochs for which bad channels need to be marked
    picks : list of int | None
        Channels to operate on. Defaults to EEG channels.
    thres : float
        The threshold value, in standard deviations, to apply. A channel
        crossing this threshold value is marked as bad. Defaults to 3.
    use_metrics : list of str
        List of metrics to use. Can be any combination of:
            'variance', 'correlation', 'hurst', 'kurtosis', 'line_noise'
        Defaults to all of them.
    Returns
    -------
    bads : list of str
        The names of the bad EEG channels.
    """
    metrics = {
        'variance':
        lambda x: np.var(x, axis=1),
        'correlation':
        lambda x: nanmean(np.ma.masked_array(np.corrcoef(x),
                                             np.identity(len(x), dtype=bool)),
                          axis=0),
        'hurst':
        lambda x: hurst(x),
        'kurtosis':
        lambda x: kurtosis(x, axis=1),
        'line_noise':
        lambda x: _freqs_power(x, epochs.info['sfreq'], [50, 60]),
    }

    if picks is None:
        picks = mne.pick_types(epochs.info, meg=False, eeg=True, exclude=[])
    if use_metrics is None:
        use_metrics = metrics.keys()

    # Concatenate epochs in time
    data = epochs.get_data()
    data = data.transpose(1, 0, 2).reshape(data.shape[1], -1)
    data = data[picks]

    # Find bad channels
    bads = []
    for m in use_metrics:
        s = metrics[m](data)
        b = [epochs.ch_names[picks[i]] for i in find_outliers(s, thres)]
        logger.info('Bad by %s:\n\t%s' % (m, b))
        bads.append(b)

    return np.unique(np.concatenate(bads)).tolist()
Ejemplo n.º 28
0
def _read_antcnt_events(eeg, event_id=None, event_id_func='strip_to_integer'):
        """Create events array from ANT cnt structure
        An event array is constructed by looking up events in the
        event_id, trying to reduce them to their integer part otherwise, and
        entirely dropping them (with a warning) if this is impossible.
        Returns a 1x3 array of zeros if no events are found."""
        if event_id_func is 'strip_to_integer':
            event_id_func = _strip_to_integer
        if event_id is None:
            event_id = dict()
    
        types = [eeg.get_trigger(i)[0] for i in range(eeg.get_trigger_count())]
        latencies= [eeg.get_trigger(i)[1] for i in range(eeg.get_trigger_count())]



        if len(types) < 1:  # if there are 0 events, we can exit here
            logger.info('No events found, returning empty stim channel ...')
            return np.zeros((0, 3))

        not_in_event_id = set(x for x in types if x not in event_id)
        not_purely_numeric = set(x for x in not_in_event_id if not x.isdigit())
        no_numbers = set([x for x in not_purely_numeric
                          if not any([d.isdigit() for d in x])])
        have_integers = set([x for x in not_purely_numeric
                             if x not in no_numbers])
        if len(not_purely_numeric) > 0:
            basewarn = "Events like the following will be dropped"
            n_no_numbers, n_have_integers = len(no_numbers), len(have_integers)
            if n_no_numbers > 0:
                no_num_warm = " entirely: {0}, {1} in total"
                warn(basewarn + no_num_warm.format(list(no_numbers)[:5],
                                                   n_no_numbers))
            if n_have_integers > 0 and event_id_func is None:
                intwarn = (", but could be reduced to their integer part "
                           "instead with the default `event_id_func`: "
                           "{0}, {1} in total")
                warn(basewarn + intwarn.format(list(have_integers)[:5],
                                               n_have_integers))

        events = list()
        for tt, latency in zip(types, latencies):
            try:  # look up the event in event_id and if not, try event_id_func
                event_code = event_id[tt] if tt in event_id else event_id_func(tt)
                events.append([int(latency), 1, event_code])
            except (ValueError, TypeError):  # if event_id_func fails
                pass  # We're already raising warnings above, so we just drop

        if len(events) < len(types):
            missings = len(types) - len(events)
            msg = ("{0}/{1} event codes could not be mapped to integers. Use "
                   "the 'event_id' parameter to map such events manually.")
            warn(msg.format(missings, len(types)))
            if len(events) < 1:
                warn("As is, the trigger channel will consist entirely of zeros.")
                return np.zeros((0, 3))

        return np.asarray(events)
Ejemplo n.º 29
0
def _compute_maps(data,
                  n_states=4,
                  max_iter=1000,
                  thresh=1e-6,
                  random_state=None,
                  verbose=None):
    """The modified K-means clustering algorithm.
    See :func:`segment` for the meaning of the parameters and return
    values.
    """
    if not isinstance(random_state, np.random.RandomState):
        random_state = np.random.RandomState(random_state)
    n_channels, n_samples = data.shape

    # Cache this value for later
    data_sum_sq = np.sum(data**2)

    # Select random timepoints for our initial topographic maps
    init_times = random_state.choice(n_samples, size=n_states, replace=False)
    maps = data[:, init_times].T
    maps /= np.linalg.norm(maps, axis=1, keepdims=True)  # Normalize the maps

    prev_residual = np.inf
    for _ in range(max_iter):
        # Assign each sample to the best matching microstate
        activation = maps.dot(data)
        segmentation = np.argmax(np.abs(activation), axis=0)

        # Recompute the topographic maps of the microstates, based on the
        # samples that were assigned to each state.
        for state in range(n_states):
            idx = (segmentation == state)
            if np.sum(idx) == 0:
                logger.info('Some microstates are never activated')
                maps[state] = 0
                continue
            # Find largest eigenvector
            # cov = data[:, idx].dot(data[:, idx].T)
            # _, vec = eigh(cov, eigvals=(n_channels - 1, n_channels - 1))
            # maps[state] = vec.ravel()
            maps[state] = data[:, idx].dot(activation[state, idx])
            maps[state] /= np.linalg.norm(maps[state])

        # Estimate residual noise
        act_sum_sq = np.sum(np.sum(maps[segmentation].T * data, axis=0)**2)
        residual = abs(data_sum_sq - act_sum_sq)
        residual /= float(n_samples * (n_channels - 1))

        # Have we converged?
        if (prev_residual - residual) < (thresh * residual):
            # logger.info('Converged at %d iterations.' % iteration)
            break

        prev_residual = residual
    else:
        logger.info('Modified K-means algorithm failed to converge.')

    return maps
Ejemplo n.º 30
0
    def __init__(self,
                 input_fname,
                 montage=None,
                 eog=None,
                 misc=(-4, -3, -2, -1),
                 stim_channel=None,
                 scale=1e-6,
                 sfreq=250,
                 missing_tol=1,
                 preload=True,
                 verbose=None):

        bci_info = {'missing_tol': missing_tol, 'stim_channel': stim_channel}
        if not eog:
            eog = list()
        if not misc:
            misc = list()

        nsamps, nchan = self._get_data_dims(input_fname)

        last_samps = [nsamps - 1]
        ch_names = ['EEG %03d' % num for num in range(1, nchan + 1)]
        ch_types = ['eeg'] * nchan
        if misc:
            misc_names = ['MISC %03d' % ii for ii in range(1, len(misc) + 1)]
            misc_types = ['misc'] * len(misc)
            for ii, mi in enumerate(misc):
                ch_names[mi] = misc_names[ii]
                ch_types[mi] = misc_types[ii]
        if eog:
            eog_names = ['EOG %03d' % ii for ii in range(len(eog))]
            eog_types = ['eog'] * len(eog)
            for ii, ei in enumerate(eog):
                ch_names[ei] = eog_names[ii]
                ch_types[ei] = eog_types[ii]
        if stim_channel:
            ch_names[stim_channel] = 'STI 014'
            ch_types[stim_channel] = 'stim'

        # mark last channel as the timestamp channel
        ch_names[-1] = "Timestamps"
        ch_types[-1] = "misc"

        # fix it for eog and misc marking
        info = create_info(ch_names, sfreq, ch_types, montage)
        info["buffer_size_sec"] = 1.
        super(RawOpenBCI, self).__init__(info,
                                         last_samps=last_samps,
                                         raw_extras=[bci_info],
                                         filenames=[input_fname],
                                         preload=False,
                                         verbose=verbose)
        # load data
        if preload:
            self.preload = preload
            logger.info('Reading raw data from %s...' % input_fname)
            self._data = self._read_segment()
Ejemplo n.º 31
0
    def reduce_to_epochs(self, marker_params):
        """Reduce each marker of the collection to a single value per epoch.

        Parameters
        ----------
        marker_params : dict with reduction parameters
            Each key of the dict should be of the form MarkerClass or 
            MarkerClass/comment. Each value should be a dictionary with two
            keys: 'reduction_func' and 'picks'.

            reduction_func: list of dictionaries. Each dictionary should have 
                two keys: 'axis' and 'function'. The marker is going to be
                reduced following the order of the list. Selecting the
                corresponding axis and applying the corresponding function.
            picks: dictionary of axis to array. Before appling the reduction
                function, the corresponding axis will be subselected by picks.
                A value of None indicates all the elements.

            Example:
                marker_params = dict()
                reduction_params['PowerSpectralDensity'] = {
                'reduction_func':
                    [{'axis': 'frequency', 'function': np.sum},
                     {'axis': 'channels', 'function': np.mean},
                     {'axis': 'epochs', 'function': np.mean}],
                'picks': {
                    'epochs': None,
                    'channels': np.arange(224)}}

        Returns
        -------
        out : dict
            Each marker of the collection will be a key, with a value 
            representing the marker value for each epoch (
                np.ndarray of float, shape(n_epochs,))
        """
        logger.info('Reducing to epochs')
        self._check_marker_params_keys(marker_params)
        ch_picks = mne.pick_types(self.ch_info_, eeg=True, meg=True)
        if ch_picks is not None:  # XXX think if info is needed down-stream
            info = mne.io.pick.pick_info(self.ch_info_, ch_picks, copy=True)
        else:
            info = self.ch_info_
        markers_to_epochs = [
            meas for meas in self.values()
            if isin_info(info_source=info, info_target=meas.ch_info_)
            and 'epochs' in meas._axis_map
        ]
        n_markers = len(markers_to_epochs)
        n_epochs = markers_to_epochs[0].data_.shape[
            markers_to_epochs[0]._axis_map['epochs']]
        out = OrderedDict()
        for ii, meas in enumerate(markers_to_epochs):
            logger.info('Reducing {}'.format(meas._get_title()))
            this_params = _get_reduction_params(marker_params, meas)
            out[meas._get_title()] = meas.reduce_to_epochs(**this_params)
        return out
Ejemplo n.º 32
0
def epochs_compute_pe(epochs,
                      kernel,
                      tau,
                      tmin=None,
                      tmax=None,
                      backend='python',
                      method_params=None):
    """Compute Permutation Entropy (PE)

    Parameters
    ----------
    epochs : instance of mne.Epochs
        The epochs on which to compute the PE.
    kernel : int
        The number of samples to use to transform to a symbol
    tau : int
        The number of samples left between the ones that defines a symbol.
    backend : {'python', 'c'}
        The backend to be used. Defaults to 'python'.
    """
    if method_params is None:
        method_params = {}

    freq = epochs.info['sfreq']

    picks = mne.io.pick.pick_types(epochs.info, meg=True, eeg=True)

    data = epochs.get_data()[:, picks, ...]
    n_epochs = len(data)

    data = np.hstack(data)

    if 'filter_freq' in method_params:
        filter_freq = method_params['filter_freq']
    else:
        filter_freq = np.double(freq) / kernel / tau
    logger.info('Filtering  at %.2f Hz' % filter_freq)
    b, a = butter(6, 2.0 * filter_freq / np.double(freq), 'lowpass')

    fdata = np.transpose(
        np.array(np.split(filtfilt(b, a, data), n_epochs, axis=1)), [1, 2, 0])

    time_mask = _time_mask(epochs.times, tmin, tmax)
    fdata = fdata[:, time_mask, :]

    if backend == 'python':
        logger.info("Performing symbolic transformation")
        sym, count = _symb_python(fdata, kernel, tau)
        pe = np.nan_to_num(-np.nansum(count * np.log(count), axis=1))
    elif backend == 'c':
        from ..optimizations.jivaro import pe as jpe
        pe, sym = jpe(fdata, kernel, tau)
    else:
        raise ValueError('backend %s not supported for PE' % backend)
    nsym = math.factorial(kernel)
    pe = pe / np.log(nsym)
    return pe, sym
Ejemplo n.º 33
0
def _handle_electrodes_reading(electrodes_fname, coord_frame, raw, verbose):
    """Read associated electrodes.tsv and populate raw.

    Handle xyz coordinates and coordinate frame of each channel.
    Assumes units of coordinates are in 'm'.
    """
    logger.info('Reading electrode '
                'coords from {}.'.format(electrodes_fname))
    electrodes_dict = _from_tsv(electrodes_fname)
    # First, make sure that ordering of names in channels.tsv matches the
    # ordering of names in the raw data. The "name" column is mandatory in BIDS
    ch_names_raw = list(raw.ch_names)
    ch_names_tsv = electrodes_dict['name']

    if ch_names_raw != ch_names_tsv:
        msg = ('Channels do not correspond between raw data and the '
               'channels.tsv file. For MNE-BIDS, the channel names in the '
               'tsv MUST be equal and in the same order as the channels in '
               'the raw data.\n\n'
               '{} channels in tsv file: "{}"\n\n --> {}\n\n'
               '{} channels in raw file: "{}"\n\n --> {}\n\n'
               .format(len(ch_names_tsv), electrodes_fname, ch_names_tsv,
                       len(ch_names_raw), raw.filenames, ch_names_raw)
               )

        # XXX: this could be due to MNE inserting a 'STI 014' channel as the
        # last channel: In that case, we can work. --> Can be removed soon,
        # because MNE will stop the synthesis of stim channels in the near
        # future
        if not (ch_names_raw[-1] == 'STI 014' and
                ch_names_raw[:-1] == ch_names_tsv):
            raise RuntimeError(msg)

    if verbose:
        print("The read in electrodes file is: \n", electrodes_dict)

    # convert coordinates to float and create list of tuples
    ch_names_raw = [x for i, x in enumerate(ch_names_raw)
                    if electrodes_dict['x'][i] != "n/a"]
    electrodes_dict['x'] = [float(x) for x in electrodes_dict['x']
                            if x != "n/a"]
    electrodes_dict['y'] = [float(x) for x in electrodes_dict['y']
                            if x != "n/a"]
    electrodes_dict['z'] = [float(x) for x in electrodes_dict['z']
                            if x != "n/a"]

    ch_locs = list(zip(electrodes_dict['x'],
                       electrodes_dict['y'],
                       electrodes_dict['z']))
    ch_pos = dict(zip(ch_names_raw, ch_locs))

    # create mne.DigMontage
    montage = mne.channels.make_dig_montage(ch_pos=ch_pos,
                                            coord_frame=coord_frame)
    raw.set_montage(montage)

    return raw
Ejemplo n.º 34
0
    def fit(self, epochs):
        for meas in self.values():
            if isinstance(meas, BasePowerSpectralDensity):
                meas._check_freq_time_range(epochs)

        for meas in self.values():
            logger.info('Fitting {}'.format(meas._get_title()))
            meas.fit(epochs)
        self.ch_info_ = list(self.values())[0].ch_info_
Ejemplo n.º 35
0
def label_svd(sub_leadfield, n_svd_comp, ch_names):
    """ Computes SVD of subleadfield for sensor types separately

    Parameters:
    -----------
    sub_leadfield: numpy array (n_sens x n_vert) with part of the 
                   leadfield matrix
    n_svd_comp: scalar, number of SVD components required
    ch_names: list of channel names

    Returns:
    --------
    this_label_lfd_summary: numpy array, n_svd_comp scaled SVD components
                            of subleadfield

    OH Aug 2015
    """

    logger.info("\nComputing SVD within labels, using %d component(s)"
                        % n_svd_comp)
    

    EEG_idx = [cc for cc in range(len(ch_names)) if ch_names[cc][:3]=='EEG']
    MAG_idx = [cc for cc in range(len(ch_names)) if (ch_names[cc][:3]=='MEG'
                                                and ch_names[cc][-1:]=='1')]
    GRA_idx = [cc for cc in range(len(ch_names)) if (ch_names[cc][:3]=='MEG'
                    and (ch_names[cc][-1:]=='2' or ch_names[cc][-1:]=='3'))]

    list_idx = []
    u_idx = -1 # keep track of which element of u_svd belongs t which sensor type
    if MAG_idx:
        list_idx.append(MAG_idx)
        u_idx += 1
        u_mag = u_idx
    if GRA_idx:
        list_idx.append(GRA_idx)
        u_idx += 1
        u_gra = u_idx
    if EEG_idx:
        list_idx.append(EEG_idx)
        u_idx += 1
        u_eeg = u_idx
    
    # # compute SVD of sub-leadfield for individual sensor types
    u_svd = [get_svd_comps(sub_leadfield[ch_idx,:], n_svd_comp) for ch_idx
                                                                  in list_idx]

    # put sensor types back together
    this_label_lfd_summary = np.zeros([len(ch_names),u_svd[0].shape[1]])
    if MAG_idx:
        this_label_lfd_summary[MAG_idx] = u_svd[u_mag]
    if GRA_idx:
        this_label_lfd_summary[GRA_idx] = u_svd[u_gra]
    if EEG_idx:
        this_label_lfd_summary[EEG_idx] = u_svd[u_eeg]    

    return this_label_lfd_summary
Ejemplo n.º 36
0
def _interpolate_bads_eeg(inst, picks=None, verbose=None):
    """ Interpolate bad EEG channels.

    Operates in place.

    Parameters
    ----------
    inst : mne.io.Raw, mne.Epochs or mne.Evoked
        The data to interpolate. Must be preloaded.
    picks: np.ndarray, shape(n_channels, ) | list | None
        The channel indices to be used for interpolation.
    """
    from mne.bem import _fit_sphere
    from mne.utils import logger, warn
    from mne.channels.interpolation import _do_interp_dots
    from mne.channels.interpolation import _make_interpolation_matrix
    import numpy as np

    if picks is None:
        picks = pick_types(inst.info, meg=False, eeg=True, exclude=[])

    bads_idx = np.zeros(len(inst.ch_names), dtype=np.bool)
    goods_idx = np.zeros(len(inst.ch_names), dtype=np.bool)

    inst.info._check_consistency()
    bads_idx[picks] = [inst.ch_names[ch] in inst.info['bads'] for ch in picks]

    if len(picks) == 0 or bads_idx.sum() == 0:
        return

    goods_idx[picks] = True
    goods_idx[bads_idx] = False

    pos = inst._get_channel_positions(picks)

    # Make sure only good EEG are used
    bads_idx_pos = bads_idx[picks]
    goods_idx_pos = goods_idx[picks]
    pos_good = pos[goods_idx_pos]
    pos_bad = pos[bads_idx_pos]

    # test spherical fit
    radius, center = _fit_sphere(pos_good)
    distance = np.sqrt(np.sum((pos_good - center) ** 2, 1))
    distance = np.mean(distance / radius)
    if np.abs(1. - distance) > 0.1:
        warn('Your spherical fit is poor, interpolation results are '
             'likely to be inaccurate.')

    logger.info('Computing interpolation matrix from {0} sensor '
                'positions'.format(len(pos_good)))

    interpolation = _make_interpolation_matrix(pos_good, pos_bad)

    logger.info('Interpolating {0} sensors'.format(len(pos_bad)))
    _do_interp_dots(inst, interpolation, goods_idx, bads_idx)
Ejemplo n.º 37
0
def _check_fname(fname, overwrite=False, must_exist=False):
    """Check for file existence."""
    _validate_type(fname, 'str', 'fname')
    from mne.utils import logger
    if must_exist and not op.isfile(fname):
        raise IOError('File "%s" does not exist' % fname)
    if op.isfile(fname):
        if not overwrite:
            raise IOError('Destination file exists. Please use option '
                          '"overwrite=True" to force overwriting.')
        elif overwrite != 'read':
            logger.info('Overwriting existing file.')
Ejemplo n.º 38
0
def _find_bad_channels(epochs, picks, use_metrics, thresh, max_iter):
    """Implements the first step of the FASTER algorithm.

    This function attempts to automatically mark bad EEG channels by performing
    outlier detection. It operated on epoched data, to make sure only relevant
    data is analyzed.

    Additional Parameters
    ---------------------
    use_metrics : list of str
        List of metrics to use. Can be any combination of:
            'variance', 'correlation', 'hurst', 'kurtosis', 'line_noise'
        Defaults to all of them.
    thresh : float
        The threshold value, in standard deviations, to apply. A channel
        crossing this threshold value is marked as bad. Defaults to 3.
    max_iter : int
        The maximum number of iterations performed during outlier detection
        (defaults to 1, as in the original FASTER paper).
    """
    from scipy.stats import kurtosis
    metrics = {
        'variance': lambda x: np.var(x, axis=1),
        'correlation': lambda x: np.mean(
            np.ma.masked_array(np.corrcoef(x),
                               np.identity(len(x), dtype=bool)), axis=0),
        'hurst': lambda x: _hurst(x),
        'kurtosis': lambda x: kurtosis(x, axis=1),
        'line_noise': lambda x: _freqs_power(x, epochs.info['sfreq'],
                                             [50, 60]),
    }

    if use_metrics is None:
        use_metrics = metrics.keys()

    # Concatenate epochs in time
    data = epochs.get_data()[:, picks]
    data = data.transpose(1, 0, 2).reshape(data.shape[1], -1)

    # Find bad channels
    bads = defaultdict(list)
    info = pick_info(epochs.info, picks, copy=True)
    for ch_type, chs in _picks_by_type(info):
        logger.info('Bad channel detection on %s channels:' % ch_type.upper())
        for metric in use_metrics:
            scores = metrics[metric](data[chs])
            bad_channels = [epochs.ch_names[picks[chs[i]]]
                            for i in find_outliers(scores, thresh, max_iter)]
            logger.info('\tBad by %s: %s' % (metric, bad_channels))
            bads[metric].append(bad_channels)

    bads = dict((k, np.concatenate(v).tolist()) for k, v in bads.items())
    return bads
Ejemplo n.º 39
0
def _find_bad_channels_in_epochs(epochs, picks, use_metrics, thresh, max_iter):
    """Implements the fourth step of the FASTER algorithm.

    This function attempts to automatically mark bad channels in each epochs by
    performing outlier detection.

    Additional Parameters
    ---------------------
    use_metrics : list of str
        List of metrics to use. Can be any combination of:
        'amplitude', 'variance', 'deviation', 'median_gradient'
        Defaults to all of them.
    thresh : float
        The threshold value, in standard deviations, to apply. A channel
        crossing this threshold value is marked as bad. Defaults to 3.
    max_iter : int
        The maximum number of iterations performed during outlier detection
        (defaults to 1, as in the original FASTER paper).
    """

    metrics = {
        'amplitude': lambda x: np.ptp(x, axis=2),
        'deviation': lambda x: _deviation(x),
        'variance': lambda x: np.var(x, axis=2),
        'median_gradient': lambda x: np.median(np.abs(np.diff(x)), axis=2),
        'line_noise': lambda x: _freqs_power(x, epochs.info['sfreq'],
                                             [50, 60]),
    }

    if use_metrics is None:
        use_metrics = metrics.keys()

    info = pick_info(epochs.info, picks, copy=True)
    data = epochs.get_data()[:, picks]
    bads = dict((m, np.zeros((len(data), len(picks)), dtype=bool)) for
                m in metrics)
    for ch_type, chs in _picks_by_type(info):
        ch_names = [info['ch_names'][k] for k in chs]
        chs = np.array(chs)
        for metric in use_metrics:
            logger.info('Bad channel-in-epoch detection on %s channels:'
                        % ch_type.upper())
            s_epochs = metrics[metric](data[:, chs])
            for i_epochs, epoch in enumerate(s_epochs):
                outliers = find_outliers(epoch, thresh, max_iter)
                if len(outliers) > 0:
                    bad_segment = [ch_names[k] for k in outliers]
                    logger.info('Epoch %d, Bad by %s:\n\t%s' % (
                        i_epochs, metric, bad_segment))
                    bads[metric][i_epochs, chs[outliers]] = True

    return bads
Ejemplo n.º 40
0
    def __init__(self, input_fname, montage=None, eog=None,
                 misc=(-4, -3, -2, -1), stim_channel=None, scale=1e-6, sfreq=250,
                 missing_tol=1, preload=True, verbose=None):

        bci_info = {'missing_tol': missing_tol, 'stim_channel': stim_channel}
        openbci_channames = ["FP1", "FP2", "C3", "C4", "P7", "P8", "O1", "O2", "F7", "F8", "F3", "F4", "T7", "T8", "P3", "P4"]
        if not eog:
            eog = list()
        if not misc:
            misc = list()
        nsamps, nchan = self._get_data_dims(input_fname)

        last_samps = [nsamps - 1]
        ch_names = ['EEG %03d' % num for num in range(1, nchan + 1)]
        ch_names[:nchan-4] = openbci_channames[:nchan-4]
        ch_types = ['eeg'] * nchan

        

        if misc:
            misc_names = ['MISC %03d' % ii for ii in range(1, len(misc) + 1)]
            misc_types = ['misc'] * len(misc)
            for ii, mi in enumerate(misc):
                ch_names[mi] = misc_names[ii]
                ch_types[mi] = misc_types[ii]
        if eog:
            eog_names = ['EOG %03d' % ii for ii in range(len(eog))]
            eog_types = ['eog'] * len(eog)
            for ii, ei in enumerate(eog):
                ch_names[ei] = eog_names[ii]
                ch_types[ei] = eog_types[ii]
        if stim_channel:
            ch_names[stim_channel] = 'STI 014'
            ch_types[stim_channel] = 'stim'

        # mark last channel as the timestamp channel
        ch_names[-1] = "Timestamps"
        ch_types[-1] = "misc"

        # fix it for eog and misc marking
        info = create_info(ch_names, sfreq, ch_types, montage)
        info["buffer_size_sec"] = 1.
        super(RawOpenBCI, self).__init__(info, last_samps=last_samps,
                                         raw_extras=[bci_info],
                                         filenames=[input_fname],
                                         preload=False, verbose=verbose)
        # load data
        if preload:
            self.preload = preload
            logger.info('Reading raw data from %s...' % input_fname)
            self._data = self._read_segment()
Ejemplo n.º 41
0
def read_info(subject, data_type, run_index=0, hcp_path=op.curdir):
    """Read info from unprocessed data

    Parameters
    ----------
    subject : str, file_map
        The subject
    data_type : str
        The kind of data to read. The following options are supported:
        'rest'
        'task_motor'
        'task_story_math'
        'task_working_memory'
        'noise_empty_room'
        'noise_subject'
    run_index : int
        The run index. For the first run, use 0, for the second, use 1.
        Also see HCP documentation for the number of runs for a given data
        type.
    hcp_path : str
        The HCP directory, defaults to op.curdir.

    Returns
    -------
    info : instance of mne.io.meas_info.Info
        The MNE channel info object.

    .. note::
        HCP MEG does not deliver only 3 of the 5 task packages from MRI HCP.
    """
    raw, config = get_file_paths(
        subject=subject, data_type=data_type, output='raw',
        run_index=run_index, hcp_path=hcp_path)

    if not op.exists(raw):
        raw = None

    meg_info = _read_bti_info(raw, config)

    if raw is None:
        logger.info('Did not find Raw data. Guessing EMG, ECG and EOG '
                    'channels')
        rename_channels(meg_info, dict(_label_mapping))
    return meg_info
Ejemplo n.º 42
0
def _make_interpolator(inst, bad_channels):
    """Find indexes and interpolation matrix to interpolate bad channels

    Parameters
    ----------
    inst : mne.io.Raw, mne.Epochs or mne.Evoked
        The data to interpolate. Must be preloaded.
    """
    bads_idx = np.zeros(len(inst.ch_names), dtype=np.bool)
    goods_idx = np.zeros(len(inst.ch_names), dtype=np.bool)

    picks = pick_types(inst.info, meg=False, eeg=True, exclude=[])
    bads_idx[picks] = [inst.ch_names[ch] in bad_channels for ch in picks]
    goods_idx[picks] = True
    goods_idx[bads_idx] = False

    if bads_idx.sum() != len(bad_channels):
        logger.warning('Channel interpolation is currently only implemented '
                       'for EEG. The MEG channels marked as bad will remain '
                       'untouched.')

    pos = get_channel_positions(inst, picks)

    # Make sure only EEG are used
    bads_idx_pos = bads_idx[picks]
    goods_idx_pos = goods_idx[picks]

    pos_good = pos[goods_idx_pos]
    pos_bad = pos[bads_idx_pos]

    # test spherical fit
    radius, center = _fit_sphere(pos_good)
    distance = np.sqrt(np.sum((pos_good - center) ** 2, 1))
    distance = np.mean(distance / radius)
    if np.abs(1. - distance) > 0.1:
        logger.warning('Your spherical fit is poor, interpolation results are '
                       'likely to be inaccurate.')

    logger.info('Computing interpolation matrix from {0} sensor '
                'positions'.format(len(pos_good)))

    interpolation = _make_interpolation_matrix(pos_good, pos_bad)

    return goods_idx, bads_idx, interpolation
Ejemplo n.º 43
0
def _find_bad_epochs(epochs, picks, use_metrics, thresh, max_iter):
    """Implements the second step of the FASTER algorithm.

    This function attempts to automatically mark bad epochs by performing
    outlier detection.

    Additional Parameters
    ---------------------
    use_metrics : list of str
        List of metrics to use. Can be any combination of:
        'amplitude', 'variance', 'deviation'. Defaults to all of them.
    thresh : float
        The threshold value, in standard deviations, to apply. A channel
        crossing this threshold value is marked as bad. Defaults to 3.
    max_iter : int
        The maximum number of iterations performed during outlier detection
        (defaults to 1, as in the original FASTER paper).
    """

    metrics = {
        'amplitude': lambda x: np.mean(np.ptp(x, axis=2), axis=1),
        'deviation': lambda x: np.mean(_deviation(x), axis=1),
        'variance': lambda x: np.mean(np.var(x, axis=2), axis=1),
    }

    if use_metrics is None:
        use_metrics = metrics.keys()

    info = pick_info(epochs.info, picks, copy=True)
    data = epochs.get_data()[:, picks]

    bads = defaultdict(list)
    for ch_type, chs in _picks_by_type(info):
        logger.info('Bad epoch detection on %s channels:' % ch_type.upper())
        for metric in use_metrics:
            scores = metrics[metric](data[:, chs])
            bad_epochs = find_outliers(scores, thresh, max_iter)
            logger.info('\tBad by %s: %s' % (metric, bad_epochs))
            bads[metric].append(bad_epochs)

    bads = dict((k, np.concatenate(v).tolist()) for k, v in bads.items())
    return bads
Ejemplo n.º 44
0
    def __init__(self, input_fname, montage=None, eog=None,
                 misc=(-3, -2, -1), stim_channel=None, scale=1e-6, sfreq=250,
                 missing_tol=1, preload=True, verbose=None):

        bci_info = {'missing_tol': missing_tol, 'stim_channel': stim_channel}
        if not eog:
            eog = list()
        if not misc:
            misc = list()
        nsamps, nchan = self._get_data_dims(input_fname)

        last_samps = [nsamps - 1]
        ch_names = ['EEG %03d' % num for num in range(1, nchan + 1)]
        ch_types = ['eeg'] * nchan
        if misc:
            misc_names = ['MISC %03d' % ii for ii in range(1, len(misc) + 1)]
            misc_types = ['misc'] * len(misc)
            for ii, mi in enumerate(misc):
                ch_names[mi] = misc_names[ii]
                ch_types[mi] = misc_types[ii]
        if eog:
            eog_names = ['EOG %03d' % ii for ii in range(len(eog))]
            eog_types = ['eog'] * len(eog)
            for ii, ei in enumerate(eog):
                ch_names[ei] = eog_names[ii]
                ch_types[ei] = eog_types[ii]
        if stim_channel:
            ch_names[stim_channel] = 'STI 014'
            ch_types[stim_channel] = 'stim'

        # fix it for eog and misc marking
        info = create_info(ch_names, sfreq, ch_types, montage)
        super(RawOpenBCI, self).__init__(info, last_samps=last_samps,
                                         raw_extras=[bci_info],
                                         filenames=[input_fname],
                                         preload=False, verbose=verbose)
        # load data
        if preload:
            self.preload = preload
            logger.info('Reading raw data from %s...' % input_fname)
            self._data, _ = self._read_segment()
Ejemplo n.º 45
0
def get_svd_comps(sub_leadfield, n_svd_comp):
    """ Compute SVD components of sub-leadfield for selected channels 
    (all channels in one SVD)
    Parameters:
    -----------
    sub_leadfield: numpy array (n_sens x n_vert) with part of the leadfield matrix
    n_svd_comp: scalar, number of SVD components required

    Returns:
    --------
    u_svd: numpy array, n_svd_comp scaled SVD components of subleadfield for 
           selected channels
    s_svd: corresponding singular values
    """

    u_svd, s_svd, _ = np.linalg.svd(sub_leadfield,
                                 full_matrices=False,
                                 compute_uv=True)        

    # get desired first vectors of u_svd
    u_svd = u_svd[:, :n_svd_comp]
   
    # project SVD components on sub-leadfield, take sum over vertices
    u_svd_proj = u_svd.T.dot(sub_leadfield).sum(axis=1)
    # make sure overall projection has positive sign
    u_svd = u_svd.dot(np.sign(np.diag(u_svd_proj)))

    u_svd = u_svd * s_svd[:n_svd_comp][np.newaxis, :]

    logger.info("\nFirst 5 singular values (n=%d): %s" % (u_svd.shape[0], \
                                                         s_svd[0:5]))
    
    # explained variance by chosen components within sub-leadfield
    my_comps = s_svd[0:n_svd_comp]

    comp_var = (100. * np.sum(my_comps * my_comps) /
                np.sum(s_svd * s_svd))
    logger.info("Your %d component(s) explain(s) %.1f%% "
                "variance." % (n_svd_comp, comp_var)) 

    return u_svd
Ejemplo n.º 46
0
def grid_search(epochs, n_interpolates, consensus_percs, prefix, n_folds=3):
    """Grid search to find optimal values of n_interpolate and consensus_perc.

    Parameters
    ----------
    epochs : instance of mne.Epochs
        The epochs object for which bad epochs must be found.
    n_interpolates : array
        The number of sensors to interpolate.
    consensus_percs : array
        The percentage of channels to be interpolated.
    n_folds : int
        Number of folds for cross-validation.
    prefix : str
        Prefix to the log
    """
    cv = KFold(len(epochs), n_folds=n_folds, random_state=42)
    err_cons = np.zeros((len(consensus_percs), len(n_interpolates),
                         n_folds))

    auto_reject = ConsensusAutoReject()
    # The thresholds must be learnt from the entire data
    auto_reject.fit(epochs)

    for fold, (train, test) in enumerate(cv):
        for jdx, n_interp in enumerate(n_interpolates):
            for idx, consensus_perc in enumerate(consensus_percs):
                logger.info('%s[Val fold %d] Trying consensus '
                            'perc %0.2f, n_interp %d' % (
                                prefix, fold + 1, consensus_perc, n_interp))
                # set the params
                auto_reject.consensus_perc = consensus_perc
                auto_reject.n_interpolate = n_interp
                # not do the transform
                auto_reject.transform(epochs[train])
                # score using this param
                X = epochs[test].get_data()
                err_cons[idx, jdx, fold] = -auto_reject.score(X)

    return err_cons
Ejemplo n.º 47
0
def _interpolate_bads_meg(epochs, bad_channels_by_epoch, mode='fast'):
    """Interpolate bad MEG channels per epoch

    Parameters
    ----------
    inst : mne.io.Raw, mne.Epochs or mne.Evoked
        The data to interpolate. Must be preloaded.
    bad_channels_by_epoch : list of list of str
        Bad channel names specified for each epoch. For example, for an Epochs
        instance containing 3 epochs: ``[['F1'], [], ['F3', 'FZ']]``

    Notes
    -----
    Based on mne 0.9.0 MEG channel interpolation.
    """
    if len(bad_channels_by_epoch) != len(epochs):
        raise ValueError("Unequal length of epochs (%i) and "
                         "bad_channels_by_epoch (%i)"
                         % (len(epochs), len(bad_channels_by_epoch)))

    interp_cache = {}
    for i, bad_channels in enumerate(bad_channels_by_epoch):
        if not bad_channels:
            continue

        # find interpolation matrix
        key = tuple(sorted(bad_channels))
        if key in interp_cache:
            picks_good, picks_bad, interpolation = interp_cache[key]
        else:
            picks_good = pick_types(epochs.info, ref_meg=False, exclude=key)
            picks_bad = pick_channels(epochs.ch_names, key)
            interpolation = _map_meg_channels(epochs, picks_good, picks_bad, mode)
            interp_cache[key] = picks_good, picks_bad, interpolation

        # apply interpolation
        logger.info('Interpolating sensors %s on epoch %s', picks_bad, i)
        epochs._data[i, picks_bad, :] = interpolation.dot(epochs._data[i, picks_good, :])
Ejemplo n.º 48
0
def _run(subjects_dir, subject, layers, ico, overwrite):
    this_env = copy.copy(os.environ)
    this_env['SUBJECTS_DIR'] = subjects_dir
    this_env['SUBJECT'] = subject

    if 'SUBJECTS_DIR' not in this_env:
        raise RuntimeError('The environment variable SUBJECTS_DIR should '
                           'be set')

    if not op.isdir(subjects_dir):
        raise RuntimeError('subjects directory %s not found, specify using '
                           'the environment variable SUBJECTS_DIR or '
                           'the command line option --subjects-dir')

    if 'FREESURFER_HOME' not in this_env:
        raise RuntimeError('The FreeSurfer environment needs to be set up '
                           'for this script')

    subj_path = op.join(subjects_dir, subject)
    if not op.exists(subj_path):
        raise RuntimeError('%s does not exits. Please check your subject '
                           'directory path.' % subj_path)
    
    logger.info('1. Setting up MRI files...')
    if overwrite:
        run_subprocess(['mne_setup_mri', '--mri', 'T1', '--subject', subject, '--overwrite'], env=this_env)
    else:
        run_subprocess(['mne_setup_mri', '--mri', 'T1', '--subject', subject], env=this_env)

    logger.info('2. Setting up %d layer BEM...' % layers)
    if layers == 3:
        flash05 = op.join(subjects_dir, subject, 'nii/FLASH5.nii')
        flash30 = op.join(subjects_dir, subject, 'nii/FLASH30.nii')

        run_subprocess(['mne', 'flash_bem_model', '-s', subject, '-d', subjects_dir,
                        '--flash05', flash05, '--flash30', flash30, '-v'], env=this_env)
        for srf in ('inner_skull', 'outer_skull', 'outer_skin'):
            shutil.copy(op.join(subjects_dir, subject, 'bem/flash/%s.surf' % srf),
                        op.join(subjects_dir, subject, 'bem/%s.surf' % srf))
    else:
        if overwrite:
            run_subprocess(['mne', 'watershed_bem', '-s', subject, '-d', subjects_dir, '--overwrite'], env=this_env)
        else:
            run_subprocess(['mne', 'watershed_bem', '-s', subject, '-d', subjects_dir], env=this_env)

    # Create dense head surface and symbolic link to head.fif file
    logger.info('3. Creating high resolution skin surface for coregisteration...')
    run_subprocess(['mne', 'make_scalp_surfaces', '--overwrite', '--subject', subject])
    if op.isfile(op.join(subjects_dir, subject, 'bem/%s-head.fif' % subject)):
        os.rename(op.join(subjects_dir, subject, 'bem/%s-head.fif' % subject),
                  op.join(subjects_dir, subject, 'bem/%s-head-sparse.fif' % subject))
    os.symlink((op.join(subjects_dir, subject, 'bem/%s-head-dense.fif' % subject)),
               (op.join(subjects_dir, subject, 'bem/%s-head.fif' % subject)))

    # Create source space
    run_subprocess(['mne_setup_source_space', '--subject', subject, '--spacing', '%.0f' % 5, '--cps'],
                   env=this_env)
Ejemplo n.º 49
0
def test_noise_reducer():

    data_path = os.environ['SUBJECTS_DIR']
    subject   = os.environ['SUBJECT']

    dname = data_path + '/' + 'empty_room_files' + '/109925_empty_room_file-raw.fif'
    subjects_dir = data_path + '/subjects'
    #
    checkresults = True
    exclart = False
    use_reffilter = True
    refflt_lpfreq = 52.
    refflt_hpfreq = 48.

    print "########## before of noisereducer call ##########"
    sigchanlist = ['MEG ..1', 'MEG ..3', 'MEG ..5', 'MEG ..7', 'MEG ..9']
    sigchanlist = None
    refchanlist = ['RFM 001', 'RFM 003', 'RFM 005', 'RFG ...']
    tmin = 15.
    noise_reducer(dname, signals=sigchanlist, noiseref=refchanlist, tmin=tmin,
                  reflp=refflt_lpfreq, refhp=refflt_hpfreq,
                  exclude_artifacts=exclart, complementary_signal=True)
    print "########## behind of noisereducer call ##########"

    print "########## Read raw data:"
    tc0 = time.clock()
    tw0 = time.time()
    raw = mne.io.Raw(dname, preload=True)
    tc1 = time.clock()
    tw1 = time.time()
    print "loading raw data  took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tc0), (tw1 - tw0))

    # Time window selection
    # weights are calc'd based on [tmin,tmax], but applied to the entire data set.
    # tstep is used in artifact detection
    tmax = raw.times[raw.last_samp]
    tstep = 0.2
    itmin = int(floor(tmin * raw.info['sfreq']))
    itmax = int(ceil(tmax * raw.info['sfreq']))
    itstep = int(ceil(tstep * raw.info['sfreq']))
    print ">>> Set time-range to [%7.3f,%7.3f]" % (tmin, tmax)

    if sigchanlist is None:
        sigpick = mne.pick_types(raw.info, meg='mag', eeg=False, stim=False, eog=False, exclude='bads')
    else:
        sigpick = channel_indices_from_list(raw.info['ch_names'][:], sigchanlist)
    nsig = len(sigpick)
    print "sigpick: %3d chans" % nsig
    if nsig == 0:
        raise ValueError("No channel selected for noise compensation")

    if refchanlist is None:
        # References are not limited to 4D ref-chans, but can be anything,
        # incl. ECG or powerline monitor.
        print ">>> Using all refchans."
        refexclude = "bads"
        refpick = mne.pick_types(raw.info, ref_meg=True, meg=False, eeg=False,
                                 stim=False, eog=False, exclude=refexclude)
    else:
        refpick = channel_indices_from_list(raw.info['ch_names'][:], refchanlist)
        print "refpick = '%s'" % refpick
    nref = len(refpick)
    print "refpick: %3d chans" % nref
    if nref == 0:
        raise ValueError("No channel selected as noise reference")

    print "########## Refchan geo data:"
    # This is just for info to locate special 4D-refs.
    for iref in refpick:
        print raw.info['chs'][iref]['ch_name'], raw.info['chs'][iref]['loc'][0:3]
    print ""

    if use_reffilter:
        print "########## Filter reference channels:"
        if refflt_lpfreq is not None:
            print " low-pass with cutoff-freq %.1f" % refflt_lpfreq
        if refflt_hpfreq is not None:
            print "high-pass with cutoff-freq %.1f" % refflt_hpfreq
        # Adapt followg drop-chans cmd to use 'all-but-refpick'
        droplist = [raw.info['ch_names'][k] for k in xrange(raw.info['nchan']) if not k in refpick]
        fltref = raw.drop_channels(droplist, copy=True)
        tct = time.clock()
        twt = time.time()
        fltref.filter(refflt_hpfreq, refflt_lpfreq, picks=np.array(xrange(nref)), method='iir')
        tc1 = time.clock()
        tw1 = time.time()
        print "filtering ref-chans  took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tct), (tw1 - twt))

    print "########## Calculating sig-ref/ref-ref-channel covariances:"
    # Calculate sig-ref/ref-ref-channel covariance:
    # (there is no need to calc inter-signal-chan cov,
    #  but there seems to be no appropriat fct available)
    # Here we copy the idea from compute_raw_data_covariance()
    # and truncate it as appropriate.
    tct = time.clock()
    twt = time.time()
    # The following reject and info{sig,ref} entries are only
    # used in _is_good-calls.
    # _is_good() from mne-0.9.git-py2.7.egg/mne/epochs.py seems to
    # ignore ref-channels (not covered by dict) and checks individual
    # data segments - artifacts across a buffer boundary are not found.
    reject = dict(grad=4000e-13, # T / m (gradiometers)
                  mag=4e-12,     # T (magnetometers)
                  eeg=40e-6,     # uV (EEG channels)
                  eog=250e-6)    # uV (EOG channels)

    infosig = copy.copy(raw.info)
    infosig['chs'] = [raw.info['chs'][k] for k in sigpick]
    infosig['ch_names'] = [raw.info['ch_names'][k] for k in sigpick]
    infosig['nchan'] = len(sigpick)
    idx_by_typesig = channel_indices_by_type(infosig)

    # inforef not good w/ filtering, but anyway useless
    inforef = copy.copy(raw.info)
    inforef['chs'] = [raw.info['chs'][k] for k in refpick]
    inforef['ch_names'] = [raw.info['ch_names'][k] for k in refpick]
    inforef['nchan'] = len(refpick)
    idx_by_typeref = channel_indices_by_type(inforef)

    # Read data in chunks:
    sigmean = 0
    refmean = 0
    sscovdata = 0
    srcovdata = 0
    rrcovdata = 0
    n_samples = 0
    for first in range(itmin, itmax, itstep):
        last = first + itstep
        if last >= itmax:
            last = itmax
        raw_segmentsig, times = raw[sigpick, first:last]
        if use_reffilter:
            raw_segmentref, times = fltref[:, first:last]
        else:
            raw_segmentref, times = raw[refpick, first:last]
        # if True:
        # if _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject, flat=None,
        #            ignore_chs=raw.info['bads']) and _is_good(raw_segmentref,
        #              inforef['ch_names'], idx_by_typeref, reject, flat=None,
        #                ignore_chs=raw.info['bads']):
        if not exclart or \
           _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject,
                    flat=None, ignore_chs=raw.info['bads']):
            sigmean += raw_segmentsig.sum(axis=1)
            refmean += raw_segmentref.sum(axis=1)
            sscovdata += (raw_segmentsig * raw_segmentsig).sum(axis=1)
            srcovdata += np.dot(raw_segmentsig, raw_segmentref.T)
            rrcovdata += np.dot(raw_segmentref, raw_segmentref.T)
            n_samples += raw_segmentsig.shape[1]
        else:
            logger.info("Artefact detected in [%d, %d]" % (first, last))

    #_check_n_samples(n_samples, len(picks))
    sigmean /= n_samples
    refmean /= n_samples
    sscovdata -= n_samples * sigmean[:] * sigmean[:]
    sscovdata /= (n_samples - 1)
    srcovdata -= n_samples * sigmean[:, None] * refmean[None, :]
    srcovdata /= (n_samples - 1)
    rrcovdata -= n_samples * refmean[:, None] * refmean[None, :]
    rrcovdata /= (n_samples - 1)
    sscovinit = sscovdata
    print "Normalize srcov..."
    rrslopedata = copy.copy(rrcovdata)
    for iref in xrange(nref):
        dtmp = rrcovdata[iref][iref]
        if dtmp > TINY:
            for isig in xrange(nsig):
                srcovdata[isig][iref] /= dtmp
            for jref in xrange(nref):
                rrslopedata[jref][iref] /= dtmp
        else:
            for isig in xrange(nsig):
                srcovdata[isig][iref] = 0.
            for jref in xrange(nref):
                rrslopedata[jref][iref] = 0.
    logger.info("Number of samples used : %d" % n_samples)
    tc1 = time.clock()
    tw1 = time.time()
    print "sigrefchn covar-calc took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tct), (tw1 - twt))

    print "########## Calculating sig-ref/ref-ref-channel covariances (robust):"
    # Calculate sig-ref/ref-ref-channel covariance:
    # (usg B.P.Welford, "Note on a method for calculating corrected sums
    #                   of squares and products", Technometrics4 (1962) 419-420)
    # (there is no need to calc inter-signal-chan cov,
    #  but there seems to be no appropriat fct available)
    # Here we copy the idea from compute_raw_data_covariance()
    # and truncate it as appropriate.
    tct = time.clock()
    twt = time.time()
    # The following reject and info{sig,ref} entries are only
    # used in _is_good-calls.
    # _is_good() from mne-0.9.git-py2.7.egg/mne/epochs.py seems to
    # ignore ref-channels (not covered by dict) and checks individual
    # data segments - artifacts across a buffer boundary are not found.
    reject = dict(grad=4000e-13, # T / m (gradiometers)
                  mag=4e-12,     # T (magnetometers)
                  eeg=40e-6,     # uV (EEG channels)
                  eog=250e-6)    # uV (EOG channels)

    infosig = copy.copy(raw.info)
    infosig['chs'] = [raw.info['chs'][k] for k in sigpick]
    infosig['ch_names'] = [raw.info['ch_names'][k] for k in sigpick]
    infosig['nchan'] = len(sigpick)
    idx_by_typesig = channel_indices_by_type(infosig)

    # inforef not good w/ filtering, but anyway useless
    inforef = copy.copy(raw.info)
    inforef['chs'] = [raw.info['chs'][k] for k in refpick]
    inforef['ch_names'] = [raw.info['ch_names'][k] for k in refpick]
    inforef['nchan'] = len(refpick)
    idx_by_typeref = channel_indices_by_type(inforef)

    # Read data in chunks:
    smean = np.zeros(nsig)
    smold = np.zeros(nsig)
    rmean = np.zeros(nref)
    rmold = np.zeros(nref)
    sscov = 0
    srcov = 0
    rrcov = np.zeros((nref, nref))
    srcov = np.zeros((nsig, nref))
    n_samples = 0
    for first in range(itmin, itmax, itstep):
        last = first + itstep
        if last >= itmax:
            last = itmax
        raw_segmentsig, times = raw[sigpick, first:last]
        if use_reffilter:
            raw_segmentref, times = fltref[:, first:last]
        else:
            raw_segmentref, times = raw[refpick, first:last]
        # if True:
        # if _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject, flat=None,
        #            ignore_chs=raw.info['bads']) and _is_good(raw_segmentref,
        #              inforef['ch_names'], idx_by_typeref, reject, flat=None,
        #                ignore_chs=raw.info['bads']):
        if not exclart or \
           _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject,
                    flat=None, ignore_chs=raw.info['bads']):
            for isl in xrange(raw_segmentsig.shape[1]):
                nsl = isl + n_samples + 1
                cnslm1dnsl = float((nsl - 1)) / float(nsl)
                sslsubmean = (raw_segmentsig[:, isl] - smold)
                rslsubmean = (raw_segmentref[:, isl] - rmold)
                smean = smold + sslsubmean / nsl
                rmean = rmold + rslsubmean / nsl
                sscov += sslsubmean * (raw_segmentsig[:, isl] - smean)
                srcov += cnslm1dnsl * np.dot(sslsubmean.reshape((nsig, 1)), rslsubmean.reshape((1, nref)))
                rrcov += cnslm1dnsl * np.dot(rslsubmean.reshape((nref, 1)), rslsubmean.reshape((1, nref)))
                smold = smean
                rmold = rmean
            n_samples += raw_segmentsig.shape[1]
        else:
            logger.info("Artefact detected in [%d, %d]" % (first, last))

    #_check_n_samples(n_samples, len(picks))
    sscov /= (n_samples - 1)
    srcov /= (n_samples - 1)
    rrcov /= (n_samples - 1)
    print "Normalize srcov..."
    rrslope = copy.copy(rrcov)
    for iref in xrange(nref):
        dtmp = rrcov[iref][iref]
        if dtmp > TINY:
            srcov[:, iref] /= dtmp
            rrslope[:, iref] /= dtmp
        else:
            srcov[:, iref] = 0.
            rrslope[:, iref] = 0.
    logger.info("Number of samples used : %d" % n_samples)
    print "Compare results with 'standard' values:"
    print "cmp(sigmean,smean):", np.allclose(smean, sigmean, atol=0.)
    print "cmp(refmean,rmean):", np.allclose(rmean, refmean, atol=0.)
    print "cmp(sscovdata,sscov):", np.allclose(sscov, sscovdata, atol=0.)
    print "cmp(srcovdata,srcov):", np.allclose(srcov, srcovdata, atol=0.)
    print "cmp(rrcovdata,rrcov):", np.allclose(rrcov, rrcovdata, atol=0.)
    tc1 = time.clock()
    tw1 = time.time()
    print "sigrefchn covar-calc took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tct), (tw1 - twt))

    if checkresults:
        print "########## Calculated initial signal channel covariance:"
        # Calculate initial signal channel covariance:
        # (only used as quality measure)
        print "initl rt(avg sig pwr) = %12.5e" % np.sqrt(np.mean(sscov))
        for i in xrange(5):
            print "initl signal-rms[%3d] = %12.5e" % (i, np.sqrt(sscov.flatten()[i]))
        print " "
    if nref < 6:
        print "rrslope-entries:"
        for i in xrange(nref):
            print rrslope[i][:]

    U, s, V = np.linalg.svd(rrslope, full_matrices=True)
    print s

    print "Applying cutoff for smallest SVs:"
    dtmp = s.max() * SVD_RELCUTOFF
    sinv = np.zeros(nref)
    for i in xrange(nref):
        if abs(s[i]) >= dtmp:
            sinv[i] = 1. / s[i]
        else:
            s[i] = 0.
    # s *= (abs(s)>=dtmp)
    # sinv = ???
    print s
    stat = np.allclose(rrslope, np.dot(U, np.dot(np.diag(s), V)))
    print ">>> Testing svd-result: %s" % stat
    if not stat:
        print "    (Maybe due to SV-cutoff?)"

    # Solve for inverse coefficients:
    print ">>> Setting RRinvtr=U diag(sinv) V"
    RRinvtr = np.zeros((nref, nref))
    RRinvtr = np.dot(U, np.dot(np.diag(sinv), V))
    if checkresults:
        # print ">>> RRinvtr-result:"
        # print RRinvtr
        stat = np.allclose(np.identity(nref), np.dot(rrslope.transpose(), RRinvtr))
        if stat:
            print ">>> Testing RRinvtr-result (shld be unit-matrix): ok"
        else:
            print ">>> Testing RRinvtr-result (shld be unit-matrix): failed"
            print np.dot(rrslope.transpose(), RRinvtr)
            # np.less_equal(np.abs(np.dot(rrslope.transpose(),RRinvtr)-np.identity(nref)),0.01*np.ones((nref,nref)))
        print ""

    print "########## Calc weight matrix..."
    # weights-matrix will be somewhat larger than necessary,
    # (to simplify indexing in compensation loop):
    weights = np.zeros((raw._data.shape[0], nref))
    for isig in xrange(nsig):
        for iref in xrange(nref):
            weights[sigpick[isig]][iref] = np.dot(srcov[isig][:], RRinvtr[iref][:])

    if np.allclose(np.zeros(weights.shape), np.abs(weights), atol=1.e-8):
        print ">>> all weights are small (<=1.e-8)."
    else:
        print ">>> largest weight %12.5e" % np.max(np.abs(weights))
        wlrg = np.where(np.abs(weights) >= 0.99 * np.max(np.abs(weights)))
        for iwlrg in xrange(len(wlrg[0])):
            print ">>> weights[%3d,%2d] = %12.5e" % \
                  (wlrg[0][iwlrg], wlrg[1][iwlrg], weights[wlrg[0][iwlrg], wlrg[1][iwlrg]])

    if nref < 5:
        print "weights-entries for first sigchans:"
        for i in xrange(5):
            print 'weights[sp(%2d)][r]=[' % i + ' '.join([' %+10.7f' %
                             val for val in weights[sigpick[i]][:]]) + ']'

    print "########## Compensating signal channels:"
    tct = time.clock()
    twt = time.time()
    # data,times = raw[:,raw.time_as_index(tmin)[0]:raw.time_as_index(tmax)[0]:]
    # Work on entire data stream:
    for isl in xrange(raw._data.shape[1]):
        slice = np.take(raw._data, [isl], axis=1)
        if use_reffilter:
            refslice = np.take(fltref._data, [isl], axis=1)
            refarr = refslice[:].flatten() - rmean
            # refarr = fltres[:,isl]-rmean
        else:
            refarr = slice[refpick].flatten() - rmean
        subrefarr = np.dot(weights[:], refarr)
        # data[:,isl] -= subrefarr   will not modify raw._data?
        raw._data[:, isl] -= subrefarr
        if isl%10000 == 0:
            print "\rProcessed slice %6d" % isl
    print "\nDone."
    tc1 = time.clock()
    tw1 = time.time()
    print "compensation loop took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tct), (tw1 - twt))

    if checkresults:
        print "########## Calculating final signal channel covariance:"
        # Calculate final signal channel covariance:
        # (only used as quality measure)
        tct = time.clock()
        twt = time.time()
        sigmean = 0
        sscovdata = 0
        n_samples = 0
        for first in range(itmin, itmax, itstep):
            last = first + itstep
            if last >= itmax:
                last = itmax
            raw_segmentsig, times = raw[sigpick, first:last]
            # Artifacts found here will probably differ from pre-noisered artifacts!
            if not exclart or \
               _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject,
                        flat=None, ignore_chs=raw.info['bads']):
                sigmean += raw_segmentsig.sum(axis=1)
                sscovdata += (raw_segmentsig * raw_segmentsig).sum(axis=1)
                n_samples += raw_segmentsig.shape[1]
        sigmean /= n_samples
        sscovdata -= n_samples * sigmean[:] * sigmean[:]
        sscovdata /= (n_samples - 1)
        print ">>> no channel got worse: ", np.all(np.less_equal(sscovdata, sscovinit))
        print "final rt(avg sig pwr) = %12.5e" % np.sqrt(np.mean(sscovdata))
        for i in xrange(5):
            print "final signal-rms[%3d] = %12.5e" % (i, np.sqrt(sscovdata.flatten()[i]))
        tc1 = time.clock()
        tw1 = time.time()
        print "signal covar-calc took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tct), (tw1 - twt))
        print " "

    nrname = dname[:dname.rfind('-raw.fif')] + ',nold-raw.fif'
    print "Saving '%s'..." % nrname
    raw.save(nrname, overwrite=True)
    tc1 = time.clock()
    tw1 = time.time()
    print "Total run         took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tc0), (tw1 - tw0))
Ejemplo n.º 50
0
def log_elapsed(t, verbose=None):
    """Log elapsed time."""
    logger.info('Report complete in %s seconds' % round(t, 1))
Ejemplo n.º 51
0
def _phase_amplitude_coupling(data, sfreq, f_phase, f_amp, ixs,
                              pac_func='ozkurt', events=None,
                              tmin=None, tmax=None, n_cycles_ph=3,
                              n_cycles_am=3, scale_amp_func=None,
                              return_data=False, concat_epochs=False,
                              n_jobs=1, verbose=None):
    """ Compute phase-amplitude coupling using pacpy.

    Parameters
    ----------
    data : array, shape ([n_epochs], n_channels, n_times)
        The data used to calculate PAC
    sfreq : float
        The sampling frequency of the data.
    f_phase : array, dtype float, shape (n_bands_phase, 2,)
        The frequency ranges to use for the phase carrier. PAC will be
        calculated between n_bands_phase * n_bands_amp frequencies.
    f_amp : array, dtype float, shape (n_bands_amp, 2,)
        The frequency ranges to use for the phase-modulated amplitude.
        PAC will be calculated between n_bands_phase * n_bands_amp frequencies.
    ixs : array-like, shape (n_ch_pairs x 2)
        The indices for low/high frequency channels. PAC will be estimated
        between n_ch_pairs of channels. Indices correspond to rows of `data`.
    pac_func : {'plv', 'glm', 'mi_canolty', 'mi_tort', 'ozkurt'} |
               list of strings
        The function for estimating PAC. Corresponds to functions in
        `pacpy.pac`. Defaults to 'ozkurt'. If multiple frequency bands are used
        then `plv` cannot be calculated.
    events : array, shape (n_events, 3) | array, shape (n_events,) | None
        MNE events array. To be supplied if data is 2D and output should be
        split by events. In this case, `tmin` and `tmax` must be provided. If
        `ndim == 1`, it is assumed to be event indices, and all events will be
        grouped together.
    tmin : float | list of floats, shape (n_pac_windows,) | None
        If `events` is not provided, it is the start time to use in `inst`.
        If `events` is provided, it is the time (in seconds) to include before
        each event index. If a list of floats is given, then PAC is calculated
        for each pair of `tmin` and `tmax`. Defaults to `min(inst.times)`.
    tmax : float | list of floats, shape (n_pac_windows,) | None
        If `events` is not provided, it is the stop time to use in `inst`.
        If `events` is provided, it is the time (in seconds) to include after
        each event index. If a list of floats is given, then PAC is calculated
        for each pair of `tmin` and `tmax`. Defaults to `max(inst.n_times)`.
    n_cycles_ph : float, int | array of floats, shape (n_bands_phase,)
        The number of cycles to be included in the window for each band-pass
        filter for phase. Defaults to 3.
    n_cycles_am : float, int | array of floats, shape (n_bands_amp,)
        The number of cycles to be included in the window for each band-pass
        filter for amplitude. Defaults to 3.
    scale_amp_func : None | function
        If not None, will be called on each amplitude signal in order to scale
        the values. Function must accept an N-D input and will operate on the
        last dimension. E.g., `sklearn.preprocessing.scale`.
        Defaults to no scaling.
    return_data : bool
        If False, output will be `[pac_out]`. If True, output will be,
        `[pac_out, phase_signal, amp_signal]`.
    concat_epochs : bool
        If True, epochs will be concatenated before calculating PAC values. If
        epochs are relatively short, this is a good idea in order to improve
        stability of the PAC metric.
    n_jobs : int
        Number of jobs to run in parallel. Defaults to 1.
    verbose : bool, str, int, or None
        If not None, override default verbose level (see `mne.verbose`).

    Returns
    -------
    pac_out : array, list of arrays, dtype float,
              shape([n_pac_funcs], n_epochs, n_channel_pairs,
                    n_freq_pairs, n_pac_windows).
        The computed phase-amplitude coupling between each pair of data sources
        given in ixs. If multiple pac metrics are specified, there will be one
        array per metric in the output list. If n_pac_funcs is 1, then the
        first dimension will be dropped.
    [phase_signal] : array, shape (n_phase_signals, n_times,)
        Only returned if `return_data` is True. The phase timeseries of the
        phase signals (first column of `ixs`).
    [amp_signal] : array, shape (n_amp_signals, n_times,)
        Only returned if `return_data` is True. The amplitude timeseries of the
        amplitude signals (second column of `ixs`).
    """
    from ..externals.pacpy import pac as ppac
    pac_func = np.atleast_1d(pac_func)
    for i_func in pac_func:
        if i_func not in _pac_funcs:
            raise ValueError("PAC function %s is not supported" % i_func)
    n_pac_funcs = pac_func.shape[0]
    ixs = np.array(ixs, ndmin=2)
    n_ch_pairs = ixs.shape[0]
    tmin = 0 if tmin is None else tmin
    tmin = np.atleast_1d(tmin)
    n_pac_windows = len(tmin)
    tmax = (data.shape[-1] - 1) / float(sfreq) if tmax is None else tmax
    tmax = np.atleast_1d(tmax)
    f_phase = np.atleast_2d(f_phase)
    f_amp = np.atleast_2d(f_amp)
    n_cycles_ph = np.atleast_1d(n_cycles_ph)
    n_cycles_am = np.atleast_1d(n_cycles_am)
    if n_cycles_ph.shape[0] == 1:
        n_cycles_ph = np.repeat(n_cycles_ph, f_phase.shape[0])
    if n_cycles_am.shape[0] == 1:
        n_cycles_am = np.repeat(n_cycles_am, f_amp.shape[0])

    if data.ndim != 2:
        raise ValueError('Data must be shape (n_channels, n_times)')
    if ixs.shape[1] != 2:
        raise ValueError('Indices must have have a 2nd dimension of length 2')
    if f_phase.shape[-1] != 2 or f_amp.shape[-1] != 2:
        raise ValueError('Frequencies must be specified w/ a low/hi tuple')
    if len(tmin) != len(tmax):
        raise ValueError('tmin and tmax have differing lengths')
    if any(i_f.shape[0] > 1 and 'plv' in pac_func for i_f in (f_amp, f_phase)):
        raise ValueError('If calculating PLV, must use a single pair of freqs')
    for icyc, i_f in zip([n_cycles_ph, n_cycles_am], [f_phase, f_amp]):
        if icyc.shape[0] != i_f.shape[0]:
            raise ValueError("n_cycles must match n_freq_bands")
        if icyc.ndim > 1:
            raise ValueError("n_cycles must be 1-d, not {}d".format(icyc.ndim))

    logger.info('Pre-filtering data and extracting phase/amplitude...')
    hi_phase = np.unique([i_func in _hi_phase_funcs for i_func in pac_func])
    if len(hi_phase) != 1:
        raise ValueError("Can't mix pac funcs that use both hi-freq phase/amp")
    hi_phase = bool(hi_phase[0])
    data_ph, data_am, ix_map_ph, ix_map_am = _pre_filter_ph_am(
        data, sfreq, ixs, f_phase, f_amp, hi_phase=hi_phase,
        scale_amp_func=scale_amp_func, n_cycles_ph=n_cycles_ph,
        n_cycles_am=n_cycles_am)

    # So we know how big the PAC output will be
    if events is None:
        n_epochs = 1
    elif concat_epochs is True:
        if events.ndim == 1:
            n_epochs = 1
        else:
            n_epochs = np.unique(events[:, -1]).shape[0]
    else:
        n_epochs = events.shape[0]

    # Iterate through each pair of frequencies
    ixs_freqs = product(range(data_ph.shape[1]), range(data_am.shape[1]))
    ixs_freqs = np.atleast_2d(list(ixs_freqs))

    freq_pac = np.array([[f_phase[ii], f_amp[jj]] for ii, jj in ixs_freqs])
    n_f_pairs = len(ixs_freqs)
    pac = np.zeros([n_pac_funcs, n_epochs, n_ch_pairs,
                    n_f_pairs, n_pac_windows])
    for i_f_pair, (ix_f_ph, ix_f_am) in enumerate(ixs_freqs):
        # Second dimension is frequency
        i_f_data_ph = data_ph[:, ix_f_ph, ...]
        i_f_data_am = data_am[:, ix_f_am, ...]

        # Redefine indices to match the new data arrays
        ixs_new = [(ix_map_ph[i], ix_map_am[j]) for i, j in ixs]
        i_f_data_ph = mne.io.RawArray(
            i_f_data_ph, mne.create_info(i_f_data_ph.shape[0], sfreq))
        i_f_data_am = mne.io.RawArray(
            i_f_data_am, mne.create_info(i_f_data_am.shape[0], sfreq))

        # Turn into Epochs if we have defined events
        if events is not None:
            i_f_data_ph = _raw_to_epochs_mne(i_f_data_ph, events, tmin, tmax)
            i_f_data_am = _raw_to_epochs_mne(i_f_data_am, events, tmin, tmax)

        # Data is either Raw or Epochs
        pbar = ProgressBar(n_epochs)
        for itime, (i_tmin, i_tmax) in enumerate(zip(tmin, tmax)):
            # Pull times of interest
            with warnings.catch_warnings():  # To suppress a depracation
                warnings.simplefilter("ignore")
                # Not sure how to do this w/o copying
                i_t_data_am = i_f_data_am.copy().crop(i_tmin, i_tmax)
                i_t_data_ph = i_f_data_ph.copy().crop(i_tmin, i_tmax)

            if concat_epochs is True:
                # Iterate through each event type and hstack
                con_data_ph = []
                con_data_am = []
                for i_ev in i_t_data_am.event_id.keys():
                    con_data_ph.append(np.hstack(i_t_data_ph[i_ev]._data))
                    con_data_am.append(np.hstack(i_t_data_am[i_ev]._data))
                i_t_data_ph = np.vstack(con_data_ph)
                i_t_data_am = np.vstack(con_data_am)
            else:
                # Just pull all epochs separately
                i_t_data_ph = i_t_data_ph._data
                i_t_data_am = i_t_data_am._data
            # Now make sure that inputs to the loop are ep x chan x time
            if i_t_data_am.ndim == 2:
                i_t_data_ph = i_t_data_ph[np.newaxis, ...]
                i_t_data_am = i_t_data_am[np.newaxis, ...]
            # Loop through epochs (or epoch grps), each index pair, and funcs
            data_iter = zip(i_t_data_ph, i_t_data_am)
            for iep, (ep_ph, ep_am) in enumerate(data_iter):
                for iix, (i_ix_ph, i_ix_am) in enumerate(ixs_new):
                    for ix_func, i_pac_func in enumerate(pac_func):
                        func = getattr(ppac, i_pac_func)
                        pac[ix_func, iep, iix, i_f_pair, itime] = func(
                            ep_ph[i_ix_ph], ep_am[i_ix_am],
                            f_phase, f_amp, filterfn=False)
            pbar.update_with_increment_value(1)
    if pac.shape[0] == 1:
        pac = pac[0]
    if return_data:
        return pac, freq_pac, data_ph, data_am
    else:
        return pac, freq_pac
def _run(subjects_dir, subject, force, overwrite, no_decimate, verbose=None):
    this_env = copy.copy(os.environ)
    subjects_dir = get_subjects_dir(subjects_dir, raise_error=True)
    this_env['SUBJECTS_DIR'] = subjects_dir
    this_env['SUBJECT'] = subject
    if 'FREESURFER_HOME' not in this_env:
        raise RuntimeError('The FreeSurfer environment needs to be set up '
                           'for this script')
    incomplete = 'warn' if force else 'raise'
    subj_path = op.join(subjects_dir, subject)
    if not op.exists(subj_path):
        raise RuntimeError('%s does not exist. Please check your subject '
                           'directory path.' % subj_path)

    mri = 'T1.mgz' if op.exists(op.join(subj_path, 'mri', 'T1.mgz')) else 'T1'

    logger.info('1. Creating a dense scalp tessellation with mkheadsurf...')

    def check_seghead(surf_path=op.join(subj_path, 'surf')):
        surf = None
        for k in ['lh.seghead', 'lh.smseghead']:
            this_surf = op.join(surf_path, k)
            if op.exists(this_surf):
                surf = this_surf
                break
        return surf

    my_seghead = check_seghead()
    if my_seghead is None:
        run_subprocess(['mkheadsurf', '-subjid', subject, '-srcvol', mri],
                       env=this_env)

    surf = check_seghead()
    if surf is None:
        raise RuntimeError('mkheadsurf did not produce the standard output '
                           'file.')

    bem_dir = op.join(subjects_dir, subject, 'bem')
    if not op.isdir(bem_dir):
        os.mkdir(bem_dir)
    dense_fname = op.join(bem_dir, '%s-head-dense.fif' % subject)
    logger.info('2. Creating %s ...' % dense_fname)
    _check_file(dense_fname, overwrite)
    surf = mne.bem._surfaces_to_bem(
        [surf], [mne.io.constants.FIFF.FIFFV_BEM_SURF_ID_HEAD], [1],
        incomplete=incomplete)[0]
    mne.write_bem_surfaces(dense_fname, surf)
    levels = 'medium', 'sparse'
    tris = [] if no_decimate else [30000, 2500]
    if os.getenv('_MNE_TESTING_SCALP', 'false') == 'true':
        tris = [len(surf['tris'])]  # don't actually decimate
    for ii, (n_tri, level) in enumerate(zip(tris, levels), 3):
        logger.info('%i. Creating %s tessellation...' % (ii, level))
        logger.info('%i.1 Decimating the dense tessellation...' % ii)
        with ETSContext():
            points, tris = mne.decimate_surface(points=surf['rr'],
                                                triangles=surf['tris'],
                                                n_triangles=n_tri)
        dec_fname = dense_fname.replace('dense', level)
        logger.info('%i.2 Creating %s' % (ii, dec_fname))
        _check_file(dec_fname, overwrite)
        dec_surf = mne.bem._surfaces_to_bem(
            [dict(rr=points, tris=tris)],
            [mne.io.constants.FIFF.FIFFV_BEM_SURF_ID_HEAD], [1], rescale=False,
            incomplete=incomplete)
        mne.write_bem_surfaces(dec_fname, dec_surf)
Ejemplo n.º 53
0
def setup_provenance(script, results_dir, config=None, use_agg=True,
                     run_id=None):
    """Setup provenance tracking

    Parameters
    ----------
    script : str
        The script that was executed.
    results_dir : str
        The results directory.
    config : None | str
        The name of the config file. By default, the function expects the
        config to be under `__script__/' named `config.py`. It can also
        be another kind of textfile, e.g. .json.
    use_agg : bool
        Whether to use the 'Agg' backend for matplotlib or not.

    Returns
    -------
    report : mne.report.Report
        The mne report.

    Side-effects
    ------------
    - make results dir if it does not exists
    - sets log file for sterr output
    - writes log file with runtime information
    """
    if use_agg is True:
        import matplotlib
        matplotlib.use('Agg')

    if not callable(script):
        if not op.isfile(script):
            raise ValueError('sorry, this is not a script!')
    if not op.isdir(results_dir) and not callable(script):
        results_dir = op.join(op.dirname(op.dirname(script)), results_dir)
    else:
        results_dir = op.join(op.curdir, results_dir)

    if not callable(script):
        step = op.splitext(op.split(script)[1])[0]
    else:
        step = script.__name__

    if not op.isabs(results_dir):
        results_dir = op.abspath(results_dir)

    start_path = op.dirname(results_dir)
    results_dir = op.join(results_dir, step)
    if not op.exists(results_dir):
        logger.info('generating results dir')
        _forec_create_dir(results_dir, start=start_path)

    if run_id is None:
        run_id = create_run_id()
        logger.info('generated run id: %s' % run_id)
    else:
        logger.info('using existing run id: %s' % run_id)

    logger.info('preparing logging:')
    logging_dir = op.join(results_dir, run_id)
    if not op.exists(logging_dir):
        logger.info('... making logging directory: %s' % logging_dir)
        os.mkdir(logging_dir)
    else:
        logger.info('... using logging directory: %s' % logging_dir)
    modules = get_versions(sys)
    runtime_log = op.join(logging_dir, 'run_time.json')
    with open(runtime_log, 'w') as fid:
        json.dump(modules, fid)
    logger.info('... writing runtime info to: %s' % runtime_log)

    script_code_out = op.join(logging_dir, 'script.py')
    if callable(script):
        script_code_in = inspect.getsourcefile(script)
    else:
        script_code_in = script

    if not op.isfile(script_code_out):
        with open(script_code_out, 'w') as fid:
            with open(script_code_in) as script_fid:
                source_code = script_fid.read()
            fid.write(source_code)
    logger.info('... logging source code of calling script')

    if config is None:
        config = 'config.py'

    if op.isabs(config):
        config_fname = config
    else:
        config_fname = ''

    config_code = op.join(  # weird behavior of join if last arg is path
        results_dir, run_id, op.split(config_fname)[-1])
    if not op.isfile(config_fname):
        logger.info('... No config found. Logging nothing.')
    elif op.isfile(config_code):
        logger.info('... Config already written. I assume that you are using'
                    ' the same run_id for different runs of your script.')
    else:
        with open(config_code, 'w') as fid:
            with open(config_fname) as config_fid:
                source_code = config_fid.read()
            fid.write(source_code)
        logger.info('... logging source code of "%s".' % config_fname)

    logger.info('... preparing Report')
    report = Report(title=step)
    report.data_path = logging_dir
    std_logfile = op.join(logging_dir, 'run_output.log')
    logger.info('... setting logfile: %s' % std_logfile)
    set_log_file(std_logfile)

    return report, run_id, results_dir, logger
Ejemplo n.º 54
0
def printer(x):
    logger.info('exec')
    return x
def _run(subjects_dir, subject, force, overwrite, verbose=None):
    this_env = copy.copy(os.environ)
    this_env['SUBJECTS_DIR'] = subjects_dir
    this_env['SUBJECT'] = subject

    if 'SUBJECTS_DIR' not in this_env:
        raise RuntimeError('The environment variable SUBJECTS_DIR should '
                           'be set')

    if not op.isdir(subjects_dir):
        raise RuntimeError('subjects directory %s not found, specify using '
                           'the environment variable SUBJECTS_DIR or '
                           'the command line option --subjects-dir')

    if 'MNE_ROOT' not in this_env:
        raise RuntimeError('MNE_ROOT environment variable is not set')

    if 'FREESURFER_HOME' not in this_env:
        raise RuntimeError('The FreeSurfer environment needs to be set up '
                           'for this script')
    force = '--force' if force else '--check'
    subj_path = op.join(subjects_dir, subject)
    if not op.exists(subj_path):
        raise RuntimeError('%s does not exits. Please check your subject '
                           'directory path.' % subj_path)

    if op.exists(op.join(subj_path, 'mri', 'T1.mgz')):
        mri = 'T1.mgz'
    else:
        mri = 'T1'

    logger.info('1. Creating a dense scalp tessellation with mkheadsurf...')

    def check_seghead(surf_path=op.join(subj_path, 'surf')):
        for k in ['/lh.seghead', '/lh.smseghead']:
            surf = surf_path + k if op.exists(surf_path + k) else None
            if surf is not None:
                break
        return surf

    my_seghead = check_seghead()
    if my_seghead is None:
        run_subprocess(['mkheadsurf', '-subjid', subject, '-srcvol', mri],
                       env=this_env)

    surf = check_seghead()
    if surf is None:
        raise RuntimeError('mkheadsurf did not produce the standard output '
                           'file.')

    dense_fname = '{0}/{1}/bem/{1}-head-dense.fif'.format(subjects_dir,
                                                          subject)
    logger.info('2. Creating %s ...' % dense_fname)
    _check_file(dense_fname, overwrite)
    run_subprocess(['mne_surf2bem', '--surf', surf, '--id', '4', force,
                    '--fif', dense_fname], env=this_env)
    levels = 'medium', 'sparse'
    my_surf = mne.read_bem_surfaces(dense_fname)[0]
    tris = [30000, 2500]
    if os.getenv('_MNE_TESTING_SCALP', 'false') == 'true':
        tris = [len(my_surf['tris'])]  # don't actually decimate
    for ii, (n_tri, level) in enumerate(zip(tris, levels), 3):
        logger.info('%i. Creating %s tessellation...' % (ii, level))
        logger.info('%i.1 Decimating the dense tessellation...' % ii)
        points, tris = mne.decimate_surface(points=my_surf['rr'],
                                            triangles=my_surf['tris'],
                                            n_triangles=n_tri)
        other_fname = dense_fname.replace('dense', level)
        logger.info('%i.2 Creating %s' % (ii, other_fname))
        _check_file(other_fname, overwrite)
        tempdir = _TempDir()
        surf_fname = tempdir + '/tmp-surf.surf'
        # convert points to meters, make mne_analyze happy
        mne.write_surface(surf_fname, points * 1e3, tris)
        # XXX for some reason --check does not work here.
        try:
            run_subprocess(['mne_surf2bem', '--surf', surf_fname, '--id', '4',
                            '--force', '--fif', other_fname], env=this_env)
        finally:
            del tempdir
def compute_ica(raw, subject, n_components=0.99, picks=None, decim=None,
                reject=None, ecg_tmin=-0.5, ecg_tmax=0.5, eog_tmin=-0.5,
                eog_tmax=0.5, n_max_ecg=3, n_max_eog=1,
                n_max_ecg_epochs=200, show=True, img_scale=1.0,
                random_state=None, report=None, artifact_stats=None):
    """Run ICA in raw data

    Parameters
    ----------,
    raw : instance of Raw
        Raw measurements to be decomposed.
    subject : str
        The name of the subject.
    picks : array-like of int, shape(n_channels, ) | None
        Channels to be included. This selection remains throughout the
        initialized ICA solution. If None only good data channels are used.
        Defaults to None.
    n_components : int | float | None | 'rank'
        The number of components used for ICA decomposition. If int, it must be
        smaller then max_pca_components. If None, all PCA components will be
        used. If float between 0 and 1 components can will be selected by the
        cumulative percentage of explained variance.
        If 'rank', the number of components equals the rank estimate.
        Defaults to 0.99.
    decim : int | None
        Increment for selecting each nth time slice. If None, all samples
        within ``start`` and ``stop`` are used. Defalts to None.
    reject : dict | None
        Rejection parameters based on peak to peak amplitude.
        Valid keys are 'grad' | 'mag' | 'eeg' | 'eog' | 'ecg'.
        If reject is None then no rejection is done. You should
        use such parameters to reject big measurement artifacts
        and not EOG for example. It only applies if `inst` is of type Raw.
        Defaults to {'mag': 5e-12}
    ecg_tmin : float
        Start time before ECG event. Defaults to -0.5.
    ecg_tmax : float
        End time after ECG event. Defaults to 0.5.
    eog_tmin : float
        Start time before rog event. Defaults to -0.5.
    eog_tmax : float
        End time after rog event. Defaults to 0.5.
    n_max_ecg : int | None
        The maximum number of ECG components to exclude. Defaults to 3.
    n_max_eog : int | None
        The maximum number of EOG components to exclude. Defaults to 1.
    n_max_ecg_epochs : int
        The maximum number of ECG epochs to use for phase-consistency
        estimation. Defaults to 200.
    show : bool
        Show figure if True
    scale_img : float
        The scaling factor for the report. Defaults to 1.0.
    random_state : None | int | instance of np.random.RandomState
        np.random.RandomState to initialize the FastICA estimation.
        As the estimation is non-deterministic it can be useful to
        fix the seed to have reproducible results. Defaults to None.
    report : instance of Report | None
        The report object. If None, a new report will be generated.
    artifact_stats : None | dict
        A dict that contains info on amplitude ranges of artifacts and
        numbers of events, etc. by channel type.

    Returns
    -------
    ica : instance of ICA
        The ICA solution.
    report : dict
        A dict with an html report ('html') and artifact statistics ('stats').
    """
    if report is None:
        report = Report(subject=subject, title='ICA preprocessing')
    if n_components == 'rank':
        n_components = raw.estimate_rank(picks=picks)
    ica = ICA(n_components=n_components, max_pca_components=None,
              random_state=random_state, max_iter=256)
    ica.fit(raw, picks=picks, decim=decim, reject=reject)

    comment = []
    for ch in ('mag', 'grad', 'eeg'):
        if ch in ica:
            comment += [ch.upper()]
    if len(comment) > 0:
        comment = '+'.join(comment) + ' '
    else:
        comment = ''

    topo_ch_type = 'mag'
    if 'GRAD' in comment and 'MAG' not in comment:
        topo_ch_type = 'grad'
    elif 'EEG' in comment:
        topo_ch_type = 'eeg'

    ###########################################################################
    # 2) identify bad components by analyzing latent sources.

    title = '%s related to %s artifacts (red) ({})'.format(subject)

    # generate ECG epochs use detection via phase statistics
    reject_ = {'mag': 5e-12, 'grad': 5000e-13, 'eeg': 300e-6}
    if reject is not None:
        reject_.update(reject)
    for ch_type in ['mag', 'grad', 'eeg']:
        if ch_type not in ica:
            reject_.pop(ch_type)

    picks_ = np.array([raw.ch_names.index(k) for k in ica.ch_names])
    if 'eeg' in ica:
        if 'ecg' in raw:
            picks_ = np.append(picks_,
                               pick_types(raw.info, meg=False, ecg=True)[0])
        else:
            logger.info('There is no ECG channel, trying to guess ECG from '
                        'magnetormeters')

    if artifact_stats is None:
        artifact_stats = dict()

    ecg_epochs = create_ecg_epochs(raw, tmin=ecg_tmin, tmax=ecg_tmax,
                                   keep_ecg=True, picks=picks_, reject=reject_)

    n_ecg_epochs_found = len(ecg_epochs.events)
    artifact_stats['ecg_n_events'] = n_ecg_epochs_found
    n_max_ecg_epochs = min(n_max_ecg_epochs, n_ecg_epochs_found)
    artifact_stats['ecg_n_used'] = n_max_ecg_epochs

    sel_ecg_epochs = np.arange(n_ecg_epochs_found)
    rng = np.random.RandomState(42)
    rng.shuffle(sel_ecg_epochs)
    ecg_ave = ecg_epochs.average()

    report.add_figs_to_section(ecg_ave.plot(), 'ECG-full', 'artifacts')
    ecg_epochs = ecg_epochs[sel_ecg_epochs[:n_max_ecg_epochs]]
    ecg_ave = ecg_epochs.average()
    report.add_figs_to_section(ecg_ave.plot(), 'ECG-used', 'artifacts')

    _put_artifact_range(artifact_stats, ecg_ave, kind='ecg')

    ecg_inds, scores = ica.find_bads_ecg(ecg_epochs, method='ctps')
    if len(ecg_inds) > 0:
        ecg_evoked = ecg_epochs.average()
        del ecg_epochs

        fig = ica.plot_scores(scores, exclude=ecg_inds, labels='ecg',
                              title='', show=show)

        report.add_figs_to_section(fig, 'scores ({})'.format(subject),
                                   section=comment + 'ECG',
                                   scale=img_scale)

        current_exclude = [e for e in ica.exclude]  # issue #2608 MNE
        fig = ica.plot_sources(raw, ecg_inds, exclude=ecg_inds,
                               title=title % ('components', 'ecg'), show=show)

        report.add_figs_to_section(fig, 'sources ({})'.format(subject),
                                   section=comment + 'ECG',
                                   scale=img_scale)
        ica.exclude = current_exclude

        fig = ica.plot_components(ecg_inds, ch_type=topo_ch_type,
                                  title='', colorbar=True, show=show)
        report.add_figs_to_section(fig, title % ('sources', 'ecg'),
                                   section=comment + 'ECG', scale=img_scale)
        ica.exclude = current_exclude

        ecg_inds = ecg_inds[:n_max_ecg]
        ica.exclude += ecg_inds
        fig = ica.plot_sources(ecg_evoked, exclude=ecg_inds, show=show)
        report.add_figs_to_section(fig, 'evoked sources ({})'.format(subject),
                                   section=comment + 'ECG',
                                   scale=img_scale)

        fig = ica.plot_overlay(ecg_evoked, exclude=ecg_inds, show=show)
        report.add_figs_to_section(fig,
                                   'rejection overlay ({})'.format(subject),
                                   section=comment + 'ECG',
                                   scale=img_scale)

    # detect EOG by correlation
    picks_eog = np.concatenate(
        [picks_, pick_types(raw.info, meg=False, eeg=False, ecg=False,
                            eog=True)])

    eog_epochs = create_eog_epochs(raw, tmin=eog_tmin, tmax=eog_tmax,
                                   picks=picks_eog, reject=reject_)
    artifact_stats['eog_n_events'] = len(eog_epochs.events)
    artifact_stats['eog_n_used'] = artifact_stats['eog_n_events']
    eog_ave = eog_epochs.average()
    report.add_figs_to_section(eog_ave.plot(), 'EOG-used', 'artifacts')
    _put_artifact_range(artifact_stats, eog_ave, kind='eog')

    eog_inds = None
    if len(eog_epochs.events) > 0:
        eog_inds, scores = ica.find_bads_eog(eog_epochs)

    if eog_inds is not None and len(eog_epochs.events) > 0:
        fig = ica.plot_scores(scores, exclude=eog_inds, labels='eog',
                              show=show, title='')
        report.add_figs_to_section(fig, 'scores ({})'.format(subject),
                                   section=comment + 'EOG',
                                   scale=img_scale)

        current_exclude = [e for e in ica.exclude]  # issue #2608 MNE
        fig = ica.plot_sources(raw, eog_inds, exclude=ecg_inds,
                               title=title % ('sources', 'eog'), show=show)
        report.add_figs_to_section(fig, 'sources', section=comment + 'EOG',
                                   scale=img_scale)
        ica.exclude = current_exclude

        fig = ica.plot_components(eog_inds, ch_type=topo_ch_type,
                                  title='', colorbar=True, show=show)
        report.add_figs_to_section(fig, title % ('components', 'eog'),
                                   section=comment + 'EOG', scale=img_scale)
        ica.exclude = current_exclude

        eog_inds = eog_inds[:n_max_eog]
        ica.exclude += eog_inds

        eog_evoked = eog_epochs.average()
        fig = ica.plot_sources(eog_evoked, exclude=eog_inds, show=show)
        report.add_figs_to_section(
            fig, 'evoked sources ({})'.format(subject),
            section=comment + 'EOG', scale=img_scale)

        fig = ica.plot_overlay(eog_evoked, exclude=eog_inds, show=show)
        report.add_figs_to_section(
            fig, 'rejection overlay({})'.format(subject),
            section=comment + 'EOG', scale=img_scale)
    else:
        del eog_epochs

    # check the amplitudes do not change
    if len(ica.exclude) > 0:
        fig = ica.plot_overlay(raw, show=show)  # EOG artifacts remain
        html = _render_components_table(ica)
        report.add_htmls_to_section(
            html, captions='excluded components',
            section='ICA rejection summary (%s)' % ch_type)
        report.add_figs_to_section(
            fig, 'rejection overlay({})'.format(subject),
            section=comment + 'RAW', scale=img_scale)
    return ica, dict(html=report, stats=artifact_stats)
Ejemplo n.º 57
0
def _run(subjects_dir, subject, raw_dir, force, mp, volume):
    this_env = copy.copy(os.environ)
    this_env['SUBJECTS_DIR'] = subjects_dir
    this_env['SUBJECT'] = subject
    parrec_dir = op.join(subjects_dir, raw_dir, subject)

    if 'SUBJECTS_DIR' not in this_env:
        raise RuntimeError('The environment variable SUBJECTS_DIR should '
                           'be set')

    if not op.isdir(subjects_dir):
        raise RuntimeError('subjects directory %s not found, specify using '
                           'the environment variable SUBJECTS_DIR or '
                           'the command line option --subjects-dir')
    
    if not op.isdir(parrec_dir):
        raise RuntimeError('%s directory not found, specify using '
                           'the command line option --raw-dir' % parrec_dir)

    if 'FREESURFER_HOME' not in this_env:
        raise RuntimeError('The FreeSurfer environment needs to be set up '
                           'for this script')

    if op.isdir(op.join(subjects_dir, subject)) and not force:
        raise RuntimeError('%s FreeSurfer directory exists. '
                           'Use command line option --force to overwrite '
                           'previous reconstruction results.' % subject)
    if force:
        shutil.rmtree(op.join(subjects_dir, subject))

    os.mkdir(op.join(subjects_dir, subject))
    os.makedirs(op.join(subjects_dir, subject, 'mri/orig/'))
    os.mkdir(op.join(subjects_dir, subject, 'mri/nii'))
    fs_nii_dir = op.join(subjects_dir, subject, 'mri/nii')

    logger.info('1. Processing raw MRI data...')
    for root, _, filenames in os.walk(parrec_dir):
        for filename in fnmatch.filter(filenames, '*Quiet_Survey*'):
            os.remove(op.join(root, filename))
    parfiles = []
    for root, dirnames, filenames in os.walk(parrec_dir):
        for filename in fnmatch.filter(filenames, '*.PAR'):
            parfiles.append(op.join(root, filename))
    parfiles.sort()
    for pf in parfiles:
        if (volume in pf) or ('FLASH' in pf):
            print('Converting {0}'.format(pf))
            pimg = nibabel.load(pf)
            pr_hdr = pimg.header
            raw_data = pimg.dataobj.get_unscaled()
            affine = pr_hdr.get_affine(origin='fov')
            nimg = nibabel.Nifti1Image(raw_data, affine, pr_hdr)
            nimg.to_filename(op.join(parrec_dir, op.basename(pf)[:-4]))
            shutil.copy(nimg.get_filename(), fs_nii_dir)

    for ff in glob.glob(op.join(fs_nii_dir, '*.nii')):
        if volume in op.basename(ff):
            os.symlink(ff, op.join(fs_nii_dir, 'MPRAGE.nii'))
        elif 'FLASH5' in op.basename(ff):
            os.symlink(ff, op.join(fs_nii_dir, 'FLASH5.nii'))
        elif 'FLASH30' in op.basename(ff):
            os.symlink(ff, op.join(fs_nii_dir, 'FLASH30.nii'))

    logger.info('2. Starting FreeSurfer reconstruction process...')
    mri = op.join(fs_nii_dir, 'MPRAGE.nii')
    run_subprocess(['mri_concat', '--rms', '--i', mri,
                    '--o', op.join(subjects_dir, subject, 'mri/orig/001.mgz')],
                   env=this_env)
    run_subprocess(['recon-all', '-openmp', mp, '-subject', subject, '-all'], env=this_env)
    for morph_to in ['fsaverage', subject]:
        run_subprocess(['mne_make_morph_maps', '--to', morph_to, '--from', subject], env=this_env)
Ejemplo n.º 58
0
def _make_forward_solutions(info, mri, src, bem, bem_eog, dev_head_ts, mindist,
                            chpi_rrs, eog_rrs, ecg_rrs, n_jobs):
    """Calculate a forward solution for a subject

    Parameters
    ----------
    info : instance of mne.io.meas_info.Info | str
        If str, then it should be a filename to a Raw, Epochs, or Evoked
        file with measurement information. If dict, should be an info
        dict (such as one from Raw, Epochs, or Evoked).
    mri : dict | str
        Either a transformation filename (usually made using mne_analyze)
        or an info dict (usually opened using read_trans()).
        If string, an ending of `.fif` or `.fif.gz` will be assumed to
        be in FIF format, any other ending will be assumed to be a text
        file with a 4x4 transformation matrix (like the `--trans` MNE-C
        option).
    src : str | instance of SourceSpaces
        If string, should be a source space filename. Can also be an
        instance of loaded or generated SourceSpaces.
    bem : str
        Filename of the BEM (e.g., "sample-5120-5120-5120-bem-sol.fif") to
        use.
    bem_eog : dict
        Spherical BEM to use for EOG (and ECG) simulation.
    dev_head_ts : list
        List of device<->head transforms.
    mindist : float
        Minimum distance of sources from inner skull surface (in mm).
    chpi_rrs : ndarray
        CHPI dipoles to simulate (magnetic dipoles).
    eog_rrs : ndarray
        EOG dipoles to simulate.
    ecg_rrs : ndarray
        ECG dipoles to simulate.
    n_jobs : int
        Number of jobs to run in parallel.

    Returns
    -------
    fwd : generator
        A generator for each forward solution in dev_head_ts.

    Notes
    -----
    Some of the forward solution calculation options from the C code
    (e.g., `--grad`, `--fixed`) are not implemented here. For those,
    consider using the C command line tools or the Python wrapper
    `do_forward_solution`.
    """
    mri_head_t, mri = _get_mri_head_t(mri)
    assert mri_head_t['from'] == FIFF.FIFFV_COORD_MRI

    if not isinstance(src, string_types):
        if not isinstance(src, SourceSpaces):
            raise TypeError('src must be a string or SourceSpaces')
    else:
        if not op.isfile(src):
            raise IOError('Source space file "%s" not found' % src)
    if isinstance(bem, dict):
        bem_extra = 'dict'
    else:
        bem_extra = bem
        if not op.isfile(bem):
            raise IOError('BEM file "%s" not found' % bem)
    if not isinstance(info, (dict, string_types)):
        raise TypeError('info should be a dict or string')
    if isinstance(info, string_types):
        info = read_info(info, verbose=False)

    # set default forward solution coordinate frame to HEAD
    # this could, in principle, be an option
    coord_frame = FIFF.FIFFV_COORD_HEAD

    # Report the setup
    logger.info('Setting up forward solutions')

    # Read the source locations
    if isinstance(src, string_types):
        src = read_source_spaces(src, verbose=False)
    else:
        # let's make a copy in case we modify something
        src = src.copy()
    nsource = sum(s['nuse'] for s in src)
    if nsource == 0:
        raise RuntimeError('No sources are active in these source spaces. '
                           '"do_all" option should be used.')
    logger.info('Read %d source spaces a total of %d active source locations'
                % (len(src), nsource))

    # make a new dict with the relevant information
    mri_id = dict(machid=np.zeros(2, np.int32), version=0, secs=0, usecs=0)
    info = dict(nchan=info['nchan'], chs=info['chs'], comps=info['comps'],
                ch_names=info['ch_names'],
                mri_file='', mri_id=mri_id, meas_file='',
                meas_id=None, working_dir=os.getcwd(),
                command_line='', bads=info['bads'])

    # Only get the EEG channels here b/c we can do MEG later
    _, _, eegels, _, eegnames, _ = \
        _prep_channels(info, False, True, True, verbose=False)

    # Transform the source spaces into the appropriate coordinates
    # (will either be HEAD or MRI)
    for s in src:
        transform_surface_to(s, coord_frame, mri_head_t)

    # Prepare the BEM model
    bem = _setup_bem(bem, bem_extra, len(eegnames), mri_head_t, verbose=False)

    # Circumvent numerical problems by excluding points too close to the skull
    if not bem['is_sphere']:
        inner_skull = _bem_find_surface(bem, 'inner_skull')
        _filter_source_spaces(inner_skull, mindist, mri_head_t, src, n_jobs,
                              verbose=False)

    # Time to do the heavy lifting: EEG first, then MEG
    rr = np.concatenate([s['rr'][s['vertno']] for s in src])
    eegfwd = _compute_forwards(rr, bem, [eegels], [None],
                               [None], ['eeg'], n_jobs, verbose=False)[0]
    eegfwd = _to_forward_dict(eegfwd, None, eegnames, coord_frame,
                              FIFF.FIFFV_MNE_FREE_ORI)
    eegeog = _compute_forwards(eog_rrs, bem_eog, [eegels], [None],
                               [None], ['eeg'], n_jobs, verbose=False)[0]
    eegeog = _to_forward_dict(eegeog, None, eegnames, coord_frame,
                              FIFF.FIFFV_MNE_FREE_ORI)

    for ti, dev_head_t in enumerate(dev_head_ts):
        # could be *slightly* more efficient not to do this N times,
        # but the cost here is tiny compared to actual fwd calculation
        logger.info('Computing gain matrix for transform #%s/%s'
                    % (ti + 1, len(dev_head_ts)))
        info = deepcopy(info)
        info['dev_head_t'] = dev_head_t
        megcoils, compcoils, _, megnames, _, meg_info = \
            _prep_channels(info, True, False, False, verbose=False)

        # make sure our sensors are all outside our BEM
        coil_rr = [coil['r0'] for coil in megcoils]
        if not bem['is_sphere']:
            idx = np.where(np.array([s['id'] for s in bem['surfs']]) ==
                           FIFF.FIFFV_BEM_SURF_ID_BRAIN)[0]
            assert len(idx) == 1
            bem_surf = transform_surface_to(bem['surfs'][idx[0]], coord_frame,
                                            mri_head_t)
            outside = _points_outside_surface(coil_rr, bem_surf, n_jobs,
                                              verbose=False)
        else:
            rad = bem['layers'][-1]['rad']
            outside = np.sqrt(np.sum((coil_rr - bem['r0']) ** 2)) >= rad
        if not np.all(outside):
            raise RuntimeError('MEG sensors collided with inner skull '
                               'surface for transform %s' % ti)

        # compute forward
        megfwd = _compute_forwards(rr, bem, [megcoils], [compcoils],
                                   [meg_info], ['meg'], n_jobs,
                                   verbose=False)[0]
        megfwd = _to_forward_dict(megfwd, None, megnames, coord_frame,
                                  FIFF.FIFFV_MNE_FREE_ORI)
        fwd = _merge_meg_eeg_fwds(megfwd, eegfwd, verbose=False)

        # pick out final dict info
        nsource = fwd['sol']['data'].shape[1] // 3
        source_nn = np.tile(np.eye(3), (nsource, 1))
        fwd.update(dict(nchan=fwd['sol']['data'].shape[0], nsource=nsource,
                        info=info, src=src, source_nn=source_nn,
                        source_rr=rr, surf_ori=False, mri_head_t=mri_head_t))
        fwd['info']['mri_head_t'] = mri_head_t
        fwd['info']['dev_head_t'] = dev_head_t

        megeog = _compute_forwards(eog_rrs, bem_eog, [megcoils], [compcoils],
                                   [meg_info], ['meg'], n_jobs,
                                   verbose=False)[0]
        megeog = _to_forward_dict(megeog, None, megnames, coord_frame,
                                  FIFF.FIFFV_MNE_FREE_ORI)
        fwd_eog = _merge_meg_eeg_fwds(megeog, eegeog, verbose=False)
        megecg = _compute_forwards(ecg_rrs, bem_eog, [megcoils], [compcoils],
                                   [meg_info], ['meg'], n_jobs,
                                   verbose=False)[0]
        fwd_ecg = _to_forward_dict(megecg, None, megnames, coord_frame,
                                   FIFF.FIFFV_MNE_FREE_ORI)
        fwd_chpi = _magnetic_dipole_field_vec(chpi_rrs, megcoils).T
        yield fwd, fwd_eog, fwd_ecg, fwd_chpi
Ejemplo n.º 59
0
def noise_reducer(fname_raw, raw=None, signals=[], noiseref=[], detrending=None,
                  tmin=None, tmax=None, reflp=None, refhp=None, refnotch=None,
                  exclude_artifacts=True, checkresults=True, return_raw=False,
                  complementary_signal=False, fnout=None, verbose=False):

    """Apply noise reduction to signal channels using reference channels.

    Parameters
    ----------
    fname_raw : (list of) rawfile names
    raw : mne Raw objects
        Allows passing of raw object as well.
    signals : list of string
              List of channels to compensate using noiseref.
              If empty use the meg signal channels.
    noiseref : list of string | str
              List of channels to use as noise reference.
              If empty use the magnetic reference channsls (default).
    signals and noiseref may contain regexp, which are resolved
    using mne.pick_channels_regexp(). All other channels are copied.
    tmin : lower latency bound for weight-calc [start of trace]
    tmax : upper latency bound for weight-calc [ end  of trace]
           Weights are calc'd for (tmin,tmax), but applied to entire data set
    refhp : high-pass frequency for reference signal filter [None]
    reflp :  low-pass frequency for reference signal filter [None]
            reflp < refhp: band-stop filter
            reflp > refhp: band-pass filter
            reflp is not None, refhp is None: low-pass filter
            reflp is None, refhp is not None: high-pass filter
    refnotch : (base) notch frequency for reference signal filter [None]
               use raw(ref)-notched(ref) as reference signal
    exclude_artifacts: filter signal-channels thru _is_good() [True]
                       (parameters are at present hard-coded!)
    return_raw : bool
        If return_raw is true, the raw object is returned and raw file
        is not written to disk. It is suggested that this option be used in cases
        where the noise_reducer is applied multiple times. [False]
    complementary_signal : replaced signal by traces that would be subtracted [False]
                           (can be useful for debugging)
    detrending: boolean to ctrl subtraction of linear trend from all magn. chans [False]
    checkresults : boolean to control internal checks and overall success [True]

    Outputfile
    ----------
    <wawa>,nr-raw.fif for input <wawa>-raw.fif

    Returns
    -------
    If return_raw is True, then mne.io.Raw instance is returned.

    Bugs
    ----
    - artifact checking is incomplete (and with arb. window of tstep=0.2s)
    - no accounting of channels used as signal/reference
    - non existing input file handled ungracefully
    """

    if type(complementary_signal) != bool:
        raise ValueError("Argument complementary_signal must be of type bool")

    # handle error if Raw object passed with file list
    if raw and isinstance(fname_raw, list):
        raise ValueError('List of file names cannot be combined with one Raw object')

    # handle error if return_raw is requested with file list
    if return_raw and isinstance(fname_raw, list):
        raise ValueError('List of file names cannot be combined return_raw.'
                         'Please pass one file at a time.')

    # handle error if Raw object is passed with detrending option
    #TODO include perform_detrending for Raw objects
    if raw and detrending:
        raise ValueError('Please perform detrending on the raw file directly. Cannot perform'
                         'detrending on the raw object')

    fnraw = get_files_from_list(fname_raw)

    # loop across all filenames
    for fname in fnraw:

        if verbose:
            print "########## Read raw data:"

        tc0 = time.clock()
        tw0 = time.time()

        if raw is None:
            if detrending:
                raw = perform_detrending(fname, save=False)
            else:
                raw = mne.io.Raw(fname, preload=True)
        else:
            # perform sanity check to make sure Raw object and file are same
            if os.path.basename(fname) != os.path.basename(raw.info['filename']):
                warnings.warn('The file name within the Raw object and provided'
                              'fname are not the same. Please check again.')

        tc1 = time.clock()
        tw1 = time.time()

        if verbose:
            print ">>> loading raw data took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tc0), (tw1 - tw0))

        # Time window selection
        # weights are calc'd based on [tmin,tmax], but applied to the entire data set.
        # tstep is used in artifact detection
        # tmin,tmax variables must not be changed here!
        if tmin is None:
            itmin = 0
        else:
            itmin = int(floor(tmin * raw.info['sfreq']))
        if tmax is None:
            itmax = raw.last_samp
        else:
            itmax = int(ceil(tmax * raw.info['sfreq']))

        if itmax - itmin < 2:
            raise ValueError("Time-window for noise compensation empty or too short")

        if verbose:
            print ">>> Set time-range to [%7.3f,%7.3f]" % \
                  (raw.times[itmin], raw.times[itmax])

        if signals is None or len(signals) == 0:
            sigpick = mne.pick_types(raw.info, meg='mag', eeg=False, stim=False,
                                     eog=False, exclude='bads')
        else:
            sigpick = channel_indices_from_list(raw.info['ch_names'][:], signals,
                                                raw.info.get('bads'))
        nsig = len(sigpick)
        if nsig == 0:
            raise ValueError("No channel selected for noise compensation")

        if noiseref is None or len(noiseref) == 0:
            # References are not limited to 4D ref-chans, but can be anything,
            # incl. ECG or powerline monitor.
            if verbose:
                print ">>> Using all refchans."
            refexclude = "bads"
            refpick = mne.pick_types(raw.info, ref_meg=True, meg=False, eeg=False,
                                     stim=False, eog=False, exclude='bads')
        else:
            refpick = channel_indices_from_list(raw.info['ch_names'][:], noiseref,
                                                raw.info.get('bads'))
        nref = len(refpick)
        if nref == 0:
            raise ValueError("No channel selected as noise reference")

        if verbose:
            print ">>> sigpick: %3d chans, refpick: %3d chans" % (nsig, nref)

        if reflp is None and refhp is None and refnotch is None:
            use_reffilter = False
            use_refantinotch = False
        else:
            use_reffilter = True
            if verbose:
                print "########## Filter reference channels:"

            use_refantinotch = False
            if refnotch is not None:
                if reflp is None and reflp is None:
                    use_refantinotch = True
                    freqlast = np.min([5.01 * refnotch, 0.5 * raw.info['sfreq']])
                    if verbose:
                        print ">>> notches at freq %.1f and harmonics below %.1f" % (refnotch, freqlast)
                else:
                    raise ValueError("Cannot specify notch- and high-/low-pass"
                                     "reference filter together")
            else:
                if verbose:
                    if reflp is not None:
                        print ">>>  low-pass with cutoff-freq %.1f" % reflp
                    if refhp is not None:
                        print ">>> high-pass with cutoff-freq %.1f" % refhp

            # Adapt followg drop-chans cmd to use 'all-but-refpick'
            droplist = [raw.info['ch_names'][k] for k in xrange(raw.info['nchan']) if not k in refpick]
            tct = time.clock()
            twt = time.time()
            fltref = raw.copy().drop_channels(droplist)
            if use_refantinotch:
                rawref = raw.copy().drop_channels(droplist)
                freqlast = np.min([5.01 * refnotch, 0.5 * raw.info['sfreq']])
                fltref.notch_filter(np.arange(refnotch, freqlast, refnotch),
                                    picks=np.array(xrange(nref)), method='iir')
                fltref._data = (rawref._data - fltref._data)
            else:
                fltref.filter(refhp, reflp, picks=np.array(xrange(nref)), method='iir')
            tc1 = time.clock()
            tw1 = time.time()
            if verbose:
                print ">>> filtering ref-chans  took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tct), (tw1 - twt))

        if verbose:
            print "########## Calculating sig-ref/ref-ref-channel covariances:"
        # Calculate sig-ref/ref-ref-channel covariance:
        # (there is no need to calc inter-signal-chan cov,
        #  but there seems to be no appropriat fct available)
        # Here we copy the idea from compute_raw_data_covariance()
        # and truncate it as appropriate.
        tct = time.clock()
        twt = time.time()
        # The following reject and infosig entries are only
        # used in _is_good-calls.
        # _is_good() from mne-0.9.git-py2.7.egg/mne/epochs.py seems to
        # ignore ref-channels (not covered by dict) and checks individual
        # data segments - artifacts across a buffer boundary are not found.
        reject = dict(grad=4000e-13, # T / m (gradiometers)
                      mag=4e-12,     # T (magnetometers)
                      eeg=40e-6,     # uV (EEG channels)
                      eog=250e-6)    # uV (EOG channels)

        infosig = copy.copy(raw.info)
        infosig['chs'] = [raw.info['chs'][k] for k in sigpick]
        infosig['ch_names'] = [raw.info['ch_names'][k] for k in sigpick]
        infosig['nchan'] = len(sigpick)
        idx_by_typesig = channel_indices_by_type(infosig)

        # Read data in chunks:
        tstep = 0.2
        itstep = int(ceil(tstep * raw.info['sfreq']))
        sigmean = 0
        refmean = 0
        sscovdata = 0
        srcovdata = 0
        rrcovdata = 0
        n_samples = 0

        for first in range(itmin, itmax, itstep):
            last = first + itstep
            if last >= itmax:
                last = itmax
            raw_segmentsig, times = raw[sigpick, first:last]
            if use_reffilter:
                raw_segmentref, times = fltref[:, first:last]
            else:
                raw_segmentref, times = raw[refpick, first:last]

            if not exclude_artifacts or \
               _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject, flat=None,
                        ignore_chs=raw.info['bads']):
                sigmean += raw_segmentsig.sum(axis=1)
                refmean += raw_segmentref.sum(axis=1)
                sscovdata += (raw_segmentsig * raw_segmentsig).sum(axis=1)
                srcovdata += np.dot(raw_segmentsig, raw_segmentref.T)
                rrcovdata += np.dot(raw_segmentref, raw_segmentref.T)
                n_samples += raw_segmentsig.shape[1]
            else:
                logger.info("Artefact detected in [%d, %d]" % (first, last))
        if n_samples <= 1:
            raise ValueError('Too few samples to calculate weights')
        sigmean /= n_samples
        refmean /= n_samples
        sscovdata -= n_samples * sigmean[:] * sigmean[:]
        sscovdata /= (n_samples - 1)
        srcovdata -= n_samples * sigmean[:, None] * refmean[None, :]
        srcovdata /= (n_samples - 1)
        rrcovdata -= n_samples * refmean[:, None] * refmean[None, :]
        rrcovdata /= (n_samples - 1)
        sscovinit = np.copy(sscovdata)
        if verbose:
            print ">>> Normalize srcov..."

        rrslope = copy.copy(rrcovdata)
        for iref in xrange(nref):
            dtmp = rrcovdata[iref, iref]
            if dtmp > TINY:
                srcovdata[:, iref] /= dtmp
                rrslope[:, iref] /= dtmp
            else:
                srcovdata[:, iref] = 0.
                rrslope[:, iref] = 0.

        if verbose:
            print ">>> Number of samples used : %d" % n_samples
            tc1 = time.clock()
            tw1 = time.time()
            print ">>> sigrefchn covar-calc took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tct), (tw1 - twt))

        if checkresults:
            if verbose:
                print "########## Calculated initial signal channel covariance:"
                # Calculate initial signal channel covariance:
                # (only used as quality measure)
                print ">>> initl rt(avg sig pwr) = %12.5e" % np.sqrt(np.mean(sscovdata))
                for i in xrange(5):
                    print ">>> initl signal-rms[%3d] = %12.5e" % (i, np.sqrt(sscovdata.flatten()[i]))
                print ">>>"

        U, s, V = np.linalg.svd(rrslope, full_matrices=True)
        if verbose:
            print ">>> singular values:"
            print s
            print ">>> Applying cutoff for smallest SVs:"

        dtmp = s.max() * SVD_RELCUTOFF
        s *= (abs(s) >= dtmp)
        sinv = [1. / s[k] if s[k] != 0. else 0. for k in xrange(nref)]
        if verbose:
            print ">>> singular values (after cutoff):"
            print s

        stat = np.allclose(rrslope, np.dot(U, np.dot(np.diag(s), V)))
        if verbose:
            print ">>> Testing svd-result: %s" % stat
            if not stat:
                print "    (Maybe due to SV-cutoff?)"

        # Solve for inverse coefficients:
        # Set RRinv.tr=U diag(sinv) V
        RRinv = np.transpose(np.dot(U, np.dot(np.diag(sinv), V)))
        if checkresults:
            stat = np.allclose(np.identity(nref), np.dot(RRinv, rrslope))
            if stat:
                if verbose:
                    print ">>> Testing RRinv-result (should be unit-matrix): ok"
            else:
                print ">>> Testing RRinv-result (should be unit-matrix): failed"
                print np.transpose(np.dot(RRinv, rrslope))
                print ">>>"

        if verbose:
            print "########## Calc weight matrix..."

        # weights-matrix will be somewhat larger than necessary,
        # (to simplify indexing in compensation loop):
        weights = np.zeros((raw._data.shape[0], nref))
        for isig in xrange(nsig):
            for iref in xrange(nref):
                weights[sigpick[isig],iref] = np.dot(srcovdata[isig,:], RRinv[:,iref])

        if verbose:
            print "########## Compensating signal channels:"
            if complementary_signal:
                print ">>> Caveat: REPLACING signal by compensation signal"

        tct = time.clock()
        twt = time.time()

        # Work on entire data stream:
        for isl in xrange(raw._data.shape[1]):
            slice = np.take(raw._data, [isl], axis=1)
            if use_reffilter:
                refslice = np.take(fltref._data, [isl], axis=1)
                refarr = refslice[:].flatten() - refmean
                # refarr = fltres[:,isl]-refmean
            else:
                refarr = slice[refpick].flatten() - refmean
            subrefarr = np.dot(weights[:], refarr)

            if not complementary_signal:
                raw._data[:, isl] -= subrefarr
            else:
                raw._data[:, isl] = subrefarr

            if (isl % 10000 == 0) and verbose:
                print "\rProcessed slice %6d" % isl

        if verbose:
            print "\nDone."
            tc1 = time.clock()
            tw1 = time.time()
            print ">>> compensation loop took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tct), (tw1 - twt))

        if checkresults:
            if verbose:
                print "########## Calculating final signal channel covariance:"
            # Calculate final signal channel covariance:
            # (only used as quality measure)
            tct = time.clock()
            twt = time.time()
            sigmean = 0
            sscovdata = 0
            n_samples = 0
            for first in range(itmin, itmax, itstep):
                last = first + itstep
                if last >= itmax:
                    last = itmax
                raw_segmentsig, times = raw[sigpick, first:last]
                # Artifacts found here will probably differ from pre-noisered artifacts!
                if not exclude_artifacts or \
                   _is_good(raw_segmentsig, infosig['ch_names'], idx_by_typesig, reject,
                            flat=None, ignore_chs=raw.info['bads']):
                    sigmean += raw_segmentsig.sum(axis=1)
                    sscovdata += (raw_segmentsig * raw_segmentsig).sum(axis=1)
                    n_samples += raw_segmentsig.shape[1]
            sigmean /= n_samples
            sscovdata -= n_samples * sigmean[:] * sigmean[:]
            sscovdata /= (n_samples - 1)
            if verbose:
                print ">>> no channel got worse: ", np.all(np.less_equal(sscovdata, sscovinit))
                print ">>> final rt(avg sig pwr) = %12.5e" % np.sqrt(np.mean(sscovdata))
                for i in xrange(5):
                    print ">>> final signal-rms[%3d] = %12.5e" % (i, np.sqrt(sscovdata.flatten()[i]))
                tc1 = time.clock()
                tw1 = time.time()
                print ">>> signal covar-calc took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tct), (tw1 - twt))
                print ">>>"

        if fnout is not None:
            fnoutloc = fnout
        else:
            fnoutloc = fname[:fname.rfind('-raw.fif')] + ',nr-raw.fif'

        if verbose:
            print ">>> Saving '%s'..." % fnoutloc

        if return_raw:
            return raw
        else:
            raw.save(fnoutloc, overwrite=True)

        tc1 = time.clock()
        tw1 = time.time()
        if verbose:
            print ">>> Total run took %.1f ms (%.2f s walltime)" % (1000. * (tc1 - tc0), (tw1 - tw0))
Ejemplo n.º 60
0
def combine_meeg(raw_fname, eeg_fname, flow=0.6, fhigh=200,
                 filter_order=2, njobs=-1):
    '''
    Functions combines meg data with eeg data. This is done by: -
        1. Adjust MEG and EEG data length.
        2. Resampling EEG data channels to match sampling
           frequency of MEG signals.
        3. Write EEG channels into MEG fif file and write to disk.

    Parameters
    ----------
    raw_fname: FIF file containing MEG data.
    eeg_fname: FIF file containing EEG data.
    flow, fhigh: Low and high frequency limits for filtering.
                 (default 0.6-200 Hz)
    filter_order: Order of the Butterworth filter used for filtering.
    njobs : Number of jobs.

    Warning: Please make sure that the filter settings provided
             are stable for both MEG and EEG data.
    Only channels ECG 001, EOG 001, EOG 002 and STI 014 are written.
    '''

    import numpy as np
    import mne
    from mne.utils import logger

    if not raw_fname.endswith('-meg.fif') and \
            not eeg_fname.endswith('-eeg.fif'):
        logger.warning('Files names are not standard. \
                        Please use standard file name extensions.')

    raw = mne.io.Raw(raw_fname, preload=True)
    eeg = mne.io.Raw(eeg_fname, preload=True)

    # Filter both signals
    filter_type = 'butter'
    logger.info('The MEG and EEG signals will be filtered from %s to %s' \
                % (flow, fhigh))
    picks_fil = mne.pick_types(raw.info, meg=True, eog=True, \
                               ecg=True, exclude='bads')
    raw.filter(flow, fhigh, picks=picks_fil, n_jobs=njobs, method='iir', \
               iir_params={'ftype': filter_type, 'order': filter_order})
    picks_fil = mne.pick_types(eeg.info, meg=False, eeg=True, exclude='bads')
    eeg.filter(flow, fhigh, picks=picks_fil, n_jobs=njobs, method='iir', \
               iir_params={'ftype': filter_type, 'order': filter_order})

    # Find sync pulse S128 in stim channel of EEG signal.
    start_idx_eeg = mne.find_events(eeg, stim_channel='STI 014', \
                                    output='onset')[0, 0]

    # Find sync pulse S128 in stim channel of MEG signal.
    start_idx_raw = mne.find_events(raw, stim_channel='STI 014', \
                                    output='onset')[0, 0]

    # Start times for both eeg and meg channels
    start_time_eeg = eeg.times[start_idx_eeg]
    start_time_raw = raw.times[start_idx_raw]

    # Stop times for both eeg and meg channels
    stop_time_eeg = eeg.times[eeg.last_samp]
    stop_time_raw = raw.times[raw.last_samp]

    # Choose channel with shortest duration (usually MEG)
    meg_duration = stop_time_eeg - start_time_eeg
    eeg_duration = stop_time_raw - start_time_raw
    diff_time = min(meg_duration, eeg_duration)

    # Reset both the channel times based on shortest duration
    end_time_eeg = diff_time + start_time_eeg
    end_time_raw = diff_time + start_time_raw

    # Calculate the index of the last time points
    stop_idx_eeg = eeg.time_as_index(round(end_time_eeg, 3))[0]
    stop_idx_raw = raw.time_as_index(round(end_time_raw, 3))[0]

    events = mne.find_events(eeg, stim_channel='STI 014', output='onset',
                             consecutive=True)
    events = events[np.where(events[:, 0] < stop_idx_eeg)[0], :]
    events = events[np.where(events[:, 0] > start_idx_eeg)[0], :]
    events[:, 0] -= start_idx_eeg

    eeg_data, eeg_times = eeg[:, start_idx_eeg:stop_idx_eeg]
    _, raw_times = raw[:, start_idx_raw:stop_idx_raw]

    # Resample eeg signal
    resamp_list = jumeg_resample(raw.info['sfreq'], eeg.info['sfreq'], \
                                 raw_times.shape[0], events=events)

    # Update eeg signal
    eeg._data, eeg._times = eeg_data[:, resamp_list], eeg_times[resamp_list]

    # Update meg signal
    raw._data, raw._times = raw[:, start_idx_raw:stop_idx_raw]
    raw._first_samps[0] = 0
    raw._last_samps[0] = raw._data.shape[1] - 1

    # Identify raw channels for ECG, EOG and STI and replace it with relevant data.
    logger.info('Only ECG 001, EOG 001, EOG002 and STI 014 will be updated.')
    raw._data[raw.ch_names.index('ECG 001')] = eeg._data[0]
    raw._data[raw.ch_names.index('EOG 001')] = eeg._data[1]
    raw._data[raw.ch_names.index('EOG 002')] = eeg._data[2]
    raw._data[raw.ch_names.index('STI 014')] = eeg._data[3]

    # Write the combined FIF file to disk.
    raw.save(raw_fname.split('-')[0] + '-raw.fif', overwrite=True)