Beispiel #1
0
def clean_by_interp(inst):
    """Clean epochs/evoked by LOOCV
    """
    inst_interp = inst.copy()
    mesg = 'Creating augmented epochs'
    pbar = ProgressBar(len(inst.info['ch_names']) - 1, mesg=mesg,
                       spinner=True)
    for ch_idx, ch in enumerate(inst.info['ch_names']):
        pbar.update(ch_idx + 1)
        if isinstance(inst, mne.Evoked):
            ch_orig = inst.data[ch_idx].copy()
        elif isinstance(inst, mne.Epochs):
            ch_orig = inst._data[:, ch_idx].copy()

        inst.info['bads'] = [ch]
        interpolate_bads(inst, reset_bads=True, mode='fast')

        if isinstance(inst, mne.Evoked):
            inst_interp.data[ch_idx] = inst.data[ch_idx]
            inst.data[ch_idx] = ch_orig
        elif isinstance(inst, mne.Epochs):
            inst_interp._data[:, ch_idx] = inst._data[:, ch_idx]
            inst._data[:, ch_idx] = ch_orig

    return inst_interp
Beispiel #2
0
    def _node_compute_mi(self, dataset, n_bins, n_perm, n_jobs, random_state):
        """Compute mi and permuted mi.

        Permutations are performed by randomizing the regressor variable. For
        the fixed effect, this randomization is performed across subjects. For
        the random effect, the randomization is performed per subject.
        """
        # get the function for computing mi
        mi_fun = get_core_mi_fun(self._mi_method)[self._mi_type]
        assert f"mi_{self._mi_method}_ephy_{self._mi_type}" == mi_fun.__name__
        # get x, y, z and subject names per roi
        if dataset._mi_type != self._mi_type:
            assert TypeError(f"Your dataset doesn't allow to compute the mi "
                             f"{self._mi_type}. Allowed mi is "
                             f"{dataset._mi_type}")
        x, y, z, suj = dataset.x, dataset.y, dataset.z, dataset.suj_roi
        n_roi, inf = dataset.n_roi, self._inference
        # evaluate true mi
        logger.info(f"    Evaluate true and permuted mi (n_perm={n_perm}, "
                    f"n_jobs={n_jobs})")
        # parallel function for computing permutations
        parallel, p_fun = parallel_func(mi_fun, n_jobs=n_jobs, verbose=False)
        pbar = ProgressBar(range(n_roi), mesg='Estimating MI')
        # evaluate permuted mi
        with parallel as para:
            mi, mi_p = [], []
            for r in range(n_roi):
                # compute the true mi
                mi += [mi_fun(x[r], y[r], z[r], suj[r], inf, n_bins=n_bins)]

                # get the randomize version of y
                y_p = permute_mi_vector(y[r],
                                        suj[r],
                                        mi_type=self._mi_type,
                                        inference=self._inference,
                                        n_perm=n_perm)
                # run permutations using the randomize regressor
                _mi = para(
                    p_fun(x[r], y_p[p], z[r], suj[r], inf, n_bins=n_bins)
                    for p in range(n_perm))
                mi_p += [np.asarray(_mi)]
                pbar.update_with_increment_value(1)
        # smoothing
        if isinstance(self._kernel, np.ndarray):
            logger.info("    Apply smoothing to the true and permuted MI")
            for r in range(len(mi)):
                for s in range(mi[r].shape[0]):
                    mi[r][s, :] = np.convolve(mi[r][s, :],
                                              self._kernel,
                                              mode='same')
                    for p in range(mi_p[r].shape[0]):
                        mi_p[r][p, s, :] = np.convolve(mi_p[r][p, s, :],
                                                       self._kernel,
                                                       mode='same')

        self._mi, self._mi_p = mi, mi_p

        return mi, mi_p
Beispiel #3
0
 def permutations():
     """Generator for the permutations with optional progress bar."""
     if verbose:
         progress = ProgressBar(len(sign_flips),
                                mesg='Performing permutations')
         for i, sign_flip in enumerate(sign_flips):
             progress.update(i)
             yield sign_flip
     else:
         for sign_flip in sign_flips:
             yield sign_flip
Beispiel #4
0
    def _node_compute_mi(self, dataset, n_perm, n_jobs, random_state):
        """Compute mi and permuted mi.

        Permutations are performed by randomizing the target roi. For the fixed
        effect, this randomization is performed across subjects. For the random
        effect, the randomization is performed per subject.
        """
        # get the function for computing mi
        core_fun = self.estimator.get_function()
        # get the pairs for computing mi
        df_conn, _ = dataset.get_connectivity_pairs(
            directed=False, as_blocks=True)
        sources, targets = df_conn['sources'], df_conn['targets']
        self._pair_names = np.concatenate(df_conn['names'])
        n_pairs = len(self._pair_names)
        # parallel function for computing permutations
        parallel, p_fun = parallel_func(comod, n_jobs=n_jobs, verbose=False)
        pbar = ProgressBar(range(n_pairs), mesg='Estimating MI')
        # evaluate true mi
        mi, mi_p, inf = [], [], self._inference
        kw_get = dict(mi_type=self._mi_type, copnorm=self._copnorm,
                      gcrn_per_suj=self._gcrn)
        for n_s, s in enumerate(sources):
            # get source data
            da_s = dataset.get_roi_data(s, **kw_get)
            suj_s = da_s['subject'].data
            for t in targets[n_s]:
                # get target data
                da_t = dataset.get_roi_data(t, **kw_get)
                suj_t = da_t['subject'].data

                # compute mi
                _mi = comod(da_s.data, da_t.data, suj_s, suj_t, inf,
                            core_fun)
                # get the randomize version of y
                y_p = permute_mi_trials(suj_t, inference=inf, n_perm=n_perm)
                # run permutations using the randomize regressor
                _mi_p = parallel(p_fun(
                    da_s.data, da_t.data[..., y_p[p]], suj_s, suj_t, inf,
                    core_fun) for p in range(n_perm))
                _mi_p = np.asarray(_mi_p)

                # kernel smoothing
                if isinstance(self._kernel, np.ndarray):
                    _mi = kernel_smoothing(_mi, self._kernel, axis=-1)
                    _mi_p = kernel_smoothing(_mi_p, self._kernel, axis=-1)

                mi += [_mi]
                mi_p += [_mi_p]
                pbar.update_with_increment_value(1)

        self._mi, self._mi_p = mi, mi_p

        return mi, mi_p
Beispiel #5
0
 def permutations():
     """Generator for the permutations with optional progress bar."""
     if verbose:
         progress = ProgressBar(len(Beh_perms),
                                mesg='Performing permutations')
         for i, Beh_perm in enumerate(Beh_perms):
             progress.update(i)
             yield Beh_perm
     else:
         for Beh_perm in Beh_perms:
             yield Beh_perm
Beispiel #6
0
def test_progressbar_parallel_more(capsys):
    """Test ProgressBar with parallel computing, advanced version."""
    assert capsys.readouterr().out == ''
    # This must be "1" because "capsys" won't get stdout properly otherwise
    parallel, p_fun, _ = parallel_func(_identity_block_wide,
                                       n_jobs=1,
                                       verbose=False)
    arr = np.arange(10)
    with use_log_level(True):
        with ProgressBar(len(arr) * 2) as pb:
            out = parallel(
                p_fun(x, pb.subset(pb_idx))
                for pb_idx, x in array_split_idx(arr, 2, n_per_split=2))
            idxs = np.concatenate([o[1] for o in out])
            assert_array_equal(idxs, np.arange(len(arr) * 2))
            out = np.concatenate([o[0] for o in out])
            assert op.isfile(pb._mmap_fname)
            sum_ = np.memmap(pb._mmap_fname,
                             dtype='bool',
                             mode='r',
                             shape=len(arr) * 2).sum()
            assert sum_ == len(arr) * 2
    assert not op.isfile(pb._mmap_fname), '__exit__ not called?'
    cap = capsys.readouterr()
    out = cap.err
    assert '100%' in out
Beispiel #7
0
def _pbar(iterable, desc, leave=True, position=None, verbose='progressbar'):

    if verbose is not False and \
            verbose not in ['progressbar', 'tqdm', 'tqdm_notebook']:
        raise ValueError('verbose must be one of {progressbar,'
                         'tqdm, tqdm_notebook, False}. Got %s' % verbose)

    if verbose == 'progressbar':
        from mne.utils import ProgressBar
        pbar = ProgressBar(iterable, mesg=desc, spinner=True)
        print('')
    elif verbose == 'tqdm':
        from tqdm import tqdm
        pbar = tqdm(iterable,
                    desc=desc,
                    leave=leave,
                    position=position,
                    dynamic_ncols=True)
    elif verbose == 'tqdm_notebook':
        from tqdm import tqdm_notebook
        pbar = tqdm_notebook(iterable,
                             desc=desc,
                             leave=leave,
                             position=position,
                             dynamic_ncols=True)
    elif verbose is False:
        pbar = iterable
    return pbar
Beispiel #8
0
def _pbar(iterable, desc, leave=True, position=None, verbose='progressbar'):

    verbose = False if verbose == 0 else verbose
    if verbose is not False and \
            verbose not in ['progressbar', 'tqdm', 'tqdm_notebook']:
        raise ValueError('verbose must be one of {progressbar,'
                         'tqdm, tqdm_notebook, False}. Got %s' % verbose)

    try:
        from tqdm import tqdm
        verbose = 'tqdm'
    except ImportError:
        pass

    if verbose == 'progressbar':
        from mne.utils import ProgressBar
        pbar = ProgressBar(iterable, mesg=desc)
    # XXX: remove the tqdm option after a few releases of MNE since it
    # natively supported by the MNE progressbar
    elif verbose == 'tqdm':
        pbar = tqdm(iterable, desc=desc, leave=leave, position=position,
                    dynamic_ncols=True)
    elif verbose == 'tqdm_notebook':
        from tqdm import tqdm_notebook
        pbar = tqdm_notebook(iterable, desc=desc, leave=leave,
                             position=position, dynamic_ncols=True)
    elif verbose is False:
        pbar = iterable
    return pbar
Beispiel #9
0
def test_progressbar():
    a = np.arange(10)
    pbar = ProgressBar(a)
    assert_equal(a, pbar.iterable)
    assert_equal(10, pbar.max_value)

    pbar = ProgressBar(10)
    assert_equal(10, pbar.max_value)
    assert_true(pbar.iterable is None)

    # Make sure that non-iterable input raises an error
    def iter_func(a):
        for ii in a:
            pass

    assert_raises(ValueError, iter_func, ProgressBar(20))
