예제 #1
0
    def test_floatsource(self):
        for source in [self.floatSource, self.randFloatSource]:
            for selection in [
                    self.intSelection, self.randIntSelection,
                    self.floatSelection, self.sortFloatSelection
            ]:
                expectedVal = np.array([
                    source[np.argmin(np.abs(source - elem))]
                    for elem in selection
                ])
                expectedIdx = np.array(
                    [np.where(source == elem)[0][0] for elem in expectedVal])
                val, idx = best_match(source, selection)
                assert np.array_equal(val, expectedVal)
                assert np.array_equal(idx, expectedIdx)

                val, idx = best_match(source,
                                      selection,
                                      squash_duplicates=True)
                _, sidx = np.unique(expectedVal, return_index=True)
                sidx.sort()
                assert np.array_equal(val, expectedVal[sidx])
                assert np.array_equal(idx, expectedIdx[sidx])

                with pytest.raises(SPYValueError):
                    best_match(source, selection, tol=1e-6)

                val, idx = best_match(
                    source, [selection.min(), selection.max()], span=True)
                expectedVal = np.array([
                    elem for elem in source
                    if selection.min() <= elem <= selection.max()
                ])
                expectedIdx = np.array(
                    [np.where(source == elem)[0][0] for elem in expectedVal])
예제 #2
0
    def _get_time(self, trials, toi=None, toilim=None):
        """
        Get relative by-trial indices of time-selections

        Parameters
        ----------
        trials : list
            List of trial-indices to perform selection on
        toi : None or list
            Time-points to be selected (in seconds) on a by-trial scale.
        toilim : None or list
            Time-window to be selected (in seconds) on a by-trial scale

        Returns
        -------
        timing : list of lists
            List of by-trial sample-indices corresponding to provided
            time-selection. If both `toi` and `toilim` are `None`, `timing`
            is a list of universal (i.e., ``slice(None)``) selectors.

        Notes
        -----
        This class method is intended to be solely used by
        :class:`syncopy.datatype.base_data.Selector` objects and thus has purely
        auxiliary character. Therefore, all input sanitization and error checking
        is left to :class:`syncopy.datatype.base_data.Selector` and not
        performed here.

        See also
        --------
        syncopy.datatype.base_data.Selector : Syncopy data selectors
        """
        timing = []
        if toilim is not None:
            for trlno in trials:
                _, selTime = best_match(self.time[trlno], toilim, span=True)
                selTime = selTime.tolist()
                if len(selTime) > 1:
                    timing.append(slice(selTime[0], selTime[-1] + 1, 1))
                else:
                    timing.append(selTime)

        elif toi is not None:
            for trlno in trials:
                _, selTime = best_match(self.time[trlno], toi)
                selTime = selTime.tolist()
                if len(selTime) > 1:
                    timeSteps = np.diff(selTime)
                    if timeSteps.min() == timeSteps.max() == 1:
                        selTime = slice(selTime[0], selTime[-1] + 1, 1)
                timing.append(selTime)

        else:
            timing = [slice(None)] * len(trials)

        return timing
예제 #3
0
    def _get_freq(self, foi=None, foilim=None):
        """
        Coming soon...
        Error checking is performed by `Selector` class
        """
        if foilim is not None:
            _, selFreq = best_match(self.freq, foilim, span=True)
            selFreq = selFreq.tolist()
            if len(selFreq) > 1:
                selFreq = slice(selFreq[0], selFreq[-1] + 1, 1)

        elif foi is not None:
            _, selFreq = best_match(self.freq, foi)
            selFreq = selFreq.tolist()
            if len(selFreq) > 1:
                freqSteps = np.diff(selFreq)
                if freqSteps.min() == freqSteps.max() == 1:
                    selFreq = slice(selFreq[0], selFreq[-1] + 1, 1)

        else:
            selFreq = slice(None)

        return selFreq
