def test_fft(): n_time_samples, n_trials, n_signals, n_windows = 100, 10, 2, 1 time_series = np.zeros((n_time_samples, n_trials, n_signals)) m = Multitaper(time_series=time_series) assert np.allclose( m.fft().shape, (n_windows, n_trials, m.tapers.shape[1], m.n_fft_samples, n_signals))
def test_tapers(): n_time_samples, n_trials, n_signals = 100, 10, 2 time_series = np.zeros((n_time_samples, n_trials, n_signals)) m = Multitaper(time_series, is_low_bias=False) assert np.allclose(m.tapers.shape, (n_time_samples, m.n_tapers)) m = Multitaper(time_series, tapers=np.zeros((10, 3))) assert np.allclose(m.tapers.shape, (10, 3))
def test_n_trials(): n_time_samples, n_trials, n_signals = 100, 10, 2 time_series = np.zeros((n_time_samples, n_trials, n_signals)) m = Multitaper(time_series=time_series) assert m.n_trials == n_trials time_series = np.zeros((n_time_samples, n_signals)) m = Multitaper(time_series=time_series) assert m.n_trials == 1
def test_time_window_step(sampling_frequency, time_window_step, expected_step): n_time_samples, n_trials, n_signals = 100, 10, 2 time_series = np.zeros((n_time_samples, n_trials, n_signals)) m = Multitaper(time_series=time_series, sampling_frequency=sampling_frequency, time_window_step=time_window_step) assert m.time_window_step == expected_step
def test_n_time_samples(sampling_frequency, time_window_duration, expected_n_time_samples_per_window): n_time_samples, n_trials, n_signals = 100, 10, 2 time_series = np.zeros((n_time_samples, n_trials, n_signals)) m = Multitaper(time_series=time_series, sampling_frequency=sampling_frequency, time_window_duration=time_window_duration) assert (m.n_time_samples_per_window == expected_n_time_samples_per_window)
def test_frequency_resolution(time_halfbandwidth_product, time_window_duration, expected_frequency_resolution): n_time_samples, n_trials, n_signals = 100, 10, 2 time_series = np.zeros((n_time_samples, n_trials, n_signals)) m = Multitaper(time_series=time_series, time_halfbandwidth_product=time_halfbandwidth_product, time_window_duration=time_window_duration) assert m.frequency_resolution == expected_frequency_resolution
def test_n_samples_per_time_step(time_window_step, n_time_samples_per_step, expected_n_samples_per_time_step): n_time_samples, n_trials, n_signals = 100, 10, 2 time_series = np.zeros((n_time_samples, n_trials, n_signals)) m = Multitaper(time_window_duration=0.10, n_time_samples_per_step=n_time_samples_per_step, time_series=time_series, time_window_step=time_window_step) assert m.n_time_samples_per_step == expected_n_samples_per_time_step
def test_frequencies(): n_time_samples, n_trials, n_signals = 100, 10, 2 time_series = np.zeros((n_time_samples, n_trials, n_signals)) n_fft_samples = 4 sampling_frequency = 1000 m = Multitaper(time_series=time_series, sampling_frequency=sampling_frequency, n_fft_samples=n_fft_samples) expected_frequencies = np.array([0, 250, -500, -250]) assert np.allclose(m.frequencies, expected_frequencies)
def test_time(time_window_duration): sampling_frequency = 1500 start_time, end_time = -2.4, 2.4 n_trials, n_signals = 10, 2 n_time_samples = int((end_time - start_time) * sampling_frequency) + 1 time_series = np.zeros((n_time_samples, n_trials, n_signals)) expected_time = np.arange(start_time, end_time, time_window_duration) if not np.allclose(expected_time[-1] + time_window_duration, end_time): expected_time = expected_time[:-1] m = Multitaper(sampling_frequency=sampling_frequency, time_series=time_series, start_time=start_time, time_window_duration=time_window_duration) assert np.allclose(m.time, expected_time)
def multitaper_connectivity(time_series, sampling_frequency, time_window_duration=None, method='coherence_magnitude', signal_names=None, squeeze=False, connectivity_kwargs=None, **kwargs): """ Transform time series to multitaper and calculate connectivity using `method`. Returns an xarray.DataArray with dimensions of ['Time', 'Frequency', 'Source', 'Target'] or ['Time', 'Frequency'] if squeeze=True Parameters ----------- signal_names : iterable of strings Sames of time series used to name the 'Source' and 'Target' axes of xarray. squeeze : bool Whether to only take the first and last source and target time series. Only makes sense for one pair of signals and symmetrical measures. Attributes ---------- time_series : array, shape (n_time_samples, n_trials, n_signals) or (n_time_samples, n_signals) sampling_frequency : float Number of samples per time unit the signal(s) are recorded at. method : str Method used for connectivity calculation time_window_duration : float, optional Duration of sliding window in which to compute the fft. Defaults to the entire time if not set. signal_names : iterable of strings Sames of time series used to name the 'Source' and 'Target' axes of xarray. squeeze : bool Whether to only take the first and last source and target time series. Only makes sense for one pair of signals and symmetrical measures. connectivity_kwargs : dict Arguments to pass to connectivity calculation """ if connectivity_kwargs is None: connectivity_kwargs = {} m = Multitaper(time_series=time_series, sampling_frequency=sampling_frequency, time_window_duration=time_window_duration, **kwargs) return connectivity_to_xarray(m, method, signal_names, squeeze, **connectivity_kwargs)
def test_n_tapers(time_halfbandwidth_product, expected_n_tapers): n_time_samples, n_trials, n_signals = 100, 10, 2 time_series = np.zeros((n_time_samples, n_trials, n_signals)) m = Multitaper(time_series=time_series, time_halfbandwidth_product=time_halfbandwidth_product) assert m.n_tapers == expected_n_tapers
def multitaper_connectivity(time_series, sampling_frequency, time_window_duration=None, method=None, signal_names=None, squeeze=False, connectivity_kwargs=None, **kwargs): """ Transform time series to multitaper and calculate connectivity using `method`. Returns an xarray.DataSet with dimensions of ['Time', 'Frequency', 'Source', 'Target'] or ['Time', 'Frequency'] if squeeze=True. Its Data variables are measures Parameters ----------- signal_names : iterable of strings Sames of time series used to name the 'Source' and 'Target' axes of xarray. squeeze : bool Whether to only take the first and last source and target time series. Only makes sense for one pair of signals and symmetrical measures. Attributes ---------- time_series : array, shape (n_time_samples, n_trials, n_signals) or (n_time_samples, n_signals) sampling_frequency : float Number of samples per time unit the signal(s) are recorded at. method : iterable of strings, optional Method used for connectivity calculation. If None, all available measures are calculated time_window_duration : float, optional Duration of sliding window in which to compute the fft. Defaults to the entire time if not set. signal_names : iterable of strings Sames of time series used to name the 'Source' and 'Target' axes of xarray. squeeze : bool Whether to only take the first and last source and target time series. Only makes sense for one pair of signals and symmetrical measures. connectivity_kwargs : dict Arguments to pass to connectivity calculation Returns -------- connectivities : Xarray.Dataset with connectivity measure(s) as data variables """ if connectivity_kwargs is None: connectivity_kwargs = {} return_dataarray = False # Default: return dataset if method is None: # All implemented methods except internal # TODO is there a better way to get all Connectivity methods? bad_methods = [ 'delay', 'n_observations', 'frequencies', 'from_multitaper', 'phase_slope_index' ] method = [ x for x in dir(Connectivity) if not x.startswith('_') and x not in bad_methods ] elif type(method) == str: method = [method] # Convert to list return_dataarray = True # Return dataarray if methods was not an iterable m = Multitaper(time_series=time_series, sampling_frequency=sampling_frequency, time_window_duration=time_window_duration, **kwargs) cons = xr.Dataset() # Initialize for this_method in method: try: con = connectivity_to_xarray(m, this_method, signal_names, squeeze, **connectivity_kwargs) cons[this_method] = con # Add data variable except NotImplementedError as e: if len(method) == 1: raise e # If that was the only method requested else: # If one measure among many, just warn logger.warning(f'{this_method} is not implemented in xarray') if return_dataarray and method[0] in cons: return cons[method[0]] else: return cons