Beispiel #10
0
Datei: rsa.py Projekt: Fosca/umne
def _compute_dissimilarity(data1, data2, metric, debug=False):

    if metric == 'spearmanr':
        metric = _spearmanr

    assert data1.shape[1] == data1.shape[
        1], "Expecting the same number of channels"
    assert data1.shape[2] == data1.shape[
        2], "Expecting the same number of time points"

    n_timepoints = data1.shape[2]

    if debug:
        print(
            'Computing dissimilarity (method={}): computing a {}*{} dissimilarity matrix using correlations for each of {} time points...'
            .format(metric, data1.shape[0], data2.shape[0], n_timepoints))

    pb = ProgressBar(n_timepoints, mesg="Computing dissimilarity")

    def run_per_timepoint(t):
        d = pairwise_distances(data1[:, :, t],
                               data2[:, :, t],
                               metric=metric,
                               n_jobs=multiprocessing.cpu_count())
        pb.update(t + 1)
        return d

    dissim_matrix_per_timepoint = np.asarray(
        [run_per_timepoint(t) for t in range(n_timepoints)])

    return dissim_matrix_per_timepoint
def test_progressbar():
    """Test progressbar class."""
    a = np.arange(10)
    pbar = ProgressBar(a)
    assert a is pbar.iterable
    assert pbar.max_value == 10

    pbar = ProgressBar(10)
    assert pbar.max_value == 10
    assert pbar.iterable is None

    # Make sure that non-iterable input raises an error
    def iter_func(a):
        for ii in a:
            pass
    pytest.raises(ValueError, iter_func, ProgressBar(20))
Beispiel #12
0
 def _upload_chunk(self, client, f_client, f_server):
     from StringIO import StringIO
     file_obj = open(f_client, 'rb')
     target_length = os.path.getsize(f_client)
     chunk_size = 10 * 1024 * 1024
     offset = 0
     uploader = client.get_chunked_uploader(file_obj, target_length)
     last_block = None
     params = dict()
     pbar = ProgressBar(target_length, spinner=True)
     error_count = 0
     while offset < target_length:
         if error_count > 3:
             raise RuntimeError
         pbar.update(offset)
         next_chunk_size = min(chunk_size, target_length - offset)
         # read data if last chunk passed
         if last_block is None:
             last_block = file_obj.read(next_chunk_size)
         # set parameters
         if offset > 0:
             params = dict(upload_id=uploader.upload_id, offset=offset)
         try:
             url, ignored_params, headers = client.request(
                 "/chunked_upload", params, method='PUT',
                 content_server=True)
             reply = client.rest_client.PUT(url, StringIO(last_block),
                                            headers)
             new_offset = reply['offset']
             uploader.upload_id = reply['upload_id']
             # avoid reading data if last chunk didn't pass
             if new_offset > offset:
                 offset = new_offset
                 last_block = None
                 error_count == 0
             else:
                 error_count += 1
         except Exception:
             error_count += 1
     if target_length > 0:
         pbar.update(target_length)
     print('')
     file_obj.close()
     uploader.finish(f_server, overwrite=True)
Beispiel #13
0
    def _interpolate_bad_epochs(self, epochs, ch_type):
        """interpolate the bad epochs.

        Parameters
        ----------
        epochs : instance of mne.Epochs
            The epochs object which must be fixed.
        """
        drop_log = self.drop_log
        # 1: bad segment, # 2: interpolated, # 3: dropped
        self.fix_log = self.drop_log.copy()
        ch_names = drop_log.columns.values
        n_consensus = self.consensus_perc * len(ch_names)
        pbar = ProgressBar(len(epochs) - 1,
                           mesg='Repairing epochs: ',
                           spinner=True)
        # TODO: raise error if preload is not True
        for epoch_idx in range(len(epochs)):
            pbar.update(epoch_idx + 1)
            # ch_score = self.scores_[ch_type][epoch_idx]
            # sorted_ch_idx = np.argsort(ch_score)
            n_bads = drop_log.ix[epoch_idx].sum()
            if n_bads == 0 or n_bads > n_consensus:
                continue
            else:
                if n_bads <= self.n_interpolate:
                    bad_chs = drop_log.ix[epoch_idx].values == 1
                else:
                    # get peak-to-peak for channels in that epoch
                    data = epochs[epoch_idx].get_data()[0, :, :]
                    peaks = np.ptp(data, axis=-1)
                    # find channels which are bad by rejection threshold
                    bad_chs = np.where(drop_log.ix[epoch_idx].values == 1)[0]
                    # find the ordering of channels amongst the bad channels
                    sorted_ch_idx = np.argsort(peaks[bad_chs])[::-1]
                    # then select only the worst n_interpolate channels
                    bad_chs = bad_chs[sorted_ch_idx[:self.n_interpolate]]

            self.fix_log.ix[epoch_idx][bad_chs] = 2
            bad_chs = ch_names[bad_chs].tolist()
            epoch = epochs[epoch_idx]
            epoch.info['bads'] = bad_chs
            interpolate_bads(epoch, reset_bads=True)
            epochs._data[epoch_idx] = epoch._data
Beispiel #14
0
    def _interpolate_bad_epochs(self, epochs, ch_type):
        """interpolate the bad epochs.

        Parameters
        ----------
        epochs : instance of mne.Epochs
            The epochs object which must be fixed.
        """
        drop_log = self.drop_log
        # 1: bad segment, # 2: interpolated, # 3: dropped
        self.fix_log = self.drop_log.copy()
        ch_names = drop_log.columns.values
        n_consensus = self.consensus_perc * len(ch_names)
        pbar = ProgressBar(len(epochs) - 1, mesg='Repairing epochs: ',
                           spinner=True)
        # TODO: raise error if preload is not True
        for epoch_idx in range(len(epochs)):
            pbar.update(epoch_idx + 1)
            # ch_score = self.scores_[ch_type][epoch_idx]
            # sorted_ch_idx = np.argsort(ch_score)
            n_bads = drop_log.ix[epoch_idx].sum()
            if n_bads == 0 or n_bads > n_consensus:
                continue
            else:
                if n_bads <= self.n_interpolate:
                    bad_chs = drop_log.ix[epoch_idx].values == 1
                else:
                    # get peak-to-peak for channels in that epoch
                    data = epochs[epoch_idx].get_data()[0, :, :]
                    peaks = np.ptp(data, axis=-1)
                    # find channels which are bad by rejection threshold
                    bad_chs = np.where(drop_log.ix[epoch_idx].values == 1)[0]
                    # find the ordering of channels amongst the bad channels
                    sorted_ch_idx = np.argsort(peaks[bad_chs])[::-1]
                    # then select only the worst n_interpolate channels
                    bad_chs = bad_chs[sorted_ch_idx[:self.n_interpolate]]

            self.fix_log.ix[epoch_idx][bad_chs] = 2
            bad_chs = ch_names[bad_chs].tolist()
            epoch = epochs[epoch_idx]
            epoch.info['bads'] = bad_chs
            interpolate_bads(epoch, reset_bads=True)
            epochs._data[epoch_idx] = epoch._data
Beispiel #15
0
def test_progressbar():
    """Test progressbar class."""
    a = np.arange(10)
    pbar = ProgressBar(a)
    assert a is pbar.iterable
    assert pbar.max_value == 10

    pbar = ProgressBar(10)
    assert pbar.max_value == 10
    assert pbar.iterable is None

    # Make sure that non-iterable input raises an error
    def iter_func(a):
        for ii in a:
            pass
    pytest.raises(Exception, iter_func, ProgressBar(20))

    # Make sure different progress bars can be used
    with catch_logging() as log, modified_env(MNE_TQDM='tqdm'), \
            use_log_level('debug'), ProgressBar(np.arange(3)) as pbar:
        for p in pbar:
            pass
    log = log.getvalue()
    assert 'Using ProgressBar with tqdm\n' in log
    with modified_env(MNE_TQDM='broken'), pytest.raises(ValueError):
        ProgressBar(np.arange(3))
    with modified_env(MNE_TQDM='tqdm.broken'), pytest.raises(AttributeError):
        ProgressBar(np.arange(3))
Beispiel #16
0
def ica_all():
    """Filter all of the EEG data in a directory and save.

    Parameters
    ----------
    l_freq : float
        Low-pass frequency (Hz).
    h_freq : float
        High-pass frequency (Hz).
    read_dir : str
        Directory from which to read the data.
    save_dir : str
        Directory in which to save the filtered data.

    """
    parser = argparse.ArgumentParser(prog='1_filter_all.py',
                                     description=__doc__)
    parser.add_argument('-i', '--input', type=str, required=True,
                        help="Directory of files to be filtered.")
    parser.add_argument('-o', '--output', type=str, required=True,
                        help="Directory in which to save filtered files.")
    parser.add_argument('-m', '--method', type=str, default='extended-infomax',
                        help='ICA method to use.')
    parser.add_argument('-v', '--verbose', type=str, default='error')
    args = parser.parse_args()

    input_dir = op.abspath(args.input)
    output_dir = op.abspath(args.output)
    ica_method = args.method

    if not op.exists(input_dir):
        sys.exit("Input directory not found.")
    if not op.exists(output_dir):
        sys.exit("Output directory not found.")

    set_log_level(verbose=args.verbose)

    input_fnames = op.join(input_dir, '*.fif')
    input_fnames = glob(input_fnames)
    n_files = len(input_fnames)

    print("Preparing to ICA {n} files".format(n=n_files))
    # Initialize a progress bar.
    progress = ProgressBar(n_files, mesg='Performing ICA')
    progress.update_with_increment_value(0)
    for fname in input_fnames:
        # Open file.
        raw = io.read_raw_fif(fname, preload=True, add_eeg_ref=False)
        # Perform ICA.
        ica = ICA(method=ica_method).fit(raw)
        # Save file.
        save_fname = op.splitext(op.split(fname)[-1])[0]
        save_fname += '-ica'
        save_fname = op.join(output_dir, save_fname)
        ica.save(save_fname + '.fif')
        # Update progress bar.
        progress.update_with_increment_value(1)

    print("")  # Get onto new line once progressbar completes.
Beispiel #17
0
def _ransac_by_window(data, interpolation_mats, win_size, win_count, matlab_strict):
    """Calculate correlations of channels with their RANSAC-predicted values.

    This function calculates RANSAC correlations for each RANSAC window
    individually, requiring RAM equivalent to [channels * sample rate * seconds
    per RANSAC window] to run. Generally, this method will use less RAM than
    :func:`_ransac_by_channel`, with the exception of short recordings with high
    electrode counts.

    Parameters
    ----------
    data : np.ndarray
        A 2-D array containing the EEG signals from all currently-good channels.
    interpolation_mats : list of np.ndarray
        A list of interpolation matrices, one for each RANSAC sample of channels.
    win_size : int
        Number of frames/samples of EEG data in each RANSAC correlation window.
    win_count: int
        Number of RANSAC correlation windows.
    matlab_strict : bool
        Whether or not RANSAC should strictly follow MATLAB PREP's internal
        math, ignoring any improvements made in PyPREP over the original code.

    Returns
    -------
    correlations : np.ndarray
        Correlations of the given channels to their predicted values within each
        RANSAC window.

    """
    ch_count = data.shape[0]
    correlations = np.ones((win_count, ch_count))

    pb = ProgressBar(range(win_count))
    for window in pb:
        # Get the current window of EEG data
        start = window * win_size
        end = (window + 1) * win_size
        actual = data[:, start:end]

        # Get the median RANSAC-predicted signal for each channel
        predicted = _predict_median_signals(actual, interpolation_mats, matlab_strict)

        # Calculate the actual vs predicted signal correlation for each channel
        correlations[window, :] = _correlate_arrays(actual, predicted, matlab_strict)

    return correlations