예제 #4
0
def freqanalysis(data, method='mtmfft', output='fourier',
                 keeptrials=True, foi=None, foilim=None,
                 pad_to_length=None, polyremoval=None,
                 taper="hann", tapsmofrq=None, nTaper=None, keeptapers=False,
                 toi="all", t_ftimwin=None, wavelet="Morlet", width=6, order=None,
                 order_max=None, order_min=1, c_1=3, adaptive=False,
                 out=None, **kwargs):
    """
    Perform (time-)frequency analysis of Syncopy :class:`~syncopy.AnalogData` objects

    **Usage Summary**

    Options available in all analysis methods:

    * **output** : one of :data:`~syncopy.specest.const_def.availableOutputs`;
      return power spectra, complex Fourier spectra or absolute values.
    * **foi**/**foilim** : frequencies of interest; either array of frequencies or
      frequency window (not both)
    * **keeptrials** : return individual trials or grand average
    * **polyremoval** : de-trending method to use (0 = mean, 1 = linear or `None`)

    List of available analysis methods and respective distinct options:

    "mtmfft" : (Multi-)tapered Fourier transform
        Perform frequency analysis on time-series trial data using either a single
        taper window (Hanning) or many tapers based on the discrete prolate
        spheroidal sequence (DPSS) that maximize energy concentration in the main
        lobe.

        * **taper** : one of :data:`~syncopy.shared.const_def.availableTapers`
        * **tapsmofrq** : spectral smoothing box for slepian tapers (in Hz)
        * **nTaper** : number of orthogonal tapers for slepian tapers
        * **keeptapers** : return individual tapers or average
        * **pad_to_length**: either pad to an absolute length or set to `'nextpow2'`

    "mtmconvol" : (Multi-)tapered sliding window Fourier transform
        Perform time-frequency analysis on time-series trial data based on a sliding
        window short-time Fourier transform using either a single Hanning taper or
        multiple DPSS tapers.

        * **taper** : one of :data:`~syncopy.specest.const_def.availableTapers`
        * **tapsmofrq** : spectral smoothing box for slepian tapers (in Hz)
        * **nTaper** : number of orthogonal tapers for slepian tapers
        * **keeptapers** : return individual tapers or average
        * **toi** : time-points of interest; can be either an array representing
          analysis window centroids (in sec), a scalar between 0 and 1 encoding
          the percentage of overlap between adjacent windows or "all" to center
          a window on every sample in the data.
        * **t_ftimwin** : sliding window length (in sec)

    "wavelet" : (Continuous non-orthogonal) wavelet transform
        Perform time-frequency analysis on time-series trial data using a non-orthogonal
        continuous wavelet transform.

        * **wavelet** : one of :data:`~syncopy.specest.const_def.availableWavelets`
        * **toi** : time-points of interest; can be either an array representing
          time points (in sec) or "all"(pre-trimming and subsampling of results)
        * **width** : Nondimensional frequency constant of Morlet wavelet function (>= 6)
        * **order** : Order of Paul wavelet function (>= 4) or derivative order
          of real-valued DOG wavelets (2 = mexican hat)

    "superlet" : Superlet transform
        Perform time-frequency analysis on time-series trial data using
        the super-resolution superlet transform (SLT) from [Moca2021]_.

        * **order_max** : Maximal order of the superlet
        * **order_min** : Minimal order of the superlet
        * **c_1** : Number of cycles of the base Morlet wavelet
        * **adaptive** : If set to `True` perform fractional adaptive SLT,
          otherwise perform multiplicative SLT

    **Full documentation below**

    Parameters
    ----------
    data : `~syncopy.AnalogData`
        A non-empty Syncopy :class:`~syncopy.datatype.AnalogData` object
    method : str
        Spectral estimation method, one of :data:`~syncopy.specest.const_def.availableMethods`
        (see below).
    output : str
        Output of spectral estimation. One of :data:`~syncopy.specest.const_def.availableOutputs` (see below);
        use `'pow'` for power spectrum (:obj:`numpy.float32`), `'fourier'` for complex
        Fourier coefficients (:obj:`numpy.complex64`) or `'abs'` for absolute
        values (:obj:`numpy.float32`).
    keeptrials : bool
        If `True` spectral estimates of individual trials are returned, otherwise
        results are averaged across trials.
    foi : array-like or None
        Frequencies of interest (Hz) for output. If desired frequencies cannot be
        matched exactly, the closest possible frequencies are used. If `foi` is `None`
        or ``foi = "all"``, all attainable frequencies (i.e., zero to Nyquist / 2)
        are selected.
    foilim : array-like (floats [fmin, fmax]) or None or "all"
        Frequency-window ``[fmin, fmax]`` (in Hz) of interest. Window
        specifications must be sorted (e.g., ``[90, 70]`` is invalid) and not NaN
        but may be unbounded (e.g., ``[-np.inf, 60.5]`` is valid). Edges `fmin`
        and `fmax` are included in the selection. If `foilim` is `None` or
        ``foilim = "all"``, all frequencies are selected.
    pad_to_length : int, None or 'nextpow2'
        Padding of the input data, if set to a number pads all trials
        to this absolute length. For instance ``pad_to_length = 2000`` pads all
        trials to an absolute length of 2000 samples, if and only if the longest
        trial contains at maximum 2000 samples.
        Alternatively if all trials have the same initial lengths
        setting `pad_to_length='nextpow2'` pads all trials to
        the next power of two.
        If `None` and trials have unequal lengths all trials are padded to match
        the longest trial.
    polyremoval : int or None
        Order of polynomial used for de-trending data in the time domain prior
        to spectral analysis. A value of 0 corresponds to subtracting the mean
        ("de-meaning"), ``polyremoval = 1`` removes linear trends (subtracting the
        least squares fit of a linear polynomial).
        If `polyremoval` is `None`, no de-trending is performed. Note that
        for spectral estimation de-meaning is very advisable and hence also the
        default.
    taper : str
        Only valid if `method` is `'mtmfft'` or `'mtmconvol'`. Windowing function,
        one of :data:`~syncopy.specest.const_def.availableTapers` (see below).
    tapsmofrq : float
        Only valid if `method` is `'mtmfft'` or `'mtmconvol'` and `taper` is `'dpss'`.
        The amount of spectral smoothing through  multi-tapering (Hz).
        Note that smoothing frequency specifications are one-sided,
        i.e., 4 Hz smoothing means plus-minus 4 Hz, i.e., a 8 Hz smoothing box.
    nTaper : int or None
        Only valid if `method` is `'mtmfft'` or `'mtmconvol'` and `taper='dpss'`.
        Number of orthogonal tapers to use. It is not recommended to set the number
        of tapers manually! Leave at `None` for the optimal number to be set automatically.
    keeptapers : bool
        Only valid if `method` is `'mtmfft'` or `'mtmconvol'`.
        If `True`, return spectral estimates for each taper.
        Otherwise power spectrum is averaged across tapers,
        if and only if `output` is `pow`.
    toi : float or array-like or "all"
        **Mandatory input** for time-frequency analysis methods (`method` is either
        `"mtmconvol"` or `"wavelet"` or `"superlet"`).
        If `toi` is scalar, it must be a value between 0 and 1 indicating the
        percentage of overlap between time-windows specified by `t_ftimwin` (only
        valid if `method` is `'mtmconvol'`).
        If `toi` is an array it explicitly selects the centroids of analysis
        windows (in seconds), if `toi` is `"all"`, analysis windows are centered
        on all samples in the data for `method="mtmconvol"`. For wavelet based
        methods (`"wavelet"` or `"superlet"`) toi needs to be either an
        equidistant array of time points or "all".
    t_ftimwin : positive float
        Only valid if `method` is `'mtmconvol'`. Sliding window length (in seconds).
    wavelet : str
        Only valid if `method` is `'wavelet'`. Wavelet function to use, one of
        :data:`~syncopy.specest.const_def.availableWavelets` (see below).
    width : positive float
        Only valid if `method` is `'wavelet'` and `wavelet` is `'Morlet'`. Nondimensional
        frequency constant of Morlet wavelet function. This number should be >= 6,
        which corresponds to 6 cycles within the analysis window to ensure sufficient
        spectral sampling.
    order : positive int
        Only valid if `method` is `'wavelet'` and `wavelet` is `'Paul'` or `'DOG'`. Order
        of the wavelet function. If `wavelet` is `'Paul'`, `order` should be chosen
        >= 4 to ensure that the analysis window contains at least a single oscillation.
        At an order of 40, the Paul wavelet  exhibits about the same number of cycles
        as the Morlet wavelet with a `width` of 6.
        All other supported wavelets functions are *real-valued* derivatives of
        Gaussians (DOGs). Hence, if `wavelet` is `'DOG'`, `order` represents the derivative order.
        The special case of a second order DOG yields a function known as "Mexican Hat",
        "Marr" or "Ricker" wavelet, which can be selected alternatively by setting
        `wavelet` to `'Mexican_hat'`, `'Marr'` or `'Ricker'`. **Note**: A real-valued
        wavelet function encodes *only* information about peaks and discontinuities
        in the signal and does *not* provide any information about amplitude or phase.
    order_max : int
        Only valid if `method` is `'superlet'`.
        Maximal order of the superlet set. Controls the maximum
        number of cycles within a SL together
        with the `c_1` parameter: c_max = c_1 * order_max
    order_min : int
        Only valid if `method` is `'superlet'`.
        Minimal order of the superlet set. Controls
        the minimal number of cycles within a SL together
        with the `c_1` parameter: c_min = c_1 * order_min
        Note that for admissability reasons c_min should be at least 3!
    c_1 : int
        Only valid if `method` is `'superlet'`.
        Number of cycles of the base Morlet wavelet. If set to lower
        than 3 increase `order_min` as to never have less than 3 cycles
        in a wavelet!
    adaptive : bool
        Only valid if `method` is `'superlet'`.
        Wether to perform multiplicative SLT or fractional adaptive SLT.
        If set to True, the order of the wavelet set will increase
        linearly with the frequencies of interest from `order_min`
        to `order_max`. If set to False the same SL will be used for
        all frequencies.
    out : None or :class:`SpectralData` object
        None if a new :class:`SpectralData` object is to be created, or an empty :class:`SpectralData` object


    Returns
    -------
    spec : :class:`~syncopy.SpectralData`
        (Time-)frequency spectrum of input data

    Notes
    -----
    .. [Moca2021] Moca, Vasile V., et al. "Time-frequency super-resolution with superlets."
       Nature communications 12.1 (2021): 1-18.

    **Options**

    .. autodata:: syncopy.specest.const_def.availableMethods

    .. autodata:: syncopy.specest.const_def.availableOutputs

    .. autodata:: syncopy.specest.const_def.availableTapers

    .. autodata:: syncopy.specest.const_def.availableWavelets

    Examples
    --------
    Coming soon...



    See also
    --------
    syncopy.specest.mtmfft.mtmfft : (multi-)tapered Fourier transform of multi-channel time series data
    syncopy.specest.mtmconvol.mtmconvol : time-frequency analysis of multi-channel time series data with a sliding window FFT
    syncopy.specest.wavelet.wavelet : time-frequency analysis of multi-channel time series data using a wavelet transform
    numpy.fft.fft : NumPy's reference FFT implementation
    scipy.signal.stft : SciPy's Short Time Fourier Transform
    """

    # Make sure our one mandatory input object can be processed
    try:
        data_parser(data, varname="data", dataclass="AnalogData",
                    writable=None, empty=False)
    except Exception as exc:
        raise exc
    timeAxis = data.dimord.index("time")

    # Get everything of interest in local namespace
    defaults = get_defaults(freqanalysis)
    lcls = locals()
    # check for ineffective additional kwargs
    check_passed_kwargs(lcls, defaults, frontend_name="freqanalysis")

    # Ensure a valid computational method was selected
    if method not in availableMethods:
        lgl = "'" + "or '".join(opt + "' " for opt in availableMethods)
        raise SPYValueError(legal=lgl, varname="method", actual=method)

    # Ensure a valid output format was selected
    if output not in spectralConversions.keys():
        lgl = "'" + "or '".join(opt + "' " for opt in spectralConversions.keys())
        raise SPYValueError(legal=lgl, varname="output", actual=output)

    # Parse all Boolean keyword arguments
    for vname in ["keeptrials", "keeptapers"]:
        if not isinstance(lcls[vname], bool):
            raise SPYTypeError(lcls[vname], varname=vname, expected="Bool")

    # If only a subset of `data` is to be processed, make some necessary adjustments
    # of the sampleinfo and trial lengths
    if data._selection is not None:
        sinfo = data._selection.trialdefinition[:, :2]
        trialList = data._selection.trials
    else:
        trialList = list(range(len(data.trials)))
        sinfo = data.sampleinfo
    lenTrials = np.diff(sinfo).squeeze()
    if not lenTrials.shape:
        lenTrials = lenTrials[None]
    numTrials = len(trialList)

    # check polyremoval
    if polyremoval is not None:
        scalar_parser(polyremoval, varname="polyremoval", ntype="int_like", lims=[0, 1])


    # --- Padding ---

    # Sliding window FFT does not support "fancy" padding
    if method == "mtmconvol" and isinstance(pad_to_length, str):
        msg = "method 'mtmconvol' only supports in-place padding for windows " +\
            "exceeding trial boundaries. Your choice of `pad_to_length = '{}'` will be ignored. "
        SPYWarning(msg.format(pad_to_length))

    if method == 'mtmfft':
        # the actual number of samples in case of later padding
        minSampleNum = validate_padding(pad_to_length, lenTrials)
    else:
        minSampleNum = lenTrials.min()

    # Compute length (in samples) of shortest trial
    minTrialLength = minSampleNum / data.samplerate

    # Shortcut to data sampling interval
    dt = 1 / data.samplerate

    foi, foilim = validate_foi(foi, foilim, data.samplerate)

    # see also https://docs.obspy.org/_modules/obspy/signal/detrend.html#polynomial
    if polyremoval is not None:
        try:
            scalar_parser(polyremoval, varname="polyremoval", lims=[0, 1], ntype="int_like")
        except Exception as exc:
            raise exc

    # Prepare keyword dict for logging (use `lcls` to get actually provided
    # keyword values, not defaults set above)
    log_dct = {"method": method,
               "output": output,
               "keeptapers": keeptapers,
               "keeptrials": keeptrials,
               "polyremoval": polyremoval,
               "pad_to_length": pad_to_length}

    # --------------------------------
    # 1st: Check time-frequency inputs
    # to prepare/sanitize `toi`
    # --------------------------------

    if method in ["mtmconvol", "wavelet", "superlet"]:

        # Get start/end timing info respecting potential in-place selection
        if toi is None:
            raise SPYTypeError(toi, varname="toi", expected="scalar or array-like or 'all'")
        if data._selection is not None:
            tStart = data._selection.trialdefinition[:, 2] / data.samplerate
        else:
            tStart = data._t0 / data.samplerate
        tEnd = tStart + lenTrials / data.samplerate

    # for these methods only 'all' or an equidistant array
    # of time points (sub-sampling, trimming) are valid
    if method in ["wavelet", "superlet"]:

        valid = True
        if isinstance(toi, Number):
            valid = False

        elif isinstance(toi, str):
            if toi != "all":
                valid = False
            else:
                # take everything
                preSelect = [slice(None)] * numTrials
                postSelect = [slice(None)] * numTrials

        elif not iter(toi):
            valid = False

        # this is the sequence type - can only be an interval!
        else:
            try:
                array_parser(toi, varname="toi", hasinf=False, hasnan=False,
                             lims=[tStart.min(), tEnd.max()], dims=(None,))
            except Exception as exc:
                raise exc
            toi = np.array(toi)
            # check for equidistancy
            if not np.allclose(np.diff(toi, 2), np.zeros(len(toi) - 2)):
                valid = False
            # trim (preSelect) and subsample output (postSelect)
            else:
                preSelect = []
                postSelect = []
                # get sample intervals and relative indices from toi
                for tk in range(numTrials):
                    start = int(data.samplerate * (toi[0] - tStart[tk]))
                    stop = int(data.samplerate * (toi[-1] - tStart[tk]) + 1)
                    preSelect.append(slice(max(0, start), max(stop, stop - start)))
                    smpIdx = np.minimum(lenTrials[tk] - 1,
                                        data.samplerate * (toi - tStart[tk]) - start)
                    postSelect.append(smpIdx.astype(np.intp))

        # get out if sth wasn't right
        if not valid:
            lgl = "array of equidistant time-points or 'all' for wavelet based methods"
            raise SPYValueError(legal=lgl, varname="toi", actual=toi)


        # Update `log_dct` w/method-specific options (use `lcls` to get actually
        # provided keyword values, not defaults set in here)
        log_dct["toi"] = lcls["toi"]

    # --------------------------------------------
    # Check options specific to mtm*-methods
    # (particularly tapers and foi/freqs alignment)
    # --------------------------------------------

    if "mtm" in method:

        if method == "mtmconvol":
            # get the sliding window size
            try:
                scalar_parser(t_ftimwin, varname="t_ftimwin",
                              lims=[dt, minTrialLength])
            except Exception as exc:
                SPYInfo("Please specify 't_ftimwin' parameter.. exiting!")
                raise exc

            # this is the effective sliding window FFT sample size
            minSampleNum = int(t_ftimwin * data.samplerate)

        # Construct array of maximally attainable frequencies
        freqs = np.fft.rfftfreq(minSampleNum, dt)

        # Match desired frequencies as close as possible to
        # actually attainable freqs
        # these are the frequencies attached to the SpectralData by the CR!
        if foi is not None:
            foi, _ = best_match(freqs, foi, squash_duplicates=True)
        elif foilim is not None:
            foi, _ = best_match(freqs, foilim, span=True, squash_duplicates=True)
        else:
            msg = (f"Automatic FFT frequency selection from {freqs[0]:.1f}Hz to "
                   f"{freqs[-1]:.1f}Hz")
            SPYInfo(msg)
            foi = freqs
        log_dct["foi"] = foi

        # Abort if desired frequency selection is empty
        if foi.size == 0:
            lgl = "non-empty frequency specification"
            act = "empty frequency selection"
            raise SPYValueError(legal=lgl, varname="foi/foilim", actual=act)

        # sanitize taper selection and retrieve dpss settings
        taper_opt = validate_taper(taper,
                                   tapsmofrq,
                                   nTaper,
                                   keeptapers,
                                   foimax=foi.max(),
                                   samplerate=data.samplerate,
                                   nSamples=minSampleNum,
                                   output=output)

        # Update `log_dct` w/method-specific options
        log_dct["taper"] = taper
        # only dpss returns non-empty taper_opt dict
        if taper_opt:
            log_dct["nTaper"] = taper_opt["Kmax"]
            log_dct["tapsmofrq"] = tapsmofrq

    # -------------------------------------------------------
    # Now, prepare explicit compute-classes for chosen method
    # -------------------------------------------------------

    if method == "mtmfft":

        check_effective_parameters(MultiTaperFFT, defaults, lcls)

        # method specific parameters
        method_kwargs = {
            'samplerate': data.samplerate,
            'taper': taper,
            'taper_opt': taper_opt,
            'nSamples': minSampleNum
        }

        # Set up compute-class
        specestMethod = MultiTaperFFT(
            foi=foi,
            timeAxis=timeAxis,
            keeptapers=keeptapers,
            polyremoval=polyremoval,
            output_fmt=output,
            method_kwargs=method_kwargs)

    elif method == "mtmconvol":

        check_effective_parameters(MultiTaperFFTConvol, defaults, lcls)

        # Process `toi` for sliding window multi taper fft,
        # we have to account for three scenarios: (1) center sliding
        # windows on all samples in (selected) trials (2) `toi` was provided as
        # percentage indicating the degree of overlap b/w time-windows and (3) a set
        # of discrete time points was provided. These three cases are encoded in
        # `overlap, i.e., ``overlap > 1` => all, `0 < overlap < 1` => percentage,
        # `overlap < 0` => discrete `toi`

        # overlap = None
        if isinstance(toi, str):
            if toi != "all":
                lgl = "`toi = 'all'` to center analysis windows on all time-points"
                raise SPYValueError(legal=lgl, varname="toi", actual=toi)
            equidistant = True
            overlap = np.inf

        elif isinstance(toi, Number):
            try:
                scalar_parser(toi, varname="toi", lims=[0, 1])
            except Exception as exc:
                raise exc
            overlap = toi
            equidistant = True
        # this captures all other cases, e.i. toi is of sequence type
        else:
            overlap = -1
            try:
                array_parser(toi, varname="toi", hasinf=False, hasnan=False,
                             lims=[tStart.min(), tEnd.max()], dims=(None,))
            except Exception as exc:
                raise exc
            toi = np.array(toi)
            tSteps = np.diff(toi)
            if (tSteps < 0).any():
                lgl = "ordered list/array of time-points"
                act = "unsorted list/array"
                raise SPYValueError(legal=lgl, varname="toi", actual=act)
            # Account for round-off errors: if toi spacing is almost at sample interval
            # manually correct it
            if np.isclose(tSteps.min(), dt):
                tSteps[np.isclose(tSteps, dt)] = dt
            if tSteps.min() < dt:
                msg = f"`toi` selection too fine, max. time resolution is {dt}s"
                SPYWarning(msg)
            # This is imho a bug in NumPy - even `arange` and `linspace` may produce
            # arrays that are numerically not exactly equidistant - `unique` will
            # show several entries here - use `allclose` to identify "even" spacings
            equidistant = np.allclose(tSteps, [tSteps[0]] * tSteps.size)

        # If `toi` was 'all' or a percentage, use entire time interval of (selected)
        # trials and check if those trials have *approximately* equal length
        if toi is None:
            if not np.allclose(lenTrials, [minSampleNum] * lenTrials.size):
                msg = "processing trials of different lengths (min = {}; max = {} samples)" +\
                    " with `toi = 'all'`"
                SPYWarning(msg.format(int(minSampleNum), int(lenTrials.max())))

        # number of samples per window
        nperseg = int(t_ftimwin * data.samplerate)
        halfWin = int(nperseg / 2)
        postSelect = slice(None) # select all is the default

        if 0 <= overlap <= 1: # `toi` is percentage
            noverlap = min(nperseg - 1, int(overlap * nperseg))
        # windows get shifted exactly 1 sample
        # to get a spectral estimate at each sample
        else:
            noverlap = nperseg - 1

        # `toi` is array
        if overlap < 0:
            # Compute necessary padding at begin/end of trials to fit sliding windows
            offStart = ((toi[0] - tStart) * data.samplerate).astype(np.intp)
            padBegin = halfWin - offStart
            padBegin = ((padBegin > 0) * padBegin).astype(np.intp)
            offEnd = ((tEnd - toi[-1]) * data.samplerate).astype(np.intp)
            padEnd = halfWin - offEnd
            padEnd = ((padEnd > 0) * padEnd).astype(np.intp)

            # Compute sample-indices (one slice/list per trial) from time-selections
            soi = []
            if equidistant:
                # soi just trims the input data to the [toi[0], toi[-1]] interval
                # postSelect then subsamples the spectral esimate to the user given toi
                postSelect = []
                for tk in range(numTrials):
                    start = max(0, int(round(data.samplerate * (toi[0] - tStart[tk]) - halfWin)))
                    stop = int(round(data.samplerate * (toi[-1] - tStart[tk]) + halfWin + 1))
                    soi.append(slice(start, max(stop, stop - start)))

                # chosen toi subsampling interval in sample units, min. is 1;
                # compute `delta_idx` s.t. stop - start / delta_idx == toi.size
                delta_idx = int(round((soi[0].stop - soi[0].start) / toi.size))
                delta_idx = delta_idx if delta_idx > 1 else 1
                postSelect = slice(None, None, delta_idx)

            else:
                for tk in range(numTrials):
                    starts = (data.samplerate * (toi - tStart[tk]) - halfWin).astype(np.intp)
                    starts += padBegin[tk]
                    stops = (data.samplerate * (toi - tStart[tk]) + halfWin + 1).astype(np.intp)
                    stops += padBegin[tk]
                    stops = np.maximum(stops, stops - starts, dtype=np.intp)
                    soi.append([slice(start, stop) for start, stop in zip(starts, stops)])
                    # postSelect here remains slice(None), as resulting spectrum
                    # has exactly one entry for each soi

        # `toi` is percentage or "all"
        else:
            soi = [slice(None)] * numTrials


        # Collect keyword args for `mtmconvol` in dictionary
        method_kwargs = {"samplerate": data.samplerate,
                         "nperseg": nperseg,
                         "noverlap": noverlap,
                         "taper" : taper,
                         "taper_opt" : taper_opt}

        # Set up compute-class
        specestMethod = MultiTaperFFTConvol(
            soi,
            postSelect,
            equidistant=equidistant,
            toi=toi,
            foi=foi,
            timeAxis=timeAxis,
            keeptapers=keeptapers,
            polyremoval=polyremoval,
            output_fmt=output,
            method_kwargs=method_kwargs)

    elif method == "wavelet":

        check_effective_parameters(WaveletTransform, defaults, lcls)

        # Check wavelet selection
        if wavelet not in availableWavelets:
            lgl = "'" + "or '".join(opt + "' " for opt in availableWavelets)
            raise SPYValueError(legal=lgl, varname="wavelet", actual=wavelet)
        if wavelet not in ["Morlet", "Paul"]:
            msg = "the chosen wavelet '{}' is real-valued and does not provide " +\
                "any information about amplitude or phase of the data. This wavelet function " +\
                "may be used to isolate peaks or discontinuities in the signal. "
            SPYWarning(msg.format(wavelet))

        # Check for consistency of `width`, `order` and `wavelet`
        if wavelet == "Morlet":
            try:
                scalar_parser(width, varname="width", lims=[1, np.inf])
            except Exception as exc:
                raise exc
            wfun = getattr(spywave, wavelet)(w0=width)
        else:
            if width != lcls["width"]:
                msg = "option `width` has no effect for wavelet '{}'"
                SPYWarning(msg.format(wavelet))

        if wavelet == "Paul":
            try:
                scalar_parser(order, varname="order", lims=[4, np.inf], ntype="int_like")
            except Exception as exc:
                raise exc
            wfun = getattr(spywave, wavelet)(m=order)
        elif wavelet == "DOG":
            try:
                scalar_parser(order, varname="order", lims=[1, np.inf], ntype="int_like")
            except Exception as exc:
                raise exc
            wfun = getattr(spywave, wavelet)(m=order)
        else:
            if order is not None:
                msg = "option `order` has no effect for wavelet '{}'"
                SPYWarning(msg.format(wavelet))
            wfun = getattr(spywave, wavelet)()

        # automatic frequency selection
        if foi is None and foilim is None:
            scales = get_optimal_wavelet_scales(
                wfun.scale_from_period, # all availableWavelets sport one!
                int(minTrialLength * data.samplerate),
                dt)
            foi = 1 / wfun.fourier_period(scales)
            msg = (f"Setting frequencies of interest to {foi[0]:.1f}-"
                   f"{foi[-1]:.1f}Hz")
            SPYInfo(msg)
        else:
            if foilim is not None:
                foi = np.arange(foilim[0], foilim[1] + 1, dtype=float)
            # 0 frequency is not valid
            foi[foi < 0.01] = 0.01
            scales = wfun.scale_from_period(1 / foi)

        # Update `log_dct` w/method-specific options (use `lcls` to get actually
        # provided keyword values, not defaults set in here)
        log_dct["foi"] = foi
        log_dct["wavelet"] = lcls["wavelet"]
        log_dct["width"] = lcls["width"]
        log_dct["order"] = lcls["order"]

        # method specific parameters
        method_kwargs = {
            'samplerate' : data.samplerate,
            'scales' : scales,
            'wavelet' : wfun
        }

        # Set up compute-class
        specestMethod = WaveletTransform(
            preSelect,
            postSelect,
            toi=toi,
            timeAxis=timeAxis,
            polyremoval=polyremoval,
            output_fmt=output,
            method_kwargs=method_kwargs)

    elif method == "superlet":

        check_effective_parameters(SuperletTransform, defaults, lcls)

        # check and parse superlet specific arguments
        if order_max is None:
            lgl = "Positive integer needed for order_max"
            raise SPYValueError(legal=lgl, varname="order_max",
                                actual=None)
        else:
            scalar_parser(
                order_max,
                varname="order_max",
                lims=[1, np.inf],
                ntype="int_like"
            )

        scalar_parser(
            order_min, varname="order_min",
            lims=[1, order_max],
            ntype="int_like"
        )
        scalar_parser(c_1, varname="c_1", lims=[1, np.inf], ntype="int_like")

        # if no frequencies are user selected, take a sensitive default
        if foi is None and foilim is None:
            scales = get_optimal_wavelet_scales(
                superlet.scale_from_period,
                int(minTrialLength * data.samplerate),
                dt)
            foi = 1 / superlet.fourier_period(scales)
            msg = (f"Setting frequencies of interest to {foi[0]:.1f}-"
                   f"{foi[-1]:.1f}Hz")
            SPYInfo(msg)
        else:
            if foilim is not None:
                # frequency range in 1Hz steps
                foi = np.arange(foilim[0], foilim[1] + 1, dtype=float)
            # 0 frequency is not valid
            foi[foi < 0.01] = 0.01
            scales = superlet.scale_from_period(1. / foi)

        # FASLT needs ordered frequencies low - high
        # meaning the scales have to go high - low
        if adaptive:
            if len(scales) < 2:
                lgl = "A range of frequencies"
                act = "Single frequency"
                raise SPYValueError(legal=lgl, varname="foi", actual=act)
            if np.any(np.diff(scales) > 0):
                msg = "Sorting frequencies low to high for adaptive SLT.."
                SPYWarning(msg)
                scales = np.sort(scales)[::-1]

        log_dct["foi"] = foi
        log_dct["c_1"] = lcls["c_1"]
        log_dct["order_max"] = lcls["order_max"]
        log_dct["order_min"] = lcls["order_min"]

        # method specific parameters
        method_kwargs = {
            'samplerate' : data.samplerate,
            'scales' : scales,
            'order_max' : order_max,
            'order_min' : order_min,
            'c_1' : c_1,
            'adaptive' : adaptive
        }

        # Set up compute-class
        specestMethod = SuperletTransform(
            preSelect,
            postSelect,
            toi=toi,
            timeAxis=timeAxis,
            polyremoval=polyremoval,
            output_fmt=output,
            method_kwargs=method_kwargs)

    # -------------------------------------------------
    # Sanitize output and call the ComputationalRoutine
    # -------------------------------------------------

    # If provided, make sure output object is appropriate
    if out is not None:
        try:
            data_parser(out, varname="out", writable=True, empty=True,
                        dataclass="SpectralData",
                        dimord=SpectralData().dimord)
        except Exception as exc:
            raise exc
        new_out = False
    else:
        out = SpectralData(dimord=SpectralData._defaultDimord)
        new_out = True

    # Perform actual computation
    specestMethod.initialize(data,
                             out._stackingDim,
                             chan_per_worker=kwargs.get("chan_per_worker"),
                             keeptrials=keeptrials)
    specestMethod.compute(data, out, parallel=kwargs.get("parallel"), log_dict=log_dct)

    # Either return newly created output object or simply quit
    return out if new_out else None
