示例#1
0
def complex_tfr(x,
                time,
                est_val=None,
                est_key=None,
                sf=600.,
                foi=None,
                cycles=None,
                time_bandwidth=None,
                n_jobs=1,
                decim=10):
    """Estimate power of epochs in array x."""
    if len(x.shape) == 2:
        x = x[np.newaxis, :, :]
    y = _compute_tfr(x,
                     foi,
                     sfreq=sf,
                     method='multitaper',
                     decim=decim,
                     n_cycles=cycles,
                     zero_mean=True,
                     time_bandwidth=time_bandwidth,
                     n_jobs=n_jobs,
                     use_fft=True,
                     output='complex')
    return y, time[::decim], est_val, est_key
示例#2
0
def _extract_phase_and_amp(data_ph, data_am, sfreq, freqs_phase,
                           freqs_amp, scale=True):
    """Extract the phase and amplitude of two signals for PAC viz.
    data should be shape (n_epochs, n_times)"""
    from sklearn.preprocessing import scale

    # Morlet transform to get complex representation
    band_ph = _compute_tfr([data_ph], freqs_phase, sfreq, method='morlet')[0]
    band_amp = _compute_tfr([data_ph], freqs_amp, sfreq, method='morlet')[0]

    # Calculate the phase/amplitude of relevant signals across epochs
    band_ph_stacked = np.hstack(np.real(band_ph))
    angle_ph = np.hstack(np.angle(band_ph))
    amp = np.hstack(np.abs(band_amp) ** 2)

    # Scale the amplitude for viz so low freqs don't dominate highs
    if scale is True:
        amp = scale(amp, axis=1)
    return angle_ph, band_ph_stacked, amp
示例#3
0
def test_compute_tfr_correct(method, decim):
    """Test that TFR actually gets us our freq back."""
    sfreq = 1000.
    t = np.arange(1000) / sfreq
    f = 50.
    data = np.sin(2 * np.pi * 50. * t)
    data *= np.hanning(data.size)
    data = data[np.newaxis, np.newaxis]
    freqs = np.arange(10, 111, 10)
    assert f in freqs
    tfr = _compute_tfr(data, freqs, sfreq, method=method, decim=decim,
                       n_cycles=2)[0, 0]
    assert freqs[np.argmax(np.abs(tfr).mean(-1))] == f
示例#4
0
def _extract_phase_and_amp(data_ph,
                           data_am,
                           sfreq,
                           freqs_phase,
                           freqs_amp,
                           scale=True):
    """Extract the phase and amplitude of two signals for PAC viz.
    data should be shape (n_epochs, n_times)"""
    from sklearn.preprocessing import scale

    # Morlet transform to get complex representation
    band_ph = _compute_tfr([data_ph], freqs_phase, sfreq, method='morlet')[0]
    band_amp = _compute_tfr([data_ph], freqs_amp, sfreq, method='morlet')[0]

    # Calculate the phase/amplitude of relevant signals across epochs
    band_ph_stacked = np.hstack(np.real(band_ph))
    angle_ph = np.hstack(np.angle(band_ph))
    amp = np.hstack(np.abs(band_amp)**2)

    # Scale the amplitude for viz so low freqs don't dominate highs
    if scale is True:
        amp = scale(amp, axis=1)
    return angle_ph, band_ph_stacked, amp
def array_tfr(epochs,
              sf=600,
              foi=None,
              cycles=None,
              time_bandwidth=None,
              decim=10,
              n_jobs=4,
              output='power'):
    from mne.time_frequency.tfr import _compute_tfr
    power = _compute_tfr(epochs,
                         foi,
                         sfreq=sf,
                         method='multitaper',
                         decim=decim,
                         n_cycles=cycles,
                         zero_mean=True,
                         time_bandwidth=time_bandwidth,
                         n_jobs=n_jobs,
                         use_fft=True,
                         output=output)
    return power
示例#6
0
文件: comp_fun.py 项目: wronk/rsn
def tfr_split(data, processing_params):
    """Helper to calculate wavelet power in batches instead of all at once

    Needed for memory issues.
    """
    '''
    batch_size = 50  # Number trials to do at once
    batch_list = []

    power_arr = np.zeros((data.shape[0], data.shape[1],
                          len(processing_params['cwt_frequencies']),
                          data.shape[2]))

    batch_inds = range(0, data.shape[0], batch_size)
    batch_inds.append(data.shape[0])

    for bi1, bi2 in zip(batch_inds[:-1], batch_inds[1:]):

        batch = _compute_tfr(data[bi1:bi2],
                             frequencies=processing_params['cwt_frequencies'],
                             sfreq=processing_params['sfreq'],
                             n_cycles=processing_params['n_cycles'],
                             decim=processing_params['post_decim'],
                             n_jobs=6, output='power')
        power_arr[bi1:bi2, :, :, :] = batch
    '''

    batch = _compute_tfr(data,
                         frequencies=processing_params['cwt_frequencies'],
                         sfreq=processing_params['sfreq'],
                         n_cycles=processing_params['n_cycles'],
                         decim=processing_params['post_decim'],
                         n_jobs=processing_params['n_jobs'],
                         output='power')

    return batch