def test_progressbar_parallel_advanced(capsys):
    """Test ProgressBar with parallel computing, advanced version."""
    assert capsys.readouterr().out == ''
    # This must be "1" because "capsys" won't get stdout properly otherwise
    parallel, p_fun, _ = parallel_func(_identity_block, n_jobs=1,
                                       verbose=False)
    arr = np.arange(10)
    with ProgressBar(len(arr), verbose_bool=True) as pb:
        out = parallel(p_fun(x, pb.subset(pb_idx))
                       for pb_idx, x in array_split_idx(arr, 2))
        assert op.isfile(pb._mmap_fname)
        sum_ = np.memmap(pb._mmap_fname, dtype='bool', mode='r',
                         shape=10).sum()
        assert sum_ == len(arr)
    assert not op.isfile(pb._mmap_fname), '__exit__ not called?'
    out = np.concatenate(out)
    assert_array_equal(out, arr)
    assert '100.00%' in capsys.readouterr().out
Beispiel #19
0
def average_maps(img_fnames, target_img, desc):
    avg = np.zeros_like(target_img.get_data())
    n_images = 0
    print('')
    for ii, image_fname in enumerate(
            ProgressBar(img_fnames, mesg=desc, spinner=True)):
        collection, name = image_fname.split('/')[-2:]
        img = read_resampled_img(image_fname)
        try:
            data = img.get_data()
        except IOError:
            continue
        # print('Image %d (max = %f)'
        #       % (ii, img.get_data().max()))
        if not np.any(np.isnan(data / data.std())):
            avg += data / data.std()
            n_images += 1
        else:
            continue
    avg /= n_images
    return avg
Beispiel #20
0
 def _upload_chunk(self, client, f_client, f_server):
     from StringIO import StringIO
     file_obj = open(f_client, 'rb')
     target_length = os.path.getsize(f_client)
     chunk_size = 10 * 1024 * 1024
     offset = 0
     uploader = client.get_chunked_uploader(file_obj, target_length)
     last_block = None
     params = dict()
     pbar = ProgressBar(target_length, spinner=True)
     error_count = 0
     while offset < target_length:
         if error_count > 3:
             raise RuntimeError
         pbar.update(offset)
         next_chunk_size = min(chunk_size, target_length - offset)
         # read data if last chunk passed
         if last_block is None:
             last_block = file_obj.read(next_chunk_size)
         # set parameters
         if offset > 0:
             params = dict(upload_id=uploader.upload_id, offset=offset)
         try:
             url, ignored_params, headers = client.request(
                 "/chunked_upload",
                 params,
                 method='PUT',
                 content_server=True)
             reply = client.rest_client.PUT(url, StringIO(last_block),
                                            headers)
             new_offset = reply['offset']
             uploader.upload_id = reply['upload_id']
             # avoid reading data if last chunk didn't pass
             if new_offset > offset:
                 offset = new_offset
                 last_block = None
                 error_count == 0
             else:
                 error_count += 1
         except Exception:
             error_count += 1
     if target_length > 0:
         pbar.update(target_length)
     print('')
     file_obj.close()
     uploader.finish(f_server, overwrite=True)
Beispiel #21
0
def test_progressbar(monkeypatch):
    """Test progressbar class."""
    a = np.arange(10)
    pbar = ProgressBar(a)
    assert a is pbar.iterable
    assert pbar.max_value == 10

    pbar = ProgressBar(10)
    assert pbar.max_value == 10
    assert pbar.iterable is None

    # Make sure that non-iterable input raises an error
    def iter_func(a):
        for ii in a:
            pass

    with pytest.raises(TypeError, match='not iterable'):
        iter_func(pbar)

    # Make sure different progress bars can be used
    monkeypatch.setenv('MNE_TQDM', 'tqdm')
    with catch_logging('debug') as log, ProgressBar(np.arange(3)) as pbar:
        for p in pbar:
            pass
    log = log.getvalue()
    assert 'Using ProgressBar with tqdm\n' in log
    monkeypatch.setenv('MNE_TQDM', 'broken')
    with pytest.raises(ValueError, match='Invalid value for the'):
        ProgressBar(np.arange(3))
    monkeypatch.setenv('MNE_TQDM', 'tqdm.broken')
    with pytest.raises(ValueError, match='Unknown tqdm'):
        ProgressBar(np.arange(3))
    # off
    monkeypatch.setenv('MNE_TQDM', 'off')
    with catch_logging('debug') as log, ProgressBar(np.arange(3)) as pbar:
        for p in pbar:
            pass
    log = log.getvalue()
    assert 'Using ProgressBar with off\n' == log
Beispiel #22
0
def filter_all():
    """Filter all of the EEG data in a directory and save.

    Parameters
    ----------
    l_freq : float
        Low-pass frequency (Hz).
    h_freq : float
        High-pass frequency (Hz).
    read_dir : str
        Directory from which to read the data.
    save_dir : str
        Directory in which to save the filtered data.

    """
    parser = argparse.ArgumentParser(prog='1_filter_all.py',
                                     description=__doc__)
    parser.add_argument('-i', '--input', type=str, required=True,
                        help="Directory of files to be filtered.")
    parser.add_argument('-o', '--output', type=str, required=True,
                        help="Directory in which to save filtered files.")
    parser.add_argument('-lp', '--lowpass', type=float, required=True,
                        help="Low-pass frequency (Hz).")
    parser.add_argument('-hp', '--highpass', type=float, required=True,
                        help="High-pass frequency (Hz).")
    parser.add_argument('-m', '--montage', default='Enobio32',
                        help='Electrode montage.')
    parser.add_argument('-ow', '--overwrite', type=bool, default='False',
                        help='If True, overwrites file if file exists.')
    parser.add_argument('-v', '--verbose', default='error')
    args = parser.parse_args()

    input_dir = op.abspath(args.input)
    output_dir = op.abspath(args.output)
    l_freq, h_freq = args.highpass, args.lowpass
    montage = args.montage
    overwrite = args.overwrite

    if not op.exists(input_dir):
        sys.exit("Input directory not found.")
    if not op.exists(output_dir):
        sys.exit("Output directory not found.")

    set_log_level(verbose=args.verbose)

    input_fnames = op.join(input_dir, '*.easy')
    input_fnames = glob(input_fnames)
    n_files = len(input_fnames)

    print("Preparing to filter {n} files".format(n=n_files))
    # Initialize a progress bar.
    progress = ProgressBar(n_files, mesg='Filtering')

    failed_files = []
    for fname in input_fnames:
        try:
            raw = read_raw_enobio(fname, montage=montage)
        except UserWarning:
            failed_files.append(fname)
            progress.update_with_increment_value(1)
            continue

        # High- and low-pass filter separately.
        raw.filter(l_freq=l_freq, h_freq=None, phase='zero',
                   fir_window='hamming', l_trans_bandwidth='auto',
                   h_trans_bandwidth='auto', filter_length='auto')
        raw.filter(l_freq=None, h_freq=h_freq, phase='zero',
                   fir_window='hamming', l_trans_bandwidth='auto',
                   h_trans_bandwidth='auto', filter_length='auto')

        # Create a new name for the filtered file.
        new_fname = op.split(fname)[-1]  # Remove path.
        new_fname = op.splitext(new_fname)[0]  # Remove extension.
        new_fname = new_fname[15:]  # Remove timestamp.
        new_fname = new_fname.replace("_Protocol 1", "")  # Remove Protocol 1.

        new_fname += '-firfilt'  # Indicate that we filtered.

        # Check for duplicates.
        base_name_to_check = op.join(output_dir, new_fname)
        if op.isfile(base_name_to_check + '.fif'):
            i = 1
            while op.isfile(base_name_to_check + '_{}.fif'.format(i)):
                i += 1
            new_fname += "_{}".format(i)

        raw.info['filename'] = new_fname  # Add this to the raw info dictionary.

        # Save the filtered file with a new name.
        save_fname = op.join(output_dir, new_fname)
        raw.save(save_fname + '.fif', overwrite=overwrite)

        # Update progress bar.
        progress.update_with_increment_value(1)

    print("")  # Get onto new line once progressbar completes.
    print("Failed on these files: {}".format(failed_files))