예제 #5
0
def connectivityanalysis(data,
                         method="coh",
                         keeptrials=False,
                         output="abs",
                         foi=None,
                         foilim=None,
                         pad_to_length=None,
                         polyremoval=None,
                         taper="hann",
                         tapsmofrq=None,
                         nTaper=None,
                         out=None,
                         **kwargs):
    """
    Perform connectivity analysis of Syncopy :class:`~syncopy.AnalogData` objects

    **Usage Summary**

    Options available in all analysis methods:

    * **foi**/**foilim** : frequencies of interest; either array of frequencies or
      frequency window (not both)
    * **polyremoval** : de-trending method to use (0 = mean, 1 = linear or `None`)

    List of available analysis methods and respective distinct options:

    "coh" : (Multi-) tapered coherency estimate
        Compute the normalized cross spectral densities
        between all channel combinations

        * **output** : one of ('abs', 'pow', 'fourier')
        * **taper** : one of :data:`~syncopy.shared.const_def.availableTapers`
        * **tapsmofrq** : spectral smoothing box for slepian tapers (in Hz)
        * **nTaper** : (optional) number of orthogonal tapers for slepian tapers
        * **pad_to_length**: either pad to an absolute length or set to `'nextpow2'`

    "corr" : Cross-correlations
        Computes the one sided (positive lags) cross-correlations
        between all channel combinations. The maximal lag is half
        the trial lengths.

        * **keeptrials** : set to `True` for single trial cross-correlations

    "granger" : Spectral Granger-Geweke causality
        Computes linear causality estimates between
        all channel combinations. The intermediate cross-spectral
        densities can be computed via multi-tapering.

        * **taper** : one of :data:`~syncopy.shared.const_def.availableTapers`
        * **tapsmofrq** : spectral smoothing box for slepian tapers (in Hz)
        * **nTaper** : (optional, not recommended) number of slepian tapers
        * **pad_to_length**: either pad to an absolute length or set to `'nextpow2'`

    Parameters
    ----------
    data : `~syncopy.AnalogData`
        A non-empty Syncopy :class:`~syncopy.datatype.AnalogData` object
    method : str
        Connectivity estimation method, one of 'coh', 'corr', 'granger'
    output : str
        Relevant for cross-spectral density estimation (`method='coh'`)
        Use `'pow'` for absolute squared coherence, `'abs'` for absolute value of coherence
        and`'fourier'` for the complex valued coherency.
    keeptrials : bool
        Relevant for cross-correlations (`method='corr'`).
        If `True` single-trial cross-correlations are returned.
    foi : array-like or None
        Frequencies of interest (Hz) for output. If desired frequencies cannot be
        matched exactly, the closest possible frequencies are used. If `foi` is `None`
        or ``foi = "all"``, all attainable frequencies (i.e., zero to Nyquist / 2)
        are selected.
    foilim : array-like (floats [fmin, fmax]) or None or "all"
        Frequency-window ``[fmin, fmax]`` (in Hz) of interest. The
        `foi` array will be constructed in 1Hz steps from `fmin` to
        `fmax` (inclusive).
    pad_to_length : int, None or 'nextpow2'
        Padding of the (tapered) signal, if set to a number pads all trials
        to this absolute length. E.g. `pad_to_length=2000` pads all
        trials to 2000 samples, if and only if the longest trial is
        at maximum 2000 samples.

        Alternatively if all trials have the same initial lengths
        setting `pad_to_length='nextpow2'` pads all trials to
        the next power of two.
        If `None` and trials have unequal lengths all trials are padded to match
        the longest trial.
    taper : str
        Only valid if `method` is `'coh'` or `'granger'`. Windowing function,
        one of :data:`~syncopy.specest.const_def.availableTapers`
    tapsmofrq : float
        Only valid if `method` is `'coh'` or `'granger'` and `taper` is `'dpss'`.
        The amount of spectral smoothing through  multi-tapering (Hz).
        Note that smoothing frequency specifications are one-sided,
        i.e., 4 Hz smoothing means plus-minus 4 Hz, i.e., a 8 Hz smoothing box.
    nTaper : int or None
        Only valid if `method` is `'coh'` or `'granger'` and ``taper = 'dpss'``.
        Number of orthogonal tapers to use. It is not recommended to set the number
        of tapers manually! Leave at `None` for the optimal number to be set automatically.

    Examples
    --------
    Coming soon...
    """

    # Make sure our one mandatory input object can be processed
    try:
        data_parser(data,
                    varname="data",
                    dataclass="AnalogData",
                    writable=None,
                    empty=False)
    except Exception as exc:
        raise exc
    timeAxis = data.dimord.index("time")

    # Get everything of interest in local namespace
    defaults = get_defaults(connectivityanalysis)
    lcls = locals()
    # check for ineffective additional kwargs
    check_passed_kwargs(lcls, defaults, frontend_name="connectivity")
    # Ensure a valid computational method was selected

    if method not in availableMethods:
        lgl = "'" + "or '".join(opt + "' " for opt in availableMethods)
        raise SPYValueError(legal=lgl, varname="method", actual=method)

    # if a subset selection is present
    # get sampleinfo and check for equidistancy
    if data._selection is not None:
        sinfo = data._selection.trialdefinition[:, :2]
        trialList = data._selection.trials
        # user picked discrete set of time points
        if isinstance(data._selection.time[0], list):
            lgl = "equidistant time points (toi) or time slice (toilim)"
            actual = "non-equidistant set of time points"
            raise SPYValueError(legal=lgl, varname="select", actual=actual)
    else:
        trialList = list(range(len(data.trials)))
        sinfo = data.sampleinfo
    lenTrials = np.diff(sinfo).squeeze()

    # check polyremoval
    if polyremoval is not None:
        scalar_parser(polyremoval,
                      varname="polyremoval",
                      ntype="int_like",
                      lims=[0, 1])

    # --- Padding ---

    if method == "corr" and pad_to_length:
        lgl = "`None`, no padding needed/allowed for cross-correlations"
        actual = f"{pad_to_length}"
        raise SPYValueError(legal=lgl, varname="pad_to_length", actual=actual)

    # the actual number of samples in case of later padding
    nSamples = validate_padding(pad_to_length, lenTrials)

    # --- Basic foi sanitization ---

    foi, foilim = validate_foi(foi, foilim, data.samplerate)

    # only now set foi array for foilim in 1Hz steps
    if foilim is not None:
        foi = np.arange(foilim[0], foilim[1] + 1, dtype=float)

    # Prepare keyword dict for logging (use `lcls` to get actually provided
    # keyword values, not defaults set above)
    log_dict = {
        "method": method,
        "output": output,
        "keeptrials": keeptrials,
        "polyremoval": polyremoval,
        "pad_to_length": pad_to_length
    }

    # --- Setting up specific Methods ---

    if method in ['coh', 'granger']:

        # --- set up computation of the single trial CSDs ---

        if keeptrials is not False:
            lgl = "False, trial averaging needed!"
            act = keeptrials
            raise SPYValueError(lgl, varname="keeptrials", actual=act)

        # Construct array of maximally attainable frequencies
        freqs = np.fft.rfftfreq(nSamples, 1 / data.samplerate)

        # Match desired frequencies as close as possible to
        # actually attainable freqs
        # these are the frequencies attached to the SpectralData by the CR!
        if foi is not None:
            foi, _ = best_match(freqs, foi, squash_duplicates=True)
        elif foilim is not None:
            foi, _ = best_match(freqs,
                                foilim,
                                span=True,
                                squash_duplicates=True)
        elif foi is None and foilim is None:
            # Construct array of maximally attainable frequencies
            msg = (f"Setting frequencies of interest to {freqs[0]:.1f}-"
                   f"{freqs[-1]:.1f}Hz")
            SPYInfo(msg)
            foi = freqs

        # sanitize taper selection and retrieve dpss settings
        taper_opt = validate_taper(
            taper,
            tapsmofrq,
            nTaper,
            keeptapers=False,  # ST_CSD's always average tapers
            foimax=foi.max(),
            samplerate=data.samplerate,
            nSamples=nSamples,
            output="pow")  # ST_CSD's always have this unit/norm

        log_dict["foi"] = foi
        log_dict["taper"] = taper
        # only dpss returns non-empty taper_opt dict
        if taper_opt:
            log_dict["nTaper"] = taper_opt["Kmax"]
            log_dict["tapsmofrq"] = tapsmofrq

        check_effective_parameters(ST_CrossSpectra, defaults, lcls)
        # parallel computation over trials
        st_compRoutine = ST_CrossSpectra(samplerate=data.samplerate,
                                         nSamples=nSamples,
                                         taper=taper,
                                         taper_opt=taper_opt,
                                         polyremoval=polyremoval,
                                         timeAxis=timeAxis,
                                         foi=foi)
        # hard coded as class attribute
        st_dimord = ST_CrossSpectra.dimord

    if method == 'coh':
        # final normalization after trial averaging
        av_compRoutine = NormalizeCrossSpectra(output=output)

    if method == 'granger':
        # after trial averaging
        # hardcoded numerical parameters
        av_compRoutine = GrangerCausality(rtol=1e-8, nIter=100, cond_max=1e4)

    if method == 'corr':
        if lcls['foi'] is not None:
            msg = 'Parameter `foi` has no effect for `corr`'
            SPYWarning(msg)
        check_effective_parameters(ST_CrossCovariance, defaults, lcls)

        # single trial cross-correlations
        if keeptrials:
            av_compRoutine = None  # no trial average
            norm = True  # normalize individual trials within the ST CR
        else:
            av_compRoutine = NormalizeCrossCov()
            norm = False

        # parallel computation over trials
        st_compRoutine = ST_CrossCovariance(samplerate=data.samplerate,
                                            polyremoval=polyremoval,
                                            timeAxis=timeAxis,
                                            norm=norm)
        # hard coded as class attribute
        st_dimord = ST_CrossCovariance.dimord

    # -------------------------------------------------
    # Call the chosen single trial ComputationalRoutine
    # -------------------------------------------------

    # the single trial results need a new DataSet
    st_out = CrossSpectralData(dimord=st_dimord)

    # Perform the trial-parallelized computation of the matrix quantity
    st_compRoutine.initialize(
        data,
        st_out._stackingDim,
        chan_per_worker=None,  # no parallelisation over channels possible
        keeptrials=keeptrials)  # we most likely need trial averaging!
    st_compRoutine.compute(data,
                           st_out,
                           parallel=kwargs.get("parallel"),
                           log_dict=log_dict)

    # if ever needed..
    # for single trial cross-corr results <-> keeptrials is True
    if keeptrials and av_compRoutine is None:
        if out is not None:
            msg = "Single trial processing does not support `out` argument but directly returns the results"
            SPYWarning(msg)
        return st_out

    # ----------------------------------------------------------------------------------
    # Sanitize output and call the chosen ComputationalRoutine on the averaged ST output
    # ----------------------------------------------------------------------------------

    # If provided, make sure output object is appropriate
    if out is not None:
        try:
            data_parser(out,
                        varname="out",
                        writable=True,
                        empty=True,
                        dataclass="CrossSpectralData",
                        dimord=st_dimord)
        except Exception as exc:
            raise exc
        new_out = False
    else:
        out = CrossSpectralData(dimord=st_dimord)
        new_out = True

    # now take the trial average from the single trial CR as input
    av_compRoutine.initialize(st_out, out._stackingDim, chan_per_worker=None)
    av_compRoutine.pre_check()  # make sure we got a trial_average
    av_compRoutine.compute(st_out, out, parallel=False, log_dict=log_dict)

    # Either return newly created output object or simply quit
    return out if new_out else None
