def complex_tfr(x, time, est_val=None, est_key=None, sf=600., foi=None, cycles=None, time_bandwidth=None, n_jobs=1, decim=10): """Estimate power of epochs in array x.""" if len(x.shape) == 2: x = x[np.newaxis, :, :] y = _compute_tfr(x, foi, sfreq=sf, method='multitaper', decim=decim, n_cycles=cycles, zero_mean=True, time_bandwidth=time_bandwidth, n_jobs=n_jobs, use_fft=True, output='complex') return y, time[::decim], est_val, est_key
def _extract_phase_and_amp(data_ph, data_am, sfreq, freqs_phase, freqs_amp, scale=True): """Extract the phase and amplitude of two signals for PAC viz. data should be shape (n_epochs, n_times)""" from sklearn.preprocessing import scale # Morlet transform to get complex representation band_ph = _compute_tfr([data_ph], freqs_phase, sfreq, method='morlet')[0] band_amp = _compute_tfr([data_ph], freqs_amp, sfreq, method='morlet')[0] # Calculate the phase/amplitude of relevant signals across epochs band_ph_stacked = np.hstack(np.real(band_ph)) angle_ph = np.hstack(np.angle(band_ph)) amp = np.hstack(np.abs(band_amp) ** 2) # Scale the amplitude for viz so low freqs don't dominate highs if scale is True: amp = scale(amp, axis=1) return angle_ph, band_ph_stacked, amp
def test_compute_tfr_correct(method, decim): """Test that TFR actually gets us our freq back.""" sfreq = 1000. t = np.arange(1000) / sfreq f = 50. data = np.sin(2 * np.pi * 50. * t) data *= np.hanning(data.size) data = data[np.newaxis, np.newaxis] freqs = np.arange(10, 111, 10) assert f in freqs tfr = _compute_tfr(data, freqs, sfreq, method=method, decim=decim, n_cycles=2)[0, 0] assert freqs[np.argmax(np.abs(tfr).mean(-1))] == f
def _extract_phase_and_amp(data_ph, data_am, sfreq, freqs_phase, freqs_amp, scale=True): """Extract the phase and amplitude of two signals for PAC viz. data should be shape (n_epochs, n_times)""" from sklearn.preprocessing import scale # Morlet transform to get complex representation band_ph = _compute_tfr([data_ph], freqs_phase, sfreq, method='morlet')[0] band_amp = _compute_tfr([data_ph], freqs_amp, sfreq, method='morlet')[0] # Calculate the phase/amplitude of relevant signals across epochs band_ph_stacked = np.hstack(np.real(band_ph)) angle_ph = np.hstack(np.angle(band_ph)) amp = np.hstack(np.abs(band_amp)**2) # Scale the amplitude for viz so low freqs don't dominate highs if scale is True: amp = scale(amp, axis=1) return angle_ph, band_ph_stacked, amp
def array_tfr(epochs, sf=600, foi=None, cycles=None, time_bandwidth=None, decim=10, n_jobs=4, output='power'): from mne.time_frequency.tfr import _compute_tfr power = _compute_tfr(epochs, foi, sfreq=sf, method='multitaper', decim=decim, n_cycles=cycles, zero_mean=True, time_bandwidth=time_bandwidth, n_jobs=n_jobs, use_fft=True, output=output) return power
def tfr_split(data, processing_params): """Helper to calculate wavelet power in batches instead of all at once Needed for memory issues. """ ''' batch_size = 50 # Number trials to do at once batch_list = [] power_arr = np.zeros((data.shape[0], data.shape[1], len(processing_params['cwt_frequencies']), data.shape[2])) batch_inds = range(0, data.shape[0], batch_size) batch_inds.append(data.shape[0]) for bi1, bi2 in zip(batch_inds[:-1], batch_inds[1:]): batch = _compute_tfr(data[bi1:bi2], frequencies=processing_params['cwt_frequencies'], sfreq=processing_params['sfreq'], n_cycles=processing_params['n_cycles'], decim=processing_params['post_decim'], n_jobs=6, output='power') power_arr[bi1:bi2, :, :, :] = batch ''' batch = _compute_tfr(data, frequencies=processing_params['cwt_frequencies'], sfreq=processing_params['sfreq'], n_cycles=processing_params['n_cycles'], decim=processing_params['post_decim'], n_jobs=processing_params['n_jobs'], output='power') return batch
def test_compute_tfr(): """Test _compute_tfr function.""" # Set parameters event_id = 1 tmin = -0.2 tmax = 0.498 # Allows exhaustive decimation testing # Setup for reading the raw data raw = read_raw_fif(raw_fname) events = read_events(event_fname) exclude = raw.info['bads'] + ['MEG 2443', 'EEG 053'] # bads + 2 more # picks MEG gradiometers picks = pick_types(raw.info, meg='grad', eeg=False, stim=False, include=[], exclude=exclude) picks = picks[:2] epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks) data = epochs.get_data() sfreq = epochs.info['sfreq'] freqs = np.arange(10, 20, 3).astype(float) # Check all combination of options for func, use_fft, zero_mean, output in product( (tfr_array_multitaper, tfr_array_morlet), (False, True), (False, True), ('complex', 'power', 'phase', 'avg_power_itc', 'avg_power', 'itc')): # Check exception if (func == tfr_array_multitaper) and (output == 'phase'): pytest.raises(NotImplementedError, func, data, sfreq=sfreq, freqs=freqs, output=output) continue # Check runs out = func(data, sfreq=sfreq, freqs=freqs, use_fft=use_fft, zero_mean=zero_mean, n_cycles=2., output=output) # Check shapes shape = np.r_[data.shape[:2], len(freqs), data.shape[2]] if ('avg' in output) or ('itc' in output): assert_array_equal(shape[1:], out.shape) else: assert_array_equal(shape, out.shape) # Check types if output in ('complex', 'avg_power_itc'): assert_equal(np.complex128, out.dtype) else: assert_equal(np.float64, out.dtype) assert (np.all(np.isfinite(out))) # Check errors params for _data in (None, 'foo', data[0]): pytest.raises(ValueError, _compute_tfr, _data, freqs, sfreq) for _freqs in (None, 'foo', [[0]]): pytest.raises(ValueError, _compute_tfr, data, _freqs, sfreq) for _sfreq in (None, 'foo'): pytest.raises(ValueError, _compute_tfr, data, freqs, _sfreq) for key in ('output', 'method', 'use_fft', 'decim', 'n_jobs'): for value in (None, 'foo'): kwargs = {key: value} # FIXME pep8 pytest.raises(ValueError, _compute_tfr, data, freqs, sfreq, **kwargs) with pytest.raises(ValueError, match='above Nyquist'): _compute_tfr(data, [sfreq], sfreq) # No time_bandwidth param in morlet pytest.raises(ValueError, _compute_tfr, data, freqs, sfreq, method='morlet', time_bandwidth=1) # No phase in multitaper XXX Check ? pytest.raises(NotImplementedError, _compute_tfr, data, freqs, sfreq, method='multitaper', output='phase') # Inter-trial coherence tests out = _compute_tfr(data, freqs, sfreq, output='itc', n_cycles=2.) assert np.sum(out >= 1) == 0 assert np.sum(out <= 0) == 0 # Check decim shapes # 2: multiple of len(times) even # 3: multiple odd # 8: not multiple, even # 9: not multiple, odd for decim in (2, 3, 8, 9, slice(0, 2), slice(1, 3), slice(2, 4)): _decim = slice(None, None, decim) if isinstance(decim, int) else decim n_time = len(np.arange(data.shape[2])[_decim]) shape = np.r_[data.shape[:2], len(freqs), n_time] for method in ('multitaper', 'morlet'): # Single trials out = _compute_tfr(data, freqs, sfreq, method=method, decim=decim, n_cycles=2.) assert_array_equal(shape, out.shape) # Averages out = _compute_tfr(data, freqs, sfreq, method=method, decim=decim, output='avg_power', n_cycles=2.) assert_array_equal(shape[1:], out.shape)
def test_compute_tfr(): """Test _compute_tfr function.""" # Set parameters event_id = 1 tmin = -0.2 tmax = 0.498 # Allows exhaustive decimation testing # Setup for reading the raw data raw = read_raw_fif(raw_fname) events = read_events(event_fname) exclude = raw.info['bads'] + ['MEG 2443', 'EEG 053'] # bads + 2 more # picks MEG gradiometers picks = pick_types(raw.info, meg='grad', eeg=False, stim=False, include=[], exclude=exclude) picks = picks[:2] epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks) data = epochs.get_data() sfreq = epochs.info['sfreq'] freqs = np.arange(10, 20, 3).astype(float) # Check all combination of options for func, use_fft, zero_mean, output in product( (tfr_array_multitaper, tfr_array_morlet), (False, True), (False, True), ('complex', 'power', 'phase', 'avg_power_itc', 'avg_power', 'itc')): # Check exception if (func == tfr_array_multitaper) and (output == 'phase'): pytest.raises(NotImplementedError, func, data, sfreq=sfreq, freqs=freqs, output=output) continue # Check runs out = func(data, sfreq=sfreq, freqs=freqs, use_fft=use_fft, zero_mean=zero_mean, n_cycles=2., output=output) # Check shapes shape = np.r_[data.shape[:2], len(freqs), data.shape[2]] if ('avg' in output) or ('itc' in output): assert_array_equal(shape[1:], out.shape) else: assert_array_equal(shape, out.shape) # Check types if output in ('complex', 'avg_power_itc'): assert_equal(np.complex, out.dtype) else: assert_equal(np.float, out.dtype) assert (np.all(np.isfinite(out))) # Check errors params for _data in (None, 'foo', data[0]): pytest.raises(ValueError, _compute_tfr, _data, freqs, sfreq) for _freqs in (None, 'foo', [[0]]): pytest.raises(ValueError, _compute_tfr, data, _freqs, sfreq) for _sfreq in (None, 'foo'): pytest.raises(ValueError, _compute_tfr, data, freqs, _sfreq) for key in ('output', 'method', 'use_fft', 'decim', 'n_jobs'): for value in (None, 'foo'): kwargs = {key: value} # FIXME pep8 pytest.raises(ValueError, _compute_tfr, data, freqs, sfreq, **kwargs) # No time_bandwidth param in morlet pytest.raises(ValueError, _compute_tfr, data, freqs, sfreq, method='morlet', time_bandwidth=1) # No phase in multitaper XXX Check ? pytest.raises(NotImplementedError, _compute_tfr, data, freqs, sfreq, method='multitaper', output='phase') # Inter-trial coherence tests out = _compute_tfr(data, freqs, sfreq, output='itc', n_cycles=2.) assert np.sum(out >= 1) == 0 assert np.sum(out <= 0) == 0 # Check decim shapes # 2: multiple of len(times) even # 3: multiple odd # 8: not multiple, even # 9: not multiple, odd for decim in (2, 3, 8, 9, slice(0, 2), slice(1, 3), slice(2, 4)): _decim = slice(None, None, decim) if isinstance(decim, int) else decim n_time = len(np.arange(data.shape[2])[_decim]) shape = np.r_[data.shape[:2], len(freqs), n_time] for method in ('multitaper', 'morlet'): # Single trials out = _compute_tfr(data, freqs, sfreq, method=method, decim=decim, n_cycles=2.) assert_array_equal(shape, out.shape) # Averages out = _compute_tfr(data, freqs, sfreq, method=method, decim=decim, output='avg_power', n_cycles=2.) assert_array_equal(shape[1:], out.shape)
def test_compute_tfr(): """Test _compute_tfr function""" # Set parameters event_id = 1 tmin = -0.2 tmax = 0.498 # Allows exhaustive decimation testing # Setup for reading the raw data raw = io.read_raw_fif(raw_fname) events = read_events(event_fname) exclude = raw.info['bads'] + ['MEG 2443', 'EEG 053'] # bads + 2 more # picks MEG gradiometers picks = pick_types(raw.info, meg='grad', eeg=False, stim=False, include=[], exclude=exclude) picks = picks[:2] epochs = Epochs(raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0)) data = epochs.get_data() sfreq = epochs.info['sfreq'] freqs = np.arange(10, 20, 3).astype(float) # Check all combination of options for method, use_fft, zero_mean, output in product( ('multitaper', 'morlet'), (False, True), (False, True), ('complex', 'power', 'phase', 'avg_power_itc', 'avg_power', 'itc')): # Check exception if (method == 'multitaper') and (output == 'phase'): assert_raises(NotImplementedError, _compute_tfr, data, freqs, sfreq, method=method, output=output) continue # Check runs out = _compute_tfr(data, freqs, sfreq, method=method, use_fft=use_fft, zero_mean=zero_mean, n_cycles=2., output=output) # Check shapes shape = np.r_[data.shape[:2], len(freqs), data.shape[2]] if ('avg' in output) or ('itc' in output): assert_array_equal(shape[1:], out.shape) else: assert_array_equal(shape, out.shape) # Check types if output in ('complex', 'avg_power_itc'): assert_equal(np.complex, out.dtype) else: assert_equal(np.float, out.dtype) assert_true(np.all(np.isfinite(out))) # Check that functions are equivalent to # i) single_trial_power: X, shape (n_signals, n_chans, n_times) old_power = single_trial_power(data, sfreq, freqs, n_cycles=2.) new_power = _compute_tfr(data, freqs, sfreq, n_cycles=2., method='morlet', output='power') assert_array_almost_equal(old_power, new_power) old_power = single_trial_power(data, sfreq, freqs, n_cycles=2., times=epochs.times, baseline=(-.100, 0), baseline_mode='ratio') new_power = rescale(new_power, epochs.times, (-.100, 0), 'ratio') # ii) cwt_morlet: X, shape (n_signals, n_times) old_complex = cwt_morlet(data[0], sfreq, freqs, n_cycles=2.) new_complex = _compute_tfr(data[[0]], freqs, sfreq, n_cycles=2., method='morlet', output='complex') assert_array_almost_equal(old_complex, new_complex[0]) # Check errors params for _data in (None, 'foo', data[0]): assert_raises(ValueError, _compute_tfr, _data, freqs, sfreq) for _freqs in (None, 'foo', [[0]]): assert_raises(ValueError, _compute_tfr, data, _freqs, sfreq) for _sfreq in (None, 'foo'): assert_raises(ValueError, _compute_tfr, data, freqs, _sfreq) for key in ('output', 'method', 'use_fft', 'decim', 'n_jobs'): for value in (None, 'foo'): kwargs = {key: value} # FIXME pep8 assert_raises(ValueError, _compute_tfr, data, freqs, sfreq, **kwargs) # No time_bandwidth param in morlet assert_raises(ValueError, _compute_tfr, data, freqs, sfreq, method='morlet', time_bandwidth=1) # No phase in multitaper XXX Check ? assert_raises(NotImplementedError, _compute_tfr, data, freqs, sfreq, method='multitaper', output='phase') # Inter-trial coherence tests out = _compute_tfr(data, freqs, sfreq, output='itc', n_cycles=2.) assert_true(np.sum(out >= 1) == 0) assert_true(np.sum(out <= 0) == 0) # Check decim shapes # 2: multiple of len(times) even # 3: multiple odd # 8: not multiple, even # 9: not multiple, odd for decim in (2, 3, 8, 9, slice(0, 2), slice(1, 3), slice(2, 4)): _decim = slice(None, None, decim) if isinstance(decim, int) else decim n_time = len(np.arange(data.shape[2])[_decim]) shape = np.r_[data.shape[:2], len(freqs), n_time] for method in ('multitaper', 'morlet'): # Single trials out = _compute_tfr(data, freqs, sfreq, method=method, decim=decim, n_cycles=2.) assert_array_equal(shape, out.shape) # Averages out = _compute_tfr(data, freqs, sfreq, method=method, decim=decim, output='avg_power', n_cycles=2.) assert_array_equal(shape[1:], out.shape)