Beispiel #23
0
def _phase_amplitude_coupling(data,
                              sfreq,
                              f_phase,
                              f_amp,
                              ixs,
                              pac_func='ozkurt',
                              events=None,
                              tmin=None,
                              tmax=None,
                              n_cycles_ph=3,
                              n_cycles_am=3,
                              scale_amp_func=None,
                              return_data=False,
                              concat_epochs=False,
                              n_jobs=1,
                              verbose=None):
    """ Compute phase-amplitude coupling using pacpy.

    Parameters
    ----------
    data : array, shape ([n_epochs], n_channels, n_times)
        The data used to calculate PAC
    sfreq : float
        The sampling frequency of the data.
    f_phase : array, dtype float, shape (n_bands_phase, 2,)
        The frequency ranges to use for the phase carrier. PAC will be
        calculated between n_bands_phase * n_bands_amp frequencies.
    f_amp : array, dtype float, shape (n_bands_amp, 2,)
        The frequency ranges to use for the phase-modulated amplitude.
        PAC will be calculated between n_bands_phase * n_bands_amp frequencies.
    ixs : array-like, shape (n_ch_pairs x 2)
        The indices for low/high frequency channels. PAC will be estimated
        between n_ch_pairs of channels. Indices correspond to rows of `data`.
    pac_func : {'plv', 'glm', 'mi_canolty', 'mi_tort', 'ozkurt'} |
               list of strings
        The function for estimating PAC. Corresponds to functions in
        `pacpy.pac`. Defaults to 'ozkurt'. If multiple frequency bands are used
        then `plv` cannot be calculated.
    events : array, shape (n_events, 3) | array, shape (n_events,) | None
        MNE events array. To be supplied if data is 2D and output should be
        split by events. In this case, `tmin` and `tmax` must be provided. If
        `ndim == 1`, it is assumed to be event indices, and all events will be
        grouped together.
    tmin : float | list of floats, shape (n_pac_windows,) | None
        If `events` is not provided, it is the start time to use in `inst`.
        If `events` is provided, it is the time (in seconds) to include before
        each event index. If a list of floats is given, then PAC is calculated
        for each pair of `tmin` and `tmax`. Defaults to `min(inst.times)`.
    tmax : float | list of floats, shape (n_pac_windows,) | None
        If `events` is not provided, it is the stop time to use in `inst`.
        If `events` is provided, it is the time (in seconds) to include after
        each event index. If a list of floats is given, then PAC is calculated
        for each pair of `tmin` and `tmax`. Defaults to `max(inst.n_times)`.
    n_cycles_ph : float, int | array of floats, shape (n_bands_phase,)
        The number of cycles to be included in the window for each band-pass
        filter for phase. Defaults to 3.
    n_cycles_am : float, int | array of floats, shape (n_bands_amp,)
        The number of cycles to be included in the window for each band-pass
        filter for amplitude. Defaults to 3.
    scale_amp_func : None | function
        If not None, will be called on each amplitude signal in order to scale
        the values. Function must accept an N-D input and will operate on the
        last dimension. E.g., `sklearn.preprocessing.scale`.
        Defaults to no scaling.
    return_data : bool
        If False, output will be `[pac_out]`. If True, output will be,
        `[pac_out, phase_signal, amp_signal]`.
    concat_epochs : bool
        If True, epochs will be concatenated before calculating PAC values. If
        epochs are relatively short, this is a good idea in order to improve
        stability of the PAC metric.
    n_jobs : int
        Number of jobs to run in parallel. Defaults to 1.
    verbose : bool, str, int, or None
        If not None, override default verbose level (see `mne.verbose`).

    Returns
    -------
    pac_out : array, list of arrays, dtype float,
              shape([n_pac_funcs], n_epochs, n_channel_pairs,
                    n_freq_pairs, n_pac_windows).
        The computed phase-amplitude coupling between each pair of data sources
        given in ixs. If multiple pac metrics are specified, there will be one
        array per metric in the output list. If n_pac_funcs is 1, then the
        first dimension will be dropped.
    [phase_signal] : array, shape (n_phase_signals, n_times,)
        Only returned if `return_data` is True. The phase timeseries of the
        phase signals (first column of `ixs`).
    [amp_signal] : array, shape (n_amp_signals, n_times,)
        Only returned if `return_data` is True. The amplitude timeseries of the
        amplitude signals (second column of `ixs`).
    """
    from ..externals.pacpy import pac as ppac
    pac_func = np.atleast_1d(pac_func)
    for i_func in pac_func:
        if i_func not in _pac_funcs:
            raise ValueError("PAC function %s is not supported" % i_func)
    n_pac_funcs = pac_func.shape[0]
    ixs = np.array(ixs, ndmin=2)
    n_ch_pairs = ixs.shape[0]
    tmin = 0 if tmin is None else tmin
    tmin = np.atleast_1d(tmin)
    n_pac_windows = len(tmin)
    tmax = (data.shape[-1] - 1) / float(sfreq) if tmax is None else tmax
    tmax = np.atleast_1d(tmax)
    f_phase = np.atleast_2d(f_phase)
    f_amp = np.atleast_2d(f_amp)
    n_cycles_ph = np.atleast_1d(n_cycles_ph)
    n_cycles_am = np.atleast_1d(n_cycles_am)
    if n_cycles_ph.shape[0] == 1:
        n_cycles_ph = np.repeat(n_cycles_ph, f_phase.shape[0])
    if n_cycles_am.shape[0] == 1:
        n_cycles_am = np.repeat(n_cycles_am, f_amp.shape[0])

    if data.ndim != 2:
        raise ValueError('Data must be shape (n_channels, n_times)')
    if ixs.shape[1] != 2:
        raise ValueError('Indices must have have a 2nd dimension of length 2')
    if f_phase.shape[-1] != 2 or f_amp.shape[-1] != 2:
        raise ValueError('Frequencies must be specified w/ a low/hi tuple')
    if len(tmin) != len(tmax):
        raise ValueError('tmin and tmax have differing lengths')
    if any(i_f.shape[0] > 1 and 'plv' in pac_func for i_f in (f_amp, f_phase)):
        raise ValueError('If calculating PLV, must use a single pair of freqs')
    for icyc, i_f in zip([n_cycles_ph, n_cycles_am], [f_phase, f_amp]):
        if icyc.shape[0] != i_f.shape[0]:
            raise ValueError("n_cycles must match n_freq_bands")
        if icyc.ndim > 1:
            raise ValueError("n_cycles must be 1-d, not {}d".format(icyc.ndim))

    logger.info('Pre-filtering data and extracting phase/amplitude...')
    hi_phase = np.unique([i_func in _hi_phase_funcs for i_func in pac_func])
    if len(hi_phase) != 1:
        raise ValueError("Can't mix pac funcs that use both hi-freq phase/amp")
    hi_phase = bool(hi_phase[0])
    data_ph, data_am, ix_map_ph, ix_map_am = _pre_filter_ph_am(
        data,
        sfreq,
        ixs,
        f_phase,
        f_amp,
        hi_phase=hi_phase,
        scale_amp_func=scale_amp_func,
        n_cycles_ph=n_cycles_ph,
        n_cycles_am=n_cycles_am)

    # So we know how big the PAC output will be
    if events is None:
        n_epochs = 1
    elif concat_epochs is True:
        if events.ndim == 1:
            n_epochs = 1
        else:
            n_epochs = np.unique(events[:, -1]).shape[0]
    else:
        n_epochs = events.shape[0]

    # Iterate through each pair of frequencies
    ixs_freqs = product(range(data_ph.shape[1]), range(data_am.shape[1]))
    ixs_freqs = np.atleast_2d(list(ixs_freqs))

    freq_pac = np.array([[f_phase[ii], f_amp[jj]] for ii, jj in ixs_freqs])
    n_f_pairs = len(ixs_freqs)
    pac = np.zeros(
        [n_pac_funcs, n_epochs, n_ch_pairs, n_f_pairs, n_pac_windows])
    for i_f_pair, (ix_f_ph, ix_f_am) in enumerate(ixs_freqs):
        # Second dimension is frequency
        i_f_data_ph = data_ph[:, ix_f_ph, ...]
        i_f_data_am = data_am[:, ix_f_am, ...]

        # Redefine indices to match the new data arrays
        ixs_new = [(ix_map_ph[i], ix_map_am[j]) for i, j in ixs]
        i_f_data_ph = mne.io.RawArray(
            i_f_data_ph, mne.create_info(i_f_data_ph.shape[0], sfreq))
        i_f_data_am = mne.io.RawArray(
            i_f_data_am, mne.create_info(i_f_data_am.shape[0], sfreq))

        # Turn into Epochs if we have defined events
        if events is not None:
            i_f_data_ph = _raw_to_epochs_mne(i_f_data_ph, events, tmin, tmax)
            i_f_data_am = _raw_to_epochs_mne(i_f_data_am, events, tmin, tmax)

        # Data is either Raw or Epochs
        pbar = ProgressBar(n_epochs)
        for itime, (i_tmin, i_tmax) in enumerate(zip(tmin, tmax)):
            # Pull times of interest
            with warnings.catch_warnings():  # To suppress a depracation
                warnings.simplefilter("ignore")
                # Not sure how to do this w/o copying
                i_t_data_am = i_f_data_am.copy().crop(i_tmin, i_tmax)
                i_t_data_ph = i_f_data_ph.copy().crop(i_tmin, i_tmax)

            if concat_epochs is True:
                # Iterate through each event type and hstack
                con_data_ph = []
                con_data_am = []
                for i_ev in i_t_data_am.event_id.keys():
                    con_data_ph.append(np.hstack(i_t_data_ph[i_ev]._data))
                    con_data_am.append(np.hstack(i_t_data_am[i_ev]._data))
                i_t_data_ph = np.vstack(con_data_ph)
                i_t_data_am = np.vstack(con_data_am)
            else:
                # Just pull all epochs separately
                i_t_data_ph = i_t_data_ph._data
                i_t_data_am = i_t_data_am._data
            # Now make sure that inputs to the loop are ep x chan x time
            if i_t_data_am.ndim == 2:
                i_t_data_ph = i_t_data_ph[np.newaxis, ...]
                i_t_data_am = i_t_data_am[np.newaxis, ...]
            # Loop through epochs (or epoch grps), each index pair, and funcs
            data_iter = zip(i_t_data_ph, i_t_data_am)
            for iep, (ep_ph, ep_am) in enumerate(data_iter):
                for iix, (i_ix_ph, i_ix_am) in enumerate(ixs_new):
                    for ix_func, i_pac_func in enumerate(pac_func):
                        func = getattr(ppac, i_pac_func)
                        pac[ix_func, iep, iix, i_f_pair,
                            itime] = func(ep_ph[i_ix_ph],
                                          ep_am[i_ix_am],
                                          f_phase,
                                          f_amp,
                                          filterfn=False)
            pbar.update_with_increment_value(1)
    if pac.shape[0] == 1:
        pac = pac[0]
    if return_data:
        return pac, freq_pac, data_ph, data_am
    else:
        return pac, freq_pac
Beispiel #24
0
def epoch_all(main_event_id=EVENT_ID):
    """Epoch EEG data and save the epochs.

    Parameters
    ----------
    input : str
        Directory of files to be epoched.
    """

    parser = argparse.ArgumentParser(prog='1_filter_all.py',
                                     description=__doc__)
    parser.add_argument('-i',
                        '--input',
                        type=str,
                        required=True,
                        help="Directory of files to be epoched.")
    parser.add_argument('-o',
                        '--output',
                        type=str,
                        required=True,
                        help="Directory in which to save filtered files.")
    parser.add_argument('-ed',
                        '--epoch-duration',
                        type=float,
                        required=True,
                        help='Duration of each epoch.')
    parser.add_argument('-co',
                        '--crop-out',
                        type=float,
                        default=0.,
                        help='Duration to crop in the beginning of recording.')
    parser.add_argument('-v', '--verbose', type=str, default='error')
    args = parser.parse_args()

    input_dir = op.abspath(args.input)
    output_dir = op.abspath(args.output)
    epoch_duration = args.epoch_duration
    crop_out = args.crop_out

    if not op.exists(input_dir):
        sys.exit("Input directory not found.")
    if not op.exists(output_dir):
        sys.exit("Output directory not found.")

    set_log_level(verbose=args.verbose)

    input_fnames = op.join(input_dir, '*.fif')
    input_fnames = glob(input_fnames)
    n_files = len(input_fnames)
    failed_files = []  # Put fnames that create errors in here.

    print("Epoching {n} files ...".format(n=n_files))
    # Initialize a progress bar.
    progress = ProgressBar(n_files, mesg='Epoching')
    progress.update_with_increment_value(0)
    for fname in input_fnames:

        raw = io.read_raw_fif(fname,
                              preload=True,
                              add_eeg_ref=False,
                              verbose=args.verbose)

        # Get the condition and event_id of this file.
        this_event_id = {}
        for key in main_event_id:
            if key in op.split(fname)[1]:
                this_event_id[key] = main_event_id[key]
                this_condition = key

        try:
            # Make the events ndarray.
            this_events = make_fixed_length_events(
                raw,
                this_event_id[this_condition],
                start=crop_out,
                stop=None,
                duration=epoch_duration)

            # Create the instance of Epochs.
            this_epochs = Epochs(raw,
                                 this_events,
                                 event_id=this_event_id,
                                 tmin=0.0,
                                 tmax=epoch_duration,
                                 baseline=None,
                                 preload=True,
                                 detrend=0,
                                 add_eeg_ref=False)

            # Append -epo to filename and save.
            save_fname = op.splitext(op.split(fname)[-1])[0]
            this_epochs.info['filename'] = save_fname
            this_epochs_fname = op.join(output_dir,
                                        this_epochs.info['filename'])
            this_epochs.save(this_epochs_fname + '.fif')

        except ValueError:
            failed_files.append(op.split(fname)[-1])

        # Update progress bar.
        progress.update_with_increment_value(1)

    print("\nFailed on:")
    for file_ in failed_files:
        print(file_)