예제 #6
0
def cross_spectra_cF(trl_dat,
                     samplerate=1,
                     nSamples=None,
                     foi=None,
                     taper="hann",
                     taper_opt=None,
                     polyremoval=False,
                     timeAxis=0,
                     chunkShape=None,
                     noCompute=False):
    """
    Single trial Fourier cross spectral estimates between all channels
    of the input data. First all the individual Fourier transforms
    are calculated via a (multi-)tapered FFT, then the pairwise
    cross-spectra are computed.

    Averaging over tapers is done implicitly
    for multi-taper analysis with `taper="dpss"`.

    Output consists of all (nChannels x nChannels+1)/2 different complex
    estimates arranged in a symmetric fashion (``CS_ij == CS_ji*``). The
    elements on the main diagonal (`CS_ii`) are the (real) auto-spectra.

    This is NOT the same as what is commonly referred to as
    "cross spectral density" as there is no (time) averaging!!
    Multi-tapering alone is not necessarily sufficient to get enough
    statitstical power for a robust csd estimate. Yet for completeness
    and testing the option `norm=True` will output a single-trial
    coherence estimate.

    Parameters
    ----------
    trl_dat : (K, N) :class:`numpy.ndarray`
        Uniformly sampled multi-channel time-series data
        The 1st dimension is interpreted as the time axis,
        columns represent individual channels.
        Dimensions can be transposed to `(N, K)` with the `timeAxis` parameter.
    samplerate : float
        Samplerate in Hz
    nSamples : int or None
        Absolute length of the (potentially to be padded) signal or
        `None` for no padding
    foi : 1D :class:`numpy.ndarray` or None, optional
        Frequencies of interest  (Hz) for output. If desired frequencies
        cannot be matched exactly the closest possible frequencies (respecting
        data length and padding) are used.
    taper : str or None
        Taper function to use, one of scipy.signal.windows
        Set to `None` for no tapering.
    taper_opt : dict, optional
        Additional keyword arguments passed to the `taper` function.
        For multi-tapering with `taper='dpss'` set the keys
        `'Kmax'` and `'NW'`.
        For further details, please refer to the
        `SciPy docs <https://docs.scipy.org/doc/scipy/reference/signal.windows.html>`_
    polyremoval : int or None
        Order of polynomial used for de-trending data in the time domain prior
        to spectral analysis. A value of 0 corresponds to subtracting the mean
        ("de-meaning"), ``polyremoval = 1`` removes linear trends (subtracting the
        least squares fit of a linear polynomial).
        If `polyremoval` is `None`, no de-trending is performed.
    timeAxis : int, optional
        Index of running time axis in `trl_dat` (0 or 1)
    noCompute : bool
        Preprocessing flag. If `True`, do not perform actual calculation but
        instead return expected shape and :class:`numpy.dtype` of output
        array.

    Returns
    -------
    CS_ij : (1, nFreq, N, N) :class:`numpy.ndarray`
        Complex cross spectra for all channel combinations ``i,j``.
        `N` corresponds to number of input channels.

    Notes
    -----
    This method is intended to be used as
    :meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction`
    inside a :class:`~syncopy.shared.computational_routine.ComputationalRoutine`.
    Thus, input parameters are presumed to be forwarded from a parent metafunction.
    Consequently, this function does **not** perform any error checking and operates
    under the assumption that all inputs have been externally validated and cross-checked.

    See also
    --------
    csd : :func:`~syncopy.connectivity.csd.csd`
             Cross-spectra backend function
    normalize_csd : :func:`~syncopy.connectivity.csd.normalize_csd`
             Coherence from trial averages
    mtmfft : :func:`~syncopy.specest.mtmfft.mtmfft`
             (Multi-)tapered Fourier analysis

    """

    # Re-arrange array if necessary and get dimensional information
    if timeAxis != 0:
        dat = trl_dat.T  # does not copy but creates view of `trl_dat`
    else:
        dat = trl_dat

    if nSamples is None:
        nSamples = dat.shape[0]

    nChannels = dat.shape[1]

    freqs = np.fft.rfftfreq(nSamples, 1 / samplerate)

    if foi is not None:
        _, freq_idx = best_match(freqs, foi, squash_duplicates=True)
        nFreq = freq_idx.size
    else:
        freq_idx = slice(None)
        nFreq = freqs.size

    # we always average over tapers here
    outShape = (1, nFreq, nChannels, nChannels)

    # For initialization of computational routine,
    # just return output shape and dtype
    # cross spectra are complex!
    if noCompute:
        return outShape, spectralDTypes["fourier"]

    # detrend
    if polyremoval == 0:
        # SciPy's overwrite_data not working for type='constant' :/
        dat = detrend(dat, type='constant', axis=0, overwrite_data=True)
    elif polyremoval == 1:
        dat = detrend(dat, type='linear', axis=0, overwrite_data=True)

    CS_ij = csd(dat, samplerate, nSamples, taper=taper, taper_opt=taper_opt)

    # where does freqs go/come from -
    # we will eventually solve this issue..
    return CS_ij[None, freq_idx, ...]
예제 #7
0
    def _get_time(self, trials, toi=None, toilim=None):
        """
        Get relative by-trial indices of time-selections

        Parameters
        ----------
        trials : list
            List of trial-indices to perform selection on
        toi : None or list
            Time-points to be selected (in seconds) on a by-trial scale.
        toilim : None or list
            Time-window to be selected (in seconds) on a by-trial scale

        Returns
        -------
        timing : list of lists
            List of by-trial sample-indices corresponding to provided
            time-selection. If both `toi` and `toilim` are `None`, `timing`
            is a list of universal (i.e., ``slice(None)``) selectors.

        Notes
        -----
        This class method is intended to be solely used by
        :class:`syncopy.datatype.base_data.Selector` objects and thus has purely
        auxiliary character. Therefore, all input sanitization and error checking
        is left to :class:`syncopy.datatype.base_data.Selector` and not
        performed here.

        See also
        --------
        syncopy.datatype.base_data.Selector : Syncopy data selectors
        """
        timing = []
        if toilim is not None:
            allTrials = self.trialtime
            for trlno in trials:
                thisTrial = self.data[self.trialid == trlno, self.dimord.index("sample")]
                trlSample = np.arange(*self.sampleinfo[trlno, :])
                trlTime = np.array(list(allTrials[np.where(self.trialid == trlno)[0][0]]))
                minSample = trlSample[np.where(trlTime >= toilim[0])[0][0]]
                maxSample = trlSample[np.where(trlTime <= toilim[1])[0][-1]]
                selSample, _ = best_match(trlSample, [minSample, maxSample], span=True)
                idxList = []
                for smp in selSample:
                    idxList += list(np.where(thisTrial == smp)[0])
                if len(idxList) > 1:
                    sampSteps = np.diff(idxList)
                    if sampSteps.min() == sampSteps.max() == 1:
                        idxList = slice(idxList[0], idxList[-1] + 1, 1)
                timing.append(idxList)

        elif toi is not None:
            allTrials = self.trialtime
            for trlno in trials:
                thisTrial = self.data[self.trialid == trlno, self.dimord.index("sample")]
                trlSample = np.arange(*self.sampleinfo[trlno, :])
                trlTime = np.array(list(allTrials[np.where(self.trialid == trlno)[0][0]]))
                _, selSample = best_match(trlTime, toi)
                for k, idx in enumerate(selSample):
                    if np.abs(trlTime[idx - 1] - toi[k]) < np.abs(trlTime[idx] - toi[k]):
                        selSample[k] = trlSample[idx -1]
                    else:
                        selSample[k] = trlSample[idx]
                idxList = []
                for smp in selSample:
                    idxList += list(np.where(thisTrial == smp)[0])
                if len(idxList) > 1:
                    sampSteps = np.diff(idxList)
                    if sampSteps.min() == sampSteps.max() == 1:
                        idxList = slice(idxList[0], idxList[-1] + 1, 1)
                timing.append(idxList)

        else:
            timing = [slice(None)] * len(trials)

        return timing