示例#7
0
def test_compute_tfr():
    """Test _compute_tfr function."""
    # Set parameters
    event_id = 1
    tmin = -0.2
    tmax = 0.498  # Allows exhaustive decimation testing

    # Setup for reading the raw data
    raw = read_raw_fif(raw_fname)
    events = read_events(event_fname)

    exclude = raw.info['bads'] + ['MEG 2443', 'EEG 053']  # bads + 2 more

    # picks MEG gradiometers
    picks = pick_types(raw.info,
                       meg='grad',
                       eeg=False,
                       stim=False,
                       include=[],
                       exclude=exclude)

    picks = picks[:2]
    epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks)
    data = epochs.get_data()
    sfreq = epochs.info['sfreq']
    freqs = np.arange(10, 20, 3).astype(float)

    # Check all combination of options
    for func, use_fft, zero_mean, output in product(
        (tfr_array_multitaper, tfr_array_morlet), (False, True), (False, True),
        ('complex', 'power', 'phase', 'avg_power_itc', 'avg_power', 'itc')):
        # Check exception
        if (func == tfr_array_multitaper) and (output == 'phase'):
            pytest.raises(NotImplementedError,
                          func,
                          data,
                          sfreq=sfreq,
                          freqs=freqs,
                          output=output)
            continue

        # Check runs
        out = func(data,
                   sfreq=sfreq,
                   freqs=freqs,
                   use_fft=use_fft,
                   zero_mean=zero_mean,
                   n_cycles=2.,
                   output=output)
        # Check shapes
        shape = np.r_[data.shape[:2], len(freqs), data.shape[2]]
        if ('avg' in output) or ('itc' in output):
            assert_array_equal(shape[1:], out.shape)
        else:
            assert_array_equal(shape, out.shape)

        # Check types
        if output in ('complex', 'avg_power_itc'):
            assert_equal(np.complex128, out.dtype)
        else:
            assert_equal(np.float64, out.dtype)
        assert (np.all(np.isfinite(out)))

    # Check errors params
    for _data in (None, 'foo', data[0]):
        pytest.raises(ValueError, _compute_tfr, _data, freqs, sfreq)
    for _freqs in (None, 'foo', [[0]]):
        pytest.raises(ValueError, _compute_tfr, data, _freqs, sfreq)
    for _sfreq in (None, 'foo'):
        pytest.raises(ValueError, _compute_tfr, data, freqs, _sfreq)
    for key in ('output', 'method', 'use_fft', 'decim', 'n_jobs'):
        for value in (None, 'foo'):
            kwargs = {key: value}  # FIXME pep8
            pytest.raises(ValueError, _compute_tfr, data, freqs, sfreq,
                          **kwargs)
    with pytest.raises(ValueError, match='above Nyquist'):
        _compute_tfr(data, [sfreq], sfreq)

    # No time_bandwidth param in morlet
    pytest.raises(ValueError,
                  _compute_tfr,
                  data,
                  freqs,
                  sfreq,
                  method='morlet',
                  time_bandwidth=1)
    # No phase in multitaper XXX Check ?
    pytest.raises(NotImplementedError,
                  _compute_tfr,
                  data,
                  freqs,
                  sfreq,
                  method='multitaper',
                  output='phase')

    # Inter-trial coherence tests
    out = _compute_tfr(data, freqs, sfreq, output='itc', n_cycles=2.)
    assert np.sum(out >= 1) == 0
    assert np.sum(out <= 0) == 0

    # Check decim shapes
    # 2: multiple of len(times) even
    # 3: multiple odd
    # 8: not multiple, even
    # 9: not multiple, odd
    for decim in (2, 3, 8, 9, slice(0, 2), slice(1, 3), slice(2, 4)):
        _decim = slice(None, None, decim) if isinstance(decim, int) else decim
        n_time = len(np.arange(data.shape[2])[_decim])
        shape = np.r_[data.shape[:2], len(freqs), n_time]
        for method in ('multitaper', 'morlet'):
            # Single trials
            out = _compute_tfr(data,
                               freqs,
                               sfreq,
                               method=method,
                               decim=decim,
                               n_cycles=2.)
            assert_array_equal(shape, out.shape)
            # Averages
            out = _compute_tfr(data,
                               freqs,
                               sfreq,
                               method=method,
                               decim=decim,
                               output='avg_power',
                               n_cycles=2.)
            assert_array_equal(shape[1:], out.shape)