Beispiel #25
0
def _phase_amplitude_coupling(data,
                              sfreq,
                              f_phase,
                              f_amp,
                              ixs,
                              pac_func='plv',
                              ev=None,
                              ev_grouping=None,
                              tmin=None,
                              tmax=None,
                              baseline=None,
                              baseline_kind='mean',
                              scale_amp_func=None,
                              use_times=None,
                              npad='auto',
                              return_data=False,
                              concat_epochs=True,
                              n_jobs=1,
                              verbose=None):
    """ Compute phase-amplitude coupling using pacpy.

    Parameters
    ----------
    data : array, shape ([n_epochs], n_channels, n_times)
        The data used to calculate PAC
    sfreq : float
        The sampling frequency of the data
    f_phase : array, dtype float, shape (2,)
        The frequency range to use for low-frequency phase carrier.
    f_amp : array, dtype float, shape (2,)
        The frequency range to use for high-frequency amplitude modulation.
    ixs : array-like, shape (n_pairs x 2)
        The indices for low/high frequency channels. PAC will be estimated
        between n_pairs of channels. Indices correspond to rows of `data`.
    pac_func : string, ['plv', 'glm', 'mi_canolty', 'mi_tort', 'ozkurt']
        The function for estimating PAC. Corresponds to functions in pacpy.pac
    ev : array-like, shape (n_events,) | None
        Indices for events. To be supplied if data is 2D and output should be
        split by events. In this case, tmin and tmax must be provided
    ev_grouping : array-like, shape (n_events,) | None
        Calculate PAC in each group separately, the output will then be of
        length unique(ev)
    tmin : float | None
        If ev is not provided, it is the start time to use in inst. If ev
        is provided, it is the time (in seconds) to include before each
        event index.
    tmax : float | None
        If ev is not provided, it is the stop time to use in inst. If ev
        is provided, it is the time (in seconds) to include after each
        event index.
    baseline : array, shape (2,) | None
        If ev is provided, it is the min/max time (in seconds) to include in
        the amplitude baseline. If None, no baseline is applied.
    baseline_kind : str
        What kind of baseline to use. See mne.baseline.rescale for options.
    scale_amp_func : None | function
        If not None, will be called on each amplitude signal in order to scale
        the values. Function must accept an N-D input and will operate on the
        last dimension. E.g., skl.preprocessing.scale
    use_times : array, shape (2,) | None
        If ev is provided, it is the min/max time (in seconds) to include in
        the PAC analysis. If None, the whole window (tmin to tmax) is used.
    npad : int | 'auto'
        The amount to pad each signal by before calculating phase/amplitude if
        the input signal is type Raw. If 'auto' the signal will be padded to
        the next power of 2 in length.
    return_data : bool
        If True, return the phase and amplitude data along with the PAC values.
    concat_epochs : bool
        If True, epochs will be concatenated before calculating PAC values. If
        epochs are relatively short, this is a good idea in order to improve
        stability of the PAC metric.
    n_jobs : int
        Number of CPUs to use in the computation.
    verbose : bool, str, int, or None
        If not None, override default verbose level (see mne.verbose).

    Returns
    -------
    pac_out : array, dtype float, shape (n_pairs, [n_events])
        The computed phase-amplitude coupling between each pair of data sources
        given in ixs.
    """
    from pacpy import pac as ppac
    if pac_func not in _pac_funcs:
        raise ValueError("PAC function {0} is not supported".format(pac_func))
    func = getattr(ppac, pac_func)
    ixs = np.array(ixs, ndmin=2)
    f_phase = np.atleast_2d(f_phase)
    f_amp = np.atleast_2d(f_amp)

    if data.ndim != 2:
        raise ValueError('Data must be shape (n_channels, n_times)')
    if ixs.shape[1] != 2:
        raise ValueError('Indices must have have a 2nd dimension of length 2')
    for ifreqs in [f_phase, f_amp]:
        if ifreqs.ndim > 2:
            raise ValueError('frequencies must be of shape (n_freq, 2)')
        if ifreqs.shape[1] != 2:
            raise ValueError('Phase frequencies must be of length 2')

    print('Pre-filtering data and extracting phase/amplitude...')
    hi_phase = pac_func in _hi_phase_funcs
    data_ph, data_am, ix_map_ph, ix_map_am = _pre_filter_ph_am(
        data, sfreq, ixs, f_phase, f_amp, npad=npad, hi_phase=hi_phase)
    ixs_new = [(ix_map_ph[i], ix_map_am[j]) for i, j in ixs]

    if ev is not None:
        use_times = [tmin, tmax] if use_times is None else use_times
        ev_grouping = np.ones_like(ev) if ev_grouping is None else ev_grouping
        data_ph, times, msk_ev = _array_raw_to_epochs(data_ph, sfreq, ev, tmin,
                                                      tmax)
        data_am, times, msk_ev = _array_raw_to_epochs(data_am, sfreq, ev, tmin,
                                                      tmax)

        # In case we cut off any events
        ev, ev_grouping = [i[msk_ev] for i in [ev, ev_grouping]]

        # Baselining before returning
        rescale(data_am, times, baseline, baseline_kind, copy=False)
        msk_time = _time_mask(times, *use_times)
        data_am, data_ph = [i[..., msk_time] for i in [data_am, data_ph]]

        # Stack epochs to a single trace if specified
        if concat_epochs is True:
            ev_unique = np.unique(ev_grouping)
            concat_data = []
            for i_ev in ev_unique:
                msk_events = ev_grouping == i_ev
                concat_data.append(
                    [np.hstack(i[msk_events]) for i in [data_am, data_ph]])
            data_am, data_ph = zip(*concat_data)
    else:
        data_ph = np.array([data_ph])
        data_am = np.array([data_am])
    data_ph = list(data_ph)
    data_am = list(data_am)

    if scale_amp_func is not None:
        for i in range(len(data_am)):
            data_am[i] = scale_amp_func(data_am[i], axis=-1)

    n_ep = len(data_ph)
    pac = np.zeros([n_ep, len(ixs_new)])
    pbar = ProgressBar(n_ep)
    for iep, (ep_ph, ep_am) in enumerate(zip(data_ph, data_am)):
        for iix, (i_ix_ph, i_ix_am) in enumerate(ixs_new):
            # f_phase and f_amp won't be used in this case
            pac[iep, iix] = func(ep_ph[i_ix_ph],
                                 ep_am[i_ix_am],
                                 f_phase,
                                 f_amp,
                                 filterfn=False)
        pbar.update_with_increment_value(1)
    if return_data:
        return pac, data_ph, data_am
    else:
        return pac