예제 #8
0
def mtmfft(trl_dat, samplerate=None, foi=None, nTaper=1, timeAxis=0,
           taper=spwin.hann, taperopt={}, 
           pad="nextpow2", padtype="zero", padlength=None,
           keeptapers=True, polyremoval=None, output_fmt="pow",
           noCompute=False, chunkShape=None):
    """
    Compute (multi-)tapered Fourier transform of multi-channel time series data
    
    Parameters
    ----------
    trl_dat : 2D :class:`numpy.ndarray`
        Uniformly sampled multi-channel time-series 
    samplerate : float
        Samplerate of `trl_dat` in Hz
    foi : 1D :class:`numpy.ndarray`
        Frequencies of interest  (Hz) for output. If desired frequencies
        cannot be matched exactly the closest possible frequencies (respecting 
        data length and padding) are used.
    nTaper : int
        Number of filter windows to use
    timeAxis : int
        Index of running time axis in `trl_dat` (0 or 1)
    taper : callable 
        Taper function to use, one of :data:`~syncopy.specest.freqanalysis.availableTapers`
    taperopt : dict
        Additional keyword arguments passed to the `taper` function. For further 
        details, please refer to the 
        `SciPy docs <https://docs.scipy.org/doc/scipy/reference/signal.windows.html>`_
    pad : str
        Padding mode; one of `'absolute'`, `'relative'`, `'maxlen'`, or `'nextpow2'`.
        See :func:`syncopy.padding` for more information.
    padtype : str
        Values to be used for padding. Can be 'zero', 'nan', 'mean', 
        'localmean', 'edge' or 'mirror'. See :func:`syncopy.padding` for 
        more information.
    padlength : None, bool or positive scalar
        Number of samples to pad to data (if `pad` is 'absolute' or 'relative'). 
        See :func:`syncopy.padding` for more information.
    keeptapers : bool
        If `True`, results of Fourier transform are preserved for each taper, 
        otherwise spectrum is averaged across tapers. 
    polyremoval : int or None
        **FIXME: Not implemented yet**
        Order of polynomial used for de-trending data in the time domain prior 
        to spectral analysis. A value of 0 corresponds to subtracting the mean 
        ("de-meaning"), ``polyremoval = 1`` removes linear trends (subtracting the 
        least squares fit of a linear polynomial), ``polyremoval = N`` for `N > 1` 
        subtracts a polynomial of order `N` (``N = 2`` quadratic, ``N = 3`` cubic 
        etc.). If `polyremoval` is `None`, no de-trending is performed. 
    output_fmt : str
        Output of spectral estimation; one of :data:`~syncopy.specest.freqanalysis.availableOutputs`
    noCompute : bool
        Preprocessing flag. If `True`, do not perform actual calculation but
        instead return expected shape and :class:`numpy.dtype` of output
        array.
    chunkShape : None or tuple
        If not `None`, represents shape of output `spec` (respecting provided 
        values of `nTaper`, `keeptapers` etc.)
        
    Returns
    -------
    spec : :class:`numpy.ndarray`
        Complex or real spectrum of (padded) input data. 

    Notes
    -----
    This method is intended to be used as 
    :meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction`
    inside a :class:`~syncopy.shared.computational_routine.ComputationalRoutine`. 
    Thus, input parameters are presumed to be forwarded from a parent metafunction. 
    Consequently, this function does **not** perform any error checking and operates 
    under the assumption that all inputs have been externally validated and cross-checked. 
    
    The computational heavy lifting in this code is performed by NumPy's reference
    implementation of the Fast Fourier Transform :func:`numpy.fft.fft`. 
    
    See also
    --------
    syncopy.freqanalysis : parent metafunction
    MultiTaperFFT : :class:`~syncopy.shared.computational_routine.ComputationalRoutine`
                    instance that calls this method as 
                    :meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction`
    numpy.fft.fft : NumPy's FFT implementation
    """
    
    # Re-arrange array if necessary and get dimensional information
    if timeAxis != 0:
        dat = trl_dat.T       # does not copy but creates view of `trl_dat`
    else:
        dat = trl_dat

    # Padding (updates no. of samples)
    if pad:
        dat = padding(dat, padtype, pad=pad, padlength=padlength, prepadlength=True)
    nSamples = dat.shape[0]
    nChannels = dat.shape[1]
    
    # Determine frequency band and shape of output (time=1 x taper x freq x channel)
    nFreq = int(np.floor(nSamples / 2) + 1)
    freqs = np.linspace(0, samplerate / 2, nFreq)
    _, fidx = best_match(freqs, foi, squash_duplicates=True)
    nFreq = fidx.size
    outShape = (1, max(1, nTaper * keeptapers), nFreq, nChannels)
    
    # For initialization of computational routine, just return output shape and dtype
    if noCompute:
        return outShape, freq.spectralDTypes[output_fmt]

    # In case tapers aren't preserved allocate `spec` "too big" and average afterwards
    spec = np.full((1, nTaper, nFreq, nChannels), np.nan, dtype=freq.spectralDTypes[output_fmt])
    fill_idx = tuple([slice(None, dim) for dim in outShape[2:]])

    # Actual computation
    win = np.atleast_2d(taper(nSamples, **taperopt))
    for taperIdx, taper in enumerate(win):
        if dat.ndim > 1:
            taper = np.tile(taper, (nChannels, 1)).T
        spec[(0, taperIdx,) + fill_idx] = freq.spectralConversions[output_fmt](np.fft.rfft(dat * taper, axis=0)[fidx, :])

    # Average across tapers if wanted
    if not keeptapers:
        return spec.mean(axis=1, keepdims=True)
    return spec
