コード例 #1
0
def test_fft():
    n_time_samples, n_trials, n_signals, n_windows = 100, 10, 2, 1
    time_series = np.zeros((n_time_samples, n_trials, n_signals))
    m = Multitaper(time_series=time_series)
    assert np.allclose(
        m.fft().shape,
        (n_windows, n_trials, m.tapers.shape[1], m.n_fft_samples, n_signals))
コード例 #2
0
def test_tapers():
    n_time_samples, n_trials, n_signals = 100, 10, 2
    time_series = np.zeros((n_time_samples, n_trials, n_signals))
    m = Multitaper(time_series, is_low_bias=False)
    assert np.allclose(m.tapers.shape, (n_time_samples, m.n_tapers))

    m = Multitaper(time_series, tapers=np.zeros((10, 3)))
    assert np.allclose(m.tapers.shape, (10, 3))
コード例 #3
0
def test_n_trials():
    n_time_samples, n_trials, n_signals = 100, 10, 2
    time_series = np.zeros((n_time_samples, n_trials, n_signals))
    m = Multitaper(time_series=time_series)
    assert m.n_trials == n_trials

    time_series = np.zeros((n_time_samples, n_signals))
    m = Multitaper(time_series=time_series)
    assert m.n_trials == 1
コード例 #4
0
def test_time_window_step(sampling_frequency, time_window_step, expected_step):
    n_time_samples, n_trials, n_signals = 100, 10, 2
    time_series = np.zeros((n_time_samples, n_trials, n_signals))
    m = Multitaper(time_series=time_series,
                   sampling_frequency=sampling_frequency,
                   time_window_step=time_window_step)
    assert m.time_window_step == expected_step
コード例 #5
0
def test_n_time_samples(sampling_frequency, time_window_duration,
                        expected_n_time_samples_per_window):
    n_time_samples, n_trials, n_signals = 100, 10, 2
    time_series = np.zeros((n_time_samples, n_trials, n_signals))
    m = Multitaper(time_series=time_series,
                   sampling_frequency=sampling_frequency,
                   time_window_duration=time_window_duration)
    assert (m.n_time_samples_per_window == expected_n_time_samples_per_window)
コード例 #6
0
def test_frequency_resolution(time_halfbandwidth_product, time_window_duration,
                              expected_frequency_resolution):
    n_time_samples, n_trials, n_signals = 100, 10, 2
    time_series = np.zeros((n_time_samples, n_trials, n_signals))
    m = Multitaper(time_series=time_series,
                   time_halfbandwidth_product=time_halfbandwidth_product,
                   time_window_duration=time_window_duration)
    assert m.frequency_resolution == expected_frequency_resolution
コード例 #7
0
def test_n_samples_per_time_step(time_window_step, n_time_samples_per_step,
                                 expected_n_samples_per_time_step):
    n_time_samples, n_trials, n_signals = 100, 10, 2
    time_series = np.zeros((n_time_samples, n_trials, n_signals))

    m = Multitaper(time_window_duration=0.10,
                   n_time_samples_per_step=n_time_samples_per_step,
                   time_series=time_series,
                   time_window_step=time_window_step)
    assert m.n_time_samples_per_step == expected_n_samples_per_time_step
コード例 #8
0
def test_frequencies():
    n_time_samples, n_trials, n_signals = 100, 10, 2
    time_series = np.zeros((n_time_samples, n_trials, n_signals))
    n_fft_samples = 4
    sampling_frequency = 1000
    m = Multitaper(time_series=time_series,
                   sampling_frequency=sampling_frequency,
                   n_fft_samples=n_fft_samples)
    expected_frequencies = np.array([0, 250, -500, -250])
    assert np.allclose(m.frequencies, expected_frequencies)
コード例 #9
0
def test_time(time_window_duration):
    sampling_frequency = 1500
    start_time, end_time = -2.4, 2.4
    n_trials, n_signals = 10, 2
    n_time_samples = int((end_time - start_time) * sampling_frequency) + 1
    time_series = np.zeros((n_time_samples, n_trials, n_signals))
    expected_time = np.arange(start_time, end_time, time_window_duration)
    if not np.allclose(expected_time[-1] + time_window_duration, end_time):
        expected_time = expected_time[:-1]
    m = Multitaper(sampling_frequency=sampling_frequency,
                   time_series=time_series,
                   start_time=start_time,
                   time_window_duration=time_window_duration)
    assert np.allclose(m.time, expected_time)