Beispiel #26
0
def plot_chpi_snr_raw(raw, win_length, n_harmonics=None, show=True, *,
                      verbose=None):
    """Compute and plot cHPI SNR from raw data

    Parameters
    ----------
    win_length : float
        Length of window to use for SNR estimates (seconds). A longer window
        will naturally include more low frequency power, resulting in lower
        SNR.
    n_harmonics : int or None
        Number of line frequency harmonics to include in the model. If None,
        use all harmonics up to the MEG analog lowpass corner.
    show : bool
        Show figure if True.
    %(verbose)s

    Returns
    -------
    fig : instance of matplotlib.figure.Figure
        cHPI SNR as function of time, residual variance.

    Notes
    -----
    A general linear model including cHPI and line frequencies is fit into
    each data window. The cHPI power obtained from the model is then divided
    by the residual variance (variance of signal unexplained by the model) to
    obtain the SNR.

    The SNR may decrease either due to decrease of cHPI amplitudes (e.g.
    head moving away from the helmet), or due to increase in the residual
    variance. In case of broadband interference that overlaps with the cHPI
    frequencies, the resulting decreased SNR accurately reflects the true
    situation. However, increased narrowband interference outside the cHPI
    and line frequencies would also cause an increase in the residual variance,
    even though it wouldn't necessarily affect estimation of the cHPI
    amplitudes. Thus, this method is intended for a rough overview of cHPI
    signal quality. A more accurate picture of cHPI quality (at an increased
    computational cost) can be obtained by examining the goodness-of-fit of
    the cHPI coil fits.
    """
    import matplotlib.pyplot as plt
    try:
        from mne.chpi import get_chpi_info
    except ImportError:
        from mne.chpi import _get_hpi_info as get_chpi_info

    # plotting parameters
    legend_fontsize = 6
    title_fontsize = 10
    tick_fontsize = 10
    label_fontsize = 10

    # get some info from fiff
    sfreq = raw.info['sfreq']
    linefreq = raw.info['line_freq']
    if n_harmonics is not None:
        linefreqs = (np.arange(n_harmonics + 1) + 1) * linefreq
    else:
        linefreqs = np.arange(linefreq, raw.info['lowpass'], linefreq)
    buflen = int(win_length * sfreq)
    if buflen <= 0:
        raise ValueError('Window length should be >0')
    cfreqs = get_chpi_info(raw.info, verbose=False)[0]
    logger.info(f'Nominal cHPI frequencies: {cfreqs} Hz')
    logger.info(f'Sampling frequency: {sfreq:0.1f} Hz')
    logger.info(f'Using line freqs: {linefreqs} Hz')
    logger.info(f'Using buffers of {buflen} samples = '
                f'{buflen / sfreq:0.3f} seconds')

    pick_meg = pick_types(raw.info, meg=True, exclude=[])
    pick_mag = pick_types(raw.info, meg='mag', exclude=[])
    pick_grad = pick_types(raw.info, meg='grad', exclude=[])
    nchan = len(pick_meg)
    # grad and mag indices into an array that already has meg channels only
    pick_mag_ = np.in1d(pick_meg, pick_mag).nonzero()[0]
    pick_grad_ = np.in1d(pick_meg, pick_grad).nonzero()[0]

    # create general linear model for the data
    t = np.arange(buflen) / float(sfreq)
    model = np.empty((len(t), 2 + 2 * (len(linefreqs) + len(cfreqs))))
    model[:, 0] = t
    model[:, 1] = np.ones(t.shape)
    # add sine and cosine term for each freq
    allfreqs = np.concatenate([linefreqs, cfreqs])
    model[:, 2::2] = np.cos(2 * np.pi * t[:, np.newaxis] * allfreqs)
    model[:, 3::2] = np.sin(2 * np.pi * t[:, np.newaxis] * allfreqs)
    inv_model = linalg.pinv(model)

    # drop last buffer to avoid overrun
    bufs = np.arange(0, raw.n_times, buflen)[:-1]
    tvec = bufs / sfreq
    snr_avg_grad = np.zeros([len(cfreqs), len(bufs)])
    hpi_pow_grad = np.zeros([len(cfreqs), len(bufs)])
    snr_avg_mag = np.zeros([len(cfreqs), len(bufs)])
    resid_vars = np.zeros([nchan, len(bufs)])
    pb = ProgressBar(bufs, mesg='Buffer')
    for ind, buf0 in enumerate(pb):
        megbuf = raw[pick_meg, buf0:buf0 + buflen][0].T
        coeffs = np.dot(inv_model, megbuf)
        coeffs_hpi = coeffs[2 + 2 * len(linefreqs):]
        resid_vars[:, ind] = np.var(megbuf - np.dot(model, coeffs), 0)
        # get total power by combining sine and cosine terms
        # sinusoidal of amplitude A has power of A**2/2
        hpi_pow = (coeffs_hpi[0::2, :] ** 2 + coeffs_hpi[1::2, :] ** 2) / 2
        hpi_pow_grad[:, ind] = hpi_pow[:, pick_grad_].mean(1)
        # divide average HPI power by average variance
        snr_avg_grad[:, ind] = hpi_pow_grad[:, ind] / \
            resid_vars[pick_grad_, ind].mean()
        snr_avg_mag[:, ind] = hpi_pow[:, pick_mag_].mean(1) / \
            resid_vars[pick_mag_, ind].mean()
    logger.info('[done]')

    cfreqs_legend = ['%s Hz' % fre for fre in cfreqs]
    fig, axs = plt.subplots(4, 1, sharex=True)

    # SNR plots for gradiometers and magnetometers
    ax = axs[0]
    lines1 = ax.plot(tvec, 10 * np.log10(snr_avg_grad.T))
    lines1_med = ax.plot(tvec, 10 * np.log10(np.median(snr_avg_grad, axis=0)),
                         lw=2, ls=':', color='k')
    ax.set_xlim([tvec.min(), tvec.max()])
    ax.set(ylabel='SNR (dB)')
    ax.yaxis.label.set_fontsize(label_fontsize)
    ax.set_title('Mean cHPI power / mean residual variance, gradiometers',
                 fontsize=title_fontsize)
    ax.tick_params(axis='both', which='major', labelsize=tick_fontsize)
    ax = axs[1]
    lines2 = ax.plot(tvec, 10 * np.log10(snr_avg_mag.T))
    lines2_med = ax.plot(tvec, 10 * np.log10(np.median(snr_avg_mag, axis=0)),
                         lw=2, ls=':', color='k')
    ax.set_xlim([tvec.min(), tvec.max()])
    ax.set(ylabel='SNR (dB)')
    ax.yaxis.label.set_fontsize(label_fontsize)
    ax.set_title('Mean cHPI power / mean residual variance, magnetometers',
                 fontsize=title_fontsize)
    ax.tick_params(axis='both', which='major', labelsize=tick_fontsize)
    ax = axs[2]
    lines3 = ax.plot(tvec, hpi_pow_grad.T)
    lines3_med = ax.plot(tvec, np.median(hpi_pow_grad, axis=0),
                         lw=2, ls=':', color='k')
    ax.set_xlim([tvec.min(), tvec.max()])
    ax.set(ylabel='Power (T/m)$^2$')
    ax.yaxis.label.set_fontsize(label_fontsize)
    ax.set_title('Mean cHPI power, gradiometers',
                 fontsize=title_fontsize)
    ax.tick_params(axis='both', which='major', labelsize=tick_fontsize)
    # residual (unexplained) variance as function of time
    ax = axs[3]
    cls = plt.get_cmap('plasma')(np.linspace(0., 0.7, len(pick_meg)))
    ax.set_prop_cycle(color=cls)
    ax.semilogy(tvec, resid_vars[pick_grad_, :].T, alpha=.4)
    ax.set_xlim([tvec.min(), tvec.max()])
    ax.set(ylabel='Var. (T/m)$^2$', xlabel='Time (s)')
    ax.xaxis.label.set_fontsize(label_fontsize)
    ax.yaxis.label.set_fontsize(label_fontsize)
    ax.set_title('Residual (unexplained) variance, all gradiometer channels',
                 fontsize=title_fontsize)
    ax.tick_params(axis='both', which='major', labelsize=tick_fontsize)
    tight_layout(pad=.5, w_pad=.1, h_pad=.2)  # from mne.viz
    # tight_layout will screw these up
    ax = axs[0]
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
    # order curve legends according to mean of data
    sind = np.argsort(snr_avg_grad.mean(axis=1))[::-1]
    handles = [lines1[i] for i in sind]
    handles.append(lines1_med[0])
    labels = [cfreqs_legend[i] for i in sind]
    labels.append('Median')
    leg_kwargs = dict(
        prop={'size': legend_fontsize}, bbox_to_anchor=(1.02, 0.5, ),
        loc='center left', borderpad=1, handlelength=1,
    )
    ax.legend(handles, labels, **leg_kwargs)
    ax = axs[1]
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
    sind = np.argsort(snr_avg_mag.mean(axis=1))[::-1]
    handles = [lines2[i] for i in sind]
    handles.append(lines2_med[0])
    labels = [cfreqs_legend[i] for i in sind]
    labels.append('Median')
    ax.legend(handles, labels, **leg_kwargs)
    ax = axs[2]
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
    sind = np.argsort(hpi_pow_grad.mean(axis=1))[::-1]
    handles = [lines3[i] for i in sind]
    handles.append(lines3_med[0])
    labels = [cfreqs_legend[i] for i in sind]
    labels.append('Median')
    ax.legend(handles, labels, **leg_kwargs)
    ax = axs[3]
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
    if show:
        plt.show()

    return fig
Beispiel #27
0
    def _node_compute_mi(self, dataset, n_perm, n_jobs, random_state):
        """Compute mi and permuted mi.

        Permutations are performed by randomizing the regressor variable. For
        the fixed effect, this randomization is performed across subjects. For
        the random effect, the randomization is performed per subject.
        """
        # get the function for computing mi
        mi_fun = self.estimator.get_function()
        # get x, y, z and subject names per roi
        if dataset._mi_type != self._mi_type:
            assert TypeError(f"Your dataset doesn't allow to compute the mi "
                             f"{self._mi_type}. Allowed mi is "
                             f"{dataset._mi_type}")
        # get data variables
        n_roi, inf = len(self._roi), self._inference
        # evaluate true mi
        logger.info(f"    Evaluate true and permuted mi (n_perm={n_perm}, "
                    f"n_jobs={n_jobs})")
        # parallel function for computing permutations
        parallel, p_fun = parallel_func(mi_fun, n_jobs=n_jobs, verbose=False)
        pbar = ProgressBar(range(n_roi), mesg='Estimating MI')
        # evaluate permuted mi
        mi, mi_p = [], []
        for r in range(n_roi):
            # get the data of selected roi
            da = dataset.get_roi_data(self._roi[r],
                                      copnorm=self._copnorm,
                                      mi_type=self._mi_type,
                                      gcrn_per_suj=self._gcrn)
            x, y, suj = da.data, da['y'].data, da['subject'].data
            kw_mi = dict()
            # cmi and categorical MI
            if 'z' in list(da.coords):
                kw_mi['z'] = da['z'].data
            if self._inference == 'rfx':
                kw_mi['categories'] = suj

            # compute the true mi
            _mi = mi_fun(x, y, **kw_mi)
            # get the randomize version of y
            y_p = permute_mi_vector(y,
                                    suj,
                                    mi_type=self._mi_type,
                                    inference=self._inference,
                                    n_perm=n_perm)
            # run permutations using the randomize regressor
            _mi_p = parallel(p_fun(x, y_p[p], **kw_mi) for p in range(n_perm))
            _mi_p = np.asarray(_mi_p)

            # kernel smoothing
            if isinstance(self._kernel, np.ndarray):
                _mi = kernel_smoothing(_mi, self._kernel, axis=-1)
                _mi_p = kernel_smoothing(_mi_p, self._kernel, axis=-1)

            mi += [_mi]
            mi_p += [_mi_p]
            pbar.update_with_increment_value(1)

        self._mi, self._mi_p = mi, mi_p

        return mi, mi_p
Beispiel #28
0
 def parallel_progress(op_iter):
     return parallel(ProgressBar(iterable=op_iter, max_value=total,
                                 mesg=mesg))