예제 #9
0
def freqanalysis(data,
                 method='mtmfft',
                 output='fourier',
                 keeptrials=True,
                 foi=None,
                 foilim=None,
                 pad=None,
                 padtype='zero',
                 padlength=None,
                 prepadlength=None,
                 postpadlength=None,
                 polyremoval=None,
                 taper="hann",
                 tapsmofrq=None,
                 keeptapers=False,
                 toi=None,
                 t_ftimwin=None,
                 wav="Morlet",
                 width=6,
                 order=None,
                 out=None,
                 **kwargs):
    """
    Perform (time-)frequency analysis of Syncopy :class:`~syncopy.AnalogData` objects
    
    **Usage Summary**
    
    Options available in all analysis methods:
    
    * **output** : one of :data:`~.availableOutputs`; return power spectra, complex 
      Fourier spectra or absolute values. 
    * **foi**/**foilim** : frequencies of interest; either array of frequencies or 
      frequency window (not both)
    * **keeptrials** : return individual trials or grand average
    * **polyremoval** : de-trending method to use (0 = mean, 1 = linear, 2 = quadratic, 
      3 = cubic, etc.)
            
    List of available analysis methods and respective distinct options:
    
    :func:`~syncopy.specest.mtmfft.mtmfft` : (Multi-)tapered Fourier transform
        Perform frequency analysis on time-series trial data using either a single 
        taper window (Hanning) or many tapers based on the discrete prolate 
        spheroidal sequence (DPSS) that maximize energy concentration in the main
        lobe. 
        
        * **taper** : one of :data:`~.availableTapers`
        * **tapsmofrq** : spectral smoothing box for tapers (in Hz)
        * **keeptapers** : return individual tapers or average
        * **pad** : padding method to use (`None`, `True`, `False`, `'absolute'`, 
          `'relative'`, `'maxlen'` or `'nextpow2'`). If `None`, then `'nextpow2'`
          is selected by default. 
        * **padtype** : values to pad data with (`'zero'`, `'nan'`, `'mean'`, `'localmean'`, 
          `'edge'` or `'mirror'`)
        * **padlength** : number of samples to pre-pend and/or append to each trial 
        * **prepadlength** : number of samples to pre-pend to each trial 
        * **postpadlength** : number of samples to append to each trial 

    :func:`~syncopy.specest.mtmconvol.mtmconvol` : (Multi-)tapered sliding window Fourier transform
        Perform time-frequency analysis on time-series trial data based on a sliding 
        window short-time Fourier transform using either a single Hanning taper or 
        multiple DPSS tapers. 
        
        * **taper** : one of :data:`~.availableTapers`
        * **tapsmofrq** : spectral smoothing box for tapers (in Hz)
        * **keeptapers** : return individual tapers or average
        * **pad** : flag indicating, whether or not to pad trials. If `None`, 
          trials are padded only if sliding window centroids are too close
          to trial boundaries for the entire window to cover available data-points. 
        * **toi** : time-points of interest; can be either an array representing 
          analysis window centroids (in sec), a scalar between 0 and 1 encoding 
          the percentage of overlap between adjacent windows or "all" to center 
          a window on every sample in the data. 
        * **t_ftimwin** : sliding window length (in sec)

    :func:`~syncopy.specest.wavelet.wavelet` : (Continuous non-orthogonal) wavelet transform
        Perform time-frequency analysis on time-series trial data using a non-orthogonal
        continuous wavelet transform. 
        
        * **wav** : one of :data:`~.availableWavelets`
        * **toi** : time-points of interest; can be either an array representing 
          time points (in sec) to center wavelets on or "all" to center a wavelet 
          on every sample in the data. 
        * **width** : Nondimensional frequency constant of Morlet wavelet function (>= 6)
        * **order** : Order of Paul wavelet function (>= 4) or derivative order
          of real-valued DOG wavelets (2 = mexican hat)

    **Full documentation below** 
    
    Parameters
    ----------
    data : `~syncopy.AnalogData`
        A non-empty Syncopy :class:`~syncopy.datatype.AnalogData` object
    method : str
        Spectral estimation method, one of :data:`~.availableMethods` 
        (see below).
    output : str
        Output of spectral estimation. One of :data:`~.availableOutputs` (see below); 
        use `'pow'` for power spectrum (:obj:`numpy.float32`), `'fourier'` for complex 
        Fourier coefficients (:obj:`numpy.complex128`) or `'abs'` for absolute 
        values (:obj:`numpy.float32`).
    keeptrials : bool
        If `True` spectral estimates of individual trials are returned, otherwise
        results are averaged across trials. 
    foi : array-like or None
        Frequencies of interest (Hz) for output. If desired frequencies cannot be 
        matched exactly, the closest possible frequencies are used. If `foi` is `None`
        or ``foi = "all"``, all attainable frequencies (i.e., zero to Nyquist / 2) 
        are selected. 
    foilim : array-like (floats [fmin, fmax]) or None or "all"
        Frequency-window ``[fmin, fmax]`` (in Hz) of interest. Window 
        specifications must be sorted (e.g., ``[90, 70]`` is invalid) and not NaN 
        but may be unbounded (e.g., ``[-np.inf, 60.5]`` is valid). Edges `fmin` 
        and `fmax` are included in the selection. If `foilim` is `None` or 
        ``foilim = "all"``, all frequencies are selected. 
    pad : str or None or bool
        One of `None`, `True`, `False`, `'absolute'`, `'relative'`, `'maxlen'` or
        `'nextpow2'`. 
        If `pad` is `None` or ``pad = True``, then method-specific defaults are 
        chosen. Specifically, if `method` is `'mtmfft'` then `pad` is set to 
        `'nextpow2'` so that all trials in `data` are padded to the next power of 
        two higher than the sample-count of the longest (selected) trial in `data`. Conversely, 
        time-frequency analysis methods (`'mtmconvol'` and `'wavelet'`), only perform
        padding if necessary, i.e., if time-window centroids are chosen too close
        to trial boundaries for the entire window to cover available data-points. 
        If `pad` is `False`, then no padding is performed. Then in case of 
        ``method = 'mtmfft'`` all trials have to have approximately the same 
        length (up to the next even sample-count), if ``method = 'mtmconvol'`` or 
        ``method = 'wavelet'``, window-centroids have to keep sufficient
        distance from trial boundaries. For more details on the padding methods 
        `'absolute'`, `'relative'`, `'maxlen'` and `'nextpow2'` see :func:`syncopy.padding`. 
    padtype : str
        Values to be used for padding. Can be `'zero'`, `'nan'`, `'mean'`, 
        `'localmean'`, `'edge'` or `'mirror'`. See :func:`syncopy.padding` for 
        more information.
    padlength : None, bool or positive int
        Only valid if `method` is `'mtmfft'` and `pad` is `'absolute'` or `'relative'`. 
        Number of samples to pad data with. See :func:`syncopy.padding` for more 
        information.
    prepadlength : None or bool or int
        Only valid if `method` is `'mtmfft'` and `pad` is `'relative'`. Number of 
        samples to pre-pend to each trial. See :func:`syncopy.padding` for more 
        information.
    postpadlength : None or bool or int
        Only valid if `method` is `'mtmfft'` and `pad` is `'relative'`. Number of 
        samples to append to each trial. See :func:`syncopy.padding` for more 
        information.
    polyremoval : int or None
        **FIXME: Not implemented yet**
        Order of polynomial used for de-trending data in the time domain prior 
        to spectral analysis. A value of 0 corresponds to subtracting the mean 
        ("de-meaning"), ``polyremoval = 1`` removes linear trends (subtracting the 
        least squares fit of a linear polynomial), ``polyremoval = N`` for `N > 1` 
        subtracts a polynomial of order `N` (``N = 2`` quadratic, ``N = 3`` cubic 
        etc.). If `polyremoval` is `None`, no de-trending is performed. 
    taper : str
        Only valid if `method` is `'mtmfft'` or `'mtmconvol'`. Windowing function, 
        one of :data:`~.availableTapers` (see below).
    tapsmofrq : float
        Only valid if `method` is `'mtmfft'` or `'mtmconvol'`. The amount of spectral 
        smoothing through  multi-tapering (Hz). Note that smoothing frequency 
        specifications are one-sided, i.e., 4 Hz smoothing means plus-minus 4 Hz, 
        i.e., a 8 Hz smoothing box.
    keeptapers : bool
        Only valid if `method` is `'mtmfft'` or `'mtmconvol'`. If `True`, return 
        spectral estimates for each taper, otherwise results are averaged across
        tapers. 
    toi : float or array-like or "all"
        **Mandatory input** for time-frequency analysis methods (`method` is either 
        `"mtmconvol"` or `"wavelet"`). 
        If `toi` is scalar, it must be a value between 0 and 1 indicating the 
        percentage of overlap between time-windows specified by `t_ftimwin` (only
        valid if `method` is `'mtmconvol'`, invalid for `'wavelet'`). 
        If `toi` is an array it explicitly selects the centroids of analysis 
        windows (in seconds). If `toi` is `"all"`, analysis windows are centered
        on all samples in the data. 
    t_ftimwin : positive float
        Only valid if `method` is `'mtmconvol'`. Sliding window length (in seconds). 
    wav : str
        Only valid if `method` is `'wavelet'`. Wavelet function to use, one of 
        :data:`~.availableWavelets` (see below).
    width : positive float
        Only valid if `method` is `'wavelet'` and `wav` is `'Morlet'`. Nondimensional 
        frequency constant of Morlet wavelet function. This number should be >= 6, 
        which corresponds to 6 cycles within the analysis window to ensure sufficient 
        spectral sampling. 
    order : positive int
        Only valid if `method` is `'wavelet'` and `wav` is `'Paul'` or `'DOG'`. Order 
        of the wavelet function. If `wav` is `'Paul'`, `order` should be chosen
        >= 4 to ensure that the analysis window contains at least a single oscillation. 
        At an order of 40, the Paul wavelet  exhibits about the same number of cycles 
        as the Morlet wavelet with a `width` of 6. 
        All other supported wavelets functions are *real-valued* derivatives of 
        Gaussians (DOGs). Hence, if `wav` is `'DOG'`, `order` represents the derivative order. 
        The special case of a second order DOG yields a function known as "Mexican Hat", 
        "Marr" or "Ricker" wavelet, which can be selected alternatively by setting
        `wav` to `'Mexican_hat'`, `'Marr'` or `'Ricker'`. **Note**: A real-valued
        wavelet function encodes *only* information about peaks and discontinuities 
        in the signal and does *not* provide any information about amplitude or phase. 
    out : None or :class:`SpectralData` object
        None if a new :class:`SpectralData` object is to be created, or an empty :class:`SpectralData` object
        

    Returns
    -------
    spec : :class:`~syncopy.SpectralData`
        (Time-)frequency spectrum of input data
        
    Notes
    -----
    Coming soon...
    
    Examples
    --------
    Coming soon...
        

    .. autodata:: syncopy.specest.freqanalysis.availableMethods

    .. autodata:: syncopy.specest.freqanalysis.availableOutputs

    .. autodata:: syncopy.specest.freqanalysis.availableTapers

    .. autodata:: syncopy.specest.freqanalysis.availableWavelets
    
    See also
    --------
    syncopy.specest.mtmfft.mtmfft : (multi-)tapered Fourier transform of multi-channel time series data
    syncopy.specest.mtmconvol.mtmconvol : time-frequency analysis of multi-channel time series data with a sliding window FFT
    syncopy.specest.wavelet.wavelet : time-frequency analysis of multi-channel time series data using a wavelet transform
    numpy.fft.fft : NumPy's reference FFT implementation
    scipy.signal.stft : SciPy's Short Time Fourier Transform
    """

    # Make sure our one mandatory input object can be processed
    try:
        data_parser(data,
                    varname="data",
                    dataclass="AnalogData",
                    writable=None,
                    empty=False)
    except Exception as exc:
        raise exc
    timeAxis = data.dimord.index("time")

    # Get everything of interest in local namespace
    defaults = get_defaults(freqanalysis)
    lcls = locals()

    # Ensure a valid computational method was selected
    if method not in availableMethods:
        lgl = "'" + "or '".join(opt + "' " for opt in availableMethods)
        raise SPYValueError(legal=lgl, varname="method", actual=method)

    # Ensure a valid output format was selected
    if output not in spectralConversions.keys():
        lgl = "'" + "or '".join(opt + "' "
                                for opt in spectralConversions.keys())
        raise SPYValueError(legal=lgl, varname="output", actual=output)

    # Parse all Boolean keyword arguments
    for vname in ["keeptrials", "keeptapers"]:
        if not isinstance(lcls[vname], bool):
            raise SPYTypeError(lcls[vname], varname=vname, expected="Bool")

    # If only a subset of `data` is to be processed, make some necessary adjustments
    # and compute minimal sample-count across (selected) trials
    if data._selection is not None:
        trialList = data._selection.trials
        sinfo = np.zeros((len(trialList), 2))
        for tk, trlno in enumerate(trialList):
            trl = data._preview_trial(trlno)
            tsel = trl.idx[timeAxis]
            if isinstance(tsel, list):
                sinfo[tk, :] = [0, len(tsel)]
            else:
                sinfo[tk, :] = [
                    trl.idx[timeAxis].start, trl.idx[timeAxis].stop
                ]
    else:
        trialList = list(range(len(data.trials)))
        sinfo = data.sampleinfo
    lenTrials = np.diff(sinfo).squeeze()
    numTrials = len(trialList)

    # Set default padding options: after this, `pad` is either `None`, `False` or `str`
    defaultPadding = {"mtmfft": "nextpow2", "mtmconvol": None, "wavelet": None}
    if pad is None or pad is True:
        pad = defaultPadding[method]

    # Sliding window FFT does not support "fancy" padding
    if method == "mtmconvol" and isinstance(pad, str):
        msg = "method 'mtmconvol' only supports in-place padding for windows " +\
            "exceeding trial boundaries. Your choice of `pad = '{}'` will be ignored. "
        SPYWarning(msg.format(pad))
        pad = None

    # Ensure padding selection makes sense: do not pad on a by-trial basis but
    # use the longest trial as reference and compute `padlength` from there
    # (only relevant for "global" padding options such as `maxlen` or `nextpow2`)
    if pad:
        if not isinstance(pad, str):
            raise SPYTypeError(pad, varname="pad", expected="str or None")
        if pad == "maxlen":
            padlength = lenTrials.max()
            prepadlength = True
            postpadlength = False
        elif pad == "nextpow2":
            padlength = 0
            for ltrl in lenTrials:
                padlength = max(padlength, _nextpow2(ltrl))
            pad = "absolute"
            prepadlength = True
            postpadlength = False
        padding(data._preview_trial(trialList[0]),
                padtype,
                pad=pad,
                padlength=padlength,
                prepadlength=prepadlength,
                postpadlength=postpadlength)

        # Compute `minSampleNum` accounting for padding
        minSamplePos = lenTrials.argmin()
        minSampleNum = padding(data._preview_trial(trialList[minSamplePos]),
                               padtype,
                               pad=pad,
                               padlength=padlength,
                               prepadlength=True).shape[timeAxis]
    else:
        if method == "mtmfft" and np.unique(
            (np.floor(lenTrials / 2))).size > 1:
            lgl = "trials of approximately equal length for method 'mtmfft'"
            act = "trials of unequal length"
            raise SPYValueError(legal=lgl, varname="data", actual=act)
        minSampleNum = lenTrials.min()

    # Compute length (in samples) of shortest trial
    minTrialLength = minSampleNum / data.samplerate

    # Basic sanitization of frequency specifications
    if foi is not None:
        if isinstance(foi, str):
            if foi == "all":
                foi = None
            else:
                raise SPYValueError(legal="'all' or `None` or list/array",
                                    varname="foi",
                                    actual=foi)
        else:
            try:
                array_parser(foi,
                             varname="foi",
                             hasinf=False,
                             hasnan=False,
                             lims=[0, data.samplerate / 2],
                             dims=(None, ))
            except Exception as exc:
                raise exc
            foi = np.array(foi, dtype="float")
    if foilim is not None:
        if isinstance(foilim, str):
            if foilim == "all":
                foilim = None
            else:
                raise SPYValueError(legal="'all' or `None` or `[fmin, fmax]`",
                                    varname="foilim",
                                    actual=foilim)
        else:
            try:
                array_parser(foilim,
                             varname="foilim",
                             hasinf=False,
                             hasnan=False,
                             lims=[0, data.samplerate / 2],
                             dims=(2, ))
            except Exception as exc:
                raise exc
    if foi is not None and foilim is not None:
        lgl = "either `foi` or `foilim` specification"
        act = "both"
        raise SPYValueError(legal=lgl, varname="foi/foilim", actual=act)

    # FIXME: implement detrending
    # see also https://docs.obspy.org/_modules/obspy/signal/detrend.html#polynomial
    if polyremoval is not None:
        raise NotImplementedError("Detrending has not been implemented yet.")
        try:
            scalar_parser(polyremoval,
                          varname="polyremoval",
                          lims=[0, 8],
                          ntype="int_like")
        except Exception as exc:
            raise exc

    # Prepare keyword dict for logging (use `lcls` to get actually provided
    # keyword values, not defaults set above)
    log_dct = {
        "method": method,
        "output": output,
        "keeptapers": keeptapers,
        "keeptrials": keeptrials,
        "polyremoval": polyremoval,
        "pad": lcls["pad"],
        "padtype": lcls["padtype"],
        "padlength": lcls["padlength"],
        "foi": lcls["foi"]
    }

    # 1st: Check time-frequency inputs to prepare/sanitize `toi`
    if method in ["mtmconvol", "wavelet"]:

        # Get start/end timing info respecting potential in-place selection
        if toi is None:
            raise SPYTypeError(toi,
                               varname="toi",
                               expected="scalar or array-like or 'all'")
        if data._selection is not None:
            tStart = data._selection.trialdefinition[:, 2] / data.samplerate
        else:
            tStart = data._t0 / data.samplerate
        tEnd = tStart + lenTrials / data.samplerate

        # Process `toi`: we have to account for three scenarios: (1) center sliding
        # windows on all samples in (selected) trials (2) `toi` was provided as
        # percentage indicating the degree of overlap b/w time-windows and (3) a set
        # of discrete time points was provided. These three cases are encoded in
        # `overlap, i.e., ``overlap > 1` => all, `0 < overlap < 1` => percentage,
        # `overlap < 0` => discrete `toi`
        if isinstance(toi, str):
            if toi != "all":
                lgl = "`toi = 'all'` to center analysis windows on all time-points"
                raise SPYValueError(legal=lgl, varname="toi", actual=toi)
            overlap = 1.1
            toi = None
            equidistant = True
        elif isinstance(toi, Number):
            if method == "wavelet":
                lgl = "array of time-points wavelets are to be centered on"
                act = "scalar value"
                raise SPYValueError(legal=lgl, varname="toi", actual=act)
            try:
                scalar_parser(toi, varname="toi", lims=[0, 1])
            except Exception as exc:
                raise exc
            overlap = toi
            equidistant = True
        else:
            overlap = -1
            try:
                array_parser(toi,
                             varname="toi",
                             hasinf=False,
                             hasnan=False,
                             lims=[tStart.min(), tEnd.max()],
                             dims=(None, ))
            except Exception as exc:
                raise exc
            toi = np.array(toi)
            tSteps = np.diff(toi)
            if (tSteps < 0).any():
                lgl = "ordered list/array of time-points"
                act = "unsorted list/array"
                raise SPYValueError(legal=lgl, varname="toi", actual=act)
            # This is imho a bug in NumPy - even `arange` and `linspace` may produce
            # arrays that are numerically not exactly equidistant - `unique` will
            # show several entries here - use `allclose` to identify "even" spacings
            equidistant = np.allclose(tSteps, [tSteps[0]] * tSteps.size)

        # If `toi` was 'all' or a percentage, use entire time interval of (selected)
        # trials and check if those trials have *approximately* equal length
        if toi is None:
            if not np.allclose(lenTrials, [minSampleNum] * lenTrials.size):
                msg = "processing trials of different lengths (min = {}; max = {} samples)" +\
                    " with `toi = 'all'`"
                SPYWarning(msg.format(int(minSampleNum), int(lenTrials.max())))
            if pad is False:
                lgl = "`pad` to be `None` or `True` to permit zero-padding " +\
                    "at trial boundaries to accommodate windows if `0 < toi < 1` " +\
                    "or if `toi` is 'all'"
                act = "False"
                raise SPYValueError(legal=lgl, actual=act, varname="pad")

        # Code recycling: `overlap`, `equidistant` etc. are really only relevant
        # for `mtmconvol`, but we use padding calc below for `wavelet` as well
        if method == "mtmconvol":
            try:
                scalar_parser(t_ftimwin,
                              varname="t_ftimwin",
                              lims=[1 / data.samplerate, minTrialLength])
            except Exception as exc:
                raise exc
        else:
            t_ftimwin = 0
        nperseg = int(t_ftimwin * data.samplerate)
        minSampleNum = nperseg
        halfWin = int(nperseg / 2)

        # `mtmconvol`: compute no. of samples overlapping across adjacent windows
        if overlap < 0:  # `toi` is equidistant range or disjoint points
            noverlap = nperseg - max(1, int(tSteps[0] * data.samplerate))
        elif 0 <= overlap <= 1:  # `toi` is percentage
            noverlap = min(nperseg - 1, int(overlap * nperseg))
        else:  # `toi` is "all"
            noverlap = nperseg - 1

        # `toi` is array
        if overlap < 0:

            # Compute necessary padding at begin/end of trials to fit sliding windows
            offStart = ((toi[0] - tStart) * data.samplerate).astype(np.intp)
            padBegin = halfWin - offStart
            padBegin = ((padBegin > 0) * padBegin).astype(np.intp)

            offEnd = ((tEnd - toi[-1]) * data.samplerate).astype(np.intp)
            padEnd = halfWin - offEnd
            padEnd = ((padEnd > 0) * padEnd).astype(np.intp)

            # Abort if padding was explicitly forbidden
            if pad is False and (np.any(padBegin) or np.any(padBegin)):
                lgl = "windows within trial bounds"
                act = "windows exceeding trials no. " +\
                    "".join(str(trlno) + ", "\
                        for trlno in np.array(trialList)[(padBegin + padEnd) > 0])[:-2]
                raise SPYValueError(legal=lgl, varname="pad", actual=act)

            # Compute sample-indices (one slice/list per trial) from time-selections
            soi = []
            if not equidistant:
                for tk in range(numTrials):
                    starts = (data.samplerate * (toi - tStart[tk]) -
                              halfWin).astype(np.intp)
                    starts += padBegin[tk]
                    stops = (data.samplerate * (toi - tStart[tk]) + halfWin +
                             1).astype(np.intp)
                    stops += padBegin[tk]
                    stops = np.maximum(stops, stops - starts, dtype=np.intp)
                    soi.append([
                        slice(start, stop)
                        for start, stop in zip(starts, stops)
                    ])
            else:
                for tk in range(numTrials):
                    start = int(data.samplerate * (toi[0] - tStart[tk]) -
                                halfWin)
                    stop = int(data.samplerate * (toi[-1] - tStart[tk]) +
                               halfWin + 1)
                    soi.append(slice(max(0, start), max(stop, stop - start)))

        # `toi` is percentage or "all"
        else:

            padBegin = np.zeros((numTrials, ))
            padEnd = np.zeros((numTrials, ))
            soi = [slice(None)] * numTrials

        # For wavelets, we need to first trim the data (via `preSelect`), then
        # extract the wanted time-points (`postSelect`)
        if method == "wavelet":

            # Simply recycle the indexing work done for `mtmconvol` (i.e., `soi`)
            preSelect = []
            if not equidistant:
                for tk in range(numTrials):
                    preSelect.append(slice(soi[tk][0].start, soi[tk][-1].stop))
            else:
                preSelect = soi

            # If `toi` is an array, convert "global" indices to "local" ones
            # (select within `preSelect`'s selection), otherwise just take all
            if overlap < 0:
                postSelect = []
                for tk in range(numTrials):
                    smpIdx = np.minimum(
                        lenTrials[tk] - 1,
                        data.samplerate * (toi - tStart[tk]) - offStart[tk] +
                        padBegin[tk])
                    postSelect.append(smpIdx.astype(np.intp))
            else:
                postSelect = [slice(None)] * numTrials

        # Update `log_dct` w/method-specific options (use `lcls` to get actually
        # provided keyword values, not defaults set in here)
        if toi is None:
            toi = "all"
        log_dct["toi"] = lcls["toi"]

    # Check options specific to mtm*-methods (particularly tapers and foi/freqs alignment)
    if "mtm" in method:

        # See if taper choice is supported
        if taper not in availableTapers:
            lgl = "'" + "or '".join(opt + "' " for opt in availableTapers)
            raise SPYValueError(legal=lgl, varname="taper", actual=taper)
        taper = getattr(spwin, taper)

        # Advanced usage: see if `taperopt` was provided - if not, leave it empty
        taperopt = kwargs.get("taperopt", {})
        if not isinstance(taperopt, dict):
            raise SPYTypeError(taperopt,
                               varname="taperopt",
                               expected="dictionary")

        # Construct array of maximally attainable frequencies
        nFreq = int(np.floor(minSampleNum / 2) + 1)
        freqs = np.linspace(0, data.samplerate / 2, nFreq)

        # Match desired frequencies as close as possible to actually attainable freqs
        if foi is not None:
            foi, _ = best_match(freqs, foi, squash_duplicates=True)
        elif foilim is not None:
            foi, _ = best_match(freqs,
                                foilim,
                                span=True,
                                squash_duplicates=True)
        else:
            foi = freqs

        # Abort if desired frequency selection is empty
        if foi.size == 0:
            lgl = "non-empty frequency specification"
            act = "empty frequency selection"
            raise SPYValueError(legal=lgl, varname="foi/foilim", actual=act)

        # Set/get `tapsmofrq` if we're working w/Slepian tapers
        if taper.__name__ == "dpss":

            # Try to derive "sane" settings by using 3/4 octave smoothing of highest `foi`
            # following Hipp et al. "Oscillatory Synchronization in Large-Scale
            # Cortical Networks Predicts Perception", Neuron, 2011
            if tapsmofrq is None:
                foimax = foi.max()
                tapsmofrq = (foimax * 2**(3 / 4 / 2) -
                             foimax * 2**(-3 / 4 / 2)) / 2
            else:
                try:
                    scalar_parser(tapsmofrq,
                                  varname="tapsmofrq",
                                  lims=[1, np.inf])
                except Exception as exc:
                    raise exc

            # Get/compute number of tapers to use (at least 1 and max. 50)
            nTaper = taperopt.get("Kmax", 1)
            if not taperopt:
                nTaper = int(
                    max(
                        2,
                        min(
                            50,
                            np.floor(tapsmofrq * minSampleNum * 1 /
                                     data.samplerate))))
                taperopt = {"NW": tapsmofrq, "Kmax": nTaper}

        else:
            nTaper = 1

        # Warn the user in case `tapsmofrq` has no effect
        if tapsmofrq is not None and taper.__name__ != "dpss":
            msg = "`tapsmofrq` is only used if `taper` is `dpss`!"
            SPYWarning(msg)

        # Update `log_dct` w/method-specific options (use `lcls` to get actually
        # provided keyword values, not defaults set in here)
        log_dct["taper"] = lcls["taper"]
        log_dct["tapsmofrq"] = lcls["tapsmofrq"]
        log_dct["nTaper"] = nTaper

        # Check for non-default values of options not supported by chosen method
        kwdict = {"wav": wav, "width": width}
        for name, kwarg in kwdict.items():
            if kwarg is not lcls[name]:
                msg = "option `{}` has no effect in methods `mtmfft` and `mtmconvol`!"
                SPYWarning(msg.format(name))

    # Now, prepare explicit compute-classes for chosen method
    if method == "mtmfft":

        # Check for non-default values of options not supported by chosen method
        kwdict = {"t_ftimwin": t_ftimwin, "toi": toi}
        for name, kwarg in kwdict.items():
            if kwarg is not lcls[name]:
                msg = "option `{}` has no effect in method `mtmfft`!"
                SPYWarning(msg.format(name))

        # Set up compute-class
        specestMethod = MultiTaperFFT(samplerate=data.samplerate,
                                      foi=foi,
                                      nTaper=nTaper,
                                      timeAxis=timeAxis,
                                      taper=taper,
                                      taperopt=taperopt,
                                      tapsmofrq=tapsmofrq,
                                      pad=pad,
                                      padtype=padtype,
                                      padlength=padlength,
                                      keeptapers=keeptapers,
                                      polyremoval=polyremoval,
                                      output_fmt=output)

    elif method == "mtmconvol":

        # Set up compute-class
        specestMethod = MultiTaperFFTConvol(soi,
                                            list(padBegin),
                                            list(padEnd),
                                            samplerate=data.samplerate,
                                            noverlap=noverlap,
                                            nperseg=nperseg,
                                            equidistant=equidistant,
                                            toi=toi,
                                            foi=foi,
                                            nTaper=nTaper,
                                            timeAxis=timeAxis,
                                            taper=taper,
                                            taperopt=taperopt,
                                            pad=pad,
                                            padtype=padtype,
                                            padlength=padlength,
                                            prepadlength=prepadlength,
                                            postpadlength=postpadlength,
                                            keeptapers=keeptapers,
                                            polyremoval=polyremoval,
                                            output_fmt=output)

    elif method == "wavelet":

        # Check for non-default values of `taper`, `tapsmofrq`, `keeptapers` and
        # `t_ftimwin` (set to 0 above)
        kwdict = {
            "taper": taper,
            "tapsmofrq": tapsmofrq,
            "keeptapers": keeptapers
        }
        for name, kwarg in kwdict.items():
            if kwarg is not lcls[name]:
                msg = "option `{}` has no effect in method `wavelet`!"
                SPYWarning(msg.format(name))
        if t_ftimwin != 0:
            msg = "option `t_ftimwin` has no effect in method `wavelet`!"
            SPYWarning(msg)

        # Check wavelet selection
        if wav not in availableWavelets:
            lgl = "'" + "or '".join(opt + "' " for opt in availableWavelets)
            raise SPYValueError(legal=lgl, varname="wav", actual=wav)
        if wav not in ["Morlet", "Paul"]:
            msg = "the chosen wavelet '{}' is real-valued and does not provide " +\
                "any information about amplitude or phase of the data. This wavelet function " +\
                "may be used to isolate peaks or discontinuities in the signal. "
            SPYWarning(msg.format(wav))

        # Check for consistency of `width`, `order` and `wav`
        if wav == "Morlet":
            try:
                scalar_parser(width, varname="width", lims=[1, np.inf])
            except Exception as exc:
                raise exc
            wfun = getattr(spywave, wav)(w0=width)
        else:
            if width != lcls["width"]:
                msg = "option `width` has no effect for wavelet '{}'"
                SPYWarning(msg.format(wav))

        if wav == "Paul":
            try:
                scalar_parser(order,
                              varname="order",
                              lims=[4, np.inf],
                              ntype="int_like")
            except Exception as exc:
                raise exc
            wfun = getattr(spywave, wav)(m=order)
        elif wav == "DOG":
            try:
                scalar_parser(order,
                              varname="order",
                              lims=[1, np.inf],
                              ntype="int_like")
            except Exception as exc:
                raise exc
            wfun = getattr(spywave, wav)(m=order)
        else:
            if order is not None:
                msg = "option `order` has no effect for wavelet '{}'"
                SPYWarning(msg.format(wav))
            wfun = getattr(spywave, wav)()

        # Process frequency selection (`toi` was taken care of above): `foilim`
        # selections are wrapped into `foi` thus the seemingly weird if construct
        # Note: SLURM workers don't like monkey-patching, so let's pretend
        # `get_optimal_wavelet_scales` is a class method by passing `wfun` as its
        # first argument
        if foi is None:
            scales = _get_optimal_wavelet_scales(
                wfun, int(minTrialLength * data.samplerate),
                1 / data.samplerate)
        if foilim is not None:
            foi = np.arange(foilim[0], foilim[1] + 1)
        if foi is not None:
            foi[foi < 0.01] = 0.01
            scales = wfun.scale_from_period(1 / foi)
            scales = scales[::
                            -1]  # FIXME: this only makes sense if `foi` was sorted -> cf Issue #94

        # Update `log_dct` w/method-specific options (use `lcls` to get actually
        # provided keyword values, not defaults set in here)
        log_dct["wav"] = lcls["wav"]
        log_dct["width"] = lcls["width"]
        log_dct["order"] = lcls["order"]

        # Set up compute-class
        specestMethod = WaveletTransform(preSelect,
                                         postSelect,
                                         list(padBegin),
                                         list(padEnd),
                                         samplerate=data.samplerate,
                                         toi=toi,
                                         scales=scales,
                                         timeAxis=timeAxis,
                                         wav=wfun,
                                         polyremoval=polyremoval,
                                         output_fmt=output)

    # If provided, make sure output object is appropriate
    if out is not None:
        try:
            data_parser(out,
                        varname="out",
                        writable=True,
                        empty=True,
                        dataclass="SpectralData",
                        dimord=SpectralData().dimord)
        except Exception as exc:
            raise exc
        new_out = False
    else:
        out = SpectralData(dimord=SpectralData._defaultDimord)
        new_out = True

    # Perform actual computation
    specestMethod.initialize(data,
                             chan_per_worker=kwargs.get("chan_per_worker"),
                             keeptrials=keeptrials)
    specestMethod.compute(data,
                          out,
                          parallel=kwargs.get("parallel"),
                          log_dict=log_dct)

    # Either return newly created output object or simply quit
    return out if new_out else None
