def test_object_padding(self): # construct AnalogData object w/trials of unequal lengths adata = generate_artificial_data(nTrials=7, nChannels=16, equidistant=False, inmemory=False) timeAxis = adata.dimord.index("time") # test dictionary generation for `create_new = False`: ensure all trials # have padded length of `total_time` seconds (1 sample tolerance) total_time = 30 pad_list = padding(adata, "zero", pad="absolute", padlength=total_time, unit="time", create_new=False) for tk, trl in enumerate(adata.trials): assert "pad_width" in pad_list[tk].keys() assert "constant_values" in pad_list[tk].keys() trl_time = (pad_list[tk]["pad_width"][timeAxis, :].sum() + trl.shape[timeAxis]) / adata.samplerate assert trl_time - total_time < 1 / adata.samplerate # jumble axes of `AnalogData` object and compute max. trial length adata2 = generate_artificial_data(nTrials=7, nChannels=16, equidistant=False, inmemory=False, dimord=adata.dimord[::-1]) timeAxis2 = adata2.dimord.index("time") maxtrllen = 0 for trl in adata2.trials: maxtrllen = max(maxtrllen, trl.shape[timeAxis2]) # symmetric `maxlen` padding: 1 sample tolerance pad_list2 = padding(adata2, "zero", pad="maxlen", create_new=False) for tk, trl in enumerate(adata2.trials): trl_len = pad_list2[tk]["pad_width"][ timeAxis2, :].sum() + trl.shape[timeAxis2] assert (trl_len - maxtrllen) <= 1 pad_list2 = padding(adata2, "zero", pad="maxlen", prepadlength=True, postpadlength=True, create_new=False) for tk, trl in enumerate(adata2.trials): trl_len = pad_list2[tk]["pad_width"][ timeAxis2, :].sum() + trl.shape[timeAxis2] assert (trl_len - maxtrllen) <= 1 # pre- and post- `maxlen` padding: no tolerance pad_list2 = padding(adata2, "zero", pad="maxlen", prepadlength=True, create_new=False) for tk, trl in enumerate(adata2.trials): trl_len = pad_list2[tk]["pad_width"][ timeAxis2, :].sum() + trl.shape[timeAxis2] assert trl_len == maxtrllen pad_list2 = padding(adata2, "zero", pad="maxlen", postpadlength=True, create_new=False) for tk, trl in enumerate(adata2.trials): trl_len = pad_list2[tk]["pad_width"][ timeAxis2, :].sum() + trl.shape[timeAxis2] assert trl_len == maxtrllen # `maxlen'-specific errors: `padlength` wrong type, wrong combo with `prepadlength` with pytest.raises(SPYTypeError): padding(adata, "zero", pad="maxlen", padlength=self.ns, create_new=False) with pytest.raises(SPYTypeError): padding(adata, "zero", pad="maxlen", prepadlength=self.ns, create_new=False) with pytest.raises(SPYTypeError): padding(adata, "zero", pad="maxlen", padlength=self.ns, prepadlength=True, create_new=False)
def mtmfft(trl_dat, samplerate=None, foi=None, nTaper=1, timeAxis=0, taper=spwin.hann, taperopt={}, pad="nextpow2", padtype="zero", padlength=None, keeptapers=True, polyremoval=None, output_fmt="pow", noCompute=False, chunkShape=None): """ Compute (multi-)tapered Fourier transform of multi-channel time series data Parameters ---------- trl_dat : 2D :class:`numpy.ndarray` Uniformly sampled multi-channel time-series samplerate : float Samplerate of `trl_dat` in Hz foi : 1D :class:`numpy.ndarray` Frequencies of interest (Hz) for output. If desired frequencies cannot be matched exactly the closest possible frequencies (respecting data length and padding) are used. nTaper : int Number of filter windows to use timeAxis : int Index of running time axis in `trl_dat` (0 or 1) taper : callable Taper function to use, one of :data:`~syncopy.specest.freqanalysis.availableTapers` taperopt : dict Additional keyword arguments passed to the `taper` function. For further details, please refer to the `SciPy docs <https://docs.scipy.org/doc/scipy/reference/signal.windows.html>`_ pad : str Padding mode; one of `'absolute'`, `'relative'`, `'maxlen'`, or `'nextpow2'`. See :func:`syncopy.padding` for more information. padtype : str Values to be used for padding. Can be 'zero', 'nan', 'mean', 'localmean', 'edge' or 'mirror'. See :func:`syncopy.padding` for more information. padlength : None, bool or positive scalar Number of samples to pad to data (if `pad` is 'absolute' or 'relative'). See :func:`syncopy.padding` for more information. keeptapers : bool If `True`, results of Fourier transform are preserved for each taper, otherwise spectrum is averaged across tapers. polyremoval : int or None **FIXME: Not implemented yet** Order of polynomial used for de-trending data in the time domain prior to spectral analysis. A value of 0 corresponds to subtracting the mean ("de-meaning"), ``polyremoval = 1`` removes linear trends (subtracting the least squares fit of a linear polynomial), ``polyremoval = N`` for `N > 1` subtracts a polynomial of order `N` (``N = 2`` quadratic, ``N = 3`` cubic etc.). If `polyremoval` is `None`, no de-trending is performed. output_fmt : str Output of spectral estimation; one of :data:`~syncopy.specest.freqanalysis.availableOutputs` noCompute : bool Preprocessing flag. If `True`, do not perform actual calculation but instead return expected shape and :class:`numpy.dtype` of output array. chunkShape : None or tuple If not `None`, represents shape of output `spec` (respecting provided values of `nTaper`, `keeptapers` etc.) Returns ------- spec : :class:`numpy.ndarray` Complex or real spectrum of (padded) input data. Notes ----- This method is intended to be used as :meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction` inside a :class:`~syncopy.shared.computational_routine.ComputationalRoutine`. Thus, input parameters are presumed to be forwarded from a parent metafunction. Consequently, this function does **not** perform any error checking and operates under the assumption that all inputs have been externally validated and cross-checked. The computational heavy lifting in this code is performed by NumPy's reference implementation of the Fast Fourier Transform :func:`numpy.fft.fft`. See also -------- syncopy.freqanalysis : parent metafunction MultiTaperFFT : :class:`~syncopy.shared.computational_routine.ComputationalRoutine` instance that calls this method as :meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction` numpy.fft.fft : NumPy's FFT implementation """ # Re-arrange array if necessary and get dimensional information if timeAxis != 0: dat = trl_dat.T # does not copy but creates view of `trl_dat` else: dat = trl_dat # Padding (updates no. of samples) if pad: dat = padding(dat, padtype, pad=pad, padlength=padlength, prepadlength=True) nSamples = dat.shape[0] nChannels = dat.shape[1] # Determine frequency band and shape of output (time=1 x taper x freq x channel) nFreq = int(np.floor(nSamples / 2) + 1) freqs = np.linspace(0, samplerate / 2, nFreq) _, fidx = best_match(freqs, foi, squash_duplicates=True) nFreq = fidx.size outShape = (1, max(1, nTaper * keeptapers), nFreq, nChannels) # For initialization of computational routine, just return output shape and dtype if noCompute: return outShape, freq.spectralDTypes[output_fmt] # In case tapers aren't preserved allocate `spec` "too big" and average afterwards spec = np.full((1, nTaper, nFreq, nChannels), np.nan, dtype=freq.spectralDTypes[output_fmt]) fill_idx = tuple([slice(None, dim) for dim in outShape[2:]]) # Actual computation win = np.atleast_2d(taper(nSamples, **taperopt)) for taperIdx, taper in enumerate(win): if dat.ndim > 1: taper = np.tile(taper, (nChannels, 1)).T spec[(0, taperIdx,) + fill_idx] = freq.spectralConversions[output_fmt](np.fft.rfft(dat * taper, axis=0)[fidx, :]) # Average across tapers if wanted if not keeptapers: return spec.mean(axis=1, keepdims=True) return spec
def test_absolute_nextpow2_array_padding(self): pad_count = { "absolute": self.ns + 20, "nextpow2": int(2**np.ceil(np.log2(self.ns))) } kws = {"absolute": pad_count["absolute"], "nextpow2": None} for pad, n_total in pad_count.items(): n_fillin = n_total - self.ns n_half = int(n_fillin / 2) arr = padding(self.data, "zero", pad=pad, padlength=kws[pad]) assert np.all(arr[:n_half, :] == 0) assert np.all(arr[-n_half:, :] == 0) assert arr.shape[0] == n_total arr = padding(self.data, "zero", pad=pad, padlength=kws[pad], prepadlength=True) assert np.all(arr[:n_fillin, :] == 0) assert arr.shape[0] == n_total arr = padding(self.data, "zero", pad=pad, padlength=kws[pad], postpadlength=True) assert np.all(arr[-n_fillin:, :] == 0) assert arr.shape[0] == n_total arr = padding(self.data, "zero", pad=pad, padlength=kws[pad], prepadlength=True, postpadlength=True) assert np.all(arr[:n_half, :] == 0) assert np.all(arr[-n_half:, :] == 0) assert arr.shape[0] == n_total # 'absolute'-specific errors: `padlength` too short, wrong type, wrong combo with `prepadlength` with pytest.raises(SPYValueError): padding(self.data, "zero", pad="absolute", padlength=self.ns - 1) with pytest.raises(SPYTypeError): padding(self.data, "zero", pad="absolute", prepadlength=self.ns) with pytest.raises(SPYTypeError): padding(self.data, "zero", pad="absolute", padlength=n_total, prepadlength=n_total) # 'nextpow2'-specific errors: `padlength` wrong type, wrong combo with `prepadlength` with pytest.raises(SPYTypeError): padding(self.data, "zero", pad="nextpow2", padlength=self.ns) with pytest.raises(SPYTypeError): padding(self.data, "zero", pad="nextpow2", prepadlength=self.ns) with pytest.raises(SPYTypeError): padding(self.data, "zero", pad="nextpow2", padlength=n_total, prepadlength=True)
def test_relative_array_padding(self): # no. of samples to pad n_center = 5 n_pre = 2 n_post = 3 n_half = int(n_center / 2) # dict for for calling `padding` lockws = { "center": { "padlength": n_center }, "pre": { "prepadlength": n_pre }, "post": { "postpadlength": n_post }, "prepost": { "prepadlength": n_pre, "postpadlength": 3 } } # expected results for padding technique (pre/post/center/prepost) and # all available `padtype`'s expected_vals = { "center": { "zero": [0, 0], "nan": [np.nan, np.nan], "mean": [ np.tile(self.data.mean(axis=0), (n_half, 1)), np.tile(self.data.mean(axis=0), (n_half, 1)) ], "localmean": [ np.tile(self.data[:n_half, :].mean(axis=0), (n_half, 1)), np.tile(self.data[-n_half:, :].mean(axis=0), (n_half, 1)) ], "edge": [ np.tile(self.data[0, :], (n_half, 1)), np.tile(self.data[-1, :], (n_half, 1)) ], "mirror": [ self.data[1:1 + n_half, :][::-1], self.data[-1 - n_half:-1, :][::-1] ] }, "pre": { "zero": [0], "nan": [np.nan], "mean": [np.tile(self.data.mean(axis=0), (n_pre, 1))], "localmean": [np.tile(self.data[:n_pre, :].mean(axis=0), (n_pre, 1))], "edge": [np.tile(self.data[0, :], (n_pre, 1))], "mirror": [self.data[1:1 + n_pre, :][::-1]] }, "post": { "zero": [0], "nan": [np.nan], "mean": [np.tile(self.data.mean(axis=0), (n_post, 1))], "localmean": [np.tile(self.data[-n_post:, :].mean(axis=0), (n_post, 1))], "edge": [np.tile(self.data[-1, :], (n_post, 1))], "mirror": [self.data[-1 - n_post:-1, :][::-1]] }, "prepost": { "zero": [0, 0], "nan": [np.nan, np.nan], "mean": [ np.tile(self.data.mean(axis=0), (n_pre, 1)), np.tile(self.data.mean(axis=0), (n_post, 1)) ], "localmean": [ np.tile(self.data[:n_pre, :].mean(axis=0), (n_pre, 1)), np.tile(self.data[-n_post:, :].mean(axis=0), (n_post, 1)) ], "edge": [ np.tile(self.data[0, :], (n_pre, 1)), np.tile(self.data[-1, :], (n_post, 1)) ], "mirror": [ self.data[1:1 + n_pre, :][::-1], self.data[-1 - n_post:-1, :][::-1] ] } } # indices for slicing resulting array to extract padded values for validation expected_idx = { "center": [slice(None, n_half), slice(-n_half, None)], "pre": [slice(None, n_pre)], "post": [slice(-n_post, None)], "prepost": [slice(None, n_pre), slice(-n_post, None)] } # expected shape of resulting array expected_shape = { "center": self.data.shape[0] + 2 * n_half, "pre": self.data.shape[0] + n_pre, "post": self.data.shape[0] + n_post, "prepost": self.data.shape[0] + n_pre + n_post } # happy padding for loc, kws in lockws.items(): for ptype in ["zero", "mean", "localmean", "edge", "mirror"]: arr = padding(self.data, ptype, pad="relative", **kws) for k, idx in enumerate(expected_idx[loc]): assert np.all(arr[idx, :] == expected_vals[loc][ptype][k]) assert arr.shape[0] == expected_shape[loc] arr = padding(self.data, "nan", pad="relative", **kws) for idx in expected_idx[loc]: assert np.all(np.isnan(arr[idx, :])) assert arr.shape[0] == expected_shape[loc] # overdetermined padding with pytest.raises(SPYTypeError): padding(self.data, "zero", pad="relative", padlength=5, prepadlength=2) with pytest.raises(SPYTypeError): padding(self.data, "zero", pad="relative", padlength=5, postpadlength=2) with pytest.raises(SPYTypeError): padding(self.data, "zero", pad="relative", padlength=5, prepadlength=2, postpadlength=2) # float input for sample counts with pytest.raises(SPYValueError): padding(self.data, "zero", pad="relative", padlength=2.5) with pytest.raises(SPYValueError): padding(self.data, "zero", pad="relative", prepadlength=2.5) with pytest.raises(SPYValueError): padding(self.data, "zero", pad="relative", postpadlength=2.5) with pytest.raises(SPYValueError): padding(self.data, "zero", pad="relative", prepadlength=2.5, postpadlength=2.5) # time-based padding w/array input with pytest.raises(SPYValueError): padding(self.data, "zero", pad="relative", padlength=2, unit="time")
def freqanalysis(data, method='mtmfft', output='fourier', keeptrials=True, foi=None, foilim=None, pad=None, padtype='zero', padlength=None, prepadlength=None, postpadlength=None, polyremoval=None, taper="hann", tapsmofrq=None, keeptapers=False, toi=None, t_ftimwin=None, wav="Morlet", width=6, order=None, out=None, **kwargs): """ Perform (time-)frequency analysis of Syncopy :class:`~syncopy.AnalogData` objects **Usage Summary** Options available in all analysis methods: * **output** : one of :data:`~.availableOutputs`; return power spectra, complex Fourier spectra or absolute values. * **foi**/**foilim** : frequencies of interest; either array of frequencies or frequency window (not both) * **keeptrials** : return individual trials or grand average * **polyremoval** : de-trending method to use (0 = mean, 1 = linear, 2 = quadratic, 3 = cubic, etc.) List of available analysis methods and respective distinct options: :func:`~syncopy.specest.mtmfft.mtmfft` : (Multi-)tapered Fourier transform Perform frequency analysis on time-series trial data using either a single taper window (Hanning) or many tapers based on the discrete prolate spheroidal sequence (DPSS) that maximize energy concentration in the main lobe. * **taper** : one of :data:`~.availableTapers` * **tapsmofrq** : spectral smoothing box for tapers (in Hz) * **keeptapers** : return individual tapers or average * **pad** : padding method to use (`None`, `True`, `False`, `'absolute'`, `'relative'`, `'maxlen'` or `'nextpow2'`). If `None`, then `'nextpow2'` is selected by default. * **padtype** : values to pad data with (`'zero'`, `'nan'`, `'mean'`, `'localmean'`, `'edge'` or `'mirror'`) * **padlength** : number of samples to pre-pend and/or append to each trial * **prepadlength** : number of samples to pre-pend to each trial * **postpadlength** : number of samples to append to each trial :func:`~syncopy.specest.mtmconvol.mtmconvol` : (Multi-)tapered sliding window Fourier transform Perform time-frequency analysis on time-series trial data based on a sliding window short-time Fourier transform using either a single Hanning taper or multiple DPSS tapers. * **taper** : one of :data:`~.availableTapers` * **tapsmofrq** : spectral smoothing box for tapers (in Hz) * **keeptapers** : return individual tapers or average * **pad** : flag indicating, whether or not to pad trials. If `None`, trials are padded only if sliding window centroids are too close to trial boundaries for the entire window to cover available data-points. * **toi** : time-points of interest; can be either an array representing analysis window centroids (in sec), a scalar between 0 and 1 encoding the percentage of overlap between adjacent windows or "all" to center a window on every sample in the data. * **t_ftimwin** : sliding window length (in sec) :func:`~syncopy.specest.wavelet.wavelet` : (Continuous non-orthogonal) wavelet transform Perform time-frequency analysis on time-series trial data using a non-orthogonal continuous wavelet transform. * **wav** : one of :data:`~.availableWavelets` * **toi** : time-points of interest; can be either an array representing time points (in sec) to center wavelets on or "all" to center a wavelet on every sample in the data. * **width** : Nondimensional frequency constant of Morlet wavelet function (>= 6) * **order** : Order of Paul wavelet function (>= 4) or derivative order of real-valued DOG wavelets (2 = mexican hat) **Full documentation below** Parameters ---------- data : `~syncopy.AnalogData` A non-empty Syncopy :class:`~syncopy.datatype.AnalogData` object method : str Spectral estimation method, one of :data:`~.availableMethods` (see below). output : str Output of spectral estimation. One of :data:`~.availableOutputs` (see below); use `'pow'` for power spectrum (:obj:`numpy.float32`), `'fourier'` for complex Fourier coefficients (:obj:`numpy.complex128`) or `'abs'` for absolute values (:obj:`numpy.float32`). keeptrials : bool If `True` spectral estimates of individual trials are returned, otherwise results are averaged across trials. foi : array-like or None Frequencies of interest (Hz) for output. If desired frequencies cannot be matched exactly, the closest possible frequencies are used. If `foi` is `None` or ``foi = "all"``, all attainable frequencies (i.e., zero to Nyquist / 2) are selected. foilim : array-like (floats [fmin, fmax]) or None or "all" Frequency-window ``[fmin, fmax]`` (in Hz) of interest. Window specifications must be sorted (e.g., ``[90, 70]`` is invalid) and not NaN but may be unbounded (e.g., ``[-np.inf, 60.5]`` is valid). Edges `fmin` and `fmax` are included in the selection. If `foilim` is `None` or ``foilim = "all"``, all frequencies are selected. pad : str or None or bool One of `None`, `True`, `False`, `'absolute'`, `'relative'`, `'maxlen'` or `'nextpow2'`. If `pad` is `None` or ``pad = True``, then method-specific defaults are chosen. Specifically, if `method` is `'mtmfft'` then `pad` is set to `'nextpow2'` so that all trials in `data` are padded to the next power of two higher than the sample-count of the longest (selected) trial in `data`. Conversely, time-frequency analysis methods (`'mtmconvol'` and `'wavelet'`), only perform padding if necessary, i.e., if time-window centroids are chosen too close to trial boundaries for the entire window to cover available data-points. If `pad` is `False`, then no padding is performed. Then in case of ``method = 'mtmfft'`` all trials have to have approximately the same length (up to the next even sample-count), if ``method = 'mtmconvol'`` or ``method = 'wavelet'``, window-centroids have to keep sufficient distance from trial boundaries. For more details on the padding methods `'absolute'`, `'relative'`, `'maxlen'` and `'nextpow2'` see :func:`syncopy.padding`. padtype : str Values to be used for padding. Can be `'zero'`, `'nan'`, `'mean'`, `'localmean'`, `'edge'` or `'mirror'`. See :func:`syncopy.padding` for more information. padlength : None, bool or positive int Only valid if `method` is `'mtmfft'` and `pad` is `'absolute'` or `'relative'`. Number of samples to pad data with. See :func:`syncopy.padding` for more information. prepadlength : None or bool or int Only valid if `method` is `'mtmfft'` and `pad` is `'relative'`. Number of samples to pre-pend to each trial. See :func:`syncopy.padding` for more information. postpadlength : None or bool or int Only valid if `method` is `'mtmfft'` and `pad` is `'relative'`. Number of samples to append to each trial. See :func:`syncopy.padding` for more information. polyremoval : int or None **FIXME: Not implemented yet** Order of polynomial used for de-trending data in the time domain prior to spectral analysis. A value of 0 corresponds to subtracting the mean ("de-meaning"), ``polyremoval = 1`` removes linear trends (subtracting the least squares fit of a linear polynomial), ``polyremoval = N`` for `N > 1` subtracts a polynomial of order `N` (``N = 2`` quadratic, ``N = 3`` cubic etc.). If `polyremoval` is `None`, no de-trending is performed. taper : str Only valid if `method` is `'mtmfft'` or `'mtmconvol'`. Windowing function, one of :data:`~.availableTapers` (see below). tapsmofrq : float Only valid if `method` is `'mtmfft'` or `'mtmconvol'`. The amount of spectral smoothing through multi-tapering (Hz). Note that smoothing frequency specifications are one-sided, i.e., 4 Hz smoothing means plus-minus 4 Hz, i.e., a 8 Hz smoothing box. keeptapers : bool Only valid if `method` is `'mtmfft'` or `'mtmconvol'`. If `True`, return spectral estimates for each taper, otherwise results are averaged across tapers. toi : float or array-like or "all" **Mandatory input** for time-frequency analysis methods (`method` is either `"mtmconvol"` or `"wavelet"`). If `toi` is scalar, it must be a value between 0 and 1 indicating the percentage of overlap between time-windows specified by `t_ftimwin` (only valid if `method` is `'mtmconvol'`, invalid for `'wavelet'`). If `toi` is an array it explicitly selects the centroids of analysis windows (in seconds). If `toi` is `"all"`, analysis windows are centered on all samples in the data. t_ftimwin : positive float Only valid if `method` is `'mtmconvol'`. Sliding window length (in seconds). wav : str Only valid if `method` is `'wavelet'`. Wavelet function to use, one of :data:`~.availableWavelets` (see below). width : positive float Only valid if `method` is `'wavelet'` and `wav` is `'Morlet'`. Nondimensional frequency constant of Morlet wavelet function. This number should be >= 6, which corresponds to 6 cycles within the analysis window to ensure sufficient spectral sampling. order : positive int Only valid if `method` is `'wavelet'` and `wav` is `'Paul'` or `'DOG'`. Order of the wavelet function. If `wav` is `'Paul'`, `order` should be chosen >= 4 to ensure that the analysis window contains at least a single oscillation. At an order of 40, the Paul wavelet exhibits about the same number of cycles as the Morlet wavelet with a `width` of 6. All other supported wavelets functions are *real-valued* derivatives of Gaussians (DOGs). Hence, if `wav` is `'DOG'`, `order` represents the derivative order. The special case of a second order DOG yields a function known as "Mexican Hat", "Marr" or "Ricker" wavelet, which can be selected alternatively by setting `wav` to `'Mexican_hat'`, `'Marr'` or `'Ricker'`. **Note**: A real-valued wavelet function encodes *only* information about peaks and discontinuities in the signal and does *not* provide any information about amplitude or phase. out : None or :class:`SpectralData` object None if a new :class:`SpectralData` object is to be created, or an empty :class:`SpectralData` object Returns ------- spec : :class:`~syncopy.SpectralData` (Time-)frequency spectrum of input data Notes ----- Coming soon... Examples -------- Coming soon... .. autodata:: syncopy.specest.freqanalysis.availableMethods .. autodata:: syncopy.specest.freqanalysis.availableOutputs .. autodata:: syncopy.specest.freqanalysis.availableTapers .. autodata:: syncopy.specest.freqanalysis.availableWavelets See also -------- syncopy.specest.mtmfft.mtmfft : (multi-)tapered Fourier transform of multi-channel time series data syncopy.specest.mtmconvol.mtmconvol : time-frequency analysis of multi-channel time series data with a sliding window FFT syncopy.specest.wavelet.wavelet : time-frequency analysis of multi-channel time series data using a wavelet transform numpy.fft.fft : NumPy's reference FFT implementation scipy.signal.stft : SciPy's Short Time Fourier Transform """ # Make sure our one mandatory input object can be processed try: data_parser(data, varname="data", dataclass="AnalogData", writable=None, empty=False) except Exception as exc: raise exc timeAxis = data.dimord.index("time") # Get everything of interest in local namespace defaults = get_defaults(freqanalysis) lcls = locals() # Ensure a valid computational method was selected if method not in availableMethods: lgl = "'" + "or '".join(opt + "' " for opt in availableMethods) raise SPYValueError(legal=lgl, varname="method", actual=method) # Ensure a valid output format was selected if output not in spectralConversions.keys(): lgl = "'" + "or '".join(opt + "' " for opt in spectralConversions.keys()) raise SPYValueError(legal=lgl, varname="output", actual=output) # Parse all Boolean keyword arguments for vname in ["keeptrials", "keeptapers"]: if not isinstance(lcls[vname], bool): raise SPYTypeError(lcls[vname], varname=vname, expected="Bool") # If only a subset of `data` is to be processed, make some necessary adjustments # and compute minimal sample-count across (selected) trials if data._selection is not None: trialList = data._selection.trials sinfo = np.zeros((len(trialList), 2)) for tk, trlno in enumerate(trialList): trl = data._preview_trial(trlno) tsel = trl.idx[timeAxis] if isinstance(tsel, list): sinfo[tk, :] = [0, len(tsel)] else: sinfo[tk, :] = [ trl.idx[timeAxis].start, trl.idx[timeAxis].stop ] else: trialList = list(range(len(data.trials))) sinfo = data.sampleinfo lenTrials = np.diff(sinfo).squeeze() numTrials = len(trialList) # Set default padding options: after this, `pad` is either `None`, `False` or `str` defaultPadding = {"mtmfft": "nextpow2", "mtmconvol": None, "wavelet": None} if pad is None or pad is True: pad = defaultPadding[method] # Sliding window FFT does not support "fancy" padding if method == "mtmconvol" and isinstance(pad, str): msg = "method 'mtmconvol' only supports in-place padding for windows " +\ "exceeding trial boundaries. Your choice of `pad = '{}'` will be ignored. " SPYWarning(msg.format(pad)) pad = None # Ensure padding selection makes sense: do not pad on a by-trial basis but # use the longest trial as reference and compute `padlength` from there # (only relevant for "global" padding options such as `maxlen` or `nextpow2`) if pad: if not isinstance(pad, str): raise SPYTypeError(pad, varname="pad", expected="str or None") if pad == "maxlen": padlength = lenTrials.max() prepadlength = True postpadlength = False elif pad == "nextpow2": padlength = 0 for ltrl in lenTrials: padlength = max(padlength, _nextpow2(ltrl)) pad = "absolute" prepadlength = True postpadlength = False padding(data._preview_trial(trialList[0]), padtype, pad=pad, padlength=padlength, prepadlength=prepadlength, postpadlength=postpadlength) # Compute `minSampleNum` accounting for padding minSamplePos = lenTrials.argmin() minSampleNum = padding(data._preview_trial(trialList[minSamplePos]), padtype, pad=pad, padlength=padlength, prepadlength=True).shape[timeAxis] else: if method == "mtmfft" and np.unique( (np.floor(lenTrials / 2))).size > 1: lgl = "trials of approximately equal length for method 'mtmfft'" act = "trials of unequal length" raise SPYValueError(legal=lgl, varname="data", actual=act) minSampleNum = lenTrials.min() # Compute length (in samples) of shortest trial minTrialLength = minSampleNum / data.samplerate # Basic sanitization of frequency specifications if foi is not None: if isinstance(foi, str): if foi == "all": foi = None else: raise SPYValueError(legal="'all' or `None` or list/array", varname="foi", actual=foi) else: try: array_parser(foi, varname="foi", hasinf=False, hasnan=False, lims=[0, data.samplerate / 2], dims=(None, )) except Exception as exc: raise exc foi = np.array(foi, dtype="float") if foilim is not None: if isinstance(foilim, str): if foilim == "all": foilim = None else: raise SPYValueError(legal="'all' or `None` or `[fmin, fmax]`", varname="foilim", actual=foilim) else: try: array_parser(foilim, varname="foilim", hasinf=False, hasnan=False, lims=[0, data.samplerate / 2], dims=(2, )) except Exception as exc: raise exc if foi is not None and foilim is not None: lgl = "either `foi` or `foilim` specification" act = "both" raise SPYValueError(legal=lgl, varname="foi/foilim", actual=act) # FIXME: implement detrending # see also https://docs.obspy.org/_modules/obspy/signal/detrend.html#polynomial if polyremoval is not None: raise NotImplementedError("Detrending has not been implemented yet.") try: scalar_parser(polyremoval, varname="polyremoval", lims=[0, 8], ntype="int_like") except Exception as exc: raise exc # Prepare keyword dict for logging (use `lcls` to get actually provided # keyword values, not defaults set above) log_dct = { "method": method, "output": output, "keeptapers": keeptapers, "keeptrials": keeptrials, "polyremoval": polyremoval, "pad": lcls["pad"], "padtype": lcls["padtype"], "padlength": lcls["padlength"], "foi": lcls["foi"] } # 1st: Check time-frequency inputs to prepare/sanitize `toi` if method in ["mtmconvol", "wavelet"]: # Get start/end timing info respecting potential in-place selection if toi is None: raise SPYTypeError(toi, varname="toi", expected="scalar or array-like or 'all'") if data._selection is not None: tStart = data._selection.trialdefinition[:, 2] / data.samplerate else: tStart = data._t0 / data.samplerate tEnd = tStart + lenTrials / data.samplerate # Process `toi`: we have to account for three scenarios: (1) center sliding # windows on all samples in (selected) trials (2) `toi` was provided as # percentage indicating the degree of overlap b/w time-windows and (3) a set # of discrete time points was provided. These three cases are encoded in # `overlap, i.e., ``overlap > 1` => all, `0 < overlap < 1` => percentage, # `overlap < 0` => discrete `toi` if isinstance(toi, str): if toi != "all": lgl = "`toi = 'all'` to center analysis windows on all time-points" raise SPYValueError(legal=lgl, varname="toi", actual=toi) overlap = 1.1 toi = None equidistant = True elif isinstance(toi, Number): if method == "wavelet": lgl = "array of time-points wavelets are to be centered on" act = "scalar value" raise SPYValueError(legal=lgl, varname="toi", actual=act) try: scalar_parser(toi, varname="toi", lims=[0, 1]) except Exception as exc: raise exc overlap = toi equidistant = True else: overlap = -1 try: array_parser(toi, varname="toi", hasinf=False, hasnan=False, lims=[tStart.min(), tEnd.max()], dims=(None, )) except Exception as exc: raise exc toi = np.array(toi) tSteps = np.diff(toi) if (tSteps < 0).any(): lgl = "ordered list/array of time-points" act = "unsorted list/array" raise SPYValueError(legal=lgl, varname="toi", actual=act) # This is imho a bug in NumPy - even `arange` and `linspace` may produce # arrays that are numerically not exactly equidistant - `unique` will # show several entries here - use `allclose` to identify "even" spacings equidistant = np.allclose(tSteps, [tSteps[0]] * tSteps.size) # If `toi` was 'all' or a percentage, use entire time interval of (selected) # trials and check if those trials have *approximately* equal length if toi is None: if not np.allclose(lenTrials, [minSampleNum] * lenTrials.size): msg = "processing trials of different lengths (min = {}; max = {} samples)" +\ " with `toi = 'all'`" SPYWarning(msg.format(int(minSampleNum), int(lenTrials.max()))) if pad is False: lgl = "`pad` to be `None` or `True` to permit zero-padding " +\ "at trial boundaries to accommodate windows if `0 < toi < 1` " +\ "or if `toi` is 'all'" act = "False" raise SPYValueError(legal=lgl, actual=act, varname="pad") # Code recycling: `overlap`, `equidistant` etc. are really only relevant # for `mtmconvol`, but we use padding calc below for `wavelet` as well if method == "mtmconvol": try: scalar_parser(t_ftimwin, varname="t_ftimwin", lims=[1 / data.samplerate, minTrialLength]) except Exception as exc: raise exc else: t_ftimwin = 0 nperseg = int(t_ftimwin * data.samplerate) minSampleNum = nperseg halfWin = int(nperseg / 2) # `mtmconvol`: compute no. of samples overlapping across adjacent windows if overlap < 0: # `toi` is equidistant range or disjoint points noverlap = nperseg - max(1, int(tSteps[0] * data.samplerate)) elif 0 <= overlap <= 1: # `toi` is percentage noverlap = min(nperseg - 1, int(overlap * nperseg)) else: # `toi` is "all" noverlap = nperseg - 1 # `toi` is array if overlap < 0: # Compute necessary padding at begin/end of trials to fit sliding windows offStart = ((toi[0] - tStart) * data.samplerate).astype(np.intp) padBegin = halfWin - offStart padBegin = ((padBegin > 0) * padBegin).astype(np.intp) offEnd = ((tEnd - toi[-1]) * data.samplerate).astype(np.intp) padEnd = halfWin - offEnd padEnd = ((padEnd > 0) * padEnd).astype(np.intp) # Abort if padding was explicitly forbidden if pad is False and (np.any(padBegin) or np.any(padBegin)): lgl = "windows within trial bounds" act = "windows exceeding trials no. " +\ "".join(str(trlno) + ", "\ for trlno in np.array(trialList)[(padBegin + padEnd) > 0])[:-2] raise SPYValueError(legal=lgl, varname="pad", actual=act) # Compute sample-indices (one slice/list per trial) from time-selections soi = [] if not equidistant: for tk in range(numTrials): starts = (data.samplerate * (toi - tStart[tk]) - halfWin).astype(np.intp) starts += padBegin[tk] stops = (data.samplerate * (toi - tStart[tk]) + halfWin + 1).astype(np.intp) stops += padBegin[tk] stops = np.maximum(stops, stops - starts, dtype=np.intp) soi.append([ slice(start, stop) for start, stop in zip(starts, stops) ]) else: for tk in range(numTrials): start = int(data.samplerate * (toi[0] - tStart[tk]) - halfWin) stop = int(data.samplerate * (toi[-1] - tStart[tk]) + halfWin + 1) soi.append(slice(max(0, start), max(stop, stop - start))) # `toi` is percentage or "all" else: padBegin = np.zeros((numTrials, )) padEnd = np.zeros((numTrials, )) soi = [slice(None)] * numTrials # For wavelets, we need to first trim the data (via `preSelect`), then # extract the wanted time-points (`postSelect`) if method == "wavelet": # Simply recycle the indexing work done for `mtmconvol` (i.e., `soi`) preSelect = [] if not equidistant: for tk in range(numTrials): preSelect.append(slice(soi[tk][0].start, soi[tk][-1].stop)) else: preSelect = soi # If `toi` is an array, convert "global" indices to "local" ones # (select within `preSelect`'s selection), otherwise just take all if overlap < 0: postSelect = [] for tk in range(numTrials): smpIdx = np.minimum( lenTrials[tk] - 1, data.samplerate * (toi - tStart[tk]) - offStart[tk] + padBegin[tk]) postSelect.append(smpIdx.astype(np.intp)) else: postSelect = [slice(None)] * numTrials # Update `log_dct` w/method-specific options (use `lcls` to get actually # provided keyword values, not defaults set in here) if toi is None: toi = "all" log_dct["toi"] = lcls["toi"] # Check options specific to mtm*-methods (particularly tapers and foi/freqs alignment) if "mtm" in method: # See if taper choice is supported if taper not in availableTapers: lgl = "'" + "or '".join(opt + "' " for opt in availableTapers) raise SPYValueError(legal=lgl, varname="taper", actual=taper) taper = getattr(spwin, taper) # Advanced usage: see if `taperopt` was provided - if not, leave it empty taperopt = kwargs.get("taperopt", {}) if not isinstance(taperopt, dict): raise SPYTypeError(taperopt, varname="taperopt", expected="dictionary") # Construct array of maximally attainable frequencies nFreq = int(np.floor(minSampleNum / 2) + 1) freqs = np.linspace(0, data.samplerate / 2, nFreq) # Match desired frequencies as close as possible to actually attainable freqs if foi is not None: foi, _ = best_match(freqs, foi, squash_duplicates=True) elif foilim is not None: foi, _ = best_match(freqs, foilim, span=True, squash_duplicates=True) else: foi = freqs # Abort if desired frequency selection is empty if foi.size == 0: lgl = "non-empty frequency specification" act = "empty frequency selection" raise SPYValueError(legal=lgl, varname="foi/foilim", actual=act) # Set/get `tapsmofrq` if we're working w/Slepian tapers if taper.__name__ == "dpss": # Try to derive "sane" settings by using 3/4 octave smoothing of highest `foi` # following Hipp et al. "Oscillatory Synchronization in Large-Scale # Cortical Networks Predicts Perception", Neuron, 2011 if tapsmofrq is None: foimax = foi.max() tapsmofrq = (foimax * 2**(3 / 4 / 2) - foimax * 2**(-3 / 4 / 2)) / 2 else: try: scalar_parser(tapsmofrq, varname="tapsmofrq", lims=[1, np.inf]) except Exception as exc: raise exc # Get/compute number of tapers to use (at least 1 and max. 50) nTaper = taperopt.get("Kmax", 1) if not taperopt: nTaper = int( max( 2, min( 50, np.floor(tapsmofrq * minSampleNum * 1 / data.samplerate)))) taperopt = {"NW": tapsmofrq, "Kmax": nTaper} else: nTaper = 1 # Warn the user in case `tapsmofrq` has no effect if tapsmofrq is not None and taper.__name__ != "dpss": msg = "`tapsmofrq` is only used if `taper` is `dpss`!" SPYWarning(msg) # Update `log_dct` w/method-specific options (use `lcls` to get actually # provided keyword values, not defaults set in here) log_dct["taper"] = lcls["taper"] log_dct["tapsmofrq"] = lcls["tapsmofrq"] log_dct["nTaper"] = nTaper # Check for non-default values of options not supported by chosen method kwdict = {"wav": wav, "width": width} for name, kwarg in kwdict.items(): if kwarg is not lcls[name]: msg = "option `{}` has no effect in methods `mtmfft` and `mtmconvol`!" SPYWarning(msg.format(name)) # Now, prepare explicit compute-classes for chosen method if method == "mtmfft": # Check for non-default values of options not supported by chosen method kwdict = {"t_ftimwin": t_ftimwin, "toi": toi} for name, kwarg in kwdict.items(): if kwarg is not lcls[name]: msg = "option `{}` has no effect in method `mtmfft`!" SPYWarning(msg.format(name)) # Set up compute-class specestMethod = MultiTaperFFT(samplerate=data.samplerate, foi=foi, nTaper=nTaper, timeAxis=timeAxis, taper=taper, taperopt=taperopt, tapsmofrq=tapsmofrq, pad=pad, padtype=padtype, padlength=padlength, keeptapers=keeptapers, polyremoval=polyremoval, output_fmt=output) elif method == "mtmconvol": # Set up compute-class specestMethod = MultiTaperFFTConvol(soi, list(padBegin), list(padEnd), samplerate=data.samplerate, noverlap=noverlap, nperseg=nperseg, equidistant=equidistant, toi=toi, foi=foi, nTaper=nTaper, timeAxis=timeAxis, taper=taper, taperopt=taperopt, pad=pad, padtype=padtype, padlength=padlength, prepadlength=prepadlength, postpadlength=postpadlength, keeptapers=keeptapers, polyremoval=polyremoval, output_fmt=output) elif method == "wavelet": # Check for non-default values of `taper`, `tapsmofrq`, `keeptapers` and # `t_ftimwin` (set to 0 above) kwdict = { "taper": taper, "tapsmofrq": tapsmofrq, "keeptapers": keeptapers } for name, kwarg in kwdict.items(): if kwarg is not lcls[name]: msg = "option `{}` has no effect in method `wavelet`!" SPYWarning(msg.format(name)) if t_ftimwin != 0: msg = "option `t_ftimwin` has no effect in method `wavelet`!" SPYWarning(msg) # Check wavelet selection if wav not in availableWavelets: lgl = "'" + "or '".join(opt + "' " for opt in availableWavelets) raise SPYValueError(legal=lgl, varname="wav", actual=wav) if wav not in ["Morlet", "Paul"]: msg = "the chosen wavelet '{}' is real-valued and does not provide " +\ "any information about amplitude or phase of the data. This wavelet function " +\ "may be used to isolate peaks or discontinuities in the signal. " SPYWarning(msg.format(wav)) # Check for consistency of `width`, `order` and `wav` if wav == "Morlet": try: scalar_parser(width, varname="width", lims=[1, np.inf]) except Exception as exc: raise exc wfun = getattr(spywave, wav)(w0=width) else: if width != lcls["width"]: msg = "option `width` has no effect for wavelet '{}'" SPYWarning(msg.format(wav)) if wav == "Paul": try: scalar_parser(order, varname="order", lims=[4, np.inf], ntype="int_like") except Exception as exc: raise exc wfun = getattr(spywave, wav)(m=order) elif wav == "DOG": try: scalar_parser(order, varname="order", lims=[1, np.inf], ntype="int_like") except Exception as exc: raise exc wfun = getattr(spywave, wav)(m=order) else: if order is not None: msg = "option `order` has no effect for wavelet '{}'" SPYWarning(msg.format(wav)) wfun = getattr(spywave, wav)() # Process frequency selection (`toi` was taken care of above): `foilim` # selections are wrapped into `foi` thus the seemingly weird if construct # Note: SLURM workers don't like monkey-patching, so let's pretend # `get_optimal_wavelet_scales` is a class method by passing `wfun` as its # first argument if foi is None: scales = _get_optimal_wavelet_scales( wfun, int(minTrialLength * data.samplerate), 1 / data.samplerate) if foilim is not None: foi = np.arange(foilim[0], foilim[1] + 1) if foi is not None: foi[foi < 0.01] = 0.01 scales = wfun.scale_from_period(1 / foi) scales = scales[:: -1] # FIXME: this only makes sense if `foi` was sorted -> cf Issue #94 # Update `log_dct` w/method-specific options (use `lcls` to get actually # provided keyword values, not defaults set in here) log_dct["wav"] = lcls["wav"] log_dct["width"] = lcls["width"] log_dct["order"] = lcls["order"] # Set up compute-class specestMethod = WaveletTransform(preSelect, postSelect, list(padBegin), list(padEnd), samplerate=data.samplerate, toi=toi, scales=scales, timeAxis=timeAxis, wav=wfun, polyremoval=polyremoval, output_fmt=output) # If provided, make sure output object is appropriate if out is not None: try: data_parser(out, varname="out", writable=True, empty=True, dataclass="SpectralData", dimord=SpectralData().dimord) except Exception as exc: raise exc new_out = False else: out = SpectralData(dimord=SpectralData._defaultDimord) new_out = True # Perform actual computation specestMethod.initialize(data, chan_per_worker=kwargs.get("chan_per_worker"), keeptrials=keeptrials) specestMethod.compute(data, out, parallel=kwargs.get("parallel"), log_dict=log_dct) # Either return newly created output object or simply quit return out if new_out else None
def mtmconvol(trl_dat, soi, padbegin, padend, samplerate=None, noverlap=None, nperseg=None, equidistant=True, toi=None, foi=None, nTaper=1, timeAxis=0, taper=signal.windows.hann, taperopt={}, keeptapers=True, polyremoval=None, output_fmt="pow", noCompute=False, chunkShape=None): """ Perform time-frequency analysis on multi-channel time series data using a sliding window FFT Parameters ---------- trl_dat : 2D :class:`numpy.ndarray` Uniformly sampled multi-channel time-series soi : list of slices or slice Samples of interest; either a single slice encoding begin- to end-samples to perform analysis on (if sliding window centroids are equidistant) or list of slices with each slice corresponding to coverage of a single analysis window (if spacing between windows is not constant) padbegin : int Number of samples to pre-pend to `trl_dat` padend : int Number of samples to append to `trl_dat` samplerate : float Samplerate of `trl_dat` in Hz noverlap : int Number of samples covered by two adjacent analysis windows nperseg : int Size of analysis windows (in samples) equidistant : bool If `True`, spacing of window-centroids is equidistant. toi : 1D :class:`numpy.ndarray` or float or str Either time-points to center windows on if `toi` is a :class:`numpy.ndarray`, or percentage of overlap between windows if `toi` is a scalar or `"all"` to center windows on all samples in `trl_dat`. Please refer to :func:`~syncopy.freqanalysis` for further details. **Note**: The value of `toi` has to agree with provided padding and window settings. See Notes for more information. foi : 1D :class:`numpy.ndarray` Frequencies of interest (Hz) for output. If desired frequencies cannot be matched exactly the closest possible frequencies (respecting data length and padding) are used. nTaper : int Number of tapers to use timeAxis : int Index of running time axis in `trl_dat` (0 or 1) taper : callable Taper function to use, one of :data:`~syncopy.specest.freqanalysis.availableTapers` taperopt : dict Additional keyword arguments passed to `taper` (see above). For further details, please refer to the `SciPy docs <https://docs.scipy.org/doc/scipy/reference/signal.windows.html>`_ keeptapers : bool If `True`, results of Fourier transform are preserved for each taper, otherwise spectrum is averaged across tapers. polyremoval : int **FIXME: Not implemented yet** Order of polynomial used for de-trending. A value of 0 corresponds to subtracting the mean ("de-meaning"), ``polyremoval = 1`` removes linear trends (subtracting the least squares fit of a linear function), ``polyremoval = N`` for `N > 1` subtracts a polynomial of order `N` (``N = 2`` quadratic, ``N = 3`` cubic etc.). If `polyremoval` is `None`, no de-trending is performed. output_fmt : str Output of spectral estimation; one of :data:`~syncopy.specest.freqanalysis.availableOutputs` noCompute : bool Preprocessing flag. If `True`, do not perform actual calculation but instead return expected shape and :class:`numpy.dtype` of output array. chunkShape : None or tuple If not `None`, represents shape of output object `spec` (respecting provided values of `nTaper`, `keeptapers` etc.) Returns ------- spec : :class:`numpy.ndarray` Complex or real time-frequency representation of (padded) input data. Notes ----- This method is intended to be used as :meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction` inside a :class:`~syncopy.shared.computational_routine.ComputationalRoutine`. Thus, input parameters are presumed to be forwarded from a parent metafunction. Consequently, this function does **not** perform any error checking and operates under the assumption that all inputs have been externally validated and cross-checked. The computational heavy lifting in this code is performed by SciPy's Short Time Fourier Transform (STFT) implementation :func:`scipy.signal.stft`. See also -------- syncopy.freqanalysis : parent metafunction MultiTaperFFTConvol : :class:`~syncopy.shared.computational_routine.ComputationalRoutine` instance that calls this method as :meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction` scipy.signal.stft : SciPy's STFT implementation """ # Re-arrange array if necessary and get dimensional information if timeAxis != 0: dat = trl_dat.T # does not copy but creates view of `trl_dat` else: dat = trl_dat # Pad input array if necessary if padbegin > 0 or padend > 0: dat = padding(dat, "zero", pad="relative", padlength=None, prepadlength=padbegin, postpadlength=padend) # Get shape of output for dry-run phase nChannels = dat.shape[1] if isinstance(toi, np.ndarray): # `toi` is an array of time-points nTime = toi.size stftBdry = None stftPad = False else: # `toi` is either 'all' or a percentage nTime = np.ceil(dat.shape[0] / (nperseg - noverlap)).astype(np.intp) stftBdry = "zeros" stftPad = True nFreq = foi.size outShape = (nTime, max(1, nTaper * keeptapers), nFreq, nChannels) if noCompute: return outShape, spyfreq.spectralDTypes[output_fmt] # In case tapers aren't preserved allocate `spec` "too big" and average afterwards spec = np.full((nTime, nTaper, nFreq, nChannels), np.nan, dtype=spyfreq.spectralDTypes[output_fmt]) # Collect keyword args for `stft` in dictionary stftKw = { "fs": samplerate, "nperseg": nperseg, "noverlap": noverlap, "return_onesided": True, "boundary": stftBdry, "padded": stftPad, "axis": 0 } # Call `stft` w/first taper to get freq/time indices: transpose resulting `pxx` # to have a time x freq x channel array win = np.atleast_2d(taper(nperseg, **taperopt)) stftKw["window"] = win[0, :] if equidistant: freq, _, pxx = signal.stft(dat[soi, :], **stftKw) _, fIdx = best_match(freq, foi, squash_duplicates=True) spec[:, 0, ...] = \ spyfreq.spectralConversions[output_fmt]( pxx.transpose(2, 0, 1))[:nTime, fIdx, :] else: freq, _, pxx = signal.stft(dat[soi[0], :], **stftKw) _, fIdx = best_match(freq, foi, squash_duplicates=True) spec[0, 0, ...] = \ spyfreq.spectralConversions[output_fmt]( pxx.transpose(2, 0, 1).squeeze())[fIdx, :] for tk in range(1, len(soi)): spec[tk, 0, ...] = \ spyfreq.spectralConversions[output_fmt]( signal.stft( dat[soi[tk], :], **stftKw)[2].transpose(2, 0, 1).squeeze())[fIdx, :] # Compute FT using determined indices above for the remaining tapers (if any) for taperIdx in range(1, win.shape[0]): stftKw["window"] = win[taperIdx, :] if equidistant: spec[:, taperIdx, ...] = \ spyfreq.spectralConversions[output_fmt]( signal.stft( dat[soi, :], **stftKw)[2].transpose(2, 0, 1))[:nTime, fIdx, :] else: for tk, sample in enumerate(soi): spec[tk, taperIdx, ...] = \ spyfreq.spectralConversions[output_fmt]( signal.stft( dat[sample, :], **stftKw)[2].transpose(2, 0, 1).squeeze())[fIdx, :] # Average across tapers if wanted if not keeptapers: return np.nanmean(spec, axis=1, keepdims=True) return spec
def wavelet(trl_dat, preselect, postselect, padbegin, padend, samplerate=None, toi=None, scales=None, timeAxis=0, wav=None, polyremoval=None, output_fmt="pow", noCompute=False, chunkShape=None): """ Perform time-frequency analysis on multi-channel time series data using a wavelet transform Parameters ---------- trl_dat : 2D :class:`numpy.ndarray` Uniformly sampled multi-channel time-series preselect : slice Begin- to end-samples to perform analysis on (trim data to interval). See Notes for details. postselect : list of slices or list of 1D NumPy arrays Actual time-points of interest within interval defined by `preselect` See Notes for details. padbegin : int Number of samples to pre-pend to `trl_dat` padend : int Number of samples to append to `trl_dat` samplerate : float Samplerate of `trl_dat` in Hz toi : 1D :class:`numpy.ndarray` or str Either time-points to center wavelets on if `toi` is a :class:`numpy.ndarray`, or `"all"` to center wavelets on all samples in `trl_dat`. Please refer to :func:`~syncopy.freqanalysis` for further details. **Note**: The value of `toi` has to agree with provided padding values. See Notes for more information. scales : 1D :class:`numpy.ndarray` Set of scales to use in wavelet transform. timeAxis : int Index of running time axis in `trl_dat` (0 or 1) wav : callable Wavelet function to use, one of :data:`~syncopy.specest.freqanalysis.availableWavelets` polyremoval : int **FIXME: Not implemented yet** Order of polynomial used for de-trending. A value of 0 corresponds to subtracting the mean ("de-meaning"), ``polyremoval = 1`` removes linear trends (subtracting the least squares fit of a linear function), ``polyremoval = N`` for `N > 1` subtracts a polynomial of order `N` (``N = 2`` quadratic, ``N = 3`` cubic etc.). If `polyremoval` is `None`, no de-trending is performed. output_fmt : str Output of spectral estimation; one of :data:`~syncopy.specest.freqanalysis.availableOutputs` noCompute : bool Preprocessing flag. If `True`, do not perform actual calculation but instead return expected shape and :class:`numpy.dtype` of output array. chunkShape : None or tuple If not `None`, represents shape of output object `spec` (respecting provided values of `scales`, `preselect`, `postselect` etc.) Returns ------- spec : :class:`numpy.ndarray` Complex or real time-frequency representation of (padded) input data. Notes ----- This method is intended to be used as :meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction` inside a :class:`~syncopy.shared.computational_routine.ComputationalRoutine`. Thus, input parameters are presumed to be forwarded from a parent metafunction. Consequently, this function does **not** perform any error checking and operates under the assumption that all inputs have been externally validated and cross-checked. For wavelets, data concatenation is performed by first trimming `trl_dat` to an interval of interest (via `preselect`), then performing the actual wavelet transform, and subsequently extracting the actually wanted time-points (via `postselect`). See also -------- syncopy.freqanalysis : parent metafunction WaveletTransform : :class:`~syncopy.shared.computational_routine.ComputationalRoutine` instance that calls this method as :meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction` """ # Re-arrange array if necessary and get dimensional information if timeAxis != 0: dat = trl_dat.T # does not copy but creates view of `trl_dat` else: dat = trl_dat # Pad input array if wanted/necessary if padbegin > 0 or padend > 0: dat = padding(dat, "zero", pad="relative", padlength=None, prepadlength=padbegin, postpadlength=padend) # Get shape of output for dry-run phase nChannels = dat.shape[1] if isinstance(toi, np.ndarray): # `toi` is an array of time-points nTime = toi.size else: # `toi` is 'all' nTime = dat.shape[0] nScales = scales.size outShape = (nTime, 1, nScales, nChannels) if noCompute: return outShape, spyfreq.spectralDTypes[output_fmt] # Compute wavelet transform with given data/time-selection spec = cwt(dat[preselect, :], axis=0, wavelet=wav, widths=scales, dt=1 / samplerate).transpose(1, 0, 2)[postselect, :, :] return spyfreq.spectralConversions[output_fmt](spec[:, np.newaxis, :, :])