Beispiel #29
0
def _phase_amplitude_coupling(data, sfreq, f_phase, f_amp, ixs,
                              pac_func='ozkurt', events=None,
                              tmin=None, tmax=None, n_cycles_ph=3,
                              n_cycles_am=3, scale_amp_func=None,
                              return_data=False, concat_epochs=False,
                              n_jobs=1, verbose=None):
    """ Compute phase-amplitude coupling using pacpy.

    Parameters
    ----------
    data : array, shape ([n_epochs], n_channels, n_times)
        The data used to calculate PAC
    sfreq : float
        The sampling frequency of the data.
    f_phase : array, dtype float, shape (n_bands_phase, 2,)
        The frequency ranges to use for the phase carrier. PAC will be
        calculated between n_bands_phase * n_bands_amp frequencies.
    f_amp : array, dtype float, shape (n_bands_amp, 2,)
        The frequency ranges to use for the phase-modulated amplitude.
        PAC will be calculated between n_bands_phase * n_bands_amp frequencies.
    ixs : array-like, shape (n_ch_pairs x 2)
        The indices for low/high frequency channels. PAC will be estimated
        between n_ch_pairs of channels. Indices correspond to rows of `data`.
    pac_func : {'plv', 'glm', 'mi_canolty', 'mi_tort', 'ozkurt'} |
               list of strings
        The function for estimating PAC. Corresponds to functions in
        `pacpy.pac`. Defaults to 'ozkurt'. If multiple frequency bands are used
        then `plv` cannot be calculated.
    events : array, shape (n_events, 3) | array, shape (n_events,) | None
        MNE events array. To be supplied if data is 2D and output should be
        split by events. In this case, `tmin` and `tmax` must be provided. If
        `ndim == 1`, it is assumed to be event indices, and all events will be
        grouped together.
    tmin : float | list of floats, shape (n_pac_windows,) | None
        If `events` is not provided, it is the start time to use in `inst`.
        If `events` is provided, it is the time (in seconds) to include before
        each event index. If a list of floats is given, then PAC is calculated
        for each pair of `tmin` and `tmax`. Defaults to `min(inst.times)`.
    tmax : float | list of floats, shape (n_pac_windows,) | None
        If `events` is not provided, it is the stop time to use in `inst`.
        If `events` is provided, it is the time (in seconds) to include after
        each event index. If a list of floats is given, then PAC is calculated
        for each pair of `tmin` and `tmax`. Defaults to `max(inst.n_times)`.
    n_cycles_ph : float, int | array of floats, shape (n_bands_phase,)
        The number of cycles to be included in the window for each band-pass
        filter for phase. Defaults to 3.
    n_cycles_am : float, int | array of floats, shape (n_bands_amp,)
        The number of cycles to be included in the window for each band-pass
        filter for amplitude. Defaults to 3.
    scale_amp_func : None | function
        If not None, will be called on each amplitude signal in order to scale
        the values. Function must accept an N-D input and will operate on the
        last dimension. E.g., `sklearn.preprocessing.scale`.
        Defaults to no scaling.
    return_data : bool
        If False, output will be `[pac_out]`. If True, output will be,
        `[pac_out, phase_signal, amp_signal]`.
    concat_epochs : bool
        If True, epochs will be concatenated before calculating PAC values. If
        epochs are relatively short, this is a good idea in order to improve
        stability of the PAC metric.
    n_jobs : int
        Number of jobs to run in parallel. Defaults to 1.
    verbose : bool, str, int, or None
        If not None, override default verbose level (see `mne.verbose`).

    Returns
    -------
    pac_out : array, list of arrays, dtype float,
              shape([n_pac_funcs], n_epochs, n_channel_pairs,
                    n_freq_pairs, n_pac_windows).
        The computed phase-amplitude coupling between each pair of data sources
        given in ixs. If multiple pac metrics are specified, there will be one
        array per metric in the output list. If n_pac_funcs is 1, then the
        first dimension will be dropped.
    [phase_signal] : array, shape (n_phase_signals, n_times,)
        Only returned if `return_data` is True. The phase timeseries of the
        phase signals (first column of `ixs`).
    [amp_signal] : array, shape (n_amp_signals, n_times,)
        Only returned if `return_data` is True. The amplitude timeseries of the
        amplitude signals (second column of `ixs`).
    """
    from ..externals.pacpy import pac as ppac
    pac_func = np.atleast_1d(pac_func)
    for i_func in pac_func:
        if i_func not in _pac_funcs:
            raise ValueError("PAC function %s is not supported" % i_func)
    n_pac_funcs = pac_func.shape[0]
    ixs = np.array(ixs, ndmin=2)
    n_ch_pairs = ixs.shape[0]
    tmin = 0 if tmin is None else tmin
    tmin = np.atleast_1d(tmin)
    n_pac_windows = len(tmin)
    tmax = (data.shape[-1] - 1) / float(sfreq) if tmax is None else tmax
    tmax = np.atleast_1d(tmax)
    f_phase = np.atleast_2d(f_phase)
    f_amp = np.atleast_2d(f_amp)
    n_cycles_ph = np.atleast_1d(n_cycles_ph)
    n_cycles_am = np.atleast_1d(n_cycles_am)
    if n_cycles_ph.shape[0] == 1:
        n_cycles_ph = np.repeat(n_cycles_ph, f_phase.shape[0])
    if n_cycles_am.shape[0] == 1:
        n_cycles_am = np.repeat(n_cycles_am, f_amp.shape[0])

    if data.ndim != 2:
        raise ValueError('Data must be shape (n_channels, n_times)')
    if ixs.shape[1] != 2:
        raise ValueError('Indices must have have a 2nd dimension of length 2')
    if f_phase.shape[-1] != 2 or f_amp.shape[-1] != 2:
        raise ValueError('Frequencies must be specified w/ a low/hi tuple')
    if len(tmin) != len(tmax):
        raise ValueError('tmin and tmax have differing lengths')
    if any(i_f.shape[0] > 1 and 'plv' in pac_func for i_f in (f_amp, f_phase)):
        raise ValueError('If calculating PLV, must use a single pair of freqs')
    for icyc, i_f in zip([n_cycles_ph, n_cycles_am], [f_phase, f_amp]):
        if icyc.shape[0] != i_f.shape[0]:
            raise ValueError("n_cycles must match n_freq_bands")
        if icyc.ndim > 1:
            raise ValueError("n_cycles must be 1-d, not {}d".format(icyc.ndim))

    logger.info('Pre-filtering data and extracting phase/amplitude...')
    hi_phase = np.unique([i_func in _hi_phase_funcs for i_func in pac_func])
    if len(hi_phase) != 1:
        raise ValueError("Can't mix pac funcs that use both hi-freq phase/amp")
    hi_phase = bool(hi_phase[0])
    data_ph, data_am, ix_map_ph, ix_map_am = _pre_filter_ph_am(
        data, sfreq, ixs, f_phase, f_amp, hi_phase=hi_phase,
        scale_amp_func=scale_amp_func, n_cycles_ph=n_cycles_ph,
        n_cycles_am=n_cycles_am)

    # So we know how big the PAC output will be
    if events is None:
        n_epochs = 1
    elif concat_epochs is True:
        if events.ndim == 1:
            n_epochs = 1
        else:
            n_epochs = np.unique(events[:, -1]).shape[0]
    else:
        n_epochs = events.shape[0]

    # Iterate through each pair of frequencies
    ixs_freqs = product(range(data_ph.shape[1]), range(data_am.shape[1]))
    ixs_freqs = np.atleast_2d(list(ixs_freqs))

    freq_pac = np.array([[f_phase[ii], f_amp[jj]] for ii, jj in ixs_freqs])
    n_f_pairs = len(ixs_freqs)
    pac = np.zeros([n_pac_funcs, n_epochs, n_ch_pairs,
                    n_f_pairs, n_pac_windows])
    for i_f_pair, (ix_f_ph, ix_f_am) in enumerate(ixs_freqs):
        # Second dimension is frequency
        i_f_data_ph = data_ph[:, ix_f_ph, ...]
        i_f_data_am = data_am[:, ix_f_am, ...]

        # Redefine indices to match the new data arrays
        ixs_new = [(ix_map_ph[i], ix_map_am[j]) for i, j in ixs]
        i_f_data_ph = mne.io.RawArray(
            i_f_data_ph, mne.create_info(i_f_data_ph.shape[0], sfreq))
        i_f_data_am = mne.io.RawArray(
            i_f_data_am, mne.create_info(i_f_data_am.shape[0], sfreq))

        # Turn into Epochs if we have defined events
        if events is not None:
            i_f_data_ph = _raw_to_epochs_mne(i_f_data_ph, events, tmin, tmax)
            i_f_data_am = _raw_to_epochs_mne(i_f_data_am, events, tmin, tmax)

        # Data is either Raw or Epochs
        pbar = ProgressBar(n_epochs)
        for itime, (i_tmin, i_tmax) in enumerate(zip(tmin, tmax)):
            # Pull times of interest
            with warnings.catch_warnings():  # To suppress a depracation
                warnings.simplefilter("ignore")
                # Not sure how to do this w/o copying
                i_t_data_am = i_f_data_am.copy().crop(i_tmin, i_tmax)
                i_t_data_ph = i_f_data_ph.copy().crop(i_tmin, i_tmax)

            if concat_epochs is True:
                # Iterate through each event type and hstack
                con_data_ph = []
                con_data_am = []
                for i_ev in i_t_data_am.event_id.keys():
                    con_data_ph.append(np.hstack(i_t_data_ph[i_ev]._data))
                    con_data_am.append(np.hstack(i_t_data_am[i_ev]._data))
                i_t_data_ph = np.vstack(con_data_ph)
                i_t_data_am = np.vstack(con_data_am)
            else:
                # Just pull all epochs separately
                i_t_data_ph = i_t_data_ph._data
                i_t_data_am = i_t_data_am._data
            # Now make sure that inputs to the loop are ep x chan x time
            if i_t_data_am.ndim == 2:
                i_t_data_ph = i_t_data_ph[np.newaxis, ...]
                i_t_data_am = i_t_data_am[np.newaxis, ...]
            # Loop through epochs (or epoch grps), each index pair, and funcs
            data_iter = zip(i_t_data_ph, i_t_data_am)
            for iep, (ep_ph, ep_am) in enumerate(data_iter):
                for iix, (i_ix_ph, i_ix_am) in enumerate(ixs_new):
                    for ix_func, i_pac_func in enumerate(pac_func):
                        func = getattr(ppac, i_pac_func)
                        pac[ix_func, iep, iix, i_f_pair, itime] = func(
                            ep_ph[i_ix_ph], ep_am[i_ix_am],
                            f_phase, f_amp, filterfn=False)
            pbar.update_with_increment_value(1)
    if pac.shape[0] == 1:
        pac = pac[0]
    if return_data:
        return pac, freq_pac, data_ph, data_am
    else:
        return pac, freq_pac
Beispiel #30
0
def conn_dfc(data,
             win_sample,
             times=None,
             roi=None,
             n_jobs=1,
             gcrn=True,
             verbose=None):
    """Single trial Dynamic Functional Connectivity.

    This function computes the Dynamic Functional Connectivity (DFC) using the
    Gaussian Copula Mutual Information (GCMI). The DFC is computed across time
    points for each trial. Note that the DFC can either be computed on windows
    manually defined or on sliding windows.

    Parameters
    ----------
    data : array_like
        Electrophysiological data array of a single subject organized as
        (n_epochs, n_roi, n_times)
    win_sample : array_like
        Array of shape (n_windows, 2) describing where each window start and
        finish. You can use the function :func:`frites.conn.define_windows`
        to define either manually either sliding windows.
    times : array_like | None
        Time vector array of shape (n_times,)
    roi : array_like | None
        ROI names of a single subject
    n_jobs : int | 1
        Number of jobs to use for parallel computing (use -1 to use all
        jobs). The parallel loop is set at the pair level.
    gcrn : bool | True
        Specify if the Gaussian Copula Rank Normalization should be applied.
        If the data are normalized (e.g z-score) this parameter can be set to
        False because the data can be considered as gaussian over time.

    Returns
    -------
    dfc : array_like
        The DFC array of shape (n_epochs, n_pairs, n_windows)

    See also
    --------
    define_windows, conn_covgc
    """
    set_log_level(verbose)
    # -------------------------------------------------------------------------
    # inputs conversion
    data, trials, roi, times, attrs = conn_io(data,
                                              roi=roi,
                                              times=times,
                                              verbose=verbose)

    # -------------------------------------------------------------------------
    # data checking
    n_epochs, n_roi, n_pts = data.shape
    assert (len(roi) == n_roi) and (len(times) == n_pts)
    assert isinstance(win_sample, np.ndarray) and (win_sample.ndim == 2)
    assert win_sample.dtype in CONFIG['INT_DTYPE']
    n_win = win_sample.shape[0]
    # get the non-directed pairs
    x_s, x_t = np.triu_indices(n_roi, k=1)
    n_pairs = len(x_s)
    pairs = np.c_[x_s, x_t]
    # build roi pairs names
    roi_p = [f"{roi[s]}-{roi[t]}" for s, t in zip(x_s, x_t)]

    # -------------------------------------------------------------------------
    # compute dfc
    logger.info(f'Computing DFC between {n_pairs} pairs (gcrn={gcrn})')
    # get the parallel function
    parallel, p_fun = parallel_func(mi_nd_gg,
                                    n_jobs=n_jobs,
                                    verbose=verbose,
                                    prefer='threads')
    pbar = ProgressBar(range(n_win), mesg='Estimating DFC')

    dfc = np.zeros((n_epochs, n_pairs, n_win), dtype=np.float32)
    with parallel as para:
        for n_w, w in enumerate(win_sample):
            # select the data in the window and copnorm across time points
            data_w = data[..., w[0]:w[1]]
            # apply gcrn over time
            if gcrn:
                data_w = copnorm_nd(data_w, axis=2)
            # compute mi between pairs
            _dfc = para(
                p_fun(data_w[:, [s], :], data_w[:,
                                                [t], :], **CONFIG["KW_GCMI"])
                for s, t in zip(x_s, x_t))
            dfc[..., n_w] = np.stack(_dfc, axis=1)
            pbar.update_with_increment_value(1)

    # -------------------------------------------------------------------------
    # dataarray conversion
    win_times = times[win_sample]
    dfc = xr.DataArray(dfc,
                       dims=('trials', 'roi', 'times'),
                       name='dfc',
                       coords=(trials, roi_p, win_times.mean(1)))
    # add the windows used in the attributes
    cfg = dict(win_sample=np.r_[tuple(win_sample)],
               win_times=np.r_[tuple(win_times)],
               type='dfc')
    dfc.attrs = {**cfg, **attrs}

    return dfc