예제 #10
0
def mtmfft_cF(trl_dat,
              foi=None,
              timeAxis=0,
              keeptapers=True,
              polyremoval=None,
              output_fmt="pow",
              noCompute=False,
              chunkShape=None,
              method_kwargs=None):
    """
    Compute (multi-)tapered Fourier transform of multi-channel time series data

    Parameters
    ----------
    trl_dat : 2D :class:`numpy.ndarray`
        Uniformly sampled multi-channel time-series
    foi : 1D :class:`numpy.ndarray`
        Frequencies of interest  (Hz) for output. If desired frequencies
        cannot be matched exactly the closest possible frequencies (respecting
        data length and padding) are used.
    timeAxis : int
        Index of running time axis in `trl_dat` (0 or 1)
    keeptapers : bool
        If `True`, return spectral estimates for each taper.
        Otherwise power spectrum is averaged across tapers,
        only valid spectral estimate if `output_fmt` is `pow`.
    pad : str
        Padding mode; one of `'absolute'`, `'relative'`, `'maxlen'`, or `'nextpow2'`.
        See :func:`syncopy.padding` for more information.
    padtype : str
        Values to be used for padding. Can be 'zero', 'nan', 'mean',
        'localmean', 'edge' or 'mirror'. See :func:`syncopy.padding` for
        more information.
    padlength : None, bool or positive scalar
        Number of samples to pad to data (if `pad` is 'absolute' or 'relative').
        See :func:`syncopy.padding` for more information.
    polyremoval : int or None
        Order of polynomial used for de-trending data in the time domain prior
        to spectral analysis. A value of 0 corresponds to subtracting the mean
        ("de-meaning"), ``polyremoval = 1`` removes linear trends (subtracting the
        least squares fit of a linear polynomial).
        If `polyremoval` is `None`, no de-trending is performed.
    output_fmt : str
        Output of spectral estimation; one of :data:`~syncopy.specest.const_def.availableOutputs`
    noCompute : bool
        Preprocessing flag. If `True`, do not perform actual calculation but
        instead return expected shape and :class:`numpy.dtype` of output
        array.
    chunkShape : None or tuple
        If not `None`, represents shape of output `spec` (respecting provided
        values of `nTaper`, `keeptapers` etc.)
    method_kwargs : dict
        Keyword arguments passed to :func:`~syncopy.specest.mtmfft.mtmfft`
        controlling the spectral estimation method

    Returns
    -------
    spec : :class:`numpy.ndarray`
        Complex or real spectrum of (padded) input data.

    Notes
    -----
    This method is intended to be used as
    :meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction`
    inside a :class:`~syncopy.shared.computational_routine.ComputationalRoutine`.
    Thus, input parameters are presumed to be forwarded from a parent metafunction.
    Consequently, this function does **not** perform any error checking and operates
    under the assumption that all inputs have been externally validated and cross-checked.

    The computational heavy lifting in this code is performed by NumPy's reference
    implementation of the Fast Fourier Transform :func:`numpy.fft.fft`.

    See also
    --------
    syncopy.freqanalysis : parent metafunction
    MultiTaperFFT : :class:`~syncopy.shared.computational_routine.ComputationalRoutine` instance
                     that calls this method as :meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction`
    numpy.fft.rfft : NumPy's FFT implementation
    """

    # Re-arrange array if necessary and get dimensional information
    if timeAxis != 0:
        dat = trl_dat.T  # does not copy but creates view of `trl_dat`
    else:
        dat = trl_dat

    if method_kwargs['nSamples'] is None:
        nSamples = dat.shape[0]
    else:
        nSamples = method_kwargs['nSamples']

    nChannels = dat.shape[1]

    # Determine frequency band and shape of output
    # (time=1 x taper x freq x channel)
    freqs = np.fft.rfftfreq(nSamples, 1 / method_kwargs["samplerate"])
    _, freq_idx = best_match(freqs, foi, squash_duplicates=True)
    nFreq = freq_idx.size
    nTaper = method_kwargs["taper_opt"].get('Kmax', 1)
    outShape = (1, max(1, nTaper * keeptapers), nFreq, nChannels)

    # For initialization of computational routine,
    # just return output shape and dtype
    if noCompute:
        return outShape, spectralDTypes[output_fmt]

    # detrend, does not work with 'FauxTrial' data..
    if polyremoval == 0:
        dat = signal.detrend(dat, type='constant', axis=0, overwrite_data=True)
    elif polyremoval == 1:
        dat = signal.detrend(dat, type='linear', axis=0, overwrite_data=True)

    # call actual specest method
    res, _ = mtmfft(dat, **method_kwargs)

    # attach time-axis and convert to output_fmt
    spec = res[np.newaxis, :, freq_idx, :]
    spec = spectralConversions[output_fmt](spec)
    # Average across tapers if wanted
    # averaging is only valid spectral estimate
    # if output_fmt == 'pow'! (gets checked in parent meta)
    if not keeptapers:
        return spec.mean(axis=1, keepdims=True)
    return spec