コード例 #10
0
def multitaper_connectivity(time_series, sampling_frequency,
                            time_window_duration=None,
                            method='coherence_magnitude', signal_names=None,
                            squeeze=False, connectivity_kwargs=None, **kwargs):
    """
    Transform time series to multitaper and
    calculate connectivity using `method`. Returns an xarray.DataArray
    with dimensions of ['Time', 'Frequency', 'Source', 'Target']
    or ['Time', 'Frequency'] if squeeze=True

    Parameters
    -----------
    signal_names : iterable of strings
        Sames of time series used to name the 'Source' and 'Target' axes of
        xarray.
    squeeze : bool
        Whether to only take the first and last source and target time series.
        Only makes sense for one pair of signals and symmetrical measures.

    Attributes
    ----------
    time_series : array, shape (n_time_samples, n_trials, n_signals) or
                               (n_time_samples, n_signals)
    sampling_frequency : float
        Number of samples per time unit the signal(s) are recorded at.
    method : str
        Method used for connectivity calculation
    time_window_duration : float, optional
        Duration of sliding window in which to compute the fft. Defaults to
        the entire time if not set.
    signal_names : iterable of strings
        Sames of time series used to name the 'Source' and 'Target' axes of
        xarray.
    squeeze : bool
        Whether to only take the first and last source and target time series.
        Only makes sense for one pair of signals and symmetrical measures.
    connectivity_kwargs : dict
        Arguments to pass to connectivity calculation


    """
    if connectivity_kwargs is None:
        connectivity_kwargs = {}
    m = Multitaper(time_series=time_series,
                   sampling_frequency=sampling_frequency,
                   time_window_duration=time_window_duration,
                   **kwargs)
    return connectivity_to_xarray(m, method, signal_names, squeeze,
                                  **connectivity_kwargs)
コード例 #11
0
def test_n_tapers(time_halfbandwidth_product, expected_n_tapers):
    n_time_samples, n_trials, n_signals = 100, 10, 2
    time_series = np.zeros((n_time_samples, n_trials, n_signals))
    m = Multitaper(time_series=time_series,
                   time_halfbandwidth_product=time_halfbandwidth_product)
    assert m.n_tapers == expected_n_tapers
コード例 #12
0
def multitaper_connectivity(time_series,
                            sampling_frequency,
                            time_window_duration=None,
                            method=None,
                            signal_names=None,
                            squeeze=False,
                            connectivity_kwargs=None,
                            **kwargs):
    """
    Transform time series to multitaper and
    calculate connectivity using `method`. Returns an xarray.DataSet
    with dimensions of ['Time', 'Frequency', 'Source', 'Target']
    or ['Time', 'Frequency'] if squeeze=True.
    Its Data variables are measures

    Parameters
    -----------
    signal_names : iterable of strings
        Sames of time series used to name the 'Source' and 'Target' axes of
        xarray.
    squeeze : bool
        Whether to only take the first and last source and target time series.
        Only makes sense for one pair of signals and symmetrical measures.

    Attributes
    ----------
    time_series : array, shape (n_time_samples, n_trials, n_signals) or
                               (n_time_samples, n_signals)
    sampling_frequency : float
        Number of samples per time unit the signal(s) are recorded at.
    method : iterable of strings, optional
        Method used for connectivity calculation. If None, all available
        measures are calculated
    time_window_duration : float, optional
        Duration of sliding window in which to compute the fft. Defaults to
        the entire time if not set.
    signal_names : iterable of strings
        Sames of time series used to name the 'Source' and 'Target' axes of
        xarray.
    squeeze : bool
        Whether to only take the first and last source and target time series.
        Only makes sense for one pair of signals and symmetrical measures.
    connectivity_kwargs : dict
        Arguments to pass to connectivity calculation

    Returns
    --------
    connectivities : Xarray.Dataset with connectivity measure(s) as data variables



    """
    if connectivity_kwargs is None:
        connectivity_kwargs = {}
    return_dataarray = False  # Default: return dataset
    if method is None:
        # All implemented methods except internal
        # TODO is there a better way to get all Connectivity methods?
        bad_methods = [
            'delay', 'n_observations', 'frequencies', 'from_multitaper',
            'phase_slope_index'
        ]
        method = [
            x for x in dir(Connectivity)
            if not x.startswith('_') and x not in bad_methods
        ]
    elif type(method) == str:
        method = [method]  # Convert to list
        return_dataarray = True  # Return dataarray if methods was not an iterable
    m = Multitaper(time_series=time_series,
                   sampling_frequency=sampling_frequency,
                   time_window_duration=time_window_duration,
                   **kwargs)
    cons = xr.Dataset()  # Initialize
    for this_method in method:
        try:
            con = connectivity_to_xarray(m, this_method, signal_names, squeeze,
                                         **connectivity_kwargs)
            cons[this_method] = con  # Add data variable
        except NotImplementedError as e:
            if len(method) == 1:
                raise e  # If that was the only method requested
            else:
                # If one measure among many, just warn
                logger.warning(f'{this_method} is not implemented in xarray')
    if return_dataarray and method[0] in cons:
        return cons[method[0]]
    else:
        return cons