Beispiel #31
0
    def fit_ica(self, data, when='next', warm_start=False):
        """Conduct Independent Components Analysis (ICA) on a segment of data.

        The fitted ICA object is stored in the variable ica. Noisy components
        can be selected in the ICA, and then the ICA can be applied to incoming
        data to remove noise. Once fitted, ICA is applied by default to data
        when using the methods make_raw() or make_epochs().

        Components marked for removal can be accessed with self.ica.exclude.

        data : int, float, mne.RawArray
            The duration of previous or incoming data to use to fit the ICA, or
            an mne.RawArray object of data.
        when : {'previous', 'next'} (defaults to 'next')
            Whether to compute ICA on the previous or next X seconds of data.
            Can be 'next' or 'previous'. If data is type mne.RawArray, this
            parameter is ignored.
        warm_start : bool (defaults to False)
            If True, will include the EEG data from the previous fit. If False,
            will only use the data specified in the parameter data.
        """
        # Re-define ICA variable to start ICA from scratch if the ICA was
        # already fitted and user wants to fit again.
        if self.ica.current_fit != 'unfitted':
            self.ica = ICA(method='extended-infomax')

        if isinstance(data, io.RawArray):
            self.raw_for_ica = data

        elif isinstance(data, numbers.Number):
            user_index = int(data * self.info['sfreq'])
            if when.lower() not in ['previous', 'next']:
                raise ValueError("when must be 'previous' or 'next'. {} was "
                                 "passed.".format(when))
            elif when == 'previous':
                end_index = len(self.data)
                start_index = end_index - user_index
                # TODO: Check if out of bounds.

            elif when == 'next':
                start_index = len(self.data)
                end_index = start_index + user_index
                # Wait until the data is available.
                pbar = ProgressBar(end_index - start_index,
                                   mesg="Collecting data")
                while len(self.data) <= end_index:
                    # Sometimes sys.stdout.flush() raises ValueError. Is it
                    # because the while loop iterates too quickly for I/O?
                    try:
                        pbar.update(len(self.data) - start_index)
                    except ValueError:
                        pass
                print("")  # Get onto new line after progress bar finishes.

            _data = np.array([r[:] for r in
                              self.data[start_index:end_index]]).T

            # Now we have the data array in _data. Use it to make instance of
            # mne.RawArray, and then we can compute the ICA on that instance.
            _data[-1, :] = 0

            # Use previous data in addition to the specified data when fitting
            # the ICA, if the user requested this.
            if warm_start and self.raw_for_ica is not None:
                self.raw_for_ica = concatenate_raws(
                    [self.raw_for_ica, io.RawArray(_data, self.info)])
            else:
                self.raw_for_ica = io.RawArray(_data, self.info)

        logger.info("Computing ICA solution ...")
        t_0 = local_clock()
        self.ica.fit(self.raw_for_ica.copy())  # Fits in-place.
        logger.info("Finished in {:.2f} s".format(local_clock() - t_0))
Beispiel #32
0
def _phase_amplitude_coupling(data, sfreq, f_phase, f_amp, ixs,
                              pac_func='plv', ev=None, ev_grouping=None,
                              tmin=None, tmax=None,
                              baseline=None, baseline_kind='mean',
                              scale_amp_func=None, use_times=None, npad='auto',
                              return_data=False, concat_epochs=True, n_jobs=1,
                              verbose=None):
    """ Compute phase-amplitude coupling using pacpy.

    Parameters
    ----------
    data : array, shape ([n_epochs], n_channels, n_times)
        The data used to calculate PAC
    sfreq : float
        The sampling frequency of the data
    f_phase : array, dtype float, shape (2,)
        The frequency range to use for low-frequency phase carrier.
    f_amp : array, dtype float, shape (2,)
        The frequency range to use for high-frequency amplitude modulation.
    ixs : array-like, shape (n_pairs x 2)
        The indices for low/high frequency channels. PAC will be estimated
        between n_pairs of channels. Indices correspond to rows of `data`.
    pac_func : string, ['plv', 'glm', 'mi_canolty', 'mi_tort', 'ozkurt']
        The function for estimating PAC. Corresponds to functions in pacpy.pac
    ev : array-like, shape (n_events,) | None
        Indices for events. To be supplied if data is 2D and output should be
        split by events. In this case, tmin and tmax must be provided
    ev_grouping : array-like, shape (n_events,) | None
        Calculate PAC in each group separately, the output will then be of
        length unique(ev)
    tmin : float | None
        If ev is not provided, it is the start time to use in inst. If ev
        is provided, it is the time (in seconds) to include before each
        event index.
    tmax : float | None
        If ev is not provided, it is the stop time to use in inst. If ev
        is provided, it is the time (in seconds) to include after each
        event index.
    baseline : array, shape (2,) | None
        If ev is provided, it is the min/max time (in seconds) to include in
        the amplitude baseline. If None, no baseline is applied.
    baseline_kind : str
        What kind of baseline to use. See mne.baseline.rescale for options.
    scale_amp_func : None | function
        If not None, will be called on each amplitude signal in order to scale
        the values. Function must accept an N-D input and will operate on the
        last dimension. E.g., skl.preprocessing.scale
    use_times : array, shape (2,) | None
        If ev is provided, it is the min/max time (in seconds) to include in
        the PAC analysis. If None, the whole window (tmin to tmax) is used.
    npad : int | 'auto'
        The amount to pad each signal by before calculating phase/amplitude if
        the input signal is type Raw. If 'auto' the signal will be padded to
        the next power of 2 in length.
    return_data : bool
        If True, return the phase and amplitude data along with the PAC values.
    concat_epochs : bool
        If True, epochs will be concatenated before calculating PAC values. If
        epochs are relatively short, this is a good idea in order to improve
        stability of the PAC metric.
    n_jobs : int
        Number of CPUs to use in the computation.
    verbose : bool, str, int, or None
        If not None, override default verbose level (see mne.verbose).

    Returns
    -------
    pac_out : array, dtype float, shape (n_pairs, [n_events])
        The computed phase-amplitude coupling between each pair of data sources
        given in ixs.
    """
    from pacpy import pac as ppac
    if pac_func not in _pac_funcs:
        raise ValueError("PAC function {0} is not supported".format(pac_func))
    func = getattr(ppac, pac_func)
    ixs = np.array(ixs, ndmin=2)
    f_phase = np.atleast_2d(f_phase)
    f_amp = np.atleast_2d(f_amp)

    if data.ndim != 2:
        raise ValueError('Data must be shape (n_channels, n_times)')
    if ixs.shape[1] != 2:
        raise ValueError('Indices must have have a 2nd dimension of length 2')
    for ifreqs in [f_phase, f_amp]:
        if ifreqs.ndim > 2:
            raise ValueError('frequencies must be of shape (n_freq, 2)')
        if ifreqs.shape[1] != 2:
            raise ValueError('Phase frequencies must be of length 2')

    print('Pre-filtering data and extracting phase/amplitude...')
    hi_phase = pac_func in _hi_phase_funcs
    data_ph, data_am, ix_map_ph, ix_map_am = _pre_filter_ph_am(
        data, sfreq, ixs, f_phase, f_amp, npad=npad, hi_phase=hi_phase)
    ixs_new = [(ix_map_ph[i], ix_map_am[j]) for i, j in ixs]

    if ev is not None:
        use_times = [tmin, tmax] if use_times is None else use_times
        ev_grouping = np.ones_like(ev) if ev_grouping is None else ev_grouping
        data_ph, times, msk_ev = _array_raw_to_epochs(
            data_ph, sfreq, ev, tmin, tmax)
        data_am, times, msk_ev = _array_raw_to_epochs(
            data_am, sfreq, ev, tmin, tmax)

        # In case we cut off any events
        ev, ev_grouping = [i[msk_ev] for i in [ev, ev_grouping]]

        # Baselining before returning
        rescale(data_am, times, baseline, baseline_kind, copy=False)
        msk_time = _time_mask(times, *use_times)
        data_am, data_ph = [i[..., msk_time] for i in [data_am, data_ph]]

        # Stack epochs to a single trace if specified
        if concat_epochs is True:
            ev_unique = np.unique(ev_grouping)
            concat_data = []
            for i_ev in ev_unique:
                msk_events = ev_grouping == i_ev
                concat_data.append([np.hstack(i[msk_events])
                                    for i in [data_am, data_ph]])
            data_am, data_ph = zip(*concat_data)
    else:
        data_ph = np.array([data_ph])
        data_am = np.array([data_am])
    data_ph = list(data_ph)
    data_am = list(data_am)

    if scale_amp_func is not None:
        for i in range(len(data_am)):
            data_am[i] = scale_amp_func(data_am[i], axis=-1)

    n_ep = len(data_ph)
    pac = np.zeros([n_ep, len(ixs_new)])
    pbar = ProgressBar(n_ep)
    for iep, (ep_ph, ep_am) in enumerate(zip(data_ph, data_am)):
        for iix, (i_ix_ph, i_ix_am) in enumerate(ixs_new):
            # f_phase and f_amp won't be used in this case
            pac[iep, iix] = func(ep_ph[i_ix_ph], ep_am[i_ix_am],
                                 f_phase, f_amp, filterfn=False)
        pbar.update_with_increment_value(1)
    if return_data:
        return pac, data_ph, data_am
    else:
        return pac