예제 #11
0
def mtmconvol_cF(trl_dat,
                 soi,
                 postselect,
                 equidistant=True,
                 toi=None,
                 foi=None,
                 nTaper=1,
                 tapsmofrq=None,
                 timeAxis=0,
                 keeptapers=True,
                 polyremoval=0,
                 output_fmt="pow",
                 noCompute=False,
                 chunkShape=None,
                 method_kwargs=None):
    """
    Perform time-frequency analysis on multi-channel time series data using a sliding window FFT

    Parameters
    ----------
    trl_dat : 2D :class:`numpy.ndarray`
        Uniformly sampled multi-channel time-series
    soi : list of slices or slice
        Samples of interest; either a single slice encoding begin- to end-samples
        to perform analysis on (if sliding window centroids are equidistant)
        or list of slices with each slice corresponding to coverage of a single
        analysis window (if spacing between windows is not constant)
    samplerate : float
        Samplerate of `trl_dat` in Hz
    noverlap : int
        Number of samples covered by two adjacent analysis windows
    nperseg : int
        Size of analysis windows (in samples)
    equidistant : bool
        If `True`, spacing of window-centroids is equidistant.
    toi : 1D :class:`numpy.ndarray` or float or str
        Either time-points to center windows on if `toi` is a :class:`numpy.ndarray`,
        or percentage of overlap between windows if `toi` is a scalar or `"all"`
        to center windows on all samples in `trl_dat`. Please refer to
        :func:`~syncopy.freqanalysis` for further details. **Note**: The value
        of `toi` has to agree with provided padding and window settings. See
        Notes for more information.
    foi : 1D :class:`numpy.ndarray`
        Frequencies of interest  (Hz) for output. If desired frequencies
        cannot be matched exactly the closest possible frequencies (respecting
        data length and padding) are used.
    nTaper : int
        Number of tapers to use
    timeAxis : int
        Index of running time axis in `trl_dat` (0 or 1)
    taper : callable
        Taper function to use, one of :data:`~syncopy.specest.const_def.availableTapers`
    taper_opt : dict
        Additional keyword arguments passed to `taper` (see above). For further
        details, please refer to the
        `SciPy docs <https://docs.scipy.org/doc/scipy/reference/signal.windows.html>`_
    keeptapers : bool
        If `True`, results of Fourier transform are preserved for each taper,
        otherwise spectrum is averaged across tapers.
    polyremoval : int
        Order of polynomial used for de-trending data in the time domain prior
        to spectral analysis. A value of 0 corresponds to subtracting the mean
        ("de-meaning"), ``polyremoval = 1`` removes linear trends (subtracting the
        least squares fit of a linear polynomial). Detrending is done on each segment!
        If `polyremoval` is `None`, no de-trending is performed.
    output_fmt : str
        Output of spectral estimation; one of :data:`~syncopy.specest.const_def.availableOutputs`
    noCompute : bool
        Preprocessing flag. If `True`, do not perform actual calculation but
        instead return expected shape and :class:`numpy.dtype` of output
        array.
    chunkShape : None or tuple
        If not `None`, represents shape of output object `spec` (respecting provided
        values of `nTaper`, `keeptapers` etc.)
    method_kwargs : dict
        Keyword arguments passed to :func:`~syncopy.specest.mtmconvol.mtmconvol`
        controlling the spectral estimation method

    Returns
    -------
    spec : :class:`numpy.ndarray`
        Complex or real time-frequency representation of (padded) input data.

    Notes
    -----
    This method is intended to be used as
    :meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction`
    inside a :class:`~syncopy.shared.computational_routine.ComputationalRoutine`.
    Thus, input parameters are presumed to be forwarded from a parent metafunction.
    Consequently, this function does **not** perform any error checking and operates
    under the assumption that all inputs have been externally validated and cross-checked.

    The computational heavy lifting in this code is performed by SciPy's Short Time
    Fourier Transform (STFT) implementation :func:`scipy.signal.stft`.

    See also
    --------
    syncopy.freqanalysis : parent metafunction
    MultiTaperFFTConvol : :class:`~syncopy.shared.computational_routine.ComputationalRoutine`
                          instance that calls this method as
                          :meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction`
    scipy.signal.stft : SciPy's STFT implementation
    """

    # Re-arrange array if necessary and get dimensional information
    if timeAxis != 0:
        dat = trl_dat.T  # does not copy but creates view of `trl_dat`
    else:
        dat = trl_dat

    # Get shape of output for dry-run phase
    nChannels = dat.shape[1]
    if isinstance(toi, np.ndarray):  # `toi` is an array of time-points
        nTime = toi.size
        stftBdry = None
        stftPad = False
    else:  # `toi` is either 'all' or a percentage
        nTime = np.ceil(
            dat.shape[0] /
            (method_kwargs['nperseg'] - method_kwargs['noverlap'])).astype(
                np.intp)
        stftBdry = "zeros"
        stftPad = True
    nFreq = foi.size
    taper_opt = method_kwargs['taper_opt']
    if taper_opt:
        nTaper = taper_opt["Kmax"]
    outShape = (nTime, max(1, nTaper * keeptapers), nFreq, nChannels)
    if noCompute:
        return outShape, spectralDTypes[output_fmt]

    # detrending options for each segment
    if polyremoval == 0:
        detrend = 'constant'
    elif polyremoval == 1:
        detrend = 'linear'
    else:
        detrend = False

    # additional keyword args for `stft` in dictionary
    method_kwargs.update({
        "boundary": stftBdry,
        "padded": stftPad,
        "detrend": detrend
    })

    if equidistant:
        ftr, freqs = mtmconvol(dat[soi, :], **method_kwargs)
        _, fIdx = best_match(freqs, foi, squash_duplicates=True)
        spec = ftr[postselect, :, fIdx, :]
        spec = spectralConversions[output_fmt](spec)

    else:
        # in this case only a single window gets centered on
        # every individual soi, so we can use mtmfft!
        samplerate = method_kwargs['samplerate']
        taper = method_kwargs['taper']

        # In case tapers aren't preserved allocate `spec` "too big"
        # and average afterwards
        spec = np.full((nTime, nTaper, nFreq, nChannels),
                       np.nan,
                       dtype=spectralDTypes[output_fmt])

        ftr, freqs = mtmfft(dat[soi[0], :],
                            samplerate,
                            taper=taper,
                            taper_opt=taper_opt)
        _, fIdx = best_match(freqs, foi, squash_duplicates=True)
        spec[0, ...] = spectralConversions[output_fmt](ftr[:, fIdx, :])
        # loop over remaining soi to center windows on
        for tk in range(1, len(soi)):
            ftr, freqs = mtmfft(dat[soi[tk], :],
                                samplerate,
                                taper=taper,
                                taper_opt=taper_opt)
            spec[tk, ...] = spectralConversions[output_fmt](ftr[:, fIdx, :])

    # Average across tapers if wanted
    # only valid if output_fmt='pow' !
    if not keeptapers:
        return np.nanmean(spec, axis=1, keepdims=True)
    return spec
예제 #12
0
def mtmconvol(trl_dat,
              soi,
              padbegin,
              padend,
              samplerate=None,
              noverlap=None,
              nperseg=None,
              equidistant=True,
              toi=None,
              foi=None,
              nTaper=1,
              timeAxis=0,
              taper=signal.windows.hann,
              taperopt={},
              keeptapers=True,
              polyremoval=None,
              output_fmt="pow",
              noCompute=False,
              chunkShape=None):
    """
    Perform time-frequency analysis on multi-channel time series data using a sliding window FFT
    
    Parameters
    ----------
    trl_dat : 2D :class:`numpy.ndarray`
        Uniformly sampled multi-channel time-series 
    soi : list of slices or slice
        Samples of interest; either a single slice encoding begin- to end-samples 
        to perform analysis on (if sliding window centroids are equidistant)
        or list of slices with each slice corresponding to coverage of a single
        analysis window (if spacing between windows is not constant)
    padbegin : int
        Number of samples to pre-pend to `trl_dat`
    padend : int
        Number of samples to append to `trl_dat`
    samplerate : float
        Samplerate of `trl_dat` in Hz
    noverlap : int
        Number of samples covered by two adjacent analysis windows
    nperseg : int
        Size of analysis windows (in samples)
    equidistant : bool
        If `True`, spacing of window-centroids is equidistant. 
    toi : 1D :class:`numpy.ndarray` or float or str
        Either time-points to center windows on if `toi` is a :class:`numpy.ndarray`,
        or percentage of overlap between windows if `toi` is a scalar or `"all"`
        to center windows on all samples in `trl_dat`. Please refer to 
        :func:`~syncopy.freqanalysis` for further details. **Note**: The value 
        of `toi` has to agree with provided padding and window settings. See 
        Notes for more information. 
    foi : 1D :class:`numpy.ndarray`
        Frequencies of interest  (Hz) for output. If desired frequencies
        cannot be matched exactly the closest possible frequencies (respecting 
        data length and padding) are used.
    nTaper : int
        Number of tapers to use
    timeAxis : int
        Index of running time axis in `trl_dat` (0 or 1)
    taper : callable 
        Taper function to use, one of :data:`~syncopy.specest.freqanalysis.availableTapers`
    taperopt : dict
        Additional keyword arguments passed to `taper` (see above). For further 
        details, please refer to the 
        `SciPy docs <https://docs.scipy.org/doc/scipy/reference/signal.windows.html>`_
    keeptapers : bool
        If `True`, results of Fourier transform are preserved for each taper, 
        otherwise spectrum is averaged across tapers. 
    polyremoval : int
        **FIXME: Not implemented yet**
        Order of polynomial used for de-trending. A value of 0 corresponds to 
        subtracting the mean ("de-meaning"), ``polyremoval = 1`` removes linear 
        trends (subtracting the least squares fit of a linear function), 
        ``polyremoval = N`` for `N > 1` subtracts a polynomial of order `N` (``N = 2`` 
        quadratic, ``N = 3`` cubic etc.). If `polyremoval` is `None`, no de-trending
        is performed. 
    output_fmt : str
        Output of spectral estimation; one of :data:`~syncopy.specest.freqanalysis.availableOutputs`
    noCompute : bool
        Preprocessing flag. If `True`, do not perform actual calculation but
        instead return expected shape and :class:`numpy.dtype` of output
        array.
    chunkShape : None or tuple
        If not `None`, represents shape of output object `spec` (respecting provided 
        values of `nTaper`, `keeptapers` etc.)
    
    Returns
    -------
    spec : :class:`numpy.ndarray`
        Complex or real time-frequency representation of (padded) input data. 
            
    Notes
    -----
    This method is intended to be used as 
    :meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction`
    inside a :class:`~syncopy.shared.computational_routine.ComputationalRoutine`. 
    Thus, input parameters are presumed to be forwarded from a parent metafunction. 
    Consequently, this function does **not** perform any error checking and operates 
    under the assumption that all inputs have been externally validated and cross-checked. 
    
    The computational heavy lifting in this code is performed by SciPy's Short Time 
    Fourier Transform (STFT) implementation :func:`scipy.signal.stft`. 
    
    See also
    --------
    syncopy.freqanalysis : parent metafunction
    MultiTaperFFTConvol : :class:`~syncopy.shared.computational_routine.ComputationalRoutine`
                          instance that calls this method as 
                          :meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction`
    scipy.signal.stft : SciPy's STFT implementation
    """

    # Re-arrange array if necessary and get dimensional information
    if timeAxis != 0:
        dat = trl_dat.T  # does not copy but creates view of `trl_dat`
    else:
        dat = trl_dat

    # Pad input array if necessary
    if padbegin > 0 or padend > 0:
        dat = padding(dat,
                      "zero",
                      pad="relative",
                      padlength=None,
                      prepadlength=padbegin,
                      postpadlength=padend)

    # Get shape of output for dry-run phase
    nChannels = dat.shape[1]
    if isinstance(toi, np.ndarray):  # `toi` is an array of time-points
        nTime = toi.size
        stftBdry = None
        stftPad = False
    else:  # `toi` is either 'all' or a percentage
        nTime = np.ceil(dat.shape[0] / (nperseg - noverlap)).astype(np.intp)
        stftBdry = "zeros"
        stftPad = True
    nFreq = foi.size
    outShape = (nTime, max(1, nTaper * keeptapers), nFreq, nChannels)
    if noCompute:
        return outShape, spyfreq.spectralDTypes[output_fmt]

    # In case tapers aren't preserved allocate `spec` "too big" and average afterwards
    spec = np.full((nTime, nTaper, nFreq, nChannels),
                   np.nan,
                   dtype=spyfreq.spectralDTypes[output_fmt])

    # Collect keyword args for `stft` in dictionary
    stftKw = {
        "fs": samplerate,
        "nperseg": nperseg,
        "noverlap": noverlap,
        "return_onesided": True,
        "boundary": stftBdry,
        "padded": stftPad,
        "axis": 0
    }

    # Call `stft` w/first taper to get freq/time indices: transpose resulting `pxx`
    # to have a time x freq x channel array
    win = np.atleast_2d(taper(nperseg, **taperopt))
    stftKw["window"] = win[0, :]
    if equidistant:
        freq, _, pxx = signal.stft(dat[soi, :], **stftKw)
        _, fIdx = best_match(freq, foi, squash_duplicates=True)
        spec[:, 0, ...] = \
            spyfreq.spectralConversions[output_fmt](
                pxx.transpose(2, 0, 1))[:nTime, fIdx, :]
    else:
        freq, _, pxx = signal.stft(dat[soi[0], :], **stftKw)
        _, fIdx = best_match(freq, foi, squash_duplicates=True)
        spec[0, 0, ...] = \
            spyfreq.spectralConversions[output_fmt](
                pxx.transpose(2, 0, 1).squeeze())[fIdx, :]
        for tk in range(1, len(soi)):
            spec[tk, 0, ...] = \
                spyfreq.spectralConversions[output_fmt](
                    signal.stft(
                        dat[soi[tk], :],
                        **stftKw)[2].transpose(2, 0, 1).squeeze())[fIdx, :]

    # Compute FT using determined indices above for the remaining tapers (if any)
    for taperIdx in range(1, win.shape[0]):
        stftKw["window"] = win[taperIdx, :]
        if equidistant:
            spec[:, taperIdx, ...] = \
                spyfreq.spectralConversions[output_fmt](
                    signal.stft(
                        dat[soi, :],
                        **stftKw)[2].transpose(2, 0, 1))[:nTime, fIdx, :]
        else:
            for tk, sample in enumerate(soi):
                spec[tk, taperIdx, ...] = \
                    spyfreq.spectralConversions[output_fmt](
                        signal.stft(
                            dat[sample, :],
                            **stftKw)[2].transpose(2, 0, 1).squeeze())[fIdx, :]

    # Average across tapers if wanted
    if not keeptapers:
        return np.nanmean(spec, axis=1, keepdims=True)
    return spec