示例#8
0
def test_compute_tfr():
    """Test _compute_tfr function."""
    # Set parameters
    event_id = 1
    tmin = -0.2
    tmax = 0.498  # Allows exhaustive decimation testing

    # Setup for reading the raw data
    raw = read_raw_fif(raw_fname)
    events = read_events(event_fname)

    exclude = raw.info['bads'] + ['MEG 2443', 'EEG 053']  # bads + 2 more

    # picks MEG gradiometers
    picks = pick_types(raw.info, meg='grad', eeg=False,
                       stim=False, include=[], exclude=exclude)

    picks = picks[:2]
    epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks)
    data = epochs.get_data()
    sfreq = epochs.info['sfreq']
    freqs = np.arange(10, 20, 3).astype(float)

    # Check all combination of options
    for func, use_fft, zero_mean, output in product(
        (tfr_array_multitaper, tfr_array_morlet), (False, True), (False, True),
        ('complex', 'power', 'phase',
         'avg_power_itc', 'avg_power', 'itc')):
        # Check exception
        if (func == tfr_array_multitaper) and (output == 'phase'):
            pytest.raises(NotImplementedError, func, data, sfreq=sfreq,
                          freqs=freqs, output=output)
            continue

        # Check runs
        out = func(data, sfreq=sfreq, freqs=freqs, use_fft=use_fft,
                   zero_mean=zero_mean, n_cycles=2., output=output)
        # Check shapes
        shape = np.r_[data.shape[:2], len(freqs), data.shape[2]]
        if ('avg' in output) or ('itc' in output):
            assert_array_equal(shape[1:], out.shape)
        else:
            assert_array_equal(shape, out.shape)

        # Check types
        if output in ('complex', 'avg_power_itc'):
            assert_equal(np.complex, out.dtype)
        else:
            assert_equal(np.float, out.dtype)
        assert (np.all(np.isfinite(out)))

    # Check errors params
    for _data in (None, 'foo', data[0]):
        pytest.raises(ValueError, _compute_tfr, _data, freqs, sfreq)
    for _freqs in (None, 'foo', [[0]]):
        pytest.raises(ValueError, _compute_tfr, data, _freqs, sfreq)
    for _sfreq in (None, 'foo'):
        pytest.raises(ValueError, _compute_tfr, data, freqs, _sfreq)
    for key in ('output', 'method', 'use_fft', 'decim', 'n_jobs'):
        for value in (None, 'foo'):
            kwargs = {key: value}  # FIXME pep8
            pytest.raises(ValueError, _compute_tfr, data, freqs, sfreq,
                          **kwargs)

    # No time_bandwidth param in morlet
    pytest.raises(ValueError, _compute_tfr, data, freqs, sfreq,
                  method='morlet', time_bandwidth=1)
    # No phase in multitaper XXX Check ?
    pytest.raises(NotImplementedError, _compute_tfr, data, freqs, sfreq,
                  method='multitaper', output='phase')

    # Inter-trial coherence tests
    out = _compute_tfr(data, freqs, sfreq, output='itc', n_cycles=2.)
    assert np.sum(out >= 1) == 0
    assert np.sum(out <= 0) == 0

    # Check decim shapes
    # 2: multiple of len(times) even
    # 3: multiple odd
    # 8: not multiple, even
    # 9: not multiple, odd
    for decim in (2, 3, 8, 9, slice(0, 2), slice(1, 3), slice(2, 4)):
        _decim = slice(None, None, decim) if isinstance(decim, int) else decim
        n_time = len(np.arange(data.shape[2])[_decim])
        shape = np.r_[data.shape[:2], len(freqs), n_time]
        for method in ('multitaper', 'morlet'):
            # Single trials
            out = _compute_tfr(data, freqs, sfreq, method=method, decim=decim,
                               n_cycles=2.)
            assert_array_equal(shape, out.shape)
            # Averages
            out = _compute_tfr(data, freqs, sfreq, method=method, decim=decim,
                               output='avg_power', n_cycles=2.)
            assert_array_equal(shape[1:], out.shape)
示例#9
0
def test_compute_tfr():
    """Test _compute_tfr function"""
    # Set parameters
    event_id = 1
    tmin = -0.2
    tmax = 0.498  # Allows exhaustive decimation testing

    # Setup for reading the raw data
    raw = io.read_raw_fif(raw_fname)
    events = read_events(event_fname)

    exclude = raw.info['bads'] + ['MEG 2443', 'EEG 053']  # bads + 2 more

    # picks MEG gradiometers
    picks = pick_types(raw.info, meg='grad', eeg=False,
                       stim=False, include=[], exclude=exclude)

    picks = picks[:2]
    epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks,
                    baseline=(None, 0))
    data = epochs.get_data()
    sfreq = epochs.info['sfreq']
    freqs = np.arange(10, 20, 3).astype(float)

    # Check all combination of options
    for method, use_fft, zero_mean, output in product(
        ('multitaper', 'morlet'), (False, True), (False, True),
        ('complex', 'power', 'phase',
         'avg_power_itc', 'avg_power', 'itc')):
        # Check exception
        if (method == 'multitaper') and (output == 'phase'):
            assert_raises(NotImplementedError, _compute_tfr, data, freqs,
                          sfreq, method=method, output=output)
            continue

        # Check runs
        out = _compute_tfr(data, freqs, sfreq, method=method,
                           use_fft=use_fft, zero_mean=zero_mean,
                           n_cycles=2., output=output)
        # Check shapes
        shape = np.r_[data.shape[:2], len(freqs), data.shape[2]]
        if ('avg' in output) or ('itc' in output):
            assert_array_equal(shape[1:], out.shape)
        else:
            assert_array_equal(shape, out.shape)

        # Check types
        if output in ('complex', 'avg_power_itc'):
            assert_equal(np.complex, out.dtype)
        else:
            assert_equal(np.float, out.dtype)
        assert_true(np.all(np.isfinite(out)))

    # Check that functions are equivalent to
    # i) single_trial_power: X, shape (n_signals, n_chans, n_times)
    old_power = single_trial_power(data, sfreq, freqs, n_cycles=2.)
    new_power = _compute_tfr(data, freqs, sfreq, n_cycles=2.,
                             method='morlet', output='power')
    assert_array_almost_equal(old_power, new_power)
    old_power = single_trial_power(data, sfreq, freqs, n_cycles=2.,
                                   times=epochs.times, baseline=(-.100, 0),
                                   baseline_mode='ratio')
    new_power = rescale(new_power, epochs.times, (-.100, 0), 'ratio')

    # ii) cwt_morlet: X, shape (n_signals, n_times)
    old_complex = cwt_morlet(data[0], sfreq, freqs, n_cycles=2.)
    new_complex = _compute_tfr(data[[0]], freqs, sfreq, n_cycles=2.,
                               method='morlet', output='complex')
    assert_array_almost_equal(old_complex, new_complex[0])

    # Check errors params
    for _data in (None, 'foo', data[0]):
        assert_raises(ValueError, _compute_tfr, _data, freqs, sfreq)
    for _freqs in (None, 'foo', [[0]]):
        assert_raises(ValueError, _compute_tfr, data, _freqs, sfreq)
    for _sfreq in (None, 'foo'):
        assert_raises(ValueError, _compute_tfr, data, freqs, _sfreq)
    for key in ('output', 'method', 'use_fft', 'decim', 'n_jobs'):
        for value in (None, 'foo'):
            kwargs = {key: value}  # FIXME pep8
            assert_raises(ValueError, _compute_tfr, data, freqs, sfreq,
                          **kwargs)

    # No time_bandwidth param in morlet
    assert_raises(ValueError, _compute_tfr, data, freqs, sfreq,
                  method='morlet', time_bandwidth=1)
    # No phase in multitaper XXX Check ?
    assert_raises(NotImplementedError, _compute_tfr, data, freqs, sfreq,
                  method='multitaper', output='phase')

    # Inter-trial coherence tests
    out = _compute_tfr(data, freqs, sfreq, output='itc', n_cycles=2.)
    assert_true(np.sum(out >= 1) == 0)
    assert_true(np.sum(out <= 0) == 0)

    # Check decim shapes
    # 2: multiple of len(times) even
    # 3: multiple odd
    # 8: not multiple, even
    # 9: not multiple, odd
    for decim in (2, 3, 8, 9, slice(0, 2), slice(1, 3), slice(2, 4)):
        _decim = slice(None, None, decim) if isinstance(decim, int) else decim
        n_time = len(np.arange(data.shape[2])[_decim])
        shape = np.r_[data.shape[:2], len(freqs), n_time]
        for method in ('multitaper', 'morlet'):
            # Single trials
            out = _compute_tfr(data, freqs, sfreq, method=method,
                               decim=decim, n_cycles=2.)
            assert_array_equal(shape, out.shape)
            # Averages
            out = _compute_tfr(data, freqs, sfreq, method=method,
                               decim=decim, output='avg_power',
                               n_cycles=2.)
            assert_array_equal(shape[1:], out.shape)