예제 #1
0
def cluster_cleanup(client=None):
    """
    Stop and close dangling parallel processing jobs
    
    Parameters
    ----------
    client : dask distributed computing client or None
        Either a concrete `dask client object <https://distributed.dask.org/en/latest/client.html>`_
        or `None`. If `None`, a global client is queried for and shut-down
        if found (without confirmation!). 
    
    Returns
    -------
    Nothing : None
    
    See also
    --------
    esi_cluster_setup : Launch a SLURM jobs on the ESI compute cluster
    """
    
    # For later reference: dynamically fetch name of current function
    funcName = "Syncopy <{}>".format(inspect.currentframe().f_code.co_name)
    
    # Attempt to establish connection to dask client
    if client is None:
        try:
            client = get_client()
        except ValueError:
            msg = "No dangling clients or clusters found."
            SPYWarning(msg)
            return
        except Exception as exc:
            raise exc
    else:
        if not isinstance(client, Client):
            raise SPYTypeError(client, varname="client", expected="dask client object")
    
    # Prepare message for prompt
    if client.cluster.__class__.__name__ == "LocalCluster":
        userClust = "LocalCluster hosted on {}".format(client.scheduler_info()["address"])
    else:
        userName = getpass.getuser()
        outDir = client.cluster.job_header.partition("--output=")[-1]
        jobID = outDir.partition("{}_".format(userName))[-1].split(os.sep)[0]
        userClust = "cluster {0}_{1}".format(userName, jobID)
    nWorkers = len(client.cluster.workers)
    
    # If connection was successful, first close the client, then the cluster
    client.close()
    client.cluster.close()
    
    # Communicate what just happened and get outta here
    msg = "{fname:s} Successfully shut down {cname:s} containing {nj:d} workers"
    print(msg.format(fname=funcName,
                     nj=nWorkers,
                     cname=userClust))

    return
예제 #2
0
    def parallel_client_detector(*args, **kwargs):

        # Extract `parallel` keyword: if `parallel` is `False`, nothing happens
        parallel = kwargs.get("parallel")

        # Detect if dask client is running and set `parallel` keyword accordingly
        results = []
        cleanup = False
        if parallel is None or parallel is True:
            if spy.__dask__:
                try:
                    dd.get_client()
                    parallel = True
                except ValueError:
                    if parallel is True:
                        objList = []
                        argList = list(args)
                        nTrials = 0
                        for arg in args:
                            if hasattr(arg, "trials"):
                                objList.append(arg)
                                nTrials = max(nTrials, len(arg.trials))
                                argList.remove(arg)
                        nObs = len(objList)
                        msg = "Syncopy <{fname:s}> Launching parallel computing client " +\
                            "to process {no:d} objects..."
                        print(msg.format(fname=func.__name__, no=nObs))
                        client = esi_cluster_setup(n_jobs=nTrials,
                                                   interactive=False)
                        cleanup = True
                        if len(objList) > 1:
                            for obj in objList:
                                results.append(func(obj, *argList, **kwargs))
                    else:
                        parallel = False
            else:
                wrng = \
                "dask seems not to be installed on this system. " +\
                "Parallel processing capabilities cannot be used. "
                SPYWarning(wrng)
                parallel = False

        # Add/update `parallel` to/in keyword args
        kwargs["parallel"] = parallel

        if len(results) == 0:
            results = func(*args, **kwargs)
        if cleanup:
            cluster_cleanup(client=client)

        return results
예제 #3
0
def validate_padding(pad_to_length, lenTrials):
    """
    Simplified padding
    """
    # supported padding options
    not_valid = False
    if not isinstance(pad_to_length, (Number, str, type(None))):
        not_valid = True
    elif isinstance(pad_to_length,
                    str) and pad_to_length not in availablePaddingOpt:
        not_valid = True
    if isinstance(pad_to_length,
                  bool):  # bool is an int subclass, check for it separately...
        not_valid = True
    if not_valid:
        lgl = "`None`, 'nextpow2' or an integer like number"
        actual = f"{pad_to_length}"
        raise SPYValueError(legal=lgl, varname="pad_to_length", actual=actual)

    # here we check for equal lengths trials in case of no user specified absolute padding length
    # we do a rough 'maxlen' padding, nextpow2 will be overruled in this case
    if lenTrials.min() != lenTrials.max() and not isinstance(
            pad_to_length, Number):
        abs_pad = int(lenTrials.max())
        msg = f"Unequal trial lengths present, automatic padding to {abs_pad} samples"
        SPYWarning(msg)

    # zero padding of ALL trials the same way
    if isinstance(pad_to_length, Number):

        scalar_parser(pad_to_length,
                      varname='pad_to_length',
                      ntype='int_like',
                      lims=[lenTrials.max(), np.inf])
        abs_pad = pad_to_length

    # or pad to optimal FFT lengths
    # (not possible for unequal lengths trials)
    elif pad_to_length == 'nextpow2':
        # after padding
        abs_pad = _nextpow2(int(lenTrials.min()))
    # no padding, equal lengths trials
    elif pad_to_length is None:
        abs_pad = int(lenTrials.max())

    # `abs_pad` is now the (soon to be padded) signal length in samples

    return abs_pad
예제 #4
0
def _anyplot(*data, overlay=None, method=None, **kwargs):
    """
    Local management routine that invokes respective class methods based on 
    caller (`obj.singlepanelplot` or `obj.multipanelplot`)
    
    This is an auxiliary method that is intended purely for internal use. Please
    refer to the user-exposed methods :func:`~syncopy.singlepanelplot` and/or
    :func:`~syncopy.multipanelplot` to actually generate plots of Syncopy data objects. 
    """

    # The only error-checking done in here: ensure `overlay` is Boolean and assert
    # `data` contains only non-empty Syncopy objects
    if not isinstance(overlay, bool):
        raise SPYTypeError(overlay, varname="overlay", expected="bool")
    for obj in data:
        try:
            data_parser(obj, varname="data", empty=False)
        except Exception as exc:
            raise exc
        # FIXME: while plotting is still WIP
        if obj.__class__.__name__ not in ["AnalogData", "SpectralData"]:
            errmsg = "Plotting currently only supported for `AnalogData` and `SpectralData` objects"
            raise NotImplementedError(errmsg)

    # See if figure was provided
    start = 0
    nData = len(data)
    fig = kwargs.pop("fig", None)
    if not overlay and fig is not None and nData > 1:
        msg = "User-provided figures not supported for non-overlay visualization " +\
            "of {} datasets. Supplied figure will not be used. "
        SPYWarning(msg.format(nData))
        fig = None

    # If we're overlaying, preserve initial figure object to plot over iteratively
    if overlay:
        fig = getattr(data[0], method)(fig=fig, **kwargs)
        start = 1
    figList = []
    for n in range(start, nData):
        figList.append(getattr(data[n], method)(fig=fig, **kwargs))

    # Return single figure object (if `overlay` is `True`) or list of mulitple figs
    if overlay:
        return fig
    return figList
예제 #5
0
def check_passed_kwargs(lcls, defaults, frontend_name):
    '''
    Catch additional kwargs passed to the frontends
    which have no effect
    '''

    # unpack **kwargs of frontend call which
    # might contain arbitrary kws passed from the user
    kw_dict = lcls.get("kwargs")

    # nothing to do..
    if not kw_dict:
        return

    relevant = list(kw_dict.keys())
    expected = [name for name in defaults]

    for name in relevant:
        if name not in expected:
            msg = f"option `{name}` has no effect in `{frontend_name}`!"
            SPYWarning(msg, caller=__name__.split('.')[-1])
예제 #6
0
def check_effective_parameters(CR, defaults, lcls):
    """
    For a given ComputationalRoutine, compare set parameters
    (*lcls*) with the accepted parameters and the frontend
    meta function *defaults* to warn if any ineffective parameters are set.

    Parameters
    ----------
    CR : :class:`~syncopy.shared.computational_routine.ComputationalRoutine
        Needs to have a `valid_kws` attribute
    defaults : dict
        Result of :func:`~syncopy.shared.tools.get_defaults`, the frontend
        parameter names plus values with default values
    lcls : dict
        Result of `locals()`, all names and values of the local (frontend-)name space
    """
    # list of possible parameter names of the CR
    expected = CR.valid_kws + ["parallel", "select"]
    relevant = [name for name in defaults if name not in generalParameters]
    for name in relevant:
        if name not in expected and (lcls[name] != defaults[name]):
            msg = f"option `{name}` has no effect in method `{CR.__name__}`!"
            SPYWarning(msg, caller=__name__.split('.')[-1])
예제 #7
0
def multipanelplot(self,
                   trials="all",
                   channels="all",
                   toilim=None,
                   avg_channels=False,
                   avg_trials=True,
                   title=None,
                   grid=None,
                   fig=None,
                   **kwargs):
    """
    Plot contents of :class:`~syncopy.AnalogData` objects using multi-panel figure(s)
    
    Please refer to :func:`syncopy.multipanelplot` for detailed usage information. 
    
    Examples
    --------
    Use :func:`~syncopy.tests.misc.generate_artificial_data` to create two synthetic 
    :class:`~syncopy.AnalogData` objects. 
    
    >>> from syncopy.tests.misc import generate_artificial_data
    >>> adata = generate_artificial_data(nTrials=10, nChannels=32) 
    >>> bdata = generate_artificial_data(nTrials=5, nChannels=16) 
    
    Show overview of first 5 channels, averaged across trials 2, 4, and 6:
    
    >>> fig = spy.multipanelplot(adata, channels=range(5), trials=[2, 4, 6])
   
    Overlay last 5 channels, averaged across trials 1, 3, 5: 

    >>> fig = spy.multipanelplot(adata, channels=range(27, 32), trials=[1, 3, 5], fig=fig)
    
    Do not average trials:
    
    >>> fig = spy.multipanelplot(adata, channels=range(27, 32), trials=[1, 3, 5], avg_trials=False)
    
    Plot `adata` and `bdata` simultaneously in two separate figures:
    
    >>> fig1, fig2 = spy.multipanelplot(adata, bdata, channels=range(5), overlay=False)
    
    Overlay `adata` and `bdata`; use channel and trial selections that are valid 
    for both datasets:
    
    >>> fig3 = spy.multipanelplot(adata, bdata, channels=range(5), trials=[1, 2, 3], avg_trials=False)
    
    See also
    --------
    syncopy.multipanelplot : visualize Syncopy data objects using multi-panel plots
    """

    # Collect input arguments in dict `inputArgs` and process them
    inputArgs = locals()
    inputArgs.pop("self")
    dimArrs, dimCounts, idx, timeIdx, chanIdx = _prep_analog_plots(
        self, "singlepanelplot", **inputArgs)
    (nTrials, nChan) = dimCounts
    (trList, chArr) = dimArrs

    # Get trial/channel count ("raw" plotting constitutes a special case)
    if trials is None:
        nTrials = 0
        if avg_trials:
            msg = "`trials` is `None` but `avg_trials` is `True`. " +\
                "Cannot perform trial averaging without trial specification - " +\
                "setting ``avg_trials = False``. "
            SPYWarning(msg)
            avg_trials = False
        if avg_channels:
            msg = "Averaging across channels w/o trial specifications results in " +\
                "single-panel plot. Please use `singlepanelplot` instead"
            SPYWarning(msg)
            return

    # If we're overlaying, ensure settings match up
    if hasattr(fig, "singlepanelplot"):
        lgl = "overlay-figure generated by `multipanelplot`"
        act = "figure generated by `singlepanelplot`"
        raise SPYValueError(legal=lgl,
                            varname="fig/singlepanelplot",
                            actual=act)
    if hasattr(fig, "nTrialPanels"):
        if nTrials != fig.nTrialPanels:
            lgl = "number of trials to plot matching existing panels in figure"
            act = "{} panels but {} trials for plotting".format(
                fig.nTrialPanels, nTrials)
            raise SPYValueError(legal=lgl,
                                varname="trials/figure panels",
                                actual=act)
        if avg_trials:
            lgl = "overlay of multi-trial plot"
            act = "trial averaging was requested for multi-trial plot overlay"
            raise SPYValueError(legal=lgl,
                                varname="trials/avg_trials",
                                actual=act)
        if trials is None:
            lgl = "`trials` to be not `None` to append to multi-trial plot"
            act = "multi-trial plot overlay was requested but `trials` is `None`"
            raise SPYValueError(legal=lgl,
                                varname="trials/overlay",
                                actual=act)
        if not avg_channels and not hasattr(fig, "chanOffsets"):
            lgl = "single-channel or channel-averages for appending to multi-trial plot"
            act = "multi-trial multi-channel plot overlay was requested"
            raise SPYValueError(legal=lgl,
                                varname="avg_channels/overlay",
                                actual=act)
    if hasattr(fig, "nChanPanels"):
        if nChan != fig.nChanPanels:
            lgl = "number of channels to plot matching existing panels in figure"
            act = "{} panels but {} channels for plotting".format(
                fig.nChanPanels, nChan)
            raise SPYValueError(legal=lgl,
                                varname="channels/figure panels",
                                actual=act)
        if avg_channels:
            lgl = "overlay of multi-channel plot"
            act = "channel averaging was requested for multi-channel plot overlay"
            raise SPYValueError(legal=lgl,
                                varname="channels/avg_channels",
                                actual=act)
        if not avg_trials:
            lgl = "overlay of multi-channel plot"
            act = "mulit-trial plot was requested for multi-channel plot overlay"
            raise SPYValueError(legal=lgl,
                                varname="channels/avg_trials",
                                actual=act)
    if hasattr(fig, "chanOffsets"):
        if avg_channels:
            lgl = "multi-channel plot"
            act = "channel averaging was requested for multi-channel plot overlay"
            raise SPYValueError(legal=lgl,
                                varname="channels/avg_channels",
                                actual=act)
        if nChan != len(fig.chanOffsets):
            lgl = "channel-count matching existing multi-channel panels in figure"
            act = "{} channels per panel but {} channels for plotting".format(
                len(fig.chanOffsets), nChan)
            raise SPYValueError(legal=lgl,
                                varname="channels/channels per panel",
                                actual=act)

    # Generic title for overlay figures
    overlayTitle = "Overlay of {} datasets"

    # Either construct subplot panel layout/vet provided layout or fetch existing
    if fig is None:

        # Determine no. of required panels
        if avg_trials and not avg_channels:
            npanels = nChan
        elif not avg_trials and avg_channels:
            npanels = nTrials
        elif not avg_trials and not avg_channels:
            npanels = int(nTrials == 0) * nChan + nTrials
        else:
            msg = "Averaging across both trials and channels results in " +\
                "single-panel plot. Please use `singlepanelplot` instead"
            SPYWarning(msg)
            return

        # Although, `_setup_figure` can call `_layout_subplot_panels` for us, we
        # need `nrow` and `ncol` below, so do it here
        if nTrials > 0:
            xLabel = "Time [s]"
        else:
            xLabel = "Samples"
        nrow = kwargs.get("nrow", None)
        ncol = kwargs.get("ncol", None)
        nrow, ncol = _layout_subplot_panels(npanels, nrow=nrow, ncol=ncol)
        fig, ax_arr = _setup_figure(npanels,
                                    nrow=nrow,
                                    ncol=ncol,
                                    xLabel=xLabel,
                                    grid=grid)
        fig.analogPlot = True

    # Get existing layout
    else:
        ax_arr = fig.get_axes()
        nrow, ncol = ax_arr[0].numRows, ax_arr[0].numCols

    # Panels correspond to channels
    if avg_trials and not avg_channels:

        # Ensure provided timing selection can actually be averaged (leverage
        # the fact that `toilim` selections exclusively generate slices)
        tLengths = _prep_toilim_avg(self)

        # Compute trial-averaged time-courses: 2D array with slice/list
        # selection does not require fancy indexing - no need to check this here
        pltArr = np.zeros((tLengths[0], nChan), dtype=self.data.dtype)
        for k, trlno in enumerate(trList):
            idx[timeIdx] = self._selection.time[k]
            pltArr += np.swapaxes(
                self._get_trial(trlno)[tuple(idx)], timeIdx, 0)
        pltArr /= nTrials

        # Cycle through channels and plot trial-averaged time-courses (time-
        # axis must be identical for all channels, set up `idx` just once)
        idx[timeIdx] = self._selection.time[0]
        time = self.time[trList[k]][self._selection.time[0]]
        for k, chan in enumerate(chArr):
            ax_arr[k].plot(time,
                           pltArr[:, k],
                           label=os.path.basename(self.filename))

        # If we're overlaying datasets, adjust panel- and sup-titles: include
        # legend in top-right axis (note: `ax_arr` is row-major flattened)
        if fig.objCount == 0:
            for k, chan in enumerate(chArr):
                ax_arr[k].set_title(chan, size=pltConfig["multiTitleSize"])
            fig.nChanPanels = nChan
            if title is None:
                if nTrials > 1:
                    title = "Average of {} trials".format(nTrials)
                else:
                    title = "Trial #{}".format(trList[0])
            fig.suptitle(title, size=pltConfig["singleTitleSize"])
        else:
            for k, chan in enumerate(chArr):
                ax_arr[k].set_title("{0}/{1}".format(ax_arr[k].get_title(),
                                                     chan))
            ax = ax_arr[ncol - 1]
            handles, labels = ax.get_legend_handles_labels()
            ax.legend(handles, labels)
            if title is None:
                title = overlayTitle.format(len(handles))
            fig.suptitle(title, size=pltConfig["singleTitleSize"])

    # Panels correspond to trials
    elif not avg_trials and avg_channels:

        # Cycle through panels to plot by-trial channel-averages
        for k, trlno in enumerate(trList):
            idx[timeIdx] = self._selection.time[k]
            time = self.time[trList[k]][self._selection.time[k]]
            ax_arr[k].plot(time,
                           self._get_trial(trlno)[tuple(idx)].mean(
                               axis=chanIdx).squeeze(),
                           label=os.path.basename(self.filename))

        # If we're overlaying datasets, adjust panel- and sup-titles: include
        # legend in top-right axis (note: `ax_arr` is row-major flattened)
        if fig.objCount == 0:
            for k, trlno in enumerate(trList):
                ax_arr[k].set_title("Trial #{}".format(trlno),
                                    size=pltConfig["multiTitleSize"])
            fig.nTrialPanels = nTrials
            if title is None:
                if nChan > 1:
                    title = "Average of {} channels".format(nChan)
                else:
                    title = chArr[0]
            fig.suptitle(title, size=pltConfig["singleTitleSize"])
        else:
            for k, trlno in enumerate(trList):
                ax_arr[k].set_title("{0}/#{1}".format(ax_arr[k].get_title(),
                                                      trlno))
            ax = ax_arr[ncol - 1]
            handles, labels = ax.get_legend_handles_labels()
            ax.legend(handles, labels)
            if title is None:
                title = overlayTitle.format(len(handles))
            fig.suptitle(title, size=pltConfig["singleTitleSize"])

    # Panels correspond to channels (if `trials` is `None`) otherwise trials
    elif not avg_trials and not avg_channels:

        # Plot each channel in separate panel
        if nTrials == 0:
            chanSec = np.arange(self.channel.size)[self._selection.channel]
            for k, chan in enumerate(chanSec):
                idx[chanIdx] = chan
                ax_arr[k].plot(self.data[tuple(idx)].squeeze(),
                               label=os.path.basename(self.filename))

            # If we're overlaying datasets, adjust panel- and sup-titles: include
            # legend in top-right axis (note: `ax_arr` is row-major flattened)
            if fig.objCount == 0:
                for k, chan in enumerate(chArr):
                    ax_arr[k].set_title(chan, size=pltConfig["multiTitleSize"])
                fig.nChanPanels = nChan
                if title is None:
                    title = "Entire Data Timecourse"
                fig.suptitle(title, size=pltConfig["singleTitleSize"])
            else:
                for k, chan in enumerate(chArr):
                    ax_arr[k].set_title("{0}/{1}".format(
                        ax_arr[k].get_title(), chan))
                ax = ax_arr[ncol - 1]
                handles, labels = ax.get_legend_handles_labels()
                ax.legend(handles, labels)
                if title is None:
                    title = overlayTitle.format(len(handles))
                fig.suptitle(title, size=pltConfig["singleTitleSize"])

        # Each trial gets its own panel w/multiple channels per panel
        else:

            # If required, compute max amplitude across provided trials + channels
            if not hasattr(fig, "chanOffsets"):
                maxAmps = np.zeros((nTrials, ), dtype=self.data.dtype)
                tickOffsets = maxAmps.copy()
                for k, trlno in enumerate(trList):
                    idx[timeIdx] = self._selection.time[k]
                    pltArr = np.abs(self._get_trial(trlno)[tuple(idx)])
                    maxAmps[k] = pltArr.max()
                    tickOffsets[k] = pltArr.mean()
                fig.chanOffsets = np.cumsum([0] + [maxAmps.max()] *
                                            (nChan - 1))
                fig.tickOffsets = fig.chanOffsets + tickOffsets.mean()

            # Cycle through panels to plot by-trial multi-channel time-courses
            for k, trlno in enumerate(trList):
                idx[timeIdx] = self._selection.time[k]
                time = self.time[trList[k]][self._selection.time[k]]
                pltArr = np.swapaxes(
                    self._get_trial(trlno)[tuple(idx)], timeIdx, 0)
                ax_arr[k].plot(
                    time, (pltArr + fig.chanOffsets.reshape(1, nChan)).reshape(
                        time.size, nChan),
                    color=plt.rcParams["axes.prop_cycle"].by_key()["color"][
                        fig.objCount],
                    label=os.path.basename(self.filename))

            # If we're overlaying datasets, adjust panel- and sup-titles: include
            # legend in top-right axis (note: `ax_arr` is row-major flattened)
            # Note: y-axis is shared across panels, so `yticks` need only be set once
            if fig.objCount == 0:
                for k, trlno in enumerate(trList):
                    ax_arr[k].set_title("Trial #{}".format(trlno),
                                        size=pltConfig["multiTitleSize"])
                ax_arr[0].set_yticks(fig.tickOffsets)
                ax_arr[0].set_yticklabels(chArr)
                fig.nTrialPanels = nTrials
                if title is None:
                    if nChan > 1:
                        title = "{} channels".format(nChan)
                    else:
                        title = chArr[0]
                fig.suptitle(title, size=pltConfig["singleTitleSize"])
            else:
                for k, trlno in enumerate(trList):
                    ax_arr[k].set_title("{0}/#{1}".format(
                        ax_arr[k].get_title(), trlno))
                ax_arr[0].set_yticklabels([" "] * chArr.size)
                ax = ax_arr[ncol - 1]
                handles, labels = ax.get_legend_handles_labels()
                ax.legend(handles[::(nChan + 1)], labels[::(nChan + 1)])
                if title is None:
                    title = overlayTitle.format(len(handles))
                fig.suptitle(title, size=pltConfig["singleTitleSize"])

    # Increment overlay-counter, draw figure and wipe data-selection slot
    fig.objCount += 1
    plt.draw()
    self._selection = None
    return fig
예제 #8
0
    def compute(self, data, out, parallel=False, parallel_store=None,
                method=None, mem_thresh=0.5, log_dict=None, parallel_debug=False):
        """
        Central management and processing method

        Parameters
        ----------
        data : syncopy data object
           Syncopy data object to be processed (has to be the same object
           that was used by :meth:`initialize` in the pre-calculation 
           dry-run). 
        out : syncopy data object
           Empty object for holding results
        parallel : bool
           If `True`, processing is performed in parallel (i.e., 
           :meth:`computeFunction` is executed concurrently across trials). 
           If `parallel` is `False`, :meth:`computeFunction` is executed 
           consecutively trial after trial (i.e., the calculation realized
           in :meth:`computeFunction` is performed sequentially). 
        parallel_store : None or bool
           Flag controlling saving mechanism. If `None`,
           ``parallel_store = parallel``, i.e., the compute-paradigm 
           dictates the employed writing method. Thus, in case of parallel
           processing, results are written in a fully concurrent
           manner (each worker saves its own local result segment on disk as 
           soon as it is done with its part of the computation). If `parallel_store`
           is `False` and `parallel` is `True` the processing result is saved 
           sequentially using a mutex. If both `parallel` and `parallel_store`
           are `False` standard single-process HDF5 writing is employed for
           saving the result of the (sequential) computation. 
        method : None or str
           If `None` the predefined methods :meth:`compute_parallel` or
           :meth:`compute_sequential` are used to control the actual computation
           (specifically, calling :meth:`computeFunction`) depending on whether
           `parallel` is `True` or `False`, respectively. If `method` is a 
           string, it has to specify the name of an alternative (provided) 
           class method that is invoked using `getattr`.
        mem_thresh : float
           Fraction of available memory required to perform computation. By
           default, the largest single trial result must not occupy more than
           50% (``mem_thresh = 0.5``) of available single-machine or worker
           memory (if `parallel` is `False` or `True`, respectively).
        log_dict : None or dict
           If `None`, the `log` properties of `out` is populated with the employed 
           keyword arguments used in :meth:`computeFunction`. 
           Otherwise, `out`'s `log` properties are filled  with items taken 
           from `log_dict`. 
        parallel_debug : bool
           If `True`, concurrent processing is performed using a single-threaded
           scheduler, i.e., all parallel computing task are run in the current
           Python thread permitting usage of tools like `pdb`/`ipdb`, `cProfile` 
           and the like in :meth:`computeFunction`.  
           Note that enabling parallel debugging effectively runs the given computation 
           on the calling local machine thereby requiring sufficient memory and 
           CPU capacity. 
        
        Returns
        -------
        Nothing : None
           The result of the computation is available in `out` once 
           :meth:`compute` terminated successfully. 

        Notes
        -----
        This routine calls several other class methods to perform all necessary
        pre- and post-processing steps in a fully automatic manner without
        requiring any user-input. Specifically, the following class methods
        are invoked consecutively (in the given order):
        
        1. :meth:`preallocate_output` allocates a (virtual) HDF5 dataset 
           of appropriate dimension for storing the result
        2. :meth:`compute_parallel` (or :meth:`compute_sequential`) performs
           the actual computation via concurrently (or sequentially) calling
           :meth:`computeFunction`
        3. :meth:`process_metadata` attaches all relevant meta-information to
           the result `out` after successful termination of the calculation
        4. :meth:`write_log` stores employed input arguments in `out.cfg`
           and `out.log` to reproduce all relevant computational steps that 
           generated `out`. 
        
        See also
        --------
        initialize : pre-calculation preparations
        preallocate_output : storage provisioning
        compute_parallel : concurrent computation using :meth:`computeFunction`
        compute_sequential : sequential computation using :meth:`computeFunction`
        process_metadata : management of meta-information
        write_log : log-entry organization
        """

        # By default, use VDS storage for parallel computing
        if parallel_store is None:
            parallel_store = parallel
            
        # Do not spill trials on disk if they're supposed to be removed anyway
        if parallel_store and not self.keeptrials:
            msg = "trial-averaging only supports sequential writing!"
            SPYWarning(msg)
            parallel_store = False

        # Concurrent processing requires some additional prep-work...
        if parallel:

            # First and foremost, make sure a dask client is accessible
            try:
                client = dd.get_client()
            except ValueError as exc:
                msg = "parallel computing client: {}"
                raise SPYIOError(msg.format(exc.args[0]))
            
            # Check if the underlying cluster hosts actually usable workers
            if not len(client.cluster.workers):
                raise SPYParallelError("No active workers found in distributed computing cluster",
                                       client=client)

            # Note: `dask_jobqueue` may not be available even if `__dask__` is `True`,
            # hence the `__name__` shenanigans instead of a simple `isinstance`
            if isinstance(client.cluster, dd.LocalCluster):
                memAttr = "memory_limit"
            elif client.cluster.__class__.__name__ == "SLURMCluster":
                memAttr = "worker_memory"
            else:
                msg = "`ComputationalRoutine` only supports `LocalCluster` and " +\
                    "`SLURMCluster` dask cluster objects. Proceed with caution. "
                SPYWarning(msg)
                memAttr = None

            # Check if trials actually fit into memory before we start computation
            if memAttr:
                wrk_size = max(getattr(wrkr, memAttr) for wrkr in client.cluster.workers.values())
                if self.chunkMem >= mem_thresh * wrk_size:
                    self.chunkMem /= 1024**3
                    wrk_size /= 1000**3
                    msg = "Single-trial result sizes ({0:2.2f} GB) larger than available " +\
                        "worker memory ({1:2.2f} GB) currently not supported"
                    raise NotImplementedError(msg.format(self.chunkMem, wrk_size))

            # In some cases distributed dask workers suffer from spontaneous
            # dementia and forget the `sys.path` of their parent process. Fun!
            def init_syncopy(dask_worker):
                spy_path = os.path.abspath(os.path.split(__path__[0])[0])
                if spy_path not in sys.path:
                    sys.path.insert(0, spy_path)
            client.register_worker_callbacks(init_syncopy)
            
            # Store provided debugging state
            self.parallelDebug = parallel_debug

        # For sequential processing, just ensure enough memory is available
        else:
            mem_size = psutil.virtual_memory().available
            if self.chunkMem >= mem_thresh * mem_size:
                self.chunkMem /= 1024**3
                mem_size /= 1024**3
                msg = "Single-trial result sizes ({0:2.2f} GB) larger than available " +\
                      "memory ({1:2.2f} GB) currently not supported"
                raise NotImplementedError(msg.format(self.chunkMem, mem_size))

        # Create HDF5 dataset of appropriate dimension
        self.preallocate_output(out, parallel_store=parallel_store)
        
        # The `method` keyword can be used to override the `parallel` flag
        if method is None:
            if parallel:
                computeMethod = self.compute_parallel
            else:
                computeMethod = self.compute_sequential
        else:
            computeMethod = getattr(self, "compute_" + method, None)

        # Ensure `data` is openend read-only to permit (potentially concurrent) 
        # reading access to backing device on disk
        data.mode = "r"
        
        # Take care of `VirtualData` objects 
        self.hdr = getattr(data, "hdr", None)
            
        # Perform actual computation
        computeMethod(data, out)

        # Reset data access mode
        data.mode = self.dataMode
        
        # Attach computed results to output object
        out.data = h5py.File(out.filename, mode="r+")[self.datasetName]

        # Store meta-data, write log and get outta here
        self.process_metadata(data, out)
        self.write_log(data, out, log_dict)
예제 #9
0
    def initialize(self, data, chan_per_worker=None, keeptrials=True):
        """
        Perform dry-run of calculation to determine output shape

        Parameters
        ----------
        data : syncopy data object
           Syncopy data object to be processed (has to be the same object 
           that is passed to :meth:`compute` for the actual calculation). 
        chan_per_worker : None or int
           Number of channels to be processed by each worker (only relevant in
           case of concurrent processing). If `chan_per_worker` is `None` (default) 
           by-trial parallelism is used, i.e., each worker processes 
           data corresponding to a full trial. If `chan_per_worker > 0`, trials 
           are split into channel-groups of size `chan_per_worker` (+ rest if the 
           number of channels is not divisible by `chan_per_worker` without 
           remainder) and workers are assigned by-trial channel-groups for 
           processing. 
        keeptrials : bool
            Flag indicating whether to return individual trials or average
        
        Returns
        -------
        Nothing : None
        
        Notes
        -----
        This class method **has** to be called prior to performing the actual
        computation realized in :meth:`computeFunction`. 

        See also
        --------
        compute : core routine performing the actual computation
        """
        
        # First store `keeptrial` keyword value (important for output shapes below)
        self.keeptrials = keeptrials
        
        # Determine if data-selection was provided; if so, extract trials and check
        # whether selection requires fancy array indexing
        if data._selection is not None:
            self.trialList = data._selection.trials
            self.useFancyIdx = data._selection._useFancy
        else:
            self.trialList = list(range(len(data.trials)))
            self.useFancyIdx = False
        numTrials = len(self.trialList)

        # If lists/tuples are in positional arguments, ensure `len == numTrials`
        # Scalars are duplicated to fit trials, e.g., ``self.argv = [3, [0, 1, 1]]``
        # then ``argv = [[3, 3, 3], [0, 1, 1]]``
        for ak, arg in enumerate(self.argv):

            # Ensure arguments are within reasonable size for distribution across workers
            # (protect against circular object references by imposing max. calls)
            self._callCount = 0
            argsize = self._sizeof(arg)
            if argsize > self._maxArgSize:
                lgl = "positional arguments less than 100 MB each"
                act = "positional argument with memory footprint of {0:4.2f} MB"
                raise SPYValueError(legal=lgl, varname="argv", actual=act.format(argsize))
            
            if isinstance(arg, (list, tuple)):
                if len(arg) != numTrials:
                    lgl = "list/tuple of positional arguments for each trial"
                    act = "length of list/tuple does not correspond to number of trials"
                    raise SPYValueError(legal=lgl, varname="argv", actual=act)
                continue
            elif isinstance(arg, np.ndarray):
                if arg.size == numTrials:
                    msg = "found NumPy array with size == #Trials. " +\
                        "Regardless, every worker will receive an identical copy " +\
                        "of this array. To propagate elements across workers, use " +\
                        "a list or tuple instead!"
                    SPYWarning(msg)
            self.argv[ak] = [arg] * numTrials
                
        # Prepare dryrun arguments and determine geometry of trials in output
        dryRunKwargs = copy(self.cfg)
        dryRunKwargs["noCompute"] = True
        chk_list = []
        dtp_list = []
        trials = []
        for tk, trialno in enumerate(self.trialList):
            trial = data._preview_trial(trialno)
            trlArg = tuple(arg[tk] for arg in self.argv)
            chunkShape, dtype = self.computeFunction(trial, 
                                                     *trlArg,
                                                     **dryRunKwargs)
            chk_list.append(list(chunkShape))
            dtp_list.append(dtype)
            trials.append(trial)
            
        # The aggregate shape is computed as max across all chunks                    
        chk_arr = np.array(chk_list)
        if np.unique(chk_arr[:, 0]).size > 1 and not self.keeptrials:
            err = "Averaging trials of unequal lengths in output currently not supported!"
            raise NotImplementedError(err)
        if np.any([dtp_list[0] != dtp for dtp in dtp_list]):
            lgl = "unique output dtype"
            act = "{} different output dtypes".format(np.unique(dtp_list).size)
            raise SPYValueError(legal=lgl, varname="dtype", actual=act)
        chunkShape = tuple(chk_arr.max(axis=0))
        self.outputShape = (chk_arr[:, 0].sum(),) + chunkShape[1:]
        self.cfg["chunkShape"] = chunkShape
        self.dtype = np.dtype(dtp_list[0])

        # Ensure channel parallelization can be done at all
        if chan_per_worker is not None and "channel" not in data.dimord:
            msg = "input object does not contain `channel` dimension for parallelization!"
            SPYWarning(msg)
            chan_per_worker = None
        if chan_per_worker is not None and self.keeptrials is False:
            msg = "trial-averaging does not support channel-block parallelization!"
            SPYWarning(msg)
            chan_per_worker = None
        if data._selection is not None:
            if chan_per_worker is not None and data._selection.channel != slice(None, None, 1):
                msg = "channel selection and simultaneous channel-block " +\
                    "parallelization not yet supported!"
                SPYWarning(msg)
                chan_per_worker = None
            
        # Allocate control variables
        trial = trials[0]
        trlArg0 = tuple(arg[0] for arg in self.argv)
        chunkShape0 = chk_arr[0, :]
        lyt = [slice(0, stop) for stop in chunkShape0]
        sourceLayout = []
        targetLayout = []
        targetShapes = []
        ArgV = []

        # If parallelization across channels is requested the first trial is 
        # split up into several chunks that need to be processed/allocated
        if chan_per_worker is not None:

            # Set up channel-chunking
            nChannels = data.channel.size
            rem = int(nChannels % chan_per_worker)
            n_blocks = [chan_per_worker] * int(nChannels//chan_per_worker) + [rem] * int(rem > 0)        
            inchanidx = data.dimord.index("channel")
            
            # Perform dry-run w/first channel-block of first trial to identify 
            # changes in output shape w.r.t. full-trial output (`chunkShape`)
            shp = list(trial.shape)
            idx = list(trial.idx)
            shp[inchanidx] = n_blocks[0]
            idx[inchanidx] = slice(0, n_blocks[0])
            trial.shape = tuple(shp)
            trial.idx = tuple(idx)
            res, _ = self.computeFunction(trial, *trlArg0, **dryRunKwargs)
            outchan = [dim for dim in res if dim not in chunkShape0]
            if len(outchan) != 1:
                lgl = "exactly one output dimension to scale w/channel count"
                act = "{0:d} dimensions affected by varying channel count".format(len(outchan))
                raise SPYValueError(legal=lgl, varname="chan_per_worker", actual=act)
            outchanidx = res.index(outchan[0])            
            
            # Get output chunks and grid indices for first trial
            chanstack = 0
            blockstack = 0
            for block in n_blocks:
                shp = list(trial.shape)
                idx = list(trial.idx)
                shp[inchanidx] = block
                idx[inchanidx] = slice(blockstack, blockstack + block)
                trial.shape = tuple(shp)
                trial.idx = tuple(idx)
                res, _ = self.computeFunction(trial, *trlArg0, **dryRunKwargs)
                lyt[outchanidx] = slice(chanstack, chanstack + res[outchanidx])
                targetLayout.append(tuple(lyt))
                targetShapes.append(tuple([slc.stop - slc.start for slc in lyt]))
                sourceLayout.append(trial.idx)
                ArgV.append(trlArg0)
                chanstack += res[outchanidx]
                blockstack += block

        # Simple: consume all channels simultaneously, i.e., just take the entire trial
        else:
            targetLayout.append(tuple(lyt))
            targetShapes.append(chunkShape0)
            sourceLayout.append(trial.idx)
            ArgV.append(trlArg0)
            
        # Construct dimensional layout of output
        stacking = targetLayout[0][0].stop
        for tk in range(1, len(self.trialList)):
            trial = trials[tk]
            trlArg = tuple(arg[tk] for arg in self.argv)
            chkshp = chk_list[tk]
            lyt = [slice(0, stop) for stop in chkshp]
            lyt[0] = slice(stacking, stacking + chkshp[0])
            stacking += chkshp[0]
            if chan_per_worker is None:
                targetLayout.append(tuple(lyt))
                targetShapes.append(tuple([slc.stop - slc.start for slc in lyt]))
                sourceLayout.append(trial.idx)
                ArgV.append(trlArg)
            else:
                chanstack = 0
                blockstack = 0
                for block in n_blocks:
                    shp = list(trial.shape)
                    idx = list(trial.idx)
                    shp[inchanidx] = block
                    idx[inchanidx] = slice(blockstack, blockstack + block)
                    trial.shape = tuple(shp)
                    trial.idx = tuple(idx)
                    res, _ = self.computeFunction(trial, *trlArg, **dryRunKwargs) # FauxTrial
                    lyt[outchanidx] = slice(chanstack, chanstack + res[outchanidx])
                    targetLayout.append(tuple(lyt))
                    targetShapes.append(tuple([slc.stop - slc.start for slc in lyt]))
                    sourceLayout.append(trial.idx)
                    chanstack += res[outchanidx]
                    blockstack += block
                    ArgV.append(trlArg)
                    
        # If the determined source layout contains unordered lists and/or index 
        # repetitions, set `self.useFancyIdx` to `True` and prepare a separate
        # `sourceSelectors` list that is used in addition to `sourceLayout` for 
        # data extraction. 
        # In this case `sourceLayout` uses ABSOLUTE indices (indices wrt to size 
        # of ENTIRE DATASET) that are SORTED W/O REPS to extract a NumPy array 
        # of appropriate size from HDF5. 
        # Then `sourceLayout` uses RELATIVE indices (indices wrt to size of CURRENT 
        # TRIAL) that can be UNSORTED W/REPS to actually perform the requested 
        # selection on the NumPy array extracted w/`sourceLayout`. 
        for grd in sourceLayout:
            if any([np.diff(sel).min() <= 0 if isinstance(sel, list) 
                    and len(sel) > 1 else False for sel in grd]):
                self.useFancyIdx = True 
                break
        if self.useFancyIdx:
            sourceSelectors = []
            for gk, grd in enumerate(sourceLayout):
                ingrid = list(grd)
                sigrid = []
                for sk, sel in enumerate(grd):
                    if isinstance(sel, list):
                        selarr = np.array(sel, dtype=np.intp)
                    else: # sel is a slice
                        step = sel.step
                        if sel.step is None:
                            step = 1
                        selarr = np.array(list(range(sel.start, sel.stop, step)), dtype=np.intp)
                    if selarr.size > 0:
                        sigrid.append(np.array(selarr) - selarr.min())
                        ingrid[sk] = slice(selarr.min(), selarr.max() + 1, 1)
                    else:
                        sigrid.append([])
                        ingrid[sk] = []
                sourceSelectors.append(tuple(sigrid))
                sourceLayout[gk] = tuple(ingrid)
        else:
            sourceSelectors = [Ellipsis] * len(sourceLayout)
            
        # Store determined shapes and grid layout
        self.sourceLayout = sourceLayout
        self.sourceSelectors = sourceSelectors
        self.targetLayout = targetLayout
        self.targetShapes = targetShapes
        self.ArgV = ArgV

        # Compute max. memory footprint of chunks
        if chan_per_worker is None:
            self.chunkMem = np.prod(self.cfg["chunkShape"]) * self.dtype.itemsize
        else:
            self.chunkMem = max([np.prod(shp) for shp in self.targetShapes]) * self.dtype.itemsize
        
        # Get data access mode (only relevant for parallel reading access)
        self.dataMode = data.mode
예제 #10
0
def freqanalysis(data, method='mtmfft', output='fourier',
                 keeptrials=True, foi=None, foilim=None,
                 pad_to_length=None, polyremoval=None,
                 taper="hann", tapsmofrq=None, nTaper=None, keeptapers=False,
                 toi="all", t_ftimwin=None, wavelet="Morlet", width=6, order=None,
                 order_max=None, order_min=1, c_1=3, adaptive=False,
                 out=None, **kwargs):
    """
    Perform (time-)frequency analysis of Syncopy :class:`~syncopy.AnalogData` objects

    **Usage Summary**

    Options available in all analysis methods:

    * **output** : one of :data:`~syncopy.specest.const_def.availableOutputs`;
      return power spectra, complex Fourier spectra or absolute values.
    * **foi**/**foilim** : frequencies of interest; either array of frequencies or
      frequency window (not both)
    * **keeptrials** : return individual trials or grand average
    * **polyremoval** : de-trending method to use (0 = mean, 1 = linear or `None`)

    List of available analysis methods and respective distinct options:

    "mtmfft" : (Multi-)tapered Fourier transform
        Perform frequency analysis on time-series trial data using either a single
        taper window (Hanning) or many tapers based on the discrete prolate
        spheroidal sequence (DPSS) that maximize energy concentration in the main
        lobe.

        * **taper** : one of :data:`~syncopy.shared.const_def.availableTapers`
        * **tapsmofrq** : spectral smoothing box for slepian tapers (in Hz)
        * **nTaper** : number of orthogonal tapers for slepian tapers
        * **keeptapers** : return individual tapers or average
        * **pad_to_length**: either pad to an absolute length or set to `'nextpow2'`

    "mtmconvol" : (Multi-)tapered sliding window Fourier transform
        Perform time-frequency analysis on time-series trial data based on a sliding
        window short-time Fourier transform using either a single Hanning taper or
        multiple DPSS tapers.

        * **taper** : one of :data:`~syncopy.specest.const_def.availableTapers`
        * **tapsmofrq** : spectral smoothing box for slepian tapers (in Hz)
        * **nTaper** : number of orthogonal tapers for slepian tapers
        * **keeptapers** : return individual tapers or average
        * **toi** : time-points of interest; can be either an array representing
          analysis window centroids (in sec), a scalar between 0 and 1 encoding
          the percentage of overlap between adjacent windows or "all" to center
          a window on every sample in the data.
        * **t_ftimwin** : sliding window length (in sec)

    "wavelet" : (Continuous non-orthogonal) wavelet transform
        Perform time-frequency analysis on time-series trial data using a non-orthogonal
        continuous wavelet transform.

        * **wavelet** : one of :data:`~syncopy.specest.const_def.availableWavelets`
        * **toi** : time-points of interest; can be either an array representing
          time points (in sec) or "all"(pre-trimming and subsampling of results)
        * **width** : Nondimensional frequency constant of Morlet wavelet function (>= 6)
        * **order** : Order of Paul wavelet function (>= 4) or derivative order
          of real-valued DOG wavelets (2 = mexican hat)

    "superlet" : Superlet transform
        Perform time-frequency analysis on time-series trial data using
        the super-resolution superlet transform (SLT) from [Moca2021]_.

        * **order_max** : Maximal order of the superlet
        * **order_min** : Minimal order of the superlet
        * **c_1** : Number of cycles of the base Morlet wavelet
        * **adaptive** : If set to `True` perform fractional adaptive SLT,
          otherwise perform multiplicative SLT

    **Full documentation below**

    Parameters
    ----------
    data : `~syncopy.AnalogData`
        A non-empty Syncopy :class:`~syncopy.datatype.AnalogData` object
    method : str
        Spectral estimation method, one of :data:`~syncopy.specest.const_def.availableMethods`
        (see below).
    output : str
        Output of spectral estimation. One of :data:`~syncopy.specest.const_def.availableOutputs` (see below);
        use `'pow'` for power spectrum (:obj:`numpy.float32`), `'fourier'` for complex
        Fourier coefficients (:obj:`numpy.complex64`) or `'abs'` for absolute
        values (:obj:`numpy.float32`).
    keeptrials : bool
        If `True` spectral estimates of individual trials are returned, otherwise
        results are averaged across trials.
    foi : array-like or None
        Frequencies of interest (Hz) for output. If desired frequencies cannot be
        matched exactly, the closest possible frequencies are used. If `foi` is `None`
        or ``foi = "all"``, all attainable frequencies (i.e., zero to Nyquist / 2)
        are selected.
    foilim : array-like (floats [fmin, fmax]) or None or "all"
        Frequency-window ``[fmin, fmax]`` (in Hz) of interest. Window
        specifications must be sorted (e.g., ``[90, 70]`` is invalid) and not NaN
        but may be unbounded (e.g., ``[-np.inf, 60.5]`` is valid). Edges `fmin`
        and `fmax` are included in the selection. If `foilim` is `None` or
        ``foilim = "all"``, all frequencies are selected.
    pad_to_length : int, None or 'nextpow2'
        Padding of the input data, if set to a number pads all trials
        to this absolute length. For instance ``pad_to_length = 2000`` pads all
        trials to an absolute length of 2000 samples, if and only if the longest
        trial contains at maximum 2000 samples.
        Alternatively if all trials have the same initial lengths
        setting `pad_to_length='nextpow2'` pads all trials to
        the next power of two.
        If `None` and trials have unequal lengths all trials are padded to match
        the longest trial.
    polyremoval : int or None
        Order of polynomial used for de-trending data in the time domain prior
        to spectral analysis. A value of 0 corresponds to subtracting the mean
        ("de-meaning"), ``polyremoval = 1`` removes linear trends (subtracting the
        least squares fit of a linear polynomial).
        If `polyremoval` is `None`, no de-trending is performed. Note that
        for spectral estimation de-meaning is very advisable and hence also the
        default.
    taper : str
        Only valid if `method` is `'mtmfft'` or `'mtmconvol'`. Windowing function,
        one of :data:`~syncopy.specest.const_def.availableTapers` (see below).
    tapsmofrq : float
        Only valid if `method` is `'mtmfft'` or `'mtmconvol'` and `taper` is `'dpss'`.
        The amount of spectral smoothing through  multi-tapering (Hz).
        Note that smoothing frequency specifications are one-sided,
        i.e., 4 Hz smoothing means plus-minus 4 Hz, i.e., a 8 Hz smoothing box.
    nTaper : int or None
        Only valid if `method` is `'mtmfft'` or `'mtmconvol'` and `taper='dpss'`.
        Number of orthogonal tapers to use. It is not recommended to set the number
        of tapers manually! Leave at `None` for the optimal number to be set automatically.
    keeptapers : bool
        Only valid if `method` is `'mtmfft'` or `'mtmconvol'`.
        If `True`, return spectral estimates for each taper.
        Otherwise power spectrum is averaged across tapers,
        if and only if `output` is `pow`.
    toi : float or array-like or "all"
        **Mandatory input** for time-frequency analysis methods (`method` is either
        `"mtmconvol"` or `"wavelet"` or `"superlet"`).
        If `toi` is scalar, it must be a value between 0 and 1 indicating the
        percentage of overlap between time-windows specified by `t_ftimwin` (only
        valid if `method` is `'mtmconvol'`).
        If `toi` is an array it explicitly selects the centroids of analysis
        windows (in seconds), if `toi` is `"all"`, analysis windows are centered
        on all samples in the data for `method="mtmconvol"`. For wavelet based
        methods (`"wavelet"` or `"superlet"`) toi needs to be either an
        equidistant array of time points or "all".
    t_ftimwin : positive float
        Only valid if `method` is `'mtmconvol'`. Sliding window length (in seconds).
    wavelet : str
        Only valid if `method` is `'wavelet'`. Wavelet function to use, one of
        :data:`~syncopy.specest.const_def.availableWavelets` (see below).
    width : positive float
        Only valid if `method` is `'wavelet'` and `wavelet` is `'Morlet'`. Nondimensional
        frequency constant of Morlet wavelet function. This number should be >= 6,
        which corresponds to 6 cycles within the analysis window to ensure sufficient
        spectral sampling.
    order : positive int
        Only valid if `method` is `'wavelet'` and `wavelet` is `'Paul'` or `'DOG'`. Order
        of the wavelet function. If `wavelet` is `'Paul'`, `order` should be chosen
        >= 4 to ensure that the analysis window contains at least a single oscillation.
        At an order of 40, the Paul wavelet  exhibits about the same number of cycles
        as the Morlet wavelet with a `width` of 6.
        All other supported wavelets functions are *real-valued* derivatives of
        Gaussians (DOGs). Hence, if `wavelet` is `'DOG'`, `order` represents the derivative order.
        The special case of a second order DOG yields a function known as "Mexican Hat",
        "Marr" or "Ricker" wavelet, which can be selected alternatively by setting
        `wavelet` to `'Mexican_hat'`, `'Marr'` or `'Ricker'`. **Note**: A real-valued
        wavelet function encodes *only* information about peaks and discontinuities
        in the signal and does *not* provide any information about amplitude or phase.
    order_max : int
        Only valid if `method` is `'superlet'`.
        Maximal order of the superlet set. Controls the maximum
        number of cycles within a SL together
        with the `c_1` parameter: c_max = c_1 * order_max
    order_min : int
        Only valid if `method` is `'superlet'`.
        Minimal order of the superlet set. Controls
        the minimal number of cycles within a SL together
        with the `c_1` parameter: c_min = c_1 * order_min
        Note that for admissability reasons c_min should be at least 3!
    c_1 : int
        Only valid if `method` is `'superlet'`.
        Number of cycles of the base Morlet wavelet. If set to lower
        than 3 increase `order_min` as to never have less than 3 cycles
        in a wavelet!
    adaptive : bool
        Only valid if `method` is `'superlet'`.
        Wether to perform multiplicative SLT or fractional adaptive SLT.
        If set to True, the order of the wavelet set will increase
        linearly with the frequencies of interest from `order_min`
        to `order_max`. If set to False the same SL will be used for
        all frequencies.
    out : None or :class:`SpectralData` object
        None if a new :class:`SpectralData` object is to be created, or an empty :class:`SpectralData` object


    Returns
    -------
    spec : :class:`~syncopy.SpectralData`
        (Time-)frequency spectrum of input data

    Notes
    -----
    .. [Moca2021] Moca, Vasile V., et al. "Time-frequency super-resolution with superlets."
       Nature communications 12.1 (2021): 1-18.

    **Options**

    .. autodata:: syncopy.specest.const_def.availableMethods

    .. autodata:: syncopy.specest.const_def.availableOutputs

    .. autodata:: syncopy.specest.const_def.availableTapers

    .. autodata:: syncopy.specest.const_def.availableWavelets

    Examples
    --------
    Coming soon...



    See also
    --------
    syncopy.specest.mtmfft.mtmfft : (multi-)tapered Fourier transform of multi-channel time series data
    syncopy.specest.mtmconvol.mtmconvol : time-frequency analysis of multi-channel time series data with a sliding window FFT
    syncopy.specest.wavelet.wavelet : time-frequency analysis of multi-channel time series data using a wavelet transform
    numpy.fft.fft : NumPy's reference FFT implementation
    scipy.signal.stft : SciPy's Short Time Fourier Transform
    """

    # Make sure our one mandatory input object can be processed
    try:
        data_parser(data, varname="data", dataclass="AnalogData",
                    writable=None, empty=False)
    except Exception as exc:
        raise exc
    timeAxis = data.dimord.index("time")

    # Get everything of interest in local namespace
    defaults = get_defaults(freqanalysis)
    lcls = locals()
    # check for ineffective additional kwargs
    check_passed_kwargs(lcls, defaults, frontend_name="freqanalysis")

    # Ensure a valid computational method was selected
    if method not in availableMethods:
        lgl = "'" + "or '".join(opt + "' " for opt in availableMethods)
        raise SPYValueError(legal=lgl, varname="method", actual=method)

    # Ensure a valid output format was selected
    if output not in spectralConversions.keys():
        lgl = "'" + "or '".join(opt + "' " for opt in spectralConversions.keys())
        raise SPYValueError(legal=lgl, varname="output", actual=output)

    # Parse all Boolean keyword arguments
    for vname in ["keeptrials", "keeptapers"]:
        if not isinstance(lcls[vname], bool):
            raise SPYTypeError(lcls[vname], varname=vname, expected="Bool")

    # If only a subset of `data` is to be processed, make some necessary adjustments
    # of the sampleinfo and trial lengths
    if data._selection is not None:
        sinfo = data._selection.trialdefinition[:, :2]
        trialList = data._selection.trials
    else:
        trialList = list(range(len(data.trials)))
        sinfo = data.sampleinfo
    lenTrials = np.diff(sinfo).squeeze()
    if not lenTrials.shape:
        lenTrials = lenTrials[None]
    numTrials = len(trialList)

    # check polyremoval
    if polyremoval is not None:
        scalar_parser(polyremoval, varname="polyremoval", ntype="int_like", lims=[0, 1])


    # --- Padding ---

    # Sliding window FFT does not support "fancy" padding
    if method == "mtmconvol" and isinstance(pad_to_length, str):
        msg = "method 'mtmconvol' only supports in-place padding for windows " +\
            "exceeding trial boundaries. Your choice of `pad_to_length = '{}'` will be ignored. "
        SPYWarning(msg.format(pad_to_length))

    if method == 'mtmfft':
        # the actual number of samples in case of later padding
        minSampleNum = validate_padding(pad_to_length, lenTrials)
    else:
        minSampleNum = lenTrials.min()

    # Compute length (in samples) of shortest trial
    minTrialLength = minSampleNum / data.samplerate

    # Shortcut to data sampling interval
    dt = 1 / data.samplerate

    foi, foilim = validate_foi(foi, foilim, data.samplerate)

    # see also https://docs.obspy.org/_modules/obspy/signal/detrend.html#polynomial
    if polyremoval is not None:
        try:
            scalar_parser(polyremoval, varname="polyremoval", lims=[0, 1], ntype="int_like")
        except Exception as exc:
            raise exc

    # Prepare keyword dict for logging (use `lcls` to get actually provided
    # keyword values, not defaults set above)
    log_dct = {"method": method,
               "output": output,
               "keeptapers": keeptapers,
               "keeptrials": keeptrials,
               "polyremoval": polyremoval,
               "pad_to_length": pad_to_length}

    # --------------------------------
    # 1st: Check time-frequency inputs
    # to prepare/sanitize `toi`
    # --------------------------------

    if method in ["mtmconvol", "wavelet", "superlet"]:

        # Get start/end timing info respecting potential in-place selection
        if toi is None:
            raise SPYTypeError(toi, varname="toi", expected="scalar or array-like or 'all'")
        if data._selection is not None:
            tStart = data._selection.trialdefinition[:, 2] / data.samplerate
        else:
            tStart = data._t0 / data.samplerate
        tEnd = tStart + lenTrials / data.samplerate

    # for these methods only 'all' or an equidistant array
    # of time points (sub-sampling, trimming) are valid
    if method in ["wavelet", "superlet"]:

        valid = True
        if isinstance(toi, Number):
            valid = False

        elif isinstance(toi, str):
            if toi != "all":
                valid = False
            else:
                # take everything
                preSelect = [slice(None)] * numTrials
                postSelect = [slice(None)] * numTrials

        elif not iter(toi):
            valid = False

        # this is the sequence type - can only be an interval!
        else:
            try:
                array_parser(toi, varname="toi", hasinf=False, hasnan=False,
                             lims=[tStart.min(), tEnd.max()], dims=(None,))
            except Exception as exc:
                raise exc
            toi = np.array(toi)
            # check for equidistancy
            if not np.allclose(np.diff(toi, 2), np.zeros(len(toi) - 2)):
                valid = False
            # trim (preSelect) and subsample output (postSelect)
            else:
                preSelect = []
                postSelect = []
                # get sample intervals and relative indices from toi
                for tk in range(numTrials):
                    start = int(data.samplerate * (toi[0] - tStart[tk]))
                    stop = int(data.samplerate * (toi[-1] - tStart[tk]) + 1)
                    preSelect.append(slice(max(0, start), max(stop, stop - start)))
                    smpIdx = np.minimum(lenTrials[tk] - 1,
                                        data.samplerate * (toi - tStart[tk]) - start)
                    postSelect.append(smpIdx.astype(np.intp))

        # get out if sth wasn't right
        if not valid:
            lgl = "array of equidistant time-points or 'all' for wavelet based methods"
            raise SPYValueError(legal=lgl, varname="toi", actual=toi)


        # Update `log_dct` w/method-specific options (use `lcls` to get actually
        # provided keyword values, not defaults set in here)
        log_dct["toi"] = lcls["toi"]

    # --------------------------------------------
    # Check options specific to mtm*-methods
    # (particularly tapers and foi/freqs alignment)
    # --------------------------------------------

    if "mtm" in method:

        if method == "mtmconvol":
            # get the sliding window size
            try:
                scalar_parser(t_ftimwin, varname="t_ftimwin",
                              lims=[dt, minTrialLength])
            except Exception as exc:
                SPYInfo("Please specify 't_ftimwin' parameter.. exiting!")
                raise exc

            # this is the effective sliding window FFT sample size
            minSampleNum = int(t_ftimwin * data.samplerate)

        # Construct array of maximally attainable frequencies
        freqs = np.fft.rfftfreq(minSampleNum, dt)

        # Match desired frequencies as close as possible to
        # actually attainable freqs
        # these are the frequencies attached to the SpectralData by the CR!
        if foi is not None:
            foi, _ = best_match(freqs, foi, squash_duplicates=True)
        elif foilim is not None:
            foi, _ = best_match(freqs, foilim, span=True, squash_duplicates=True)
        else:
            msg = (f"Automatic FFT frequency selection from {freqs[0]:.1f}Hz to "
                   f"{freqs[-1]:.1f}Hz")
            SPYInfo(msg)
            foi = freqs
        log_dct["foi"] = foi

        # Abort if desired frequency selection is empty
        if foi.size == 0:
            lgl = "non-empty frequency specification"
            act = "empty frequency selection"
            raise SPYValueError(legal=lgl, varname="foi/foilim", actual=act)

        # sanitize taper selection and retrieve dpss settings
        taper_opt = validate_taper(taper,
                                   tapsmofrq,
                                   nTaper,
                                   keeptapers,
                                   foimax=foi.max(),
                                   samplerate=data.samplerate,
                                   nSamples=minSampleNum,
                                   output=output)

        # Update `log_dct` w/method-specific options
        log_dct["taper"] = taper
        # only dpss returns non-empty taper_opt dict
        if taper_opt:
            log_dct["nTaper"] = taper_opt["Kmax"]
            log_dct["tapsmofrq"] = tapsmofrq

    # -------------------------------------------------------
    # Now, prepare explicit compute-classes for chosen method
    # -------------------------------------------------------

    if method == "mtmfft":

        check_effective_parameters(MultiTaperFFT, defaults, lcls)

        # method specific parameters
        method_kwargs = {
            'samplerate': data.samplerate,
            'taper': taper,
            'taper_opt': taper_opt,
            'nSamples': minSampleNum
        }

        # Set up compute-class
        specestMethod = MultiTaperFFT(
            foi=foi,
            timeAxis=timeAxis,
            keeptapers=keeptapers,
            polyremoval=polyremoval,
            output_fmt=output,
            method_kwargs=method_kwargs)

    elif method == "mtmconvol":

        check_effective_parameters(MultiTaperFFTConvol, defaults, lcls)

        # Process `toi` for sliding window multi taper fft,
        # we have to account for three scenarios: (1) center sliding
        # windows on all samples in (selected) trials (2) `toi` was provided as
        # percentage indicating the degree of overlap b/w time-windows and (3) a set
        # of discrete time points was provided. These three cases are encoded in
        # `overlap, i.e., ``overlap > 1` => all, `0 < overlap < 1` => percentage,
        # `overlap < 0` => discrete `toi`

        # overlap = None
        if isinstance(toi, str):
            if toi != "all":
                lgl = "`toi = 'all'` to center analysis windows on all time-points"
                raise SPYValueError(legal=lgl, varname="toi", actual=toi)
            equidistant = True
            overlap = np.inf

        elif isinstance(toi, Number):
            try:
                scalar_parser(toi, varname="toi", lims=[0, 1])
            except Exception as exc:
                raise exc
            overlap = toi
            equidistant = True
        # this captures all other cases, e.i. toi is of sequence type
        else:
            overlap = -1
            try:
                array_parser(toi, varname="toi", hasinf=False, hasnan=False,
                             lims=[tStart.min(), tEnd.max()], dims=(None,))
            except Exception as exc:
                raise exc
            toi = np.array(toi)
            tSteps = np.diff(toi)
            if (tSteps < 0).any():
                lgl = "ordered list/array of time-points"
                act = "unsorted list/array"
                raise SPYValueError(legal=lgl, varname="toi", actual=act)
            # Account for round-off errors: if toi spacing is almost at sample interval
            # manually correct it
            if np.isclose(tSteps.min(), dt):
                tSteps[np.isclose(tSteps, dt)] = dt
            if tSteps.min() < dt:
                msg = f"`toi` selection too fine, max. time resolution is {dt}s"
                SPYWarning(msg)
            # This is imho a bug in NumPy - even `arange` and `linspace` may produce
            # arrays that are numerically not exactly equidistant - `unique` will
            # show several entries here - use `allclose` to identify "even" spacings
            equidistant = np.allclose(tSteps, [tSteps[0]] * tSteps.size)

        # If `toi` was 'all' or a percentage, use entire time interval of (selected)
        # trials and check if those trials have *approximately* equal length
        if toi is None:
            if not np.allclose(lenTrials, [minSampleNum] * lenTrials.size):
                msg = "processing trials of different lengths (min = {}; max = {} samples)" +\
                    " with `toi = 'all'`"
                SPYWarning(msg.format(int(minSampleNum), int(lenTrials.max())))

        # number of samples per window
        nperseg = int(t_ftimwin * data.samplerate)
        halfWin = int(nperseg / 2)
        postSelect = slice(None) # select all is the default

        if 0 <= overlap <= 1: # `toi` is percentage
            noverlap = min(nperseg - 1, int(overlap * nperseg))
        # windows get shifted exactly 1 sample
        # to get a spectral estimate at each sample
        else:
            noverlap = nperseg - 1

        # `toi` is array
        if overlap < 0:
            # Compute necessary padding at begin/end of trials to fit sliding windows
            offStart = ((toi[0] - tStart) * data.samplerate).astype(np.intp)
            padBegin = halfWin - offStart
            padBegin = ((padBegin > 0) * padBegin).astype(np.intp)
            offEnd = ((tEnd - toi[-1]) * data.samplerate).astype(np.intp)
            padEnd = halfWin - offEnd
            padEnd = ((padEnd > 0) * padEnd).astype(np.intp)

            # Compute sample-indices (one slice/list per trial) from time-selections
            soi = []
            if equidistant:
                # soi just trims the input data to the [toi[0], toi[-1]] interval
                # postSelect then subsamples the spectral esimate to the user given toi
                postSelect = []
                for tk in range(numTrials):
                    start = max(0, int(round(data.samplerate * (toi[0] - tStart[tk]) - halfWin)))
                    stop = int(round(data.samplerate * (toi[-1] - tStart[tk]) + halfWin + 1))
                    soi.append(slice(start, max(stop, stop - start)))

                # chosen toi subsampling interval in sample units, min. is 1;
                # compute `delta_idx` s.t. stop - start / delta_idx == toi.size
                delta_idx = int(round((soi[0].stop - soi[0].start) / toi.size))
                delta_idx = delta_idx if delta_idx > 1 else 1
                postSelect = slice(None, None, delta_idx)

            else:
                for tk in range(numTrials):
                    starts = (data.samplerate * (toi - tStart[tk]) - halfWin).astype(np.intp)
                    starts += padBegin[tk]
                    stops = (data.samplerate * (toi - tStart[tk]) + halfWin + 1).astype(np.intp)
                    stops += padBegin[tk]
                    stops = np.maximum(stops, stops - starts, dtype=np.intp)
                    soi.append([slice(start, stop) for start, stop in zip(starts, stops)])
                    # postSelect here remains slice(None), as resulting spectrum
                    # has exactly one entry for each soi

        # `toi` is percentage or "all"
        else:
            soi = [slice(None)] * numTrials


        # Collect keyword args for `mtmconvol` in dictionary
        method_kwargs = {"samplerate": data.samplerate,
                         "nperseg": nperseg,
                         "noverlap": noverlap,
                         "taper" : taper,
                         "taper_opt" : taper_opt}

        # Set up compute-class
        specestMethod = MultiTaperFFTConvol(
            soi,
            postSelect,
            equidistant=equidistant,
            toi=toi,
            foi=foi,
            timeAxis=timeAxis,
            keeptapers=keeptapers,
            polyremoval=polyremoval,
            output_fmt=output,
            method_kwargs=method_kwargs)

    elif method == "wavelet":

        check_effective_parameters(WaveletTransform, defaults, lcls)

        # Check wavelet selection
        if wavelet not in availableWavelets:
            lgl = "'" + "or '".join(opt + "' " for opt in availableWavelets)
            raise SPYValueError(legal=lgl, varname="wavelet", actual=wavelet)
        if wavelet not in ["Morlet", "Paul"]:
            msg = "the chosen wavelet '{}' is real-valued and does not provide " +\
                "any information about amplitude or phase of the data. This wavelet function " +\
                "may be used to isolate peaks or discontinuities in the signal. "
            SPYWarning(msg.format(wavelet))

        # Check for consistency of `width`, `order` and `wavelet`
        if wavelet == "Morlet":
            try:
                scalar_parser(width, varname="width", lims=[1, np.inf])
            except Exception as exc:
                raise exc
            wfun = getattr(spywave, wavelet)(w0=width)
        else:
            if width != lcls["width"]:
                msg = "option `width` has no effect for wavelet '{}'"
                SPYWarning(msg.format(wavelet))

        if wavelet == "Paul":
            try:
                scalar_parser(order, varname="order", lims=[4, np.inf], ntype="int_like")
            except Exception as exc:
                raise exc
            wfun = getattr(spywave, wavelet)(m=order)
        elif wavelet == "DOG":
            try:
                scalar_parser(order, varname="order", lims=[1, np.inf], ntype="int_like")
            except Exception as exc:
                raise exc
            wfun = getattr(spywave, wavelet)(m=order)
        else:
            if order is not None:
                msg = "option `order` has no effect for wavelet '{}'"
                SPYWarning(msg.format(wavelet))
            wfun = getattr(spywave, wavelet)()

        # automatic frequency selection
        if foi is None and foilim is None:
            scales = get_optimal_wavelet_scales(
                wfun.scale_from_period, # all availableWavelets sport one!
                int(minTrialLength * data.samplerate),
                dt)
            foi = 1 / wfun.fourier_period(scales)
            msg = (f"Setting frequencies of interest to {foi[0]:.1f}-"
                   f"{foi[-1]:.1f}Hz")
            SPYInfo(msg)
        else:
            if foilim is not None:
                foi = np.arange(foilim[0], foilim[1] + 1, dtype=float)
            # 0 frequency is not valid
            foi[foi < 0.01] = 0.01
            scales = wfun.scale_from_period(1 / foi)

        # Update `log_dct` w/method-specific options (use `lcls` to get actually
        # provided keyword values, not defaults set in here)
        log_dct["foi"] = foi
        log_dct["wavelet"] = lcls["wavelet"]
        log_dct["width"] = lcls["width"]
        log_dct["order"] = lcls["order"]

        # method specific parameters
        method_kwargs = {
            'samplerate' : data.samplerate,
            'scales' : scales,
            'wavelet' : wfun
        }

        # Set up compute-class
        specestMethod = WaveletTransform(
            preSelect,
            postSelect,
            toi=toi,
            timeAxis=timeAxis,
            polyremoval=polyremoval,
            output_fmt=output,
            method_kwargs=method_kwargs)

    elif method == "superlet":

        check_effective_parameters(SuperletTransform, defaults, lcls)

        # check and parse superlet specific arguments
        if order_max is None:
            lgl = "Positive integer needed for order_max"
            raise SPYValueError(legal=lgl, varname="order_max",
                                actual=None)
        else:
            scalar_parser(
                order_max,
                varname="order_max",
                lims=[1, np.inf],
                ntype="int_like"
            )

        scalar_parser(
            order_min, varname="order_min",
            lims=[1, order_max],
            ntype="int_like"
        )
        scalar_parser(c_1, varname="c_1", lims=[1, np.inf], ntype="int_like")

        # if no frequencies are user selected, take a sensitive default
        if foi is None and foilim is None:
            scales = get_optimal_wavelet_scales(
                superlet.scale_from_period,
                int(minTrialLength * data.samplerate),
                dt)
            foi = 1 / superlet.fourier_period(scales)
            msg = (f"Setting frequencies of interest to {foi[0]:.1f}-"
                   f"{foi[-1]:.1f}Hz")
            SPYInfo(msg)
        else:
            if foilim is not None:
                # frequency range in 1Hz steps
                foi = np.arange(foilim[0], foilim[1] + 1, dtype=float)
            # 0 frequency is not valid
            foi[foi < 0.01] = 0.01
            scales = superlet.scale_from_period(1. / foi)

        # FASLT needs ordered frequencies low - high
        # meaning the scales have to go high - low
        if adaptive:
            if len(scales) < 2:
                lgl = "A range of frequencies"
                act = "Single frequency"
                raise SPYValueError(legal=lgl, varname="foi", actual=act)
            if np.any(np.diff(scales) > 0):
                msg = "Sorting frequencies low to high for adaptive SLT.."
                SPYWarning(msg)
                scales = np.sort(scales)[::-1]

        log_dct["foi"] = foi
        log_dct["c_1"] = lcls["c_1"]
        log_dct["order_max"] = lcls["order_max"]
        log_dct["order_min"] = lcls["order_min"]

        # method specific parameters
        method_kwargs = {
            'samplerate' : data.samplerate,
            'scales' : scales,
            'order_max' : order_max,
            'order_min' : order_min,
            'c_1' : c_1,
            'adaptive' : adaptive
        }

        # Set up compute-class
        specestMethod = SuperletTransform(
            preSelect,
            postSelect,
            toi=toi,
            timeAxis=timeAxis,
            polyremoval=polyremoval,
            output_fmt=output,
            method_kwargs=method_kwargs)

    # -------------------------------------------------
    # Sanitize output and call the ComputationalRoutine
    # -------------------------------------------------

    # If provided, make sure output object is appropriate
    if out is not None:
        try:
            data_parser(out, varname="out", writable=True, empty=True,
                        dataclass="SpectralData",
                        dimord=SpectralData().dimord)
        except Exception as exc:
            raise exc
        new_out = False
    else:
        out = SpectralData(dimord=SpectralData._defaultDimord)
        new_out = True

    # Perform actual computation
    specestMethod.initialize(data,
                             out._stackingDim,
                             chan_per_worker=kwargs.get("chan_per_worker"),
                             keeptrials=keeptrials)
    specestMethod.compute(data, out, parallel=kwargs.get("parallel"), log_dict=log_dct)

    # Either return newly created output object or simply quit
    return out if new_out else None
예제 #11
0
def _prep_spectral_plots(self, name, **inputArgs):
    """
    Local helper that performs sanity checks and sets up data selection
    
    Parameters
    ----------
    self : :class:`~syncopy.SpectralData` object
        Syncopy :class:`~syncopy.SpectralData` object that is being processed by 
        the respective :meth:`.singlepanelplot` or :meth:`.multipanelplot` class methods
        defined in this module. 
    name : str
        Name of caller (i.e., "singlepanelplot" or "multipanelplot")
    inputArgs : dict
        Input arguments of caller (i.e., :meth:`.singlepanelplot` or :meth:`.multipanelplot`)
        collected in dictionary
        
    Returns
    -------
    dimArrs : tuple
        Four-element tuple containing (in this order): `trList`, list of (selected) 
        trials to visualize, `chArr`, 1D :class:`numpy.ndarray` of channel specifiers
        based on provided user selection, `freqArr`, 1D :class:`numpy.ndarray` of 
        frequency specifiers based on provided user selection, `tpArr`, 
        1D :class:`numpy.ndarray` of taper specifiers based on provided user selection. 
        Note that `"all"` and `None` selections are converted to arrays ready for
        indexing. 
    dimCounts : tuple
        Four-element tuple holding sizes of corresponding selection arrays comprised
        in `dimArrs`. Elements are (in this order): number of (selected) trials 
        `nTrials`, number of (selected) channels `nChan`, number of (selected) 
        frequencies `nFreq`, number of (selected) tapers `nTap`. 
    isTimeFrequency : bool
        If `True`, input object contains time-frequency data, `False` otherwise
    complexConversion : callable
        Lambda function that performs complex-to-float conversion of Fourier 
        coefficients (if necessary). 
    pltDtype : str or :class:`numpy.dtype`
        Numeric type of (potentially converted) complex Fourier coefficients. 
    dataLbl : str
        Caption for y-axis or colorbar (depending on value of `isTimeFrequency`). 
        
    Notes
    -----
    This is an auxiliary method that is intended purely for internal use. Please
    refer to the user-exposed methods :func:`~syncopy.singlepanelplot` and/or
    :func:`~syncopy.multipanelplot` to actually generate plots of Syncopy data objects. 
        
    See also
    --------
    :meth:`syncopy.plotting.spy_plotting._prep_plots` : General basic input parsing for all Syncopy plotting routines
    """

    # Basic sanity checks for all plotting routines w/any Syncopy object
    _prep_plots(self, name, **inputArgs)

    # Ensure our binary flags are actually binary
    if not isinstance(inputArgs["avg_channels"], bool):
        raise SPYTypeError(inputArgs["avg_channels"],
                           varname="avg_channels",
                           expected="bool")
    if not isinstance(inputArgs["avg_tapers"], bool):
        raise SPYTypeError(inputArgs["avg_tapers"],
                           varname="avg_tapers",
                           expected="bool")
    if not isinstance(inputArgs.get("avg_trials", True), bool):
        raise SPYTypeError(inputArgs["avg_trials"],
                           varname="avg_trials",
                           expected="bool")

    # Pass provided selections on to `Selector` class which performs error
    # checking and generates required indexing arrays
    self._selection = {
        "trials": inputArgs["trials"],
        "channels": inputArgs["channels"],
        "tapers": inputArgs["tapers"],
        "toilim": inputArgs["toilim"],
        "foilim": inputArgs["foilim"]
    }

    # Ensure any optional keywords controlling plotting appearance make sense
    if inputArgs["title"] is not None:
        if not isinstance(inputArgs["title"], str):
            raise SPYTypeError(inputArgs["title"],
                               varname="title",
                               expected="str")
    if inputArgs["grid"] is not None:
        if not isinstance(inputArgs["grid"], bool):
            raise SPYTypeError(inputArgs["grid"],
                               varname="grid",
                               expected="bool")

    # Get trial/channel/taper count and collect quantities in tuple
    trList = self._selection.trials
    nTrials = len(trList)
    chArr = self.channel[self._selection.channel]
    nChan = chArr.size
    freqArr = self.freq[self._selection.freq]
    nFreq = freqArr.size
    tpArr = np.arange(self.taper.size)[self._selection.taper]
    nTap = tpArr.size
    dimCounts = (nTrials, nChan, nFreq, nTap)
    dimArrs = (trList, chArr, freqArr, tpArr)

    # Determine whether we're dealing w/tf data
    isTimeFrequency = False
    if any([t.size > 1 for t in self.time]):
        isTimeFrequency = True

    # Ensure provided min/max range for plotting TF data makes sense
    vminmax = False
    if inputArgs.get("vmin", None) is not None:
        try:
            scalar_parser(inputArgs["vmin"], varname="vmin")
        except Exception as exc:
            raise exc
        vminmax = True
    if inputArgs.get("vmax", None) is not None:
        try:
            scalar_parser(inputArgs["vmax"], varname="vmax")
        except Exception as exc:
            raise exc
        vminmax = True
    if inputArgs.get("vmin", None) and inputArgs.get("vmax", None):
        if inputArgs["vmin"] >= inputArgs["vmax"]:
            lgl = "minimal data range bound to be less than provided maximum "
            act = "vmax < vmin"
            raise SPYValueError(legal=lgl, varname="vmin/vamx", actual=act)
    if vminmax and not isTimeFrequency:
        msg = "`vmin` and `vmax` is only used for time-frequency visualizations"
        SPYWarning(msg)

    # Check for complex entries in data and set datatype for plotting arrays
    # constructed below (always use floats w/same precision as data)
    if "complex" in self.data.dtype.name:
        msg = "Found complex Fourier coefficients - visualization will use absolute values."
        SPYWarning(msg)
        complexConversion = lambda x: np.absolute(x).real
        pltDtype = "f{}".format(self.data.dtype.itemsize)
        dataLbl = "Absolute Frequency [dB]"
    else:
        complexConversion = lambda x: x
        pltDtype = self.data.dtype
        dataLbl = "Power [dB]"

    return dimArrs, dimCounts, isTimeFrequency, complexConversion, pltDtype, dataLbl
예제 #12
0
def connectivityanalysis(data,
                         method="coh",
                         keeptrials=False,
                         output="abs",
                         foi=None,
                         foilim=None,
                         pad_to_length=None,
                         polyremoval=None,
                         taper="hann",
                         tapsmofrq=None,
                         nTaper=None,
                         out=None,
                         **kwargs):
    """
    Perform connectivity analysis of Syncopy :class:`~syncopy.AnalogData` objects

    **Usage Summary**

    Options available in all analysis methods:

    * **foi**/**foilim** : frequencies of interest; either array of frequencies or
      frequency window (not both)
    * **polyremoval** : de-trending method to use (0 = mean, 1 = linear or `None`)

    List of available analysis methods and respective distinct options:

    "coh" : (Multi-) tapered coherency estimate
        Compute the normalized cross spectral densities
        between all channel combinations

        * **output** : one of ('abs', 'pow', 'fourier')
        * **taper** : one of :data:`~syncopy.shared.const_def.availableTapers`
        * **tapsmofrq** : spectral smoothing box for slepian tapers (in Hz)
        * **nTaper** : (optional) number of orthogonal tapers for slepian tapers
        * **pad_to_length**: either pad to an absolute length or set to `'nextpow2'`

    "corr" : Cross-correlations
        Computes the one sided (positive lags) cross-correlations
        between all channel combinations. The maximal lag is half
        the trial lengths.

        * **keeptrials** : set to `True` for single trial cross-correlations

    "granger" : Spectral Granger-Geweke causality
        Computes linear causality estimates between
        all channel combinations. The intermediate cross-spectral
        densities can be computed via multi-tapering.

        * **taper** : one of :data:`~syncopy.shared.const_def.availableTapers`
        * **tapsmofrq** : spectral smoothing box for slepian tapers (in Hz)
        * **nTaper** : (optional, not recommended) number of slepian tapers
        * **pad_to_length**: either pad to an absolute length or set to `'nextpow2'`

    Parameters
    ----------
    data : `~syncopy.AnalogData`
        A non-empty Syncopy :class:`~syncopy.datatype.AnalogData` object
    method : str
        Connectivity estimation method, one of 'coh', 'corr', 'granger'
    output : str
        Relevant for cross-spectral density estimation (`method='coh'`)
        Use `'pow'` for absolute squared coherence, `'abs'` for absolute value of coherence
        and`'fourier'` for the complex valued coherency.
    keeptrials : bool
        Relevant for cross-correlations (`method='corr'`).
        If `True` single-trial cross-correlations are returned.
    foi : array-like or None
        Frequencies of interest (Hz) for output. If desired frequencies cannot be
        matched exactly, the closest possible frequencies are used. If `foi` is `None`
        or ``foi = "all"``, all attainable frequencies (i.e., zero to Nyquist / 2)
        are selected.
    foilim : array-like (floats [fmin, fmax]) or None or "all"
        Frequency-window ``[fmin, fmax]`` (in Hz) of interest. The
        `foi` array will be constructed in 1Hz steps from `fmin` to
        `fmax` (inclusive).
    pad_to_length : int, None or 'nextpow2'
        Padding of the (tapered) signal, if set to a number pads all trials
        to this absolute length. E.g. `pad_to_length=2000` pads all
        trials to 2000 samples, if and only if the longest trial is
        at maximum 2000 samples.

        Alternatively if all trials have the same initial lengths
        setting `pad_to_length='nextpow2'` pads all trials to
        the next power of two.
        If `None` and trials have unequal lengths all trials are padded to match
        the longest trial.
    taper : str
        Only valid if `method` is `'coh'` or `'granger'`. Windowing function,
        one of :data:`~syncopy.specest.const_def.availableTapers`
    tapsmofrq : float
        Only valid if `method` is `'coh'` or `'granger'` and `taper` is `'dpss'`.
        The amount of spectral smoothing through  multi-tapering (Hz).
        Note that smoothing frequency specifications are one-sided,
        i.e., 4 Hz smoothing means plus-minus 4 Hz, i.e., a 8 Hz smoothing box.
    nTaper : int or None
        Only valid if `method` is `'coh'` or `'granger'` and ``taper = 'dpss'``.
        Number of orthogonal tapers to use. It is not recommended to set the number
        of tapers manually! Leave at `None` for the optimal number to be set automatically.

    Examples
    --------
    Coming soon...
    """

    # Make sure our one mandatory input object can be processed
    try:
        data_parser(data,
                    varname="data",
                    dataclass="AnalogData",
                    writable=None,
                    empty=False)
    except Exception as exc:
        raise exc
    timeAxis = data.dimord.index("time")

    # Get everything of interest in local namespace
    defaults = get_defaults(connectivityanalysis)
    lcls = locals()
    # check for ineffective additional kwargs
    check_passed_kwargs(lcls, defaults, frontend_name="connectivity")
    # Ensure a valid computational method was selected

    if method not in availableMethods:
        lgl = "'" + "or '".join(opt + "' " for opt in availableMethods)
        raise SPYValueError(legal=lgl, varname="method", actual=method)

    # if a subset selection is present
    # get sampleinfo and check for equidistancy
    if data._selection is not None:
        sinfo = data._selection.trialdefinition[:, :2]
        trialList = data._selection.trials
        # user picked discrete set of time points
        if isinstance(data._selection.time[0], list):
            lgl = "equidistant time points (toi) or time slice (toilim)"
            actual = "non-equidistant set of time points"
            raise SPYValueError(legal=lgl, varname="select", actual=actual)
    else:
        trialList = list(range(len(data.trials)))
        sinfo = data.sampleinfo
    lenTrials = np.diff(sinfo).squeeze()

    # check polyremoval
    if polyremoval is not None:
        scalar_parser(polyremoval,
                      varname="polyremoval",
                      ntype="int_like",
                      lims=[0, 1])

    # --- Padding ---

    if method == "corr" and pad_to_length:
        lgl = "`None`, no padding needed/allowed for cross-correlations"
        actual = f"{pad_to_length}"
        raise SPYValueError(legal=lgl, varname="pad_to_length", actual=actual)

    # the actual number of samples in case of later padding
    nSamples = validate_padding(pad_to_length, lenTrials)

    # --- Basic foi sanitization ---

    foi, foilim = validate_foi(foi, foilim, data.samplerate)

    # only now set foi array for foilim in 1Hz steps
    if foilim is not None:
        foi = np.arange(foilim[0], foilim[1] + 1, dtype=float)

    # Prepare keyword dict for logging (use `lcls` to get actually provided
    # keyword values, not defaults set above)
    log_dict = {
        "method": method,
        "output": output,
        "keeptrials": keeptrials,
        "polyremoval": polyremoval,
        "pad_to_length": pad_to_length
    }

    # --- Setting up specific Methods ---

    if method in ['coh', 'granger']:

        # --- set up computation of the single trial CSDs ---

        if keeptrials is not False:
            lgl = "False, trial averaging needed!"
            act = keeptrials
            raise SPYValueError(lgl, varname="keeptrials", actual=act)

        # Construct array of maximally attainable frequencies
        freqs = np.fft.rfftfreq(nSamples, 1 / data.samplerate)

        # Match desired frequencies as close as possible to
        # actually attainable freqs
        # these are the frequencies attached to the SpectralData by the CR!
        if foi is not None:
            foi, _ = best_match(freqs, foi, squash_duplicates=True)
        elif foilim is not None:
            foi, _ = best_match(freqs,
                                foilim,
                                span=True,
                                squash_duplicates=True)
        elif foi is None and foilim is None:
            # Construct array of maximally attainable frequencies
            msg = (f"Setting frequencies of interest to {freqs[0]:.1f}-"
                   f"{freqs[-1]:.1f}Hz")
            SPYInfo(msg)
            foi = freqs

        # sanitize taper selection and retrieve dpss settings
        taper_opt = validate_taper(
            taper,
            tapsmofrq,
            nTaper,
            keeptapers=False,  # ST_CSD's always average tapers
            foimax=foi.max(),
            samplerate=data.samplerate,
            nSamples=nSamples,
            output="pow")  # ST_CSD's always have this unit/norm

        log_dict["foi"] = foi
        log_dict["taper"] = taper
        # only dpss returns non-empty taper_opt dict
        if taper_opt:
            log_dict["nTaper"] = taper_opt["Kmax"]
            log_dict["tapsmofrq"] = tapsmofrq

        check_effective_parameters(ST_CrossSpectra, defaults, lcls)
        # parallel computation over trials
        st_compRoutine = ST_CrossSpectra(samplerate=data.samplerate,
                                         nSamples=nSamples,
                                         taper=taper,
                                         taper_opt=taper_opt,
                                         polyremoval=polyremoval,
                                         timeAxis=timeAxis,
                                         foi=foi)
        # hard coded as class attribute
        st_dimord = ST_CrossSpectra.dimord

    if method == 'coh':
        # final normalization after trial averaging
        av_compRoutine = NormalizeCrossSpectra(output=output)

    if method == 'granger':
        # after trial averaging
        # hardcoded numerical parameters
        av_compRoutine = GrangerCausality(rtol=1e-8, nIter=100, cond_max=1e4)

    if method == 'corr':
        if lcls['foi'] is not None:
            msg = 'Parameter `foi` has no effect for `corr`'
            SPYWarning(msg)
        check_effective_parameters(ST_CrossCovariance, defaults, lcls)

        # single trial cross-correlations
        if keeptrials:
            av_compRoutine = None  # no trial average
            norm = True  # normalize individual trials within the ST CR
        else:
            av_compRoutine = NormalizeCrossCov()
            norm = False

        # parallel computation over trials
        st_compRoutine = ST_CrossCovariance(samplerate=data.samplerate,
                                            polyremoval=polyremoval,
                                            timeAxis=timeAxis,
                                            norm=norm)
        # hard coded as class attribute
        st_dimord = ST_CrossCovariance.dimord

    # -------------------------------------------------
    # Call the chosen single trial ComputationalRoutine
    # -------------------------------------------------

    # the single trial results need a new DataSet
    st_out = CrossSpectralData(dimord=st_dimord)

    # Perform the trial-parallelized computation of the matrix quantity
    st_compRoutine.initialize(
        data,
        st_out._stackingDim,
        chan_per_worker=None,  # no parallelisation over channels possible
        keeptrials=keeptrials)  # we most likely need trial averaging!
    st_compRoutine.compute(data,
                           st_out,
                           parallel=kwargs.get("parallel"),
                           log_dict=log_dict)

    # if ever needed..
    # for single trial cross-corr results <-> keeptrials is True
    if keeptrials and av_compRoutine is None:
        if out is not None:
            msg = "Single trial processing does not support `out` argument but directly returns the results"
            SPYWarning(msg)
        return st_out

    # ----------------------------------------------------------------------------------
    # Sanitize output and call the chosen ComputationalRoutine on the averaged ST output
    # ----------------------------------------------------------------------------------

    # If provided, make sure output object is appropriate
    if out is not None:
        try:
            data_parser(out,
                        varname="out",
                        writable=True,
                        empty=True,
                        dataclass="CrossSpectralData",
                        dimord=st_dimord)
        except Exception as exc:
            raise exc
        new_out = False
    else:
        out = CrossSpectralData(dimord=st_dimord)
        new_out = True

    # now take the trial average from the single trial CR as input
    av_compRoutine.initialize(st_out, out._stackingDim, chan_per_worker=None)
    av_compRoutine.pre_check()  # make sure we got a trial_average
    av_compRoutine.compute(st_out, out, parallel=False, log_dict=log_dict)

    # Either return newly created output object or simply quit
    return out if new_out else None
예제 #13
0
def validate_taper(taper, tapsmofrq, nTaper, keeptapers, foimax, samplerate,
                   nSamples, output):
    """
    General taper validation and Slepian/dpss input sanitization.
    The default is to max out `nTaper` to achieve the desired frequency
    smoothing bandwidth. For details about the Slepion settings see

    "The Effective Bandwidth of a Multitaper Spectral Estimator,
    A. T. Walden, E. J. McCoy and D. B. Percival"

    Parameters
    ----------
    taper : str
        Windowing function, one of :data:`~syncopy.shared.const_def.availableTapers`
    tapsmofrq : float or None
        Taper smoothing bandwidth for `taper='dpss'`
    nTaper : int_like or None
        Number of tapers to use for multi-tapering (not recommended)

    Other Parameters
    ----------------
    keeptapers : bool
    foimax : float
        Maximum frequency for the analysis
    samplerate : float
        the samplerate in Hz
    nSamples : int
        Number of samples
    output : str, one of {'abs', 'pow', 'fourier'}
        Fourier transformation output type

    Returns
    -------
    dpss_opt : dict
        For multi-tapering (`taper='dpss'`) contains the
        parameters `NW` and `Kmax` for `scipy.signal.windows.dpss`.
        For all other tapers this is an empty dictionary.
    """

    # See if taper choice is supported
    if taper not in availableTapers:
        lgl = "'" + "or '".join(opt + "' " for opt in availableTapers)
        raise SPYValueError(legal=lgl, varname="taper", actual=taper)

    # Warn user about DPSS only settings
    if taper != "dpss":
        if tapsmofrq is not None:
            msg = "`tapsmofrq` is only used if `taper` is `dpss`!"
            SPYWarning(msg)
        if nTaper is not None:
            msg = "`nTaper` is only used if `taper` is `dpss`!"
            SPYWarning(msg)
        if keeptapers:
            msg = "`keeptapers` is only used if `taper` is `dpss`!"
            SPYWarning(msg)

        # empty dpss_opt, only Slepians have options
        return {}

    # direct mtm estimate (averaging) only valid for spectral power
    if taper == "dpss" and not keeptapers and output != "pow":
        lgl = "'pow', the only valid option for taper averaging"
        raise SPYValueError(legal=lgl, varname="output", actual=output)

    # Set/get `tapsmofrq` if we're working w/Slepian tapers
    elif taper == "dpss":

        # --- minimal smoothing bandwidth ---
        # --- such that Kmax/nTaper is at least 1
        minBw = 2 * samplerate / nSamples
        # -----------------------------------

        # user set tapsmofrq directly
        if tapsmofrq is not None:
            try:
                scalar_parser(tapsmofrq, varname="tapsmofrq", lims=[0, np.inf])
            except Exception as exc:
                raise exc

            if tapsmofrq < minBw:
                msg = f'Setting tapsmofrq to the minimal attainable bandwidth of {minBw:.2f}Hz'
                SPYInfo(msg)
                tapsmofrq = minBw

        # we now enforce a user submitted smoothing bw
        else:
            lgl = "smoothing bandwidth in Hz, typical values are in the range 1-10Hz"
            raise SPYValueError(legal=lgl,
                                varname="tapsmofrq",
                                actual=tapsmofrq)

            # Try to derive "sane" settings by using 3/4 octave
            # smoothing of highest `foi`
            # following Hill et al. "Oscillatory Synchronization in Large-Scale
            # Cortical Networks Predicts Perception", Neuron, 2011
            # FIX ME: This "sane setting" seems quite excessive (huuuge bwidths)

            # tapsmofrq = (foimax * 2**(3 / 4 / 2) - foimax * 2**(-3 / 4 / 2)) / 2
            # msg = f'Automatic setting of `tapsmofrq` to {tapsmofrq:.2f}'
            # SPYInfo(msg)

        # --------------------------------------------
        # set parameters for scipy.signal.windows.dpss
        NW = tapsmofrq * nSamples / (2 * samplerate)
        # from the minBw setting NW always is at least 1
        Kmax = int(2 * NW - 1)  # optimal number of tapers
        # --------------------------------------------

        # the recommended way:
        # set nTaper automatically to achieve exact effective smoothing bandwidth
        if nTaper is None:
            msg = f'Using {Kmax} taper(s) for multi-tapering'
            SPYInfo(msg)
            dpss_opt = {'NW': NW, 'Kmax': Kmax}
            return dpss_opt

        elif nTaper is not None:
            try:
                scalar_parser(nTaper,
                              varname="nTaper",
                              ntype="int_like",
                              lims=[1, np.inf])
            except Exception as exc:
                raise exc

            if nTaper != Kmax:
                msg = f'''
                Manually setting the number of tapers is not recommended
                and may (strongly) distort the effective smoothing bandwidth!\n
                The optimal number of tapers is {Kmax}, you have chosen to use {nTaper}.
                '''
                SPYWarning(msg)

            dpss_opt = {'NW': NW, 'Kmax': nTaper}
            return dpss_opt
예제 #14
0
def padding(data,
            padtype,
            pad="absolute",
            padlength=None,
            prepadlength=None,
            postpadlength=None,
            unit="samples",
            create_new=True):
    """
    Perform data padding on Syncopy object or :class:`numpy.ndarray`
    
    **Usage Summary**
    
    Depending on the value of `pad` the following padding length specifications
    are supported:
    
    +------------+----------------------+---------------+----------------------+----------------------+
    | `pad`      | `data`               | `padlength`   | `prepadlength`       | `postpadlength`      |
    +============+======================+===============+======================+======================+
    | 'absolute' | Syncopy object/array | number        | `None`/`bool`        | `None`/`bool`        |
    +------------+----------------------+---------------+----------------------+----------------------+
    | 'relative' | Syncopy object/array | number/`None` | number/`None`/`bool` | number/`None`/`bool` |
    +------------+----------------------+---------------+----------------------+----------------------+
    | 'maxlen'   | Syncopy object       | `None`/`bool` | `None`/`bool`        | `None`/`bool`        |
    +------------+----------------------+---------------+----------------------+----------------------+
    | 'nextpow2' | Syncopy object/array | `None`/`bool` | `None`/`bool`        | `None`/`bool`        |
    +------------+----------------------+---------------+----------------------+----------------------+
    
    * `data` can be either a Syncopy object containing multiple trials or a
      :class:`numpy.ndarray` representing a single trial
    * (pre/post)padlength: can be either `None`, `True`/`False` or a positive
      number: if `True` indicates where to pad, e.g., by using ``pad =
      'maxlen'`` and  ``prepadlength = True``, `data` is padded at the beginning
      of each trial. **Only** if `pad` is 'relative' are scalar values supported
      for `prepadlength` and `postpadlength`
    * ``pad = 'absolute'``: pad to desired absolute length, e.g., by using ``pad
      = 5`` and ``unit = 'time'`` all trials are (if necessary) padded to 5s
      length. Here, `padlength` **has** to be provided, `prepadlength` and
      `postpadlength` can be `None` or `True`/`False`
    * ``pad = 'relative'``: pad by provided `padlength`, e.g., by using
      ``padlength = 20`` and ``unit = 'samples'``, 20 samples are padded
      symmetrically around (before and after) each trial. Use ``padlength = 20``
      and ``prepadlength = True`` **or** directly ``prepadlength = 20`` to pad
      before each trial. Here, at least one of `padlength`, `prepadlength` or
      `postpadlength` **has** to be provided. 
    * ``pad = 'maxlen'``: (only valid for **Syncopy objects**) pad up to maximal
      trial length found in `data`. All lengths have to be either Boolean
      indicating padding location or `None` (if all are `None`, symmetric
      padding is performed)
    * ``pad = 'nextpow2'``: pad each trial up to closest power of two. All
      lengths have to be either Boolean indicating padding location or `None`
      (if all are `None`, symmetric padding is performed)
    
    Full documentation below. 
    
    Parameters 
    ----------
    data : Syncopy object or :class:`numpy.ndarray`
        Non-empty Syncopy data object or array representing numeric data to be
        padded. **NOTE**: if `data` is a :class:`numpy.ndarray`, it is assumed
        that it represents recordings from only a single trial, where its first
        axis corresponds to time. In other words, `data` is a
        'time'-by-'channel' array such that its rows reflect samples and its
        columns represent channels. If `data` is a Syncopy object, trial
        information and dimensional order are fetched from `data.trials` and
        `data.dimord`, respectively. 
    padtype : str
        Padding value(s) to be used. Available options are:

        * 'zero' : pad using zeros
        * 'nan' : pad using `np.nan`'s
        * 'mean' : pad with by-channel mean value across each trial
        * 'localmean' : pad with by-channel mean value using only `padlength` or
          `prepadlength`/`postpadlength` number of boundary-entries for averaging
        * 'edge' : pad with trial-boundary values
        * 'mirror' : pad with reflections of trial-boundary values
        
    pad : str
        Padding mode to be used. Available options are:
        
        * 'absolute' : pad each trial to achieve a desired absolute length such
          that all trials have identical length post padding. If `pad` is `absolute`
          a `padlength` **has** to be provided, `prepadlength` and `postpadlength`
          may be `True` or `False`, respectively (see Examples for details).
        * 'relative' : pad each trial by provided `padlength` such that all trials
          are extended by the same amount regardless of their original lengths.
          If `pad` is `relative`, `prepadlength` and `postpadlength` can either 
          be specified directly (using numerical values) or implicitly by only
          providing `padlength` and setting `prepadlength` and `postpadlength`
          to `True` or `False`, respectively (see Examples for details). If `pad`
          is `relative` at least one of `padlength`, `prepadlength` or `postpadlength`
          **has** to be provided. 
        * 'maxlen' : only usable if `data` is a Syncopy object. If `pad` is
          'maxlen' all trials are padded to achieve the length of the longest
          trial in `data`, i.e., post padding, all trials have the same length, 
          which equals the size of the longest trial pre-padding. For 
          ``pad = 'maxlen'``, `padlength`, `prepadlength` as well as `postpadlength` 
          have to be either Boolean or `None` indicating the preferred padding 
          location (pre-trial, post-trial or symmetrically pre- and post-trial). 
          If all are `None`, symmetric padding is performed (see Examples for 
          details). 
        * 'nextpow2' : pad each trial to achieve a length equals the closest power
          of two of its original length. For ``pad = 'nextpow2'``, `padlength`, 
          `prepadlength` as well as `postpadlength` have to be either Boolean
          or `None` indicating the preferred padding location (pre-trial, post-trial 
          or symmetrically pre- and post-trial). If all are `None`, symmetric 
          padding is performed (see Examples for details). 

    padlength : None, bool or positive scalar
        Length to be padded to `data` (if `padlength` is scalar-valued) or
        padding location (if `padlength` is Boolean). Depending on the value of
        `pad`, `padlength` can be used to pre-pend (if `padlength` is a positive
        number and `prepadlength` is `True`) or append trials (if `padlength` is
        a positive number and `postpadlength` is `True`). If neither
        `prepadlength` nor `postpadlength` are specified (i.e, both are `None`),
        symmetric pre- and post-trial padding is performed (i.e., ``0.5 * padlength``
        before and after each trial - note that odd sample counts are rounded downward
        to the nearest even integer). If ``unit = 'time'``, `padlength` is assumed 
        to be given in seconds, otherwise (``unit = 'samples'``), `padlength` is 
        interpreted as sample-count. Note that only ``pad = 'relative'`` and 
        ``pad = 'absolute'`` support numeric values of `padlength`. 
    prepadlength : None, bool or positive scalar
        Length to be pre-pended before each trial (if `prepadlength` is
        scalar-valued) or pre-padding flag (if `prepadlength` is `True`). If
        `prepadlength` is `True`, pre-padding length is either directly inferred
        from `padlength` or implicitly derived from chosen padding mode defined
        by `pad`. If ``unit = 'time'``, `prepadlength` is assumed to be given in
        seconds, otherwise (``unit = 'samples'``), `prepadlength` is interpreted
        as sample-count. Note that only ``pad = 'relative'`` supports numeric
        values of `prepadlength`. 
    postpadlength : None, bool or positive scalar
        Length to be appended after each trial (if `postpadlength` is
        scalar-valued) or post-padding flag (if `postpadlength` is `True`). If
        `postpadlength` is `True`, post-padding length is either directly inferred
        from `padlength` or implicitly derived from chosen padding mode defined
        by `pad`. If ``unit = 'time'``, `postpadlength` is assumed to be given in
        seconds, otherwise (``unit = 'samples'``), `postpadlength` is interpreted
        as sample-count. Note that only ``pad = 'relative'`` supports numeric
        values of `postpadlength`. 
    unit : str
        Unit of numerical values given by `padlength` and/or `prepadlength`
        and/or `postpadlength`. If ``unit = 'time'``, `padlength`,
        `prepadlength`, and `postpadlength` are assumed to be given in seconds,
        otherwise (``unit = 'samples'``), `padlength`, `prepadlength`, and
        `postpadlength` are interpreted as sample-counts. **Note** Providing
        padding lengths in seconds (i.e., ``unit = 'time'``) is only supported
        if `data` is a Syncopy object. 
    create_new : bool
        If `True`, a padded copy of the same type as `data` is returned (a
        :class:`numpy.ndarray` or Syncopy object). If `create_new` is `False`,
        either a single dictionary (if `data` is a :class:`numpy.ndarray`) or a
        ``len(data.trials)``-long list of dictionaries (if `data` is a Syncopy
        object) with all necessary options for performing the actual padding
        operation with :func:`numpy.pad` is returned.  
        
    Returns
    -------
    pad_dict : dict, if `data` is a :class:`numpy.ndarray` and ``create_new = False``
        Dictionary whose items contain all necessary parameters for calling
        :func:`numpy.pad` to perform the desired padding operation on `data`. 
    pad_dicts : list, if `data` is a Syncopy object and ``create_new = False``
        List of dictionaries for calling :func:`numpy.pad` to perform the
        desired padding operation on all trials found in `data`. 
    out : :class:`numpy.ndarray`, if `data` is a :class:`numpy.ndarray` and ``create_new = True``
        Padded version (deep copy) of `data`
    out : Syncopy object, if `data` is a Syncopy object and ``create_new = True``
        Padded version (deep copy) of `data`
        
    Notes
    -----
    This method emulates (and extends) FieldTrip's `ft_preproc_padding` by
    providing a convenience wrapper for NumPy's :func:`numpy.pad` that performs
    the actual heavy lifting. 
    
    Examples
    --------
    Consider the following small array representing a toy-problem-trial of `ns` 
    samples across `nc` channels:
    
    >>> nc = 7; ns = 30
    >>> trl = np.random.randn(ns, nc)
    
    We start by padding a total of 10 zeros symmetrically to `trl`
    
    >>> padded = spy.padding(trl, 'zero', pad='relative', padlength=10)
    >>> padded[:6, :]
    array([[ 0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ],
       [ 0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ],
       [ 0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ],
       [ 0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ],
       [ 0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ],
       [-1.0866,  2.3358,  0.8758,  0.5196,  0.8049, -0.659 , -0.9173]])
    >>> padded[-6:, :]
    array([[ 0.027 ,  1.8069,  1.5249, -0.7953, -0.8933,  1.0202, -0.6862],
       [ 0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ],
       [ 0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ],
       [ 0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ],
       [ 0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ],
       [ 0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.    ]])
    >>> padded.shape
    (40, 7)
    
    Note that the above call is equivalent to
    
    >>> padded_ident = spy.padding(trl, 'zero', pad='relative', padlength=10, prepadlength=True, postpadlength=True)
    >>> np.array_equal(padded_ident, padded)
    True
    >>> padded_ident = spy.padding(trl, 'zero', pad='relative', prepadlength=5, postpadlength=5)
    >>> np.array_equal(padded_ident, padded)
    True
    
    Similarly, 
    
    >>> prepad = spy.padding(trl, 'nan', pad='relative', prepadlength=10)
    
    is the same as
    
    >>> prepad_ident = spy.padding(trl, 'nan', pad='relative', padlength=10, prepadlength=True)
    >>> np.allclose(prepad, prepad_ident, equal_nan=True)
    True
    
    Define bogus trials on `trl` and create a dummy object with unit samplerate
    
    >>> tdf = np.vstack([np.arange(0, ns, 5),
                         np.arange(5, ns + 5, 5),
                         np.ones((int(ns / 5), )),
                         np.ones((int(ns / 5), )) * np.pi]).T
    >>> adata = spy.AnalogData(trl, trialdefinition=tdf, samplerate=1)

    Pad each trial to the closest power of two by appending by-trial channel 
    averages. However, do not perform actual padding, but only prepare dictionaries
    of parameters to be passed on to :func:`numpy.pad`
    
    >>> pad_dicts = spy.padding(adata, 'mean', pad='nextpow2', postpadlength=True, create_new=False)
    >>> len(pad_dicts) == len(adata.trials) 
    True
    >>> pad_dicts[0]
    {'pad_width': array([[0, 3],
        [0, 0]]), 'mode': 'mean'}
        
    Similarly, the following call generates a list of dictionaries preparing 
    absolute padding by prepending zeros with :func:`numpy.pad`
    
    >>> pad_dicts = spy.padding(adata, 'zero', pad='absolute', padlength=10, prepadlength=True, create_new=False)
    >>> pad_dicts[0]
    {'pad_width': array([[5, 0],
        [0, 0]]), 'mode': 'constant', 'constant_values': 0}
            
    See also
    --------
    numpy.pad : fast array padding in NumPy
    """

    # Detect whether input is data object or array-like
    if any(["BaseData" in str(base) for base in data.__class__.__mro__]):
        try:
            data_parser(data,
                        varname="data",
                        dataclass="AnalogData",
                        empty=False)
        except Exception as exc:
            raise exc
        timeAxis = data.dimord.index("time")
        spydata = True
    elif data.__class__.__name__ == "FauxTrial":
        if len(data.shape) != 2:
            lgl = "two-dimensional AnalogData trial segment"
            act = "{}-dimensional trial segment"
            raise SPYValueError(legal=lgl,
                                varname="data",
                                actual=act.format(len(data.shape)))
        timeAxis = data.dimord.index("time")
        spydata = False
    else:
        try:
            array_parser(data, varname="data", dims=2)
        except Exception as exc:
            raise exc
        timeAxis = 0
        spydata = False

    # FIXME: Creation of new spy-object currently not supported
    if not isinstance(create_new, bool):
        raise SPYTypeError(create_new, varname="create_new", expected="bool")
    if spydata and create_new:
        raise NotImplementedError(
            "Creation of padded spy objects currently not supported. ")

    # Use FT-compatible options (sans FT option 'remove')
    if not isinstance(padtype, str):
        raise SPYTypeError(padtype, varname="padtype", expected="string")
    options = ["zero", "nan", "mean", "localmean", "edge", "mirror"]
    if padtype not in options:
        lgl = "'" + "or '".join(opt + "' " for opt in options)
        raise SPYValueError(legal=lgl, varname="padtype", actual=padtype)

    # Check `pad` and ensure we can actually perform the requested operation
    if not isinstance(pad, str):
        raise SPYTypeError(pad, varname="pad", expected="string")
    options = ["absolute", "relative", "maxlen", "nextpow2"]
    if pad not in options:
        lgl = "'" + "or '".join(opt + "' " for opt in options)
        raise SPYValueError(legal=lgl, varname="pad", actual=pad)
    if pad == "maxlen" and not spydata:
        lgl = "syncopy data object when using option 'maxlen'"
        raise SPYValueError(legal=lgl, varname="pad", actual="maxlen")

    # Make sure a data object was provided if we're working with time values
    if not isinstance(unit, str):
        raise SPYTypeError(unit, varname="unit", expected="string")
    options = ["samples", "time"]
    if unit not in options:
        lgl = "'" + "or '".join(opt + "' " for opt in options)
        raise SPYValueError(legal=lgl, varname="unit", actual=unit)
    if unit == "time" and not spydata:
        raise SPYValueError(
            legal="syncopy data object when using option 'time'",
            varname="unit",
            actual="time")

    # Set up dictionary for type-checking of provided padding lengths
    nt_dict = {"samples": "int_like", "time": None}

    # If we're padding up to an absolute bound or the max. length across
    # trials, compute lower bound for padding (in samples or seconds)
    if pad in ["absolute", "maxlen"]:
        if spydata:
            maxTrialLen = np.diff(data.sampleinfo).max()
        else:
            maxTrialLen = data.shape[
                timeAxis]  # if `pad="absolute" and data is array
    else:
        maxTrialLen = np.inf
    if unit == "time":
        padlim = maxTrialLen / data.samplerate
    else:
        padlim = maxTrialLen

    # To ease option processing, collect padding length keywords in dict
    plengths = {
        "padlength": padlength,
        "prepadlength": prepadlength,
        "postpadlength": postpadlength
    }

    # In case of relative padding, we need at least one scalar value to proceed
    if pad == "relative":

        # If `padlength = None`, pre- or post- need to be set; if `padlength`
        # is set, both pre- and post- need to be `None` or `True`/`False`.
        # After this code block, pre- and post- are guaranteed to be numeric.
        if padlength is None:
            for key in ["prepadlength", "postpadlength"]:
                if plengths[key] is not None:
                    try:
                        scalar_parser(plengths[key],
                                      varname=key,
                                      ntype=nt_dict[unit],
                                      lims=[0, np.inf])
                    except Exception as exc:
                        raise exc
                else:
                    plengths[key] = 0
        else:
            try:
                scalar_parser(padlength,
                              varname="padlength",
                              ntype=nt_dict[unit],
                              lims=[0, np.inf])
            except Exception as exc:
                raise exc
            for key in ["prepadlength", "postpadlength"]:
                if not isinstance(plengths[key], (bool, type(None))):
                    raise SPYTypeError(plengths[key],
                                       varname=key,
                                       expected="bool or None")

            if prepadlength is None and postpadlength is None:
                prepadlength = True
                postpadlength = True
            else:
                prepadlength = prepadlength is not None
                postpadlength = postpadlength is not None

            if prepadlength and postpadlength:
                plengths["prepadlength"] = padlength / 2
                plengths["postpadlength"] = padlength / 2
            else:
                plengths["prepadlength"] = prepadlength * padlength
                plengths["postpadlength"] = postpadlength * padlength

        # Under-determined: abort if requested padding length is 0
        if all(value == 0 for value in plengths.values() if value is not None):
            lgl = "either non-zero value of `padlength` or `prepadlength` " + \
                  "and/or `postpadlength` to be set"
            raise SPYValueError(legal=lgl,
                                varname="padlength",
                                actual="0|None")

    else:

        # For absolute padding, the desired length has to be >= max. trial length
        if pad == "absolute":
            try:
                scalar_parser(padlength,
                              varname="padlength",
                              ntype=nt_dict[unit],
                              lims=[padlim, np.inf])
            except Exception as exc:
                raise exc
            for key in ["prepadlength", "postpadlength"]:
                if not isinstance(plengths[key], (bool, type(None))):
                    raise SPYTypeError(plengths[key],
                                       varname=key,
                                       expected="bool or None")

        # For `maxlen` or `nextpow2` we don't want any numeric entries at all
        else:
            for key, value in plengths.items():
                if not isinstance(value, (bool, type(None))):
                    raise SPYTypeError(value,
                                       varname=key,
                                       expected="bool or None")

            # Warn of potential conflicts
            if padlength and (prepadlength or postpadlength):
                msg = "Found `padlength` and `prepadlength` and/or " +\
                    "`postpadlength`. Symmetric padding is performed. "
                SPYWarning(msg)

        # If both pre-/post- are `None`, set them to `True` to use symmetric
        # padding, otherwise convert `None` entries to `False`
        if prepadlength is None and postpadlength is None:
            plengths["prepadlength"] = True
            plengths["postpadlength"] = True
        else:
            plengths["prepadlength"] = plengths["prepadlength"] is not None
            plengths["postpadlength"] = plengths["postpadlength"] is not None

    # Update pre-/post-padding and (if required) convert time to samples
    prepadlength = plengths["prepadlength"]
    postpadlength = plengths["postpadlength"]
    if unit == "time":
        if pad == "relative":
            prepadlength = int(prepadlength * data.samplerate)
            postpadlength = int(postpadlength * data.samplerate)
        elif pad == "absolute":
            padlength = int(padlength * data.samplerate)

    # Construct dict of keywords for ``np.pad`` depending on chosen `padtype`
    kws = {
        "zero": {
            "mode": "constant",
            "constant_values": 0
        },
        "nan": {
            "mode": "constant",
            "constant_values": np.nan
        },
        "localmean": {
            "mode": "mean",
            "stat_length": -1
        },
        "mean": {
            "mode": "mean"
        },
        "edge": {
            "mode": "edge"
        },
        "mirror": {
            "mode": "reflect"
        }
    }

    # If in put was syncopy data object, padding is done on a per-trial basis
    if spydata:

        # A list of input keywords for ``np.pad`` is constructed, no matter if
        # we actually want to build a new object or not
        pad_opts = []
        for trl in data.trials:
            nSamples = trl.shape[timeAxis]
            if pad == "absolute":
                padding = (padlength - nSamples) / (prepadlength +
                                                    postpadlength)
            elif pad == "relative":
                padding = True
            elif pad == "maxlen":
                padding = (maxTrialLen - nSamples) / (prepadlength +
                                                      postpadlength)
            elif pad == "nextpow2":
                padding = (_nextpow2(nSamples) - nSamples) / (prepadlength +
                                                              postpadlength)
            pw = np.zeros((2, 2), dtype=int)
            pw[timeAxis, :] = [prepadlength * padding, postpadlength * padding]
            pad_opts.append(dict({"pad_width": pw}, **kws[padtype]))
            if padtype == "localmean":
                pad_opts[-1]["stat_length"] = pw[timeAxis, :]

        if create_new:
            pass
        else:
            return pad_opts

    # Input was a array/FauxTrial (i.e., single trial) - we have to do the padding just once
    else:

        nSamples = data.shape[timeAxis]
        if pad == "absolute":
            padding = (padlength - nSamples) / (prepadlength + postpadlength)
        elif pad == "relative":
            padding = True
        elif pad == "nextpow2":
            padding = (_nextpow2(nSamples) - nSamples) / (prepadlength +
                                                          postpadlength)
        pw = np.zeros((2, 2), dtype=int)
        pw[timeAxis, :] = [prepadlength * padding, postpadlength * padding]
        pad_opts = dict({"pad_width": pw}, **kws[padtype])
        if padtype == "localmean":
            pad_opts["stat_length"] = pw[timeAxis, :]

        if create_new:
            if isinstance(data, np.ndarray):
                return np.pad(data, **pad_opts)
            else:  # FIXME: currently only supports FauxTrial
                shp = list(data.shape)
                shp[timeAxis] += pw[timeAxis, :].sum()
                idx = list(data.idx)
                if isinstance(idx[timeAxis], slice):
                    idx[timeAxis] = slice(idx[timeAxis].start,
                                          idx[timeAxis].start + shp[timeAxis])
                else:
                    idx[timeAxis] = pw[timeAxis, 0] * [idx[timeAxis][0]] + idx[timeAxis] \
                                    + pw[timeAxis, 1] * [idx[timeAxis][-1]]
                return data.__class__(shp, idx, data.dtype, data.dimord)
        else:
            return pad_opts
예제 #15
0
def multipanelplot(self,
                   trials="all",
                   channels="all",
                   tapers="all",
                   toilim=None,
                   foilim=None,
                   avg_channels=False,
                   avg_tapers=True,
                   avg_trials=True,
                   panels="channels",
                   interp="spline36",
                   cmap="plasma",
                   vmin=None,
                   vmax=None,
                   title=None,
                   grid=None,
                   fig=None,
                   **kwargs):
    """
    Plot contents of :class:`~syncopy.SpectralData` objects using multi-panel figure(s)
    
    Please refer to :func:`syncopy.multipanelplot` for detailed usage information. 
    
    Examples
    --------
    Use 16 panels to show frequency range 30-80 Hz of first 16 channels in `freqData` 
    averaged across trials 2, 4, and 6:
    
    >>> fig = spy.multipanelplot(freqData, trials=[2, 4, 6], channels=range(16),
                                 foilim=[30, 80], panels="channels")
                                  
    Same settings, but each panel represents a trial:
    
    >>> fig = spy.multipanelplot(freqData, trials=[2, 4, 6], channels=range(16),
                                 foilim=[30, 80], panels="trials", avg_trials=False, 
                                 avg_channels=True)
    
    Plot time-frequency contents of channels `'ecog_mua1'` and `'ecog_mua2'` of 
    `tfData` 
        
    >>> fig = spy.multipanelplot(tfData, channels=['ecog_mua1', 'ecog_mua2'])
    
    Note that multi-panel overlay plotting is **not** supported for 
    :class:`~syncopy.SpectralData` objects.
    
    See also
    --------
    syncopy.multipanelplot : visualize Syncopy data objects using multi-panel plots
    """

    # Collect input arguments in dict `inputArgs` and process them
    inputArgs = locals()
    inputArgs.pop("self")
    (dimArrs, dimCounts, isTimeFrequency, complexConversion, pltDtype,
     dataLbl) = _prep_spectral_plots(self, "multipanelplot", **inputArgs)
    (nTrials, nChan, nFreq, nTap) = dimCounts
    (trList, chArr, freqArr, tpArr) = dimArrs

    # No overlaying here...
    if hasattr(fig, "objCount"):
        msg = "Overlays of multi-panel `SpectralData` plots not supported"
        raise SPYError(msg)

    # Ensure panel-specification makes sense and is compatible w/averaging selection
    if not isinstance(panels, str):
        raise SPYTypeError(panels, varname="panels", expected="str")
    if panels not in availablePanels:
        lgl = "'" + "or '".join(opt + "' " for opt in availablePanels)
        raise SPYValueError(legal=lgl, varname="panels", actual=panels)
    if (panels == "channels" and avg_channels) or (panels == "trials" and avg_trials) \
        or (panels == "tapers" and avg_tapers):
        msg = "Cannot use `panels = {}` and average across {} at the same time. "
        SPYWarning(msg.format(panels, panels))
        return

    # Ensure the proper amount of averaging was specified
    avgFlags = [avg_channels, avg_trials, avg_tapers]
    if sum(avgFlags) == 0 and nTap * nTrials > 1:
        msg = "Need to average across at least one of tapers, channels or trials " +\
            "for visualization. "
        SPYWarning(msg)
        return
    if sum(avgFlags) == 3:
        msg = "Averaging across trials, channels and tapers results in " +\
            "single-panel plot. Please use `singlepanelplot` instead"
        SPYWarning(msg)
        return
    if isTimeFrequency:
        if sum(avgFlags) != 2:
            msg = "Multi-panel time-frequency visualization requires averaging across " +\
                "two out of three dimensions (tapers, channels trials)"
            SPYWarning(msg)
            return

    # Prepare figure (same for all cases)
    if panels == "channels":
        npanels = nChan
    elif panels == "trials":
        npanels = nTrials
    else:  # ``panels == "tapers"``
        npanels = nTap

    # Construct subplot panel layout or vet provided layout
    nrow = kwargs.get("nrow", None)
    ncol = kwargs.get("ncol", None)
    if not isTimeFrequency:
        fig, ax_arr = _setup_figure(npanels,
                                    nrow=nrow,
                                    ncol=ncol,
                                    xLabel="Frequency [Hz]",
                                    yLabel=dataLbl,
                                    grid=grid,
                                    include_colorbar=False,
                                    sharex=True,
                                    sharey=True)
    else:
        fig, ax_arr, cax = _setup_figure(npanels,
                                         nrow=nrow,
                                         ncol=ncol,
                                         xLabel="Time [s]",
                                         yLabel="Frequency [Hz]",
                                         grid=grid,
                                         include_colorbar=True,
                                         sharex=True,
                                         sharey=True)

    # Monkey-patch object-counter to newly created figure
    fig.spectralPlot = True

    # Start with the "simple" case: "regular" spectra, no time involved
    if not isTimeFrequency:

        # We're not dealing w/TF data here
        nTime = 1
        N = 1

        # For each panel stratification, set corresponding positional and
        # keyword args for iteratively calling `_compute_pltArr`
        if panels == "channels":

            panelVar = "channel"
            panelValues = chArr
            panelTitles = chArr

            if not avg_trials and avg_tapers:
                avgDim1 = "taper"
                avgDim2 = None
                innerVar = "trial"
                innerValues = trList
                majorTitle = "{} trials averaged across {} tapers".format(
                    nTrials, nTap)
                showLegend = True
            elif avg_trials and not avg_tapers:
                avgDim1 = None
                avgDim2 = None
                innerVar = "taper"
                innerValues = tpArr
                majorTitle = "{} tapers averaged across {} trials".format(
                    nTap, nTrials)
                showLegend = True
            else:  # `avg_trials` and `avg_tapers`
                avgDim1 = "taper"
                avgDim2 = None
                innerVar = "trial"
                innerValues = ["all"]
                majorTitle = " Average of {} tapers and {} trials".format(
                    nTap, nTrials)
                showLegend = False

        elif panels == "trials":

            panelVar = "trial"
            panelValues = trList
            panelTitles = ["Trial #{}".format(trlno) for trlno in trList]

            if not avg_channels and avg_tapers:
                avgDim1 = "taper"
                avgDim2 = None
                innerVar = "channel"
                innerValues = chArr
                majorTitle = "{} channels averaged across {} tapers".format(
                    nChan, nTap)
                showLegend = True
            elif avg_channels and not avg_tapers:
                avgDim1 = "channel"
                avgDim2 = None
                innerVar = "taper"
                innerValues = tpArr
                majorTitle = "{} tapers averaged across {} channels".format(
                    nTap, nChan)
                showLegend = True
            else:  # `avg_channels` and `avg_tapers`
                avgDim1 = "taper"
                avgDim2 = "channel"
                innerVar = "trial"
                innerValues = ["all"]
                majorTitle = " Average of {} channels and {} tapers".format(
                    nChan, nTap)
                showLegend = False

        else:  # panels = "tapers"

            panelVar = "taper"
            panelValues = tpArr
            panelTitles = ["Taper #{}".format(tpno) for tpno in tpArr]

            if not avg_trials and avg_channels:
                avgDim1 = "channel"
                avgDim2 = None
                innerVar = "trial"
                innerValues = trList
                majorTitle = "{} trials averaged across {} channels".format(
                    nTrials, nChan)
                showLegend = True
            elif avg_trials and not avg_channels:
                avgDim1 = None
                avgDim2 = None
                innerVar = "channel"
                innerValues = chArr
                majorTitle = "{} channels averaged across {} trials".format(
                    nChan, nTrials)
                showLegend = True
            else:  # `avg_trials` and `avg_channels`
                avgDim1 = "channel"
                avgDim2 = None
                innerVar = "trial"
                innerValues = ["all"]
                majorTitle = " Average of {} channels and {} trials".format(
                    nChan, nTrials)
                showLegend = False

        # Loop over panels, within each panel, loop over `innerValues` to (potentially)
        # plot multiple spectra per panel
        kwargs = {"avg1": avgDim1, "avg2": avgDim2}
        for panelCount, panelVal in enumerate(panelValues):
            kwargs[panelVar] = panelVal
            for innerVal in innerValues:
                kwargs[innerVar] = innerVal
                pltArr = _compute_pltArr(self, nFreq, N, nTime,
                                         complexConversion, pltDtype, **kwargs)
                ax_arr[panelCount].plot(freqArr,
                                        np.log10(pltArr),
                                        label=innerVar.capitalize() + " " +
                                        str(innerVal))
            ax_arr[panelCount].set_title(panelTitles[panelCount],
                                         size=pltConfig["multiTitleSize"])
        if showLegend:
            handles, labels = ax_arr[0].get_legend_handles_labels()
            ax_arr[0].legend(handles, labels)
        if title is None:
            fig.suptitle(majorTitle, size=pltConfig["singleTitleSize"])

    # Now, multi-panel time-frequency visualizations
    else:

        # Compute (and verify) length of selected time intervals
        tLengths = _prep_toilim_avg(self)
        nTime = tLengths[0]
        time = self.time[trList[0]][self._selection.time[0]]
        N = 1

        if panels == "channels":
            panelVar = "channel"
            panelValues = chArr
            panelTitles = chArr
            majorTitle = " Average of {} tapers and {} trials".format(
                nTap, nTrials)
            avgDim1 = "taper"
            avgDim2 = None

        elif panels == "trials":
            panelVar = "trial"
            panelValues = trList
            panelTitles = ["Trial #{}".format(trlno) for trlno in trList]
            majorTitle = " Average of {} channels and {} tapers".format(
                nChan, nTap)
            avgDim1 = "taper"
            avgDim2 = "channel"

        else:  # panels = "tapers"
            panelVar = "taper"
            panelValues = tpArr
            panelTitles = ["Taper #{}".format(tpno) for tpno in tpArr]
            majorTitle = " Average of {} channels and {} trials".format(
                nChan, nTrials)
            avgDim1 = "channel"
            avgDim2 = None

        # Loop over panels, within each panel, loop over `innerValues` to (potentially)
        # plot multiple spectra per panel
        kwargs = {"avg1": avgDim1, "avg2": avgDim2}
        vmins = []
        vmaxs = []
        for panelCount, panelVal in enumerate(panelValues):
            kwargs[panelVar] = panelVal
            pltArr = _compute_pltArr(self, nFreq, N, nTime, complexConversion,
                                     pltDtype, **kwargs)
            vmins.append(pltArr.min())
            vmaxs.append(pltArr.max())
            ax_arr[panelCount].imshow(pltArr,
                                      origin="lower",
                                      interpolation=interp,
                                      cmap=cmap,
                                      extent=(time[0], time[-1], freqArr[0],
                                              freqArr[-1]),
                                      aspect="auto")
            ax_arr[panelCount].set_title(panelTitles[panelCount],
                                         size=pltConfig["multiTitleSize"])

        # Render colorbar
        if vmin is None:
            vmin = min(vmins)
        if vmax is None:
            vmax = max(vmaxs)
        cbar = _setup_colorbar(fig,
                               ax_arr,
                               cax,
                               label=dataLbl.replace(" [dB]", ""),
                               outline=False,
                               vmin=vmin,
                               vmax=vmax)
        if title is None:
            fig.suptitle(majorTitle, size=pltConfig["singleTitleSize"])

    # Increment overlay-counter and draw figure
    fig.objCount += 1
    plt.draw()
    self._selection = None
    return fig
예제 #16
0
def freqanalysis(data,
                 method='mtmfft',
                 output='fourier',
                 keeptrials=True,
                 foi=None,
                 foilim=None,
                 pad=None,
                 padtype='zero',
                 padlength=None,
                 prepadlength=None,
                 postpadlength=None,
                 polyremoval=None,
                 taper="hann",
                 tapsmofrq=None,
                 keeptapers=False,
                 toi=None,
                 t_ftimwin=None,
                 wav="Morlet",
                 width=6,
                 order=None,
                 out=None,
                 **kwargs):
    """
    Perform (time-)frequency analysis of Syncopy :class:`~syncopy.AnalogData` objects
    
    **Usage Summary**
    
    Options available in all analysis methods:
    
    * **output** : one of :data:`~.availableOutputs`; return power spectra, complex 
      Fourier spectra or absolute values. 
    * **foi**/**foilim** : frequencies of interest; either array of frequencies or 
      frequency window (not both)
    * **keeptrials** : return individual trials or grand average
    * **polyremoval** : de-trending method to use (0 = mean, 1 = linear, 2 = quadratic, 
      3 = cubic, etc.)
            
    List of available analysis methods and respective distinct options:
    
    :func:`~syncopy.specest.mtmfft.mtmfft` : (Multi-)tapered Fourier transform
        Perform frequency analysis on time-series trial data using either a single 
        taper window (Hanning) or many tapers based on the discrete prolate 
        spheroidal sequence (DPSS) that maximize energy concentration in the main
        lobe. 
        
        * **taper** : one of :data:`~.availableTapers`
        * **tapsmofrq** : spectral smoothing box for tapers (in Hz)
        * **keeptapers** : return individual tapers or average
        * **pad** : padding method to use (`None`, `True`, `False`, `'absolute'`, 
          `'relative'`, `'maxlen'` or `'nextpow2'`). If `None`, then `'nextpow2'`
          is selected by default. 
        * **padtype** : values to pad data with (`'zero'`, `'nan'`, `'mean'`, `'localmean'`, 
          `'edge'` or `'mirror'`)
        * **padlength** : number of samples to pre-pend and/or append to each trial 
        * **prepadlength** : number of samples to pre-pend to each trial 
        * **postpadlength** : number of samples to append to each trial 

    :func:`~syncopy.specest.mtmconvol.mtmconvol` : (Multi-)tapered sliding window Fourier transform
        Perform time-frequency analysis on time-series trial data based on a sliding 
        window short-time Fourier transform using either a single Hanning taper or 
        multiple DPSS tapers. 
        
        * **taper** : one of :data:`~.availableTapers`
        * **tapsmofrq** : spectral smoothing box for tapers (in Hz)
        * **keeptapers** : return individual tapers or average
        * **pad** : flag indicating, whether or not to pad trials. If `None`, 
          trials are padded only if sliding window centroids are too close
          to trial boundaries for the entire window to cover available data-points. 
        * **toi** : time-points of interest; can be either an array representing 
          analysis window centroids (in sec), a scalar between 0 and 1 encoding 
          the percentage of overlap between adjacent windows or "all" to center 
          a window on every sample in the data. 
        * **t_ftimwin** : sliding window length (in sec)

    :func:`~syncopy.specest.wavelet.wavelet` : (Continuous non-orthogonal) wavelet transform
        Perform time-frequency analysis on time-series trial data using a non-orthogonal
        continuous wavelet transform. 
        
        * **wav** : one of :data:`~.availableWavelets`
        * **toi** : time-points of interest; can be either an array representing 
          time points (in sec) to center wavelets on or "all" to center a wavelet 
          on every sample in the data. 
        * **width** : Nondimensional frequency constant of Morlet wavelet function (>= 6)
        * **order** : Order of Paul wavelet function (>= 4) or derivative order
          of real-valued DOG wavelets (2 = mexican hat)

    **Full documentation below** 
    
    Parameters
    ----------
    data : `~syncopy.AnalogData`
        A non-empty Syncopy :class:`~syncopy.datatype.AnalogData` object
    method : str
        Spectral estimation method, one of :data:`~.availableMethods` 
        (see below).
    output : str
        Output of spectral estimation. One of :data:`~.availableOutputs` (see below); 
        use `'pow'` for power spectrum (:obj:`numpy.float32`), `'fourier'` for complex 
        Fourier coefficients (:obj:`numpy.complex128`) or `'abs'` for absolute 
        values (:obj:`numpy.float32`).
    keeptrials : bool
        If `True` spectral estimates of individual trials are returned, otherwise
        results are averaged across trials. 
    foi : array-like or None
        Frequencies of interest (Hz) for output. If desired frequencies cannot be 
        matched exactly, the closest possible frequencies are used. If `foi` is `None`
        or ``foi = "all"``, all attainable frequencies (i.e., zero to Nyquist / 2) 
        are selected. 
    foilim : array-like (floats [fmin, fmax]) or None or "all"
        Frequency-window ``[fmin, fmax]`` (in Hz) of interest. Window 
        specifications must be sorted (e.g., ``[90, 70]`` is invalid) and not NaN 
        but may be unbounded (e.g., ``[-np.inf, 60.5]`` is valid). Edges `fmin` 
        and `fmax` are included in the selection. If `foilim` is `None` or 
        ``foilim = "all"``, all frequencies are selected. 
    pad : str or None or bool
        One of `None`, `True`, `False`, `'absolute'`, `'relative'`, `'maxlen'` or
        `'nextpow2'`. 
        If `pad` is `None` or ``pad = True``, then method-specific defaults are 
        chosen. Specifically, if `method` is `'mtmfft'` then `pad` is set to 
        `'nextpow2'` so that all trials in `data` are padded to the next power of 
        two higher than the sample-count of the longest (selected) trial in `data`. Conversely, 
        time-frequency analysis methods (`'mtmconvol'` and `'wavelet'`), only perform
        padding if necessary, i.e., if time-window centroids are chosen too close
        to trial boundaries for the entire window to cover available data-points. 
        If `pad` is `False`, then no padding is performed. Then in case of 
        ``method = 'mtmfft'`` all trials have to have approximately the same 
        length (up to the next even sample-count), if ``method = 'mtmconvol'`` or 
        ``method = 'wavelet'``, window-centroids have to keep sufficient
        distance from trial boundaries. For more details on the padding methods 
        `'absolute'`, `'relative'`, `'maxlen'` and `'nextpow2'` see :func:`syncopy.padding`. 
    padtype : str
        Values to be used for padding. Can be `'zero'`, `'nan'`, `'mean'`, 
        `'localmean'`, `'edge'` or `'mirror'`. See :func:`syncopy.padding` for 
        more information.
    padlength : None, bool or positive int
        Only valid if `method` is `'mtmfft'` and `pad` is `'absolute'` or `'relative'`. 
        Number of samples to pad data with. See :func:`syncopy.padding` for more 
        information.
    prepadlength : None or bool or int
        Only valid if `method` is `'mtmfft'` and `pad` is `'relative'`. Number of 
        samples to pre-pend to each trial. See :func:`syncopy.padding` for more 
        information.
    postpadlength : None or bool or int
        Only valid if `method` is `'mtmfft'` and `pad` is `'relative'`. Number of 
        samples to append to each trial. See :func:`syncopy.padding` for more 
        information.
    polyremoval : int or None
        **FIXME: Not implemented yet**
        Order of polynomial used for de-trending data in the time domain prior 
        to spectral analysis. A value of 0 corresponds to subtracting the mean 
        ("de-meaning"), ``polyremoval = 1`` removes linear trends (subtracting the 
        least squares fit of a linear polynomial), ``polyremoval = N`` for `N > 1` 
        subtracts a polynomial of order `N` (``N = 2`` quadratic, ``N = 3`` cubic 
        etc.). If `polyremoval` is `None`, no de-trending is performed. 
    taper : str
        Only valid if `method` is `'mtmfft'` or `'mtmconvol'`. Windowing function, 
        one of :data:`~.availableTapers` (see below).
    tapsmofrq : float
        Only valid if `method` is `'mtmfft'` or `'mtmconvol'`. The amount of spectral 
        smoothing through  multi-tapering (Hz). Note that smoothing frequency 
        specifications are one-sided, i.e., 4 Hz smoothing means plus-minus 4 Hz, 
        i.e., a 8 Hz smoothing box.
    keeptapers : bool
        Only valid if `method` is `'mtmfft'` or `'mtmconvol'`. If `True`, return 
        spectral estimates for each taper, otherwise results are averaged across
        tapers. 
    toi : float or array-like or "all"
        **Mandatory input** for time-frequency analysis methods (`method` is either 
        `"mtmconvol"` or `"wavelet"`). 
        If `toi` is scalar, it must be a value between 0 and 1 indicating the 
        percentage of overlap between time-windows specified by `t_ftimwin` (only
        valid if `method` is `'mtmconvol'`, invalid for `'wavelet'`). 
        If `toi` is an array it explicitly selects the centroids of analysis 
        windows (in seconds). If `toi` is `"all"`, analysis windows are centered
        on all samples in the data. 
    t_ftimwin : positive float
        Only valid if `method` is `'mtmconvol'`. Sliding window length (in seconds). 
    wav : str
        Only valid if `method` is `'wavelet'`. Wavelet function to use, one of 
        :data:`~.availableWavelets` (see below).
    width : positive float
        Only valid if `method` is `'wavelet'` and `wav` is `'Morlet'`. Nondimensional 
        frequency constant of Morlet wavelet function. This number should be >= 6, 
        which corresponds to 6 cycles within the analysis window to ensure sufficient 
        spectral sampling. 
    order : positive int
        Only valid if `method` is `'wavelet'` and `wav` is `'Paul'` or `'DOG'`. Order 
        of the wavelet function. If `wav` is `'Paul'`, `order` should be chosen
        >= 4 to ensure that the analysis window contains at least a single oscillation. 
        At an order of 40, the Paul wavelet  exhibits about the same number of cycles 
        as the Morlet wavelet with a `width` of 6. 
        All other supported wavelets functions are *real-valued* derivatives of 
        Gaussians (DOGs). Hence, if `wav` is `'DOG'`, `order` represents the derivative order. 
        The special case of a second order DOG yields a function known as "Mexican Hat", 
        "Marr" or "Ricker" wavelet, which can be selected alternatively by setting
        `wav` to `'Mexican_hat'`, `'Marr'` or `'Ricker'`. **Note**: A real-valued
        wavelet function encodes *only* information about peaks and discontinuities 
        in the signal and does *not* provide any information about amplitude or phase. 
    out : None or :class:`SpectralData` object
        None if a new :class:`SpectralData` object is to be created, or an empty :class:`SpectralData` object
        

    Returns
    -------
    spec : :class:`~syncopy.SpectralData`
        (Time-)frequency spectrum of input data
        
    Notes
    -----
    Coming soon...
    
    Examples
    --------
    Coming soon...
        

    .. autodata:: syncopy.specest.freqanalysis.availableMethods

    .. autodata:: syncopy.specest.freqanalysis.availableOutputs

    .. autodata:: syncopy.specest.freqanalysis.availableTapers

    .. autodata:: syncopy.specest.freqanalysis.availableWavelets
    
    See also
    --------
    syncopy.specest.mtmfft.mtmfft : (multi-)tapered Fourier transform of multi-channel time series data
    syncopy.specest.mtmconvol.mtmconvol : time-frequency analysis of multi-channel time series data with a sliding window FFT
    syncopy.specest.wavelet.wavelet : time-frequency analysis of multi-channel time series data using a wavelet transform
    numpy.fft.fft : NumPy's reference FFT implementation
    scipy.signal.stft : SciPy's Short Time Fourier Transform
    """

    # Make sure our one mandatory input object can be processed
    try:
        data_parser(data,
                    varname="data",
                    dataclass="AnalogData",
                    writable=None,
                    empty=False)
    except Exception as exc:
        raise exc
    timeAxis = data.dimord.index("time")

    # Get everything of interest in local namespace
    defaults = get_defaults(freqanalysis)
    lcls = locals()

    # Ensure a valid computational method was selected
    if method not in availableMethods:
        lgl = "'" + "or '".join(opt + "' " for opt in availableMethods)
        raise SPYValueError(legal=lgl, varname="method", actual=method)

    # Ensure a valid output format was selected
    if output not in spectralConversions.keys():
        lgl = "'" + "or '".join(opt + "' "
                                for opt in spectralConversions.keys())
        raise SPYValueError(legal=lgl, varname="output", actual=output)

    # Parse all Boolean keyword arguments
    for vname in ["keeptrials", "keeptapers"]:
        if not isinstance(lcls[vname], bool):
            raise SPYTypeError(lcls[vname], varname=vname, expected="Bool")

    # If only a subset of `data` is to be processed, make some necessary adjustments
    # and compute minimal sample-count across (selected) trials
    if data._selection is not None:
        trialList = data._selection.trials
        sinfo = np.zeros((len(trialList), 2))
        for tk, trlno in enumerate(trialList):
            trl = data._preview_trial(trlno)
            tsel = trl.idx[timeAxis]
            if isinstance(tsel, list):
                sinfo[tk, :] = [0, len(tsel)]
            else:
                sinfo[tk, :] = [
                    trl.idx[timeAxis].start, trl.idx[timeAxis].stop
                ]
    else:
        trialList = list(range(len(data.trials)))
        sinfo = data.sampleinfo
    lenTrials = np.diff(sinfo).squeeze()
    numTrials = len(trialList)

    # Set default padding options: after this, `pad` is either `None`, `False` or `str`
    defaultPadding = {"mtmfft": "nextpow2", "mtmconvol": None, "wavelet": None}
    if pad is None or pad is True:
        pad = defaultPadding[method]

    # Sliding window FFT does not support "fancy" padding
    if method == "mtmconvol" and isinstance(pad, str):
        msg = "method 'mtmconvol' only supports in-place padding for windows " +\
            "exceeding trial boundaries. Your choice of `pad = '{}'` will be ignored. "
        SPYWarning(msg.format(pad))
        pad = None

    # Ensure padding selection makes sense: do not pad on a by-trial basis but
    # use the longest trial as reference and compute `padlength` from there
    # (only relevant for "global" padding options such as `maxlen` or `nextpow2`)
    if pad:
        if not isinstance(pad, str):
            raise SPYTypeError(pad, varname="pad", expected="str or None")
        if pad == "maxlen":
            padlength = lenTrials.max()
            prepadlength = True
            postpadlength = False
        elif pad == "nextpow2":
            padlength = 0
            for ltrl in lenTrials:
                padlength = max(padlength, _nextpow2(ltrl))
            pad = "absolute"
            prepadlength = True
            postpadlength = False
        padding(data._preview_trial(trialList[0]),
                padtype,
                pad=pad,
                padlength=padlength,
                prepadlength=prepadlength,
                postpadlength=postpadlength)

        # Compute `minSampleNum` accounting for padding
        minSamplePos = lenTrials.argmin()
        minSampleNum = padding(data._preview_trial(trialList[minSamplePos]),
                               padtype,
                               pad=pad,
                               padlength=padlength,
                               prepadlength=True).shape[timeAxis]
    else:
        if method == "mtmfft" and np.unique(
            (np.floor(lenTrials / 2))).size > 1:
            lgl = "trials of approximately equal length for method 'mtmfft'"
            act = "trials of unequal length"
            raise SPYValueError(legal=lgl, varname="data", actual=act)
        minSampleNum = lenTrials.min()

    # Compute length (in samples) of shortest trial
    minTrialLength = minSampleNum / data.samplerate

    # Basic sanitization of frequency specifications
    if foi is not None:
        if isinstance(foi, str):
            if foi == "all":
                foi = None
            else:
                raise SPYValueError(legal="'all' or `None` or list/array",
                                    varname="foi",
                                    actual=foi)
        else:
            try:
                array_parser(foi,
                             varname="foi",
                             hasinf=False,
                             hasnan=False,
                             lims=[0, data.samplerate / 2],
                             dims=(None, ))
            except Exception as exc:
                raise exc
            foi = np.array(foi, dtype="float")
    if foilim is not None:
        if isinstance(foilim, str):
            if foilim == "all":
                foilim = None
            else:
                raise SPYValueError(legal="'all' or `None` or `[fmin, fmax]`",
                                    varname="foilim",
                                    actual=foilim)
        else:
            try:
                array_parser(foilim,
                             varname="foilim",
                             hasinf=False,
                             hasnan=False,
                             lims=[0, data.samplerate / 2],
                             dims=(2, ))
            except Exception as exc:
                raise exc
    if foi is not None and foilim is not None:
        lgl = "either `foi` or `foilim` specification"
        act = "both"
        raise SPYValueError(legal=lgl, varname="foi/foilim", actual=act)

    # FIXME: implement detrending
    # see also https://docs.obspy.org/_modules/obspy/signal/detrend.html#polynomial
    if polyremoval is not None:
        raise NotImplementedError("Detrending has not been implemented yet.")
        try:
            scalar_parser(polyremoval,
                          varname="polyremoval",
                          lims=[0, 8],
                          ntype="int_like")
        except Exception as exc:
            raise exc

    # Prepare keyword dict for logging (use `lcls` to get actually provided
    # keyword values, not defaults set above)
    log_dct = {
        "method": method,
        "output": output,
        "keeptapers": keeptapers,
        "keeptrials": keeptrials,
        "polyremoval": polyremoval,
        "pad": lcls["pad"],
        "padtype": lcls["padtype"],
        "padlength": lcls["padlength"],
        "foi": lcls["foi"]
    }

    # 1st: Check time-frequency inputs to prepare/sanitize `toi`
    if method in ["mtmconvol", "wavelet"]:

        # Get start/end timing info respecting potential in-place selection
        if toi is None:
            raise SPYTypeError(toi,
                               varname="toi",
                               expected="scalar or array-like or 'all'")
        if data._selection is not None:
            tStart = data._selection.trialdefinition[:, 2] / data.samplerate
        else:
            tStart = data._t0 / data.samplerate
        tEnd = tStart + lenTrials / data.samplerate

        # Process `toi`: we have to account for three scenarios: (1) center sliding
        # windows on all samples in (selected) trials (2) `toi` was provided as
        # percentage indicating the degree of overlap b/w time-windows and (3) a set
        # of discrete time points was provided. These three cases are encoded in
        # `overlap, i.e., ``overlap > 1` => all, `0 < overlap < 1` => percentage,
        # `overlap < 0` => discrete `toi`
        if isinstance(toi, str):
            if toi != "all":
                lgl = "`toi = 'all'` to center analysis windows on all time-points"
                raise SPYValueError(legal=lgl, varname="toi", actual=toi)
            overlap = 1.1
            toi = None
            equidistant = True
        elif isinstance(toi, Number):
            if method == "wavelet":
                lgl = "array of time-points wavelets are to be centered on"
                act = "scalar value"
                raise SPYValueError(legal=lgl, varname="toi", actual=act)
            try:
                scalar_parser(toi, varname="toi", lims=[0, 1])
            except Exception as exc:
                raise exc
            overlap = toi
            equidistant = True
        else:
            overlap = -1
            try:
                array_parser(toi,
                             varname="toi",
                             hasinf=False,
                             hasnan=False,
                             lims=[tStart.min(), tEnd.max()],
                             dims=(None, ))
            except Exception as exc:
                raise exc
            toi = np.array(toi)
            tSteps = np.diff(toi)
            if (tSteps < 0).any():
                lgl = "ordered list/array of time-points"
                act = "unsorted list/array"
                raise SPYValueError(legal=lgl, varname="toi", actual=act)
            # This is imho a bug in NumPy - even `arange` and `linspace` may produce
            # arrays that are numerically not exactly equidistant - `unique` will
            # show several entries here - use `allclose` to identify "even" spacings
            equidistant = np.allclose(tSteps, [tSteps[0]] * tSteps.size)

        # If `toi` was 'all' or a percentage, use entire time interval of (selected)
        # trials and check if those trials have *approximately* equal length
        if toi is None:
            if not np.allclose(lenTrials, [minSampleNum] * lenTrials.size):
                msg = "processing trials of different lengths (min = {}; max = {} samples)" +\
                    " with `toi = 'all'`"
                SPYWarning(msg.format(int(minSampleNum), int(lenTrials.max())))
            if pad is False:
                lgl = "`pad` to be `None` or `True` to permit zero-padding " +\
                    "at trial boundaries to accommodate windows if `0 < toi < 1` " +\
                    "or if `toi` is 'all'"
                act = "False"
                raise SPYValueError(legal=lgl, actual=act, varname="pad")

        # Code recycling: `overlap`, `equidistant` etc. are really only relevant
        # for `mtmconvol`, but we use padding calc below for `wavelet` as well
        if method == "mtmconvol":
            try:
                scalar_parser(t_ftimwin,
                              varname="t_ftimwin",
                              lims=[1 / data.samplerate, minTrialLength])
            except Exception as exc:
                raise exc
        else:
            t_ftimwin = 0
        nperseg = int(t_ftimwin * data.samplerate)
        minSampleNum = nperseg
        halfWin = int(nperseg / 2)

        # `mtmconvol`: compute no. of samples overlapping across adjacent windows
        if overlap < 0:  # `toi` is equidistant range or disjoint points
            noverlap = nperseg - max(1, int(tSteps[0] * data.samplerate))
        elif 0 <= overlap <= 1:  # `toi` is percentage
            noverlap = min(nperseg - 1, int(overlap * nperseg))
        else:  # `toi` is "all"
            noverlap = nperseg - 1

        # `toi` is array
        if overlap < 0:

            # Compute necessary padding at begin/end of trials to fit sliding windows
            offStart = ((toi[0] - tStart) * data.samplerate).astype(np.intp)
            padBegin = halfWin - offStart
            padBegin = ((padBegin > 0) * padBegin).astype(np.intp)

            offEnd = ((tEnd - toi[-1]) * data.samplerate).astype(np.intp)
            padEnd = halfWin - offEnd
            padEnd = ((padEnd > 0) * padEnd).astype(np.intp)

            # Abort if padding was explicitly forbidden
            if pad is False and (np.any(padBegin) or np.any(padBegin)):
                lgl = "windows within trial bounds"
                act = "windows exceeding trials no. " +\
                    "".join(str(trlno) + ", "\
                        for trlno in np.array(trialList)[(padBegin + padEnd) > 0])[:-2]
                raise SPYValueError(legal=lgl, varname="pad", actual=act)

            # Compute sample-indices (one slice/list per trial) from time-selections
            soi = []
            if not equidistant:
                for tk in range(numTrials):
                    starts = (data.samplerate * (toi - tStart[tk]) -
                              halfWin).astype(np.intp)
                    starts += padBegin[tk]
                    stops = (data.samplerate * (toi - tStart[tk]) + halfWin +
                             1).astype(np.intp)
                    stops += padBegin[tk]
                    stops = np.maximum(stops, stops - starts, dtype=np.intp)
                    soi.append([
                        slice(start, stop)
                        for start, stop in zip(starts, stops)
                    ])
            else:
                for tk in range(numTrials):
                    start = int(data.samplerate * (toi[0] - tStart[tk]) -
                                halfWin)
                    stop = int(data.samplerate * (toi[-1] - tStart[tk]) +
                               halfWin + 1)
                    soi.append(slice(max(0, start), max(stop, stop - start)))

        # `toi` is percentage or "all"
        else:

            padBegin = np.zeros((numTrials, ))
            padEnd = np.zeros((numTrials, ))
            soi = [slice(None)] * numTrials

        # For wavelets, we need to first trim the data (via `preSelect`), then
        # extract the wanted time-points (`postSelect`)
        if method == "wavelet":

            # Simply recycle the indexing work done for `mtmconvol` (i.e., `soi`)
            preSelect = []
            if not equidistant:
                for tk in range(numTrials):
                    preSelect.append(slice(soi[tk][0].start, soi[tk][-1].stop))
            else:
                preSelect = soi

            # If `toi` is an array, convert "global" indices to "local" ones
            # (select within `preSelect`'s selection), otherwise just take all
            if overlap < 0:
                postSelect = []
                for tk in range(numTrials):
                    smpIdx = np.minimum(
                        lenTrials[tk] - 1,
                        data.samplerate * (toi - tStart[tk]) - offStart[tk] +
                        padBegin[tk])
                    postSelect.append(smpIdx.astype(np.intp))
            else:
                postSelect = [slice(None)] * numTrials

        # Update `log_dct` w/method-specific options (use `lcls` to get actually
        # provided keyword values, not defaults set in here)
        if toi is None:
            toi = "all"
        log_dct["toi"] = lcls["toi"]

    # Check options specific to mtm*-methods (particularly tapers and foi/freqs alignment)
    if "mtm" in method:

        # See if taper choice is supported
        if taper not in availableTapers:
            lgl = "'" + "or '".join(opt + "' " for opt in availableTapers)
            raise SPYValueError(legal=lgl, varname="taper", actual=taper)
        taper = getattr(spwin, taper)

        # Advanced usage: see if `taperopt` was provided - if not, leave it empty
        taperopt = kwargs.get("taperopt", {})
        if not isinstance(taperopt, dict):
            raise SPYTypeError(taperopt,
                               varname="taperopt",
                               expected="dictionary")

        # Construct array of maximally attainable frequencies
        nFreq = int(np.floor(minSampleNum / 2) + 1)
        freqs = np.linspace(0, data.samplerate / 2, nFreq)

        # Match desired frequencies as close as possible to actually attainable freqs
        if foi is not None:
            foi, _ = best_match(freqs, foi, squash_duplicates=True)
        elif foilim is not None:
            foi, _ = best_match(freqs,
                                foilim,
                                span=True,
                                squash_duplicates=True)
        else:
            foi = freqs

        # Abort if desired frequency selection is empty
        if foi.size == 0:
            lgl = "non-empty frequency specification"
            act = "empty frequency selection"
            raise SPYValueError(legal=lgl, varname="foi/foilim", actual=act)

        # Set/get `tapsmofrq` if we're working w/Slepian tapers
        if taper.__name__ == "dpss":

            # Try to derive "sane" settings by using 3/4 octave smoothing of highest `foi`
            # following Hipp et al. "Oscillatory Synchronization in Large-Scale
            # Cortical Networks Predicts Perception", Neuron, 2011
            if tapsmofrq is None:
                foimax = foi.max()
                tapsmofrq = (foimax * 2**(3 / 4 / 2) -
                             foimax * 2**(-3 / 4 / 2)) / 2
            else:
                try:
                    scalar_parser(tapsmofrq,
                                  varname="tapsmofrq",
                                  lims=[1, np.inf])
                except Exception as exc:
                    raise exc

            # Get/compute number of tapers to use (at least 1 and max. 50)
            nTaper = taperopt.get("Kmax", 1)
            if not taperopt:
                nTaper = int(
                    max(
                        2,
                        min(
                            50,
                            np.floor(tapsmofrq * minSampleNum * 1 /
                                     data.samplerate))))
                taperopt = {"NW": tapsmofrq, "Kmax": nTaper}

        else:
            nTaper = 1

        # Warn the user in case `tapsmofrq` has no effect
        if tapsmofrq is not None and taper.__name__ != "dpss":
            msg = "`tapsmofrq` is only used if `taper` is `dpss`!"
            SPYWarning(msg)

        # Update `log_dct` w/method-specific options (use `lcls` to get actually
        # provided keyword values, not defaults set in here)
        log_dct["taper"] = lcls["taper"]
        log_dct["tapsmofrq"] = lcls["tapsmofrq"]
        log_dct["nTaper"] = nTaper

        # Check for non-default values of options not supported by chosen method
        kwdict = {"wav": wav, "width": width}
        for name, kwarg in kwdict.items():
            if kwarg is not lcls[name]:
                msg = "option `{}` has no effect in methods `mtmfft` and `mtmconvol`!"
                SPYWarning(msg.format(name))

    # Now, prepare explicit compute-classes for chosen method
    if method == "mtmfft":

        # Check for non-default values of options not supported by chosen method
        kwdict = {"t_ftimwin": t_ftimwin, "toi": toi}
        for name, kwarg in kwdict.items():
            if kwarg is not lcls[name]:
                msg = "option `{}` has no effect in method `mtmfft`!"
                SPYWarning(msg.format(name))

        # Set up compute-class
        specestMethod = MultiTaperFFT(samplerate=data.samplerate,
                                      foi=foi,
                                      nTaper=nTaper,
                                      timeAxis=timeAxis,
                                      taper=taper,
                                      taperopt=taperopt,
                                      tapsmofrq=tapsmofrq,
                                      pad=pad,
                                      padtype=padtype,
                                      padlength=padlength,
                                      keeptapers=keeptapers,
                                      polyremoval=polyremoval,
                                      output_fmt=output)

    elif method == "mtmconvol":

        # Set up compute-class
        specestMethod = MultiTaperFFTConvol(soi,
                                            list(padBegin),
                                            list(padEnd),
                                            samplerate=data.samplerate,
                                            noverlap=noverlap,
                                            nperseg=nperseg,
                                            equidistant=equidistant,
                                            toi=toi,
                                            foi=foi,
                                            nTaper=nTaper,
                                            timeAxis=timeAxis,
                                            taper=taper,
                                            taperopt=taperopt,
                                            pad=pad,
                                            padtype=padtype,
                                            padlength=padlength,
                                            prepadlength=prepadlength,
                                            postpadlength=postpadlength,
                                            keeptapers=keeptapers,
                                            polyremoval=polyremoval,
                                            output_fmt=output)

    elif method == "wavelet":

        # Check for non-default values of `taper`, `tapsmofrq`, `keeptapers` and
        # `t_ftimwin` (set to 0 above)
        kwdict = {
            "taper": taper,
            "tapsmofrq": tapsmofrq,
            "keeptapers": keeptapers
        }
        for name, kwarg in kwdict.items():
            if kwarg is not lcls[name]:
                msg = "option `{}` has no effect in method `wavelet`!"
                SPYWarning(msg.format(name))
        if t_ftimwin != 0:
            msg = "option `t_ftimwin` has no effect in method `wavelet`!"
            SPYWarning(msg)

        # Check wavelet selection
        if wav not in availableWavelets:
            lgl = "'" + "or '".join(opt + "' " for opt in availableWavelets)
            raise SPYValueError(legal=lgl, varname="wav", actual=wav)
        if wav not in ["Morlet", "Paul"]:
            msg = "the chosen wavelet '{}' is real-valued and does not provide " +\
                "any information about amplitude or phase of the data. This wavelet function " +\
                "may be used to isolate peaks or discontinuities in the signal. "
            SPYWarning(msg.format(wav))

        # Check for consistency of `width`, `order` and `wav`
        if wav == "Morlet":
            try:
                scalar_parser(width, varname="width", lims=[1, np.inf])
            except Exception as exc:
                raise exc
            wfun = getattr(spywave, wav)(w0=width)
        else:
            if width != lcls["width"]:
                msg = "option `width` has no effect for wavelet '{}'"
                SPYWarning(msg.format(wav))

        if wav == "Paul":
            try:
                scalar_parser(order,
                              varname="order",
                              lims=[4, np.inf],
                              ntype="int_like")
            except Exception as exc:
                raise exc
            wfun = getattr(spywave, wav)(m=order)
        elif wav == "DOG":
            try:
                scalar_parser(order,
                              varname="order",
                              lims=[1, np.inf],
                              ntype="int_like")
            except Exception as exc:
                raise exc
            wfun = getattr(spywave, wav)(m=order)
        else:
            if order is not None:
                msg = "option `order` has no effect for wavelet '{}'"
                SPYWarning(msg.format(wav))
            wfun = getattr(spywave, wav)()

        # Process frequency selection (`toi` was taken care of above): `foilim`
        # selections are wrapped into `foi` thus the seemingly weird if construct
        # Note: SLURM workers don't like monkey-patching, so let's pretend
        # `get_optimal_wavelet_scales` is a class method by passing `wfun` as its
        # first argument
        if foi is None:
            scales = _get_optimal_wavelet_scales(
                wfun, int(minTrialLength * data.samplerate),
                1 / data.samplerate)
        if foilim is not None:
            foi = np.arange(foilim[0], foilim[1] + 1)
        if foi is not None:
            foi[foi < 0.01] = 0.01
            scales = wfun.scale_from_period(1 / foi)
            scales = scales[::
                            -1]  # FIXME: this only makes sense if `foi` was sorted -> cf Issue #94

        # Update `log_dct` w/method-specific options (use `lcls` to get actually
        # provided keyword values, not defaults set in here)
        log_dct["wav"] = lcls["wav"]
        log_dct["width"] = lcls["width"]
        log_dct["order"] = lcls["order"]

        # Set up compute-class
        specestMethod = WaveletTransform(preSelect,
                                         postSelect,
                                         list(padBegin),
                                         list(padEnd),
                                         samplerate=data.samplerate,
                                         toi=toi,
                                         scales=scales,
                                         timeAxis=timeAxis,
                                         wav=wfun,
                                         polyremoval=polyremoval,
                                         output_fmt=output)

    # If provided, make sure output object is appropriate
    if out is not None:
        try:
            data_parser(out,
                        varname="out",
                        writable=True,
                        empty=True,
                        dataclass="SpectralData",
                        dimord=SpectralData().dimord)
        except Exception as exc:
            raise exc
        new_out = False
    else:
        out = SpectralData(dimord=SpectralData._defaultDimord)
        new_out = True

    # Perform actual computation
    specestMethod.initialize(data,
                             chan_per_worker=kwargs.get("chan_per_worker"),
                             keeptrials=keeptrials)
    specestMethod.compute(data,
                          out,
                          parallel=kwargs.get("parallel"),
                          log_dict=log_dct)

    # Either return newly created output object or simply quit
    return out if new_out else None
예제 #17
0
def _make_trialdef(cfg, trialdefinition, samplerate):
    """
    Local helper to construct trialdefinition arrays for time-frequency
    :class:`~syncopy.SpectralData` objects

    Parameters
    ----------
    cfg : dict
        Config dictionary attribute of `ComputationalRoutine` subclass
    trialdefinition : 2D :class:`numpy.ndarray`
        Provisional trialdefnition array either directly copied from the
        :class:`~syncopy.AnalogData` input object or computed by the
        :class:`~syncopy.datatype.base_data.Selector` class.
    samplerate : float
        Original sampling rate of :class:`~syncopy.AnalogData` input object

    Returns
    -------
    trialdefinition : 2D :class:`numpy.ndarray`
        Updated trialdefinition array reflecting provided `toi`/`toilim` selection
    samplerate : float
        Sampling rate accouting for potentially new spacing b/w time-points (accouting
        for provided `toi`/`toilim` selection)

    Notes
    -----
    This routine is a local auxiliary method that is purely intended for internal
    use. Thus, no error checking is performed.

    See also
    --------
    syncopy.specest.mtmconvol.mtmconvol : :meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction`
                                          performing time-frequency analysis using (multi-)tapered sliding window Fourier transform
    syncopy.specest.wavelet.wavelet : :meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction`
                                      performing time-frequency analysis using non-orthogonal continuous wavelet transform
    syncopy.specest.superlet.superlet : :meth:`~syncopy.shared.computational_routine.ComputationalRoutine.computeFunction`
                                      performing time-frequency analysis using super-resolution superlet transform

    """

    # If `toi` is array, use it to construct timing info
    toi = cfg["toi"]
    if isinstance(toi, np.ndarray):

        # Some index gymnastics to get trial begin/end samples
        nToi = toi.size
        time = np.cumsum([nToi] * trialdefinition.shape[0])
        trialdefinition[:, 0] = time - nToi
        trialdefinition[:, 1] = time

        # Important: differentiate b/w equidistant time ranges and disjoint points
        tSteps = np.diff(toi)
        if np.allclose(tSteps, [tSteps[0]] * tSteps.size):
            samplerate = 1 / (toi[1] - toi[0])
        else:
            msg = (
                "`SpectralData`'s `time` property currently does not support "
                + "unevenly spaced `toi` selections!")
            SPYWarning(msg, caller="freqanalysis")
            samplerate = 1.0
            trialdefinition[:, 2] = 0

        # Reconstruct trigger-onset based on provided time-point array
        trialdefinition[:, 2] = toi[0] * samplerate

    # If `toi` was a percentage, some cumsum/winSize algebra is required
    # Note: if `toi` was "all", simply use provided `trialdefinition` and `samplerate`

    elif isinstance(toi, Number):
        mKw = cfg['method_kwargs']
        winSize = mKw["nperseg"] - mKw["noverlap"]
        trialdefinitionLens = np.ceil(
            np.diff(trialdefinition[:, :2]) / winSize)
        sumLens = np.cumsum(trialdefinitionLens).reshape(
            trialdefinitionLens.shape)
        trialdefinition[:, 0] = np.ravel(sumLens - trialdefinitionLens)
        trialdefinition[:, 1] = sumLens.ravel()
        trialdefinition[:, 2] = trialdefinition[:, 2] / winSize
        samplerate = np.round(samplerate / winSize, 2)

    # If `toi` was "all", do **not** simply use provided `trialdefinition`: overlapping
    # trials require thie below `cumsum` gymnastics
    else:
        bounds = np.cumsum(np.diff(trialdefinition[:, :2]))
        trialdefinition[1:, 0] = bounds[:-1]
        trialdefinition[:, 1] = bounds

    return trialdefinition, samplerate
예제 #18
0
def save(out,
         container=None,
         tag=None,
         filename=None,
         overwrite=False,
         memuse=100):
    r"""Save Syncopy data object to disk

    The underlying array data object is stored in a HDF5 file, the metadata in
    a JSON file. Both can be placed inside a Syncopy container, which is a
    regular directory with the extension '.spy'. 

    Parameters
    ----------
    out : Syncopy data object
        Object to be stored on disk.    
    container : str
        Path to Syncopy container folder (\*.spy) to be used for saving. If 
        omitted, the extension '.spy' will be added to the folder name.
    tag : str
        Tag to be appended to container basename
    filename :  str
        Explicit path to data file. This is only necessary if the data should
        not be part of a container folder. An extension (\*.<dataclass>) is
        added if omitted. The `tag` argument is ignored.      
    overwrite : bool
        If `True` an existing HDF5 file and its accompanying JSON file is 
        overwritten (without prompt). 
    memuse : scalar 
        Approximate in-memory cache size (in MB) for writing data to disk
        (only relevant for :class:`syncopy.VirtualData` or memory map data sources)
        
    Returns
    -------
    Nothing : None
    
    Notes
    ------
    Syncopy objects may also be saved using the class method ``.save`` that 
    acts as a wrapper for :func:`syncopy.save`, e.g., 
    
    >>> save(obj, container="new_spy_container")
    
    is equivalent to
    
    >>> obj.save(container="new_spy_container")
    
    However, once a Syncopy object has been saved, the class method ``.save``
    can be used as a shortcut to quick-save recent changes, e.g., 
    
    >>> obj.save()
    
    writes the current state of `obj` to the data/meta-data files on-disk 
    associated with `obj` (overwriting both in the process). Similarly, 
    
    >>> obj.save(tag='newtag')
    
    saves `obj` in the current container 'new_spy_container' under a different 
    tag. 

    Examples
    -------- 
    Save the Syncopy data object `obj` on disk in the current working directory
    without creating a spy-container
    
    >>> spy.save(obj, filename="session1")
    >>> # --> os.getcwd()/session1.<dataclass>
    >>> # --> os.getcwd()/session1.<dataclass>.info
    
    Save `obj` without creating a spy-container using an absolute path

    >>> spy.save(obj, filename="/tmp/session1")
    >>> # --> /tmp/session1.<dataclass>
    >>> # --> /tmp/session1.<dataclass>.info
    
    Save `obj` in a new spy-container created in the current working directory

    >>> spy.save(obj, container="container.spy")
    >>> # --> os.getcwd()/container.spy/container.<dataclass>
    >>> # --> os.getcwd()/container.spy/container.<dataclass>.info

    Save `obj` in a new spy-container created by providing an absolute path

    >>> spy.save(obj, container="/tmp/container.spy")
    >>> # --> /tmp/container.spy/container.<dataclass>
    >>> # --> /tmp/container.spy/container.<dataclass>.info

    Save `obj` in a new (or existing) spy-container under a different tag
    
    >>> spy.save(obj, container="session1.spy", tag="someTag")
    >>> # --> os.getcwd()/session1.spy/session1_someTag.<dataclass>
    >>> # --> os.getcwd()/session1.spy/session1_someTag.<dataclass>.info

    See also
    --------
    syncopy.load : load data created with :func:`syncopy.save`
    """

    # Make sure `out` is a valid Syncopy data object
    data_parser(out, varname="out", writable=None, empty=False)

    if filename is None and container is None:
        raise SPYError('filename and container cannot both be `None`')

    if container is not None and filename is None:
        # construct filename from container name
        if not isinstance(container, str):
            raise SPYTypeError(container, varname="container", expected="str")
        if not os.path.splitext(container)[1] == ".spy":
            container += ".spy"
        fileInfo = filename_parser(container)
        filename = os.path.join(fileInfo["folder"], fileInfo["container"],
                                fileInfo["basename"])
        # handle tag
        if tag is not None:
            if not isinstance(tag, str):
                raise SPYTypeError(tag, varname="tag", expected="str")
            filename += '_' + tag

    elif container is not None and filename is not None:
        raise SPYError(
            "container and filename cannot be used at the same time")

    if not isinstance(filename, str):
        raise SPYTypeError(filename, varname="filename", expected="str")

    # add extension if not part of the filename
    if "." not in os.path.splitext(filename)[1]:
        filename += out._classname_to_extension()

    try:
        scalar_parser(memuse, varname="memuse", lims=[0, np.inf])
    except Exception as exc:
        raise exc

    if not isinstance(overwrite, bool):
        raise SPYTypeError(overwrite, varname="overwrite", expected="bool")

    # Parse filename for validity and construct full path to HDF5 file
    fileInfo = filename_parser(filename)
    if fileInfo["extension"] != out._classname_to_extension():
        raise SPYError("""Extension in filename ({ext}) does not match data 
                    class ({dclass})""".format(ext=fileInfo["extension"],
                                               dclass=out.__class__.__name__))
    dataFile = os.path.join(fileInfo["folder"], fileInfo["filename"])

    # If `out` is to replace its own on-disk representation, be more careful
    if overwrite and dataFile == out.filename:
        replace = True
    else:
        replace = False

    # Prevent `out` from trying to re-create its own data file
    if replace:
        out.data.flush()
        h5f = out.data.file
        dat = out.data
        trl = h5f["trialdefinition"]
    else:
        if not os.path.exists(fileInfo["folder"]):
            try:
                os.makedirs(fileInfo["folder"])
            except IOError:
                raise SPYIOError(fileInfo["folder"])
            except Exception as exc:
                raise exc
        else:
            if os.path.exists(dataFile):
                if not os.path.isfile(dataFile):
                    raise SPYIOError(dataFile)
                if overwrite:
                    try:
                        h5f = h5py.File(dataFile, mode="w")
                        h5f.close()
                    except Exception as exc:
                        msg = "Cannot overwrite {} - file may still be open. "
                        msg += "Original error message below\n{}"
                        raise SPYError(msg.format(dataFile, str(exc)))
                else:
                    raise SPYIOError(dataFile, exists=True)
        h5f = h5py.File(dataFile, mode="w")

        # Save each member of `_hdfFileDatasetProperties` in target HDF file
        for datasetName in out._hdfFileDatasetProperties:
            dataset = getattr(out, datasetName)

            # Member is a memory map
            if isinstance(dataset, np.memmap):
                # Given memory cap, compute how many data blocks can be grabbed
                # per swipe (divide by 2 since we're working with an add'l tmp array)
                memuse *= 1024**2 / 2
                nrow = int(
                    memuse /
                    (np.prod(dataset.shape[1:]) * dataset.dtype.itemsize))
                rem = int(dataset.shape[0] % nrow)
                n_blocks = [nrow] * int(
                    dataset.shape[0] // nrow) + [rem] * int(rem > 0)

                # Write data block-wise to dataset (use `clear` to wipe blocks of
                # mem-maps from memory)
                dat = h5f.create_dataset(datasetName,
                                         dtype=dataset.dtype,
                                         shape=dataset.shape)
                for m, M in enumerate(n_blocks):
                    dat[m * nrow:m * nrow +
                        M, :] = out.data[m * nrow:m * nrow + M, :]
                    out.clear()

            # Member is a HDF5 dataset
            else:
                dat = h5f.create_dataset(datasetName, data=dataset)

    # Now write trial-related information
    trl_arr = np.array(out.trialdefinition)
    if replace:
        trl[()] = trl_arr
        trl.flush()
    else:
        trl = h5f.create_dataset("trialdefinition",
                                 data=trl_arr,
                                 maxshape=(None, trl_arr.shape[1]))

    # Write to log already here so that the entry can be exported to json
    infoFile = dataFile + FILE_EXT["info"]
    out.log = "Wrote files " + dataFile + "\n\t\t\t" + 2 * " " + infoFile

    # While we're at it, write cfg entries
    out.cfg = {
        "method": sys._getframe().f_code.co_name,
        "files": [dataFile, infoFile]
    }

    # Assemble dict for JSON output: order things by their "readability"
    outDict = OrderedDict(startInfoDict)
    outDict["filename"] = fileInfo["filename"]
    outDict["dataclass"] = out.__class__.__name__
    outDict["data_dtype"] = dat.dtype.name
    outDict["data_shape"] = dat.shape
    outDict["data_offset"] = dat.id.get_offset()
    outDict["trl_dtype"] = trl.dtype.name
    outDict["trl_shape"] = trl.shape
    outDict["trl_offset"] = trl.id.get_offset()
    if isinstance(out.data, np.ndarray):
        if np.isfortran(out.data):
            outDict["order"] = "F"
    else:
        outDict["order"] = "C"

    for key in out._infoFileProperties:
        value = getattr(out, key)
        if isinstance(value, np.ndarray):
            value = value.tolist()
        # potentially nested dicts
        elif isinstance(value, dict):
            value = dict(value)
            _dict_converter(value)
        outDict[key] = value

    # Save relevant stuff as HDF5 attributes
    for key in out._hdfFileAttributeProperties:
        if outDict[key] is None:
            h5f.attrs[key] = "None"
        else:
            try:
                h5f.attrs[key] = outDict[key]
            except RuntimeError:
                msg = "Too many entries in `{}` - truncating HDF5 attribute. " +\
                    "Please refer to {} for complete listing."
                info_fle = os.path.split(
                    os.path.split(filename.format(ext=FILE_EXT["info"]))[0])[1]
                info_fle = os.path.join(
                    info_fle,
                    os.path.basename(filename.format(ext=FILE_EXT["info"])))
                SPYWarning(msg.format(key, info_fle))
                h5f.attrs[key] = [outDict[key][0], "...", outDict[key][-1]]

    # Re-assign filename after saving (and remove source in case it came from `__storage__`)
    if not replace:
        h5f.close()
        if __storage__ in out.filename:
            out.data.file.close()
            os.unlink(out.filename)
        out.data = dataFile

    # Compute checksum and finally write JSON (automatically overwrites existing)
    outDict["file_checksum"] = hash_file(dataFile)

    with open(infoFile, 'w') as out_json:
        json.dump(outDict, out_json, indent=4)

    return
예제 #19
0
def io_parser(fs_loc, varname="", isfile=True, ext="", exists=True):
    """
    Parse file-system location strings for reading/writing files/directories

    Parameters
    ----------
    fs_loc : str
        String pointing to (hopefully valid) file-system location
        (absolute/relative path of file or directory ).
    varname : str
        Local variable name used in caller, see Examples for details.
    isfile : bool
        Indicates whether `fs_loc` points to a file (`isfile = True`) or
        directory (`isfile = False`)
    ext : str or 1darray-like
        Valid filename extension(s). Can be a single string (e.g., `ext = "lfp"`)
        or a list/1darray of valid extensions (e.g., `ext = ["lfp", "mua"]`).
    exists : bool
        If `exists = True` ensure that file-system location specified by `fs_loc` exists
        (typically used when reading from `fs_loc`), otherwise (`exists = False`)
        check for already present conflicting files/directories (typically used when
        creating/writing to `fs_loc`).

    Returns
    -------
    fs_path : str
        Absolute path of `fs_loc`.
    fs_name : str (only if `isfile = True`)
        Name (including extension) of input file (without path).

    Examples
    --------
    To test whether `"/path/to/dataset.lfp"` points to an existing file, one
    might use

    >>> io_parser("/path/to/dataset.lfp")
    '/path/to', 'dataset.lfp'

    The following call ensures that a folder called "mydata" can be safely
    created in the current working directory

    >>> io_parser("mydata", isfile=False, exists=False)
    '/path/to/cwd/mydata'

    Suppose a routine wants to save data to a file with potential
    extensions `".lfp"` or `".mua"`. The following call may be used to ensure
    the user input `dsetname = "relative/dir/dataset.mua"` is a valid choice:

    >>> abs_path, filename = io_parser(dsetname, varname="dsetname", ext=["lfp", "mua"], exists=False)
    >>> abs_path
    '/full/path/to/relative/dir/'
    >>> filename
    'dataset.mua'
    """

    # Start by resovling potential conflicts
    if not isfile and len(ext) > 0:
        msg = "filename extension(s) specified but `isfile = False`. Exiting..."
        SPYWarning(msg)
        return

    # Make sure `fs_loc` is actually a string
    if not isinstance(fs_loc, str):
        raise SPYTypeError(fs_loc, varname=varname, expected=str)

    # Avoid headaches, use absolute paths...
    fs_loc = os.path.abspath(os.path.expanduser(fs_loc))

    # Ensure that filesystem object does/does not exist
    if exists and not os.path.exists(fs_loc):
        raise SPYIOError(fs_loc, exists=False)
    if not exists and os.path.exists(fs_loc):
        raise SPYIOError(fs_loc, exists=True)

    # First, take care of directories...
    if not isfile:
        isdir = os.path.isdir(fs_loc)
        if (isdir and not exists):
            raise SPYIOError (fs_loc, exists=isdir)
        elif (not isdir and exists):
            raise SPYValueError(legal="directory", actual="file")
        else:
            return fs_loc

    # ...now files
    else:

        # Separate filename from its path
        file_name = os.path.basename(fs_loc)

        # If wanted, parse filename extension(s)
        if len(ext):

            # Extract filename extension and get rid of its dot
            file_ext = os.path.splitext(file_name)[1]
            file_ext = file_ext.replace(".", "")

            # In here, having no extension counts as an error
            error = False
            if len(file_ext) == 0:
                error = True
            if file_ext not in str(ext) or error:
                if isinstance(ext, (list, np.ndarray)):
                    ext = "'" + "or '".join(ex + "' " for ex in ext)
                raise SPYValueError(ext, varname="filename-extension", actual=file_ext)

        # Now make sure file does or does not exist
        isfile = os.path.isfile(fs_loc)
        if (isfile and not exists):
            raise SPYIOError(fs_loc, exists=isfile)
        elif (not isfile and exists):
            raise SPYValueError(legal="file", actual="directory")
        else:
            return fs_loc.split(file_name)[0], file_name
예제 #20
0
def load(filename,
         tag=None,
         dataclass=None,
         checksum=False,
         mode="r+",
         out=None):
    """
    Load Syncopy data object(s) from disk
    
    Either loads single files within or outside of '.spy'-containers or loads
    multiple objects from a single '.spy'-container. Loading from containers can 
    be further controlled by imposing restrictions on object class(es) (via 
    `dataclass`) and file-name tag(s) (via `tag`). 
    
    Parameters
    ----------
    filename : str
        Either path to Syncopy container folder (\*.spy, if omitted, the extension
        '.spy' will be appended) or name of data or metadata file. If `filename`
        points to a container and no further specifications are provided, the 
        entire contents of the container is loaded. Otherwise, specific objects
        may be selected using the `dataclass` or `tag` keywords (see below). 
    tag : None or str or list
        If `filename` points to a container, `tag` may be used to filter objects
        by filename-`tag`. Multiple tags can be provided using a list, e.g., 
        ``tag = ['experiment1', 'experiment2']``. Can be combined with `dataclass`
        (see below). Invalid if `filename` points to a single file. 
    dataclass : None or str or list
        If provided, only objects of provided dataclass are loaded from disk. 
        Available options are '.analog', '.spectral', .spike' and '.event' 
        (as listed in  ``spy.FILE_EXT["data"]``). Multiple class specifications
        can be provided using a list, e.g., ``dataclass = ['.analog', '.spike']``.
        Can be combined with `tag` (see above) and is also valid if `filename`
        points to a single file (e.g., to ensure loaded object is of a specific
        type). 
    checksum : bool
        If `True`, checksum-matching is performed on loaded object(s) to ensure
        data-integrity (impairs performance particularly when loading large files). 
    mode : str
        Data access mode of loaded objects (can be 'r' for read-only, 'r+' or 'w'
        for read/write access). 
    out : Syncopy data object
        Empty object to be filled with data loaded from disk. Has to match the 
        type of the on-disk file (e.g., ``filename = 'mydata.analog'`` requires
        `out` to be a :class:`syncopy.AnalogData` object). Can only be used 
        when loading single objects from disk (`out` is ignored when multiple
        files are loaded from a container). 
        
    Returns
    -------
    Nothing : None
        If a single file is loaded and `out` was provided, `out` is filled with
        data loaded from disk, i.e., :func:`syncopy.load` does **not** create a 
        new object
    obj : Syncopy data object
        If a single file is loaded and `out` was `None`, :func:`syncopy.load` 
        returns a new object. 
    objdict : dict
        If multiple files are loaded, :func:`syncopy.load` creates a new object
        for each file and places them in a dictionary whose keys are the base-names
        (sans path) of the corresponding files. 
        
    Notes
    -----
    All of Syncopy's classes offer (limited) support for data loading upon object
    creation. Just as the class method ``.save`` can be used as a shortcut for
    :func:`syncopy.save`, Syncopy objects can be created from Syncopy data-files 
    upon creation, e.g., 
    
    >>> adata = spy.AnalogData('/path/to/session1.analog')
    
    creates a new :class:`syncopy.AnalogData` object and immediately fills it 
    with data loaded from the file "/path/to/session1.analog". 
    
    Since only one object can be created at a time, this loading shortcut only 
    supports single file specifications (i.e., ``spy.AnalogData("container.spy")``
    is invalid). 

    Examples
    -------- 
    Load all objects found in the spy-container "sessionName" (the extension ".spy" 
    may or may not be provided)
    
    >>> objectDict = spy.load("sessionName")
    >>> # --> returns a dict with base-filenames as keys
    
    Load all :class:`syncopy.AnalogData` and :class:`syncopy.SpectralData` objects
    from the spy-container "sessionName"
    
    >>> objectDict = spy.load("sessionName.spy", dataclass=['analog', 'spectral'])
    
    Load a specific :class:`syncopy.AnalogData` object from the above spy-container
    
    >>> obj = spy.load("sessionName.spy/sessionName_someTag.analog")
    
    This is equivalent to
    
    >>> obj = spy.AnalogData("sessionName.spy/sessionName_someTag.analog")
    
    If the "sessionName" spy-container only contains one object with the tag 
    "someTag", the above call is equivalent to
    
    >>> obj = spy.load("sessionName.spy", tag="someTag")
    
    If there are multiple objects of different types using the same tag "someTag",
    the above call can be further narrowed down to only load the requested 
    :class:`syncopy.AnalogData` object
       
    >>> obj = spy.load("sessionName.spy", tag="someTag", dataclass="analog")
    
    See also
    --------
    syncopy.save : save syncopy object on disk
    """

    # Ensure `filename` is either a valid .spy container or data file: if `filename`
    # is a directory w/o '.spy' extension, append it
    if not isinstance(filename, str):
        raise SPYTypeError(filename, varname="filename", expected="str")
    if len(os.path.splitext(os.path.abspath(
            os.path.expanduser(filename)))[1]) == 0:
        filename += FILE_EXT["dir"]
    try:
        fileInfo = filename_parser(filename)
    except Exception as exc:
        raise exc

    if tag is not None:
        if isinstance(tag, str):
            tags = [tag]
        else:
            tags = tag
        try:
            array_parser(tags, varname="tag", ntype=str)
        except Exception as exc:
            raise exc
        if fileInfo["filename"] is not None:
            raise SPYError("Only containers can be loaded with `tag` keyword!")
        for tk in range(len(tags)):
            tags[tk] = "*" + tags[tk] + "*"
    else:
        tags = "*"

    # If `dataclass` was provided, format it for our needs (e.g. 'spike' -> ['.spike'])
    if dataclass is not None:
        if isinstance(dataclass, str):
            dataclass = [dataclass]
        try:
            array_parser(dataclass, varname="dataclass", ntype=str)
        except Exception as exc:
            raise exc
        dataclass = [
            "." + dclass if not dclass.startswith(".") else dclass
            for dclass in dataclass
        ]
        extensions = set(dataclass).intersection(FILE_EXT["data"])
        if len(extensions) == 0:
            lgl = "extension(s) '" + "or '".join(ext + "' "
                                                 for ext in FILE_EXT["data"])
            raise SPYValueError(legal=lgl,
                                varname="dataclass",
                                actual=str(dataclass))

    # Avoid any misunderstandings here...
    if not isinstance(checksum, bool):
        raise SPYTypeError(checksum, varname="checksum", expected="bool")

    # Abuse `AnalogData.mode`-setter to vet `mode`
    try:
        spd.AnalogData().mode = mode
    except Exception as exc:
        raise exc

    # If `filename` points to a spy container, `glob` what's inside, otherwise just load
    if fileInfo["filename"] is None:

        if dataclass is None:
            extensions = FILE_EXT["data"]
        container = os.path.join(fileInfo["folder"], fileInfo["container"])
        fileList = []
        for ext in extensions:
            for tag in tags:
                fileList.extend(glob(os.path.join(container, tag + ext)))
        if len(fileList) == 0:
            fsloc = os.path.join(container, "" + \
                                 "or ".join(tag + " " for tag in tags) + \
                                 "with extensions " + \
                                 "or ".join(ext + " " for ext in extensions))
            raise SPYIOError(fsloc, exists=False)
        if len(fileList) == 1:
            return _load(fileList[0], checksum, mode, out)
        if out is not None:
            msg = "When loading multiple objects, the `out` keyword is ignored"
            SPYWarning(msg)
        objectDict = {}
        for fname in fileList:
            obj = _load(fname, checksum, mode, None)
            objectDict[os.path.basename(obj.filename)] = obj
        return objectDict

    else:

        if dataclass is not None:
            if os.path.splitext(fileInfo["filename"])[1] not in dataclass:
                lgl = "extension '" + \
                    "or '".join(dclass + "' " for dclass in dataclass)
                raise SPYValueError(legal=lgl,
                                    varname="filename",
                                    actual=fileInfo["filename"])
        return _load(filename, checksum, mode, out)
예제 #21
0
def selectdata(data,
               trials=None,
               channels=None,
               channels_i=None,
               channels_j=None,
               toi=None,
               toilim=None,
               foi=None,
               foilim=None,
               tapers=None,
               units=None,
               eventids=None,
               out=None,
               inplace=False,
               clear=False,
               **kwargs):
    """
    Create a new Syncopy object from a selection

    **Usage Notice**

    Syncopy offers two modes for selecting data:

    * **in-place** selections mark subsets of a Syncopy data object for processing
      via a ``select`` dictionary *without* creating a new object
    * **deep-copy** selections copy subsets of a Syncopy data object to keep and
      preserve in a new object created by :func:`~syncopy.selectdata`

    All Syncopy metafunctions, such as :func:`~syncopy.freqanalysis`, support
    **in-place** data selection via a ``select`` keyword, effectively avoiding
    potentially slow copy operations and saving disk space. The keys accepted
    by the `select` dictionary are identical to the keyword arguments discussed
    below. In addition, ``select = "all"`` can be used to select entire object
    contents. Examples

    >>> select = {"toilim" : [-0.25, 0]}
    >>> spy.freqanalysis(data, select=select)
    >>> # or equivalently
    >>> cfg = spy.get_defaults(spy.freqanalysis)
    >>> cfg.select = select
    >>> spy.freqanalysis(cfg, data)

    **Usage Summary**

    List of Syncopy data objects and respective valid data selectors:

    :class:`~syncopy.AnalogData` : trials, channels, toi/toilim
        Examples

        >>> spy.selectdata(data, trials=[0, 3, 5], channels=["channel01", "channel02"])
        >>> cfg = spy.StructDict()
        >>> cfg.trials = [5, 3, 0]; cfg.toilim = [0.25, 0.5]
        >>> spy.selectdata(cfg, data)

    :class:`~syncopy.SpectralData` : trials, channels, toi/toilim, foi/foilim, tapers
        Examples

        >>> spy.selectdata(data, trials=[0, 3, 5], channels=["channel01", "channel02"])
        >>> cfg = spy.StructDict()
        >>> cfg.foi = [30, 40, 50]; cfg.tapers = slice(2, 4)
        >>> spy.selectdata(cfg, data)

    :class:`~syncopy.EventData` : trials, toi/toilim, eventids
        Examples

        >>> spy.selectdata(data, toilim=[-1, 2.5], eventids=[0, 1])
        >>> cfg = spy.StructDict()
        >>> cfg.trials = [0, 0, 1, 0]; cfg.eventids = slice(2, None)
        >>> spy.selectdata(cfg, data)

    :class:`~syncopy.SpikeData` : trials, toi/toilim, units, channels
        Examples

        >>> spy.selectdata(data, toilim=[-1, 2.5], units=range(0, 10))
        >>> cfg = spy.StructDict()
        >>> cfg.toi = [1.25, 3.2]; cfg.trials = [0, 1, 2, 3]
        >>> spy.selectdata(cfg, data)

    **Note** Any property that is not specifically accessed via one of the provided
    selectors is taken as is, e.g., ``spy.selectdata(data, trials=[1, 2])``
    selects the entire contents of trials no. 2 and 3, while
    ``spy.selectdata(data, channels=range(0, 50))`` selects the first 50 channels
    of `data` across all defined trials. Consequently, if no keywords are specified,
    the entire contents of `data` is selected.

    **Full documentation below**

    Parameters
    ----------
    data : Syncopy data object
        A non-empty Syncopy data object. **Note** the type of `data` determines
        which keywords can be used.  Some keywords are only valid for certain
        types of Syncopy objects, e.g., "freqs" is not a valid selector for an
        :class:`~syncopy.AnalogData` object.
    trials : list (integers) or None or "all"
        List of integers representing trial numbers to be selected; can include
        repetitions and need not be sorted (e.g., ``trials = [0, 1, 0, 0, 2]``
        is valid) but must be finite and not NaN. If `trials` is `None`, or
        ``trials = "all"`` all trials are selected.
    channels : list (integers or strings), slice, range or None or "all"
        Channel-selection; can be a list of channel names (``['channel3', 'channel1']``),
        a list of channel indices (``[3, 5]``), a slice (``slice(3, 10)``) or
        range (``range(3, 10)``). Note that following Python conventions, channels
        are counted starting at zero, and range and slice selections are half-open
        intervals of the form `[low, high)`, i.e., low is included , high is
        excluded. Thus, ``channels = [0, 1, 2]`` or ``channels = slice(0, 3)``
        selects the first up to (and including) the third channel. Selections can
        be unsorted and may include repetitions but must match exactly, be finite
        and not NaN. If `channels` is `None`, or ``channels = "all"`` all channels
        are selected.
    toi : list (floats) or None or "all"
        Time-points to be selected (in seconds) in each trial. Timing is expected
        to be on a by-trial basis (e.g., relative to trigger onsets). Selections
        can be approximate, unsorted and may include repetitions but must be
        finite and not NaN. Fuzzy matching is performed for approximate selections
        (i.e., selected time-points are close but not identical to timing information
        found in `data`) using a nearest-neighbor search for elements of `toi`.
        If `toi` is `None` or ``toi = "all"``, the entire time-span in each trial
        is selected.
    toilim : list (floats [tmin, tmax]) or None or "all"
        Time-window ``[tmin, tmax]`` (in seconds) to be extracted from each trial.
        Window specifications must be sorted (e.g., ``[2.2, 1.1]`` is invalid)
        and not NaN but may be unbounded (e.g., ``[1.1, np.inf]`` is valid). Edges
        `tmin` and `tmax` are included in the selection.
        If `toilim` is `None` or ``toilim = "all"``, the entire time-span in each
        trial is selected.
    foi : list (floats) or None or "all"
        Frequencies to be selected (in Hz). Selections can be approximate, unsorted
        and may include repetitions but must be finite and not NaN. Fuzzy matching
        is performed for approximate selections (i.e., selected frequencies are
        close but not identical to frequencies found in `data`) using a nearest-
        neighbor search for elements of `foi` in `data.freq`. If `foi` is `None`
        or ``foi = "all"``, all frequencies are selected.
    foilim : list (floats [fmin, fmax]) or None or "all"
        Frequency-window ``[fmin, fmax]`` (in Hz) to be extracted. Window
        specifications must be sorted (e.g., ``[90, 70]`` is invalid) and not NaN
        but may be unbounded (e.g., ``[-np.inf, 60.5]`` is valid). Edges `fmin`
        and `fmax` are included in the selection. If `foilim` is `None` or
        ``foilim = "all"``, all frequencies are selected.
    tapers : list (integers or strings), slice, range or None or "all"
        Taper-selection; can be a list of taper names (``['dpss-win-1', 'dpss-win-3']``),
        a list of taper indices (``[3, 5]``), a slice (``slice(3, 10)``) or range
        (``range(3, 10)``). Note that following Python conventions, tapers are
        counted starting at zero, and range and slice selections are half-open
        intervals of the form `[low, high)`, i.e., low is included , high is
        excluded. Thus, ``tapers = [0, 1, 2]`` or ``tapers = slice(0, 3)`` selects
        the first up to (and including) the third taper. Selections can be unsorted
        and may include repetitions but must match exactly, be finite and not NaN.
        If `tapers` is `None` or ``tapers = "all"``, all tapers are selected.
    units : list (integers or strings), slice, range or None or "all"
        Unit-selection; can be a list of unit names (``['unit10', 'unit3']``), a
        list of unit indices (``[3, 5]``), a slice (``slice(3, 10)``) or range
        (``range(3, 10)``). Note that following Python conventions, units are
        counted starting at zero, and range and slice selections are half-open
        intervals of the form `[low, high)`, i.e., low is included , high is
        excluded. Thus, ``units = [0, 1, 2]`` or ``units = slice(0, 3)`` selects
        the first up to (and including) the third unit. Selections can be unsorted
        and may include repetitions but must match exactly, be finite and not NaN.
        If `units` is `None` or ``units = "all"``, all units are selected.
    eventids : list (integers), slice, range or None or "all"
        Event-ID-selection; can be a list of event-id codes (``[2, 0, 1]``), slice
        (``slice(0, 2)``) or range (``range(0, 2)``). Note that following Python
        conventions, range and slice selections are half-open intervals of the
        form `[low, high)`, i.e., low is included , high is excluded. Selections
        can be unsorted and may include repetitions but must match exactly, be
        finite and not NaN. If `eventids` is `None` or ``eventids = "all"``, all
        events are selected.
    inplace : bool
        If `inplace` is `True` **no** new object is created. Instead the provided
        selection is stored in the input object's `_selection` attribute for later
        use. By default `inplace` is `False` and all calls to `selectdata` create
        a new Syncopy data object.

    Returns
    -------
    dataselection : Syncopy data object
        Syncopy data object of the same type as `data` but containing only the
        subset specified by provided selectors.

    Notes
    -----
    This routine represents a convenience function for creating new Syncopy objects
    based on existing data entities. However, in many situations, the creation
    of a new object (and thus the allocation of additional disk-space) might not
    be necessary: all Syncopy metafunctions, such as :func:`~syncopy.freqanalysis`,
    support **in-place** data selection.

    Consider the following example: assume `data` is an :class:`~syncopy.AnalogData`
    object representing 220 trials of LFP recordings containing baseline (between
    second -0.25 and 0) and stimulus-on data (on the interval [0.25, 0.5]).
    To compute the baseline spectrum, data-selection does **not**
    have to be performed before calling :func:`~syncopy.freqanalysis` but instead
    can be done in-place:

    >>> import syncopy as spy
    >>> cfg = spy.get_defaults(spy.freqanalysis)
    >>> cfg.method = 'mtmfft'
    >>> cfg.taper = 'dpss'
    >>> cfg.output = 'pow'
    >>> cfg.tapsmofrq = 10
    >>> # define baseline/stimulus-on ranges
    >>> baseSelect = {"toilim": [-0.25, 0]}
    >>> stimSelect = {"toilim": [0.25, 0.5]}
    >>> # in-place selection of baseline interval performed by `freqanalysis`
    >>> cfg.select = baseSelect
    >>> baselineSpectrum = spy.freqanalysis(cfg, data)
    >>> # in-place selection of stimulus-on time-frame performed by `freqanalysis`
    >>> cfg.select = stimSelect
    >>> stimonSpectrum = spy.freqanalysis(cfg, data)

    Especially for large data-sets, in-place data selection performed by Syncopy's
    metafunctions does not only save disk-space but can significantly increase
    performance.

    Examples
    --------
    Use :func:`~syncopy.tests.misc.generate_artificial_data` to create a synthetic
    :class:`syncopy.AnalogData` object.

    >>> from syncopy.tests.misc import generate_artificial_data
    >>> adata = generate_artificial_data(nTrials=10, nChannels=32)

    Assume a hypothetical trial onset at second 2.0 with the first second of each
    trial representing baseline recordings. To extract only the stimulus-on period
    from `adata`, one could use

    >>> stimon = spy.selectdata(adata, toilim=[2.0, np.inf])

    Note that this is equivalent to

    >>> stimon = adata.selectdata(toilim=[2.0, np.inf])

    See also
    --------
    :func:`syncopy.show` : Show (subsets) of Syncopy objects
    """

    # Ensure our one mandatory input is usable
    try:
        data_parser(data, varname="data", empty=False)
    except Exception as exc:
        raise exc

    # Vet the only inputs not checked by `Selector`
    if not isinstance(inplace, bool):
        raise SPYTypeError(inplace, varname="inplace", expected="Boolean")
    if not isinstance(inplace, bool):
        raise SPYTypeError(clear, varname="clear", expected="Boolean")

    # If provided, make sure output object is appropriate
    if not inplace:
        if out is not None:
            try:
                data_parser(out,
                            varname="out",
                            writable=True,
                            empty=True,
                            dataclass=data.__class__.__name__,
                            dimord=data.dimord)
            except Exception as exc:
                raise exc
            new_out = False
        else:
            out = data.__class__(dimord=data.dimord)
            new_out = True
    else:
        if out is not None:
            lgl = "no output object for in-place selection"
            raise SPYValueError(lgl,
                                varname="out",
                                actual=out.__class__.__name__)

    # FIXME: remove once tests are in place (cf #165)
    if channels_i is not None or channels_j is not None:
        SPYWarning(
            "CrossSpectralData channel selection currently untested and experimental!"
        )

    # Collect provided keywords in dict
    selectDict = {
        "trials": trials,
        "channels": channels,
        "channels_i": channels_i,
        "channels_j": channels_j,
        "toi": toi,
        "toilim": toilim,
        "foi": foi,
        "foilim": foilim,
        "tapers": tapers,
        "units": units,
        "eventids": eventids
    }

    # First simplest case: determine whether we just need to clear an existing selection
    if clear:
        if any(value is not None for value in selectDict.values()):
            lgl = "no data selectors if `clear = True`"
            raise SPYValueError(lgl, varname="select", actual=selectDict)
        if data._selection is None:
            SPYInfo("No in-place selection found. ")
        else:
            data._selection = None
            SPYInfo("In-place selection cleared")
        return

    # Pass provided selections on to `Selector` class which performs error checking
    data._selection = selectDict

    # If an in-place selection was requested we're done
    if inplace:
        SPYInfo("In-place selection attached to data object: {}".format(
            data._selection))
        return

    # Create inventory of all available selectors and actually provided values
    # to create a bookkeeping dict for logging
    log_dct = {"inplace": inplace, "clear": clear}
    log_dct.update(selectDict)
    log_dct.update(**kwargs)

    # Fire up `ComputationalRoutine`-subclass to do the actual selecting/copying
    selectMethod = DataSelection()
    selectMethod.initialize(data,
                            out._stackingDim,
                            chan_per_worker=kwargs.get("chan_per_worker"))
    selectMethod.compute(data,
                         out,
                         parallel=kwargs.get("parallel"),
                         log_dict=log_dct)

    # Wipe data-selection slot to not alter input object
    data._selection = None

    # Either return newly created output object or simply quit
    return out if new_out else None
예제 #22
0
def esi_cluster_setup(partition="8GBS", n_jobs=2, mem_per_job=None,
                      timeout=180, interactive=True, start_client=True,
                      **kwargs):
    """
    Start a distributed Dask cluster of parallel processing workers using SLURM 
    (or local multi-processing)
    
    Parameters
    ----------
    partition : str
        Name of SLURM partition/queue to use
    n_jobs : int
        Number of jobs to spawn
    mem_per_job : None or str
        Memory booking for each job. Can be specified either in megabytes 
        (e.g., ``mem_per_job = 1500MB``) or gigabytes (e.g., ``mem_per_job = "2GB"``). 
        If `mem_per_job` is `None`, it is attempted to infer a sane default value
        from the chosen queue, e.g., for ``partition = "8GBS"`` `mem_per_job` is 
        automatically set to the allowed maximum of `'8GB'`. However, even in
        queues with guaranted memory bookings, it is possible to allocate less
        memory than the allowed maximum per job to spawn numerous low-memory 
        jobs. See Examples for details. 
    timeout : int
        Number of seconds to wait for requested jobs to start up. 
    interactive : bool
        If `True`, user input is required in case not all jobs could 
        be started in the provided waiting period (determined by `timeout`). 
        If `interactive` is `False` and the jobs could not be started
        within `timeout` seconds, a `TimeoutError` is raised. 
    start_client : bool
        If `True`, a distributed computing client is launched and attached to
        the workers. If `start_client` is `False`, only a distributed 
        computing cluster is started to which compute-clients can connect. 
    **kwargs : dict
        Additional keyword arguments can be used to control job-submission details. 
        
    Returns
    -------
    proc : object
        A distributed computing client (if ``start_client = True``) or 
        a distributed computing cluster (otherwise). 

    Examples
    --------
    The following command launches 10 SLURM jobs with 2 gigabytes memory each 
    in the `8GBS` partition
    
    >>> spy.esi_cluster_setup(n_jobs=10, partition="8GBS", mem_per_job="2GB") 
    
    If you want to access properties of the created distributed computing client, 
    assign an explicit return quantity, i.e., 
    
    >>> client = spy.esi_cluster_setup(n_jobs=10, partition="8GBS", mem_per_job="2GB") 
    
    The underlying distributed computing cluster can be accessed using
    
    >>> client.cluster
    
    Notes
    -----
    Syncopy's parallel computing engine relies on the concurrent processing library
    `Dask <https://docs.dask.org/en/latest/>`_. Thus, the distributed computing
    clients used by Syncopy are in fact instances of :class:`dask.distributed.Client`. 
    This function specifically acts  as a wrapper for :class:`dask_jobqueue.SLURMCluster`. 
    Users familiar with Dask in general and its distributed scheduler and cluster 
    objects in particular, may leverage Dask's entire API to fine-tune parallel 
    processing jobs to their liking (if wanted). 
    
    See also
    --------
    cluster_cleanup : remove dangling parallel processing job-clusters
    """
    
    # For later reference: dynamically fetch name of current function
    funcName = "Syncopy <{}>".format(inspect.currentframe().f_code.co_name)
    
    # Be optimistic: prepare success message
    successMsg = "{name:s} Cluster dashboard accessible at {dash:s}"

    # Retrieve all partitions currently available in SLURM
    out, err = subprocess.Popen("sinfo -h -o %P",
                                stdout=subprocess.PIPE, stderr=subprocess.PIPE,
                                text=True, shell=True).communicate()
    if len(err) > 0:
        
        # SLURM is not installed, either allocate `LocalCluster` or just leave
        if "sinfo: not found" in err:
            if interactive:
                msg = "{name:s} SLURM does not seem to be installed on this machine " +\
                    "({host:s}). Do you want to start a local multi-processing " +\
                    "computing client instead? "
                startLocal = user_yesno(msg.format(name=funcName, host=socket.gethostname()), 
                                        default="no")
            else:
                startLocal = True
            if startLocal:
                client = Client()
                successMsg = "{name:s} Local parallel computing client ready. \n" + successMsg
                print(successMsg.format(name=funcName, dash=client.cluster.dashboard_link))
                if start_client:
                    return client
                return client.cluster
            return 

        # SLURM is installed, but something's wrong        
        msg = "SLURM queuing system from node {node:s}. " +\
              "Original error message below:\n{error:s}"
        raise SPYIOError(msg.format(node=socket.gethostname(), error=err))
    options = out.split()

    # Make sure we're in a valid partition (exclude IT partitions from output message)
    if partition not in options:
        valid = list(set(options).difference(["DEV", "PPC"]))
        raise SPYValueError(legal="'" + "or '".join(opt + "' " for opt in valid),
                            varname="partition", actual=partition)

    # Parse job count
    try:
        scalar_parser(n_jobs, varname="n_jobs", ntype="int_like", lims=[1, np.inf])
    except Exception as exc:
        raise exc

    # Get requested memory per job
    if mem_per_job is not None:
        if not isinstance(mem_per_job, str):
            raise SPYTypeError(mem_per_job, varname="mem_per_job", expected="string")
        if not any(szstr in mem_per_job for szstr in ["MB", "GB"]):
            lgl = "string representation of requested memory (e.g., '8GB', '12000MB')"
            raise SPYValueError(legal=lgl, varname="mem_per_job", actual=mem_per_job)

    # Query memory limit of chosen partition and ensure that `mem_per_job` is
    # set for partitions w/o limit
    idx = partition.find("GB")
    if idx > 0:
        mem_lim = int(partition[:idx]) * 1000
    else:
        if partition == "PREPO":
            mem_lim = 16000
        else:
            if mem_per_job is None:
                lgl = "explicit memory amount as required by partition '{}'"
                raise SPYValueError(legal=lgl.format(partition),
                                    varname="mem_per_job", actual=mem_per_job)
        mem_lim = np.inf

    # Consolidate requested memory with chosen partition (or assign default memory)
    if mem_per_job is None:
        mem_per_job = str(mem_lim) + "MB"
    else:
        if "MB" in mem_per_job:
            mem_req = int(mem_per_job[:mem_per_job.find("MB")])
        else:
            mem_req = int(round(float(mem_per_job[:mem_per_job.find("GB")]) * 1000))
        if mem_req > mem_lim:
            msg = "`mem_per_job` exceeds limit of {lim:d}GB for partition {par:s}. " +\
                "Capping memory at partition limit. "
            SPYWarning(msg.format(lim=mem_lim, par=partition))
            mem_per_job = str(int(mem_lim)) + "GB"

    # Parse requested timeout period
    try:
        scalar_parser(timeout, varname="timeout", ntype="int_like", lims=[1, np.inf])
    except Exception as exc:
        raise exc

    # Determine if cluster allocation is happening interactively
    if not isinstance(interactive, bool):
        raise SPYTypeError(interactive, varname="interactive", expected="bool")

    # Determine if a dask client was requested
    if not isinstance(start_client, bool):
        raise SPYTypeError(start_client, varname="start_client", expected="bool")

    # Set/get "hidden" kwargs
    workers_per_job = kwargs.get("workers_per_job", 1)
    try:
        scalar_parser(workers_per_job, varname="workers_per_job",
                      ntype="int_like", lims=[1, 8])
    except Exception as exc:
        raise exc

    n_cores = kwargs.get("n_cores", 1)
    try:
        scalar_parser(n_cores, varname="n_cores",
                      ntype="int_like", lims=[1, np.inf])
    except Exception as exc:
        raise exc

    slurm_wdir = kwargs.get("slurmWorkingDirectory", None)
    if slurm_wdir is None:
        usr = getpass.getuser()
        slurm_wdir = "/mnt/hpx/slurm/{usr:s}/{usr:s}_{date:s}"
        slurm_wdir = slurm_wdir.format(usr=usr,
                                       date=datetime.now().strftime('%Y%m%d-%H%M%S'))
        os.makedirs(slurm_wdir, exist_ok=True)
    else:
        try:
            io_parser(slurm_wdir, varname="slurmWorkingDirectory", isfile=False)
        except Exception as exc:
            raise exc
        
    # Hotfix for upgraded cluster-nodes: point to correct Python executable if working from /home
    pyExec = sys.executable
    if sys.executable.startswith("/home"):
        pyExec = "/mnt/gs" + sys.executable
        
    # Create `SLURMCluster` object using provided parameters
    out_files = os.path.join(slurm_wdir, "slurm-%j.out")
    cluster = SLURMCluster(cores=n_cores,
                           memory=mem_per_job,
                           processes=workers_per_job,
                           local_directory=slurm_wdir,
                           queue=partition,
                           name="spyswarm",
                           python=pyExec,
                           header_skip=["-t", "--mem"],
                           job_extra=["--output={}".format(out_files)])
                           # interface="asdf", # interface is set via `psutil.net_if_addrs()`
                           # job_extra=["--hint=nomultithread",
                           #            "--threads-per-core=1"]
                           
    # Compute total no. of workers and up-scale cluster accordingly
    total_workers = n_jobs * workers_per_job
    cluster.scale(total_workers)

    # Fire up waiting routine to avoid premature cluster setups
    if _cluster_waiter(cluster, funcName, total_workers, timeout, interactive):
        return
    
    # Kill a zombie cluster in non-interactive mode
    if not interactive and _count_running_workers(cluster) == 0:
        cluster.close()
        err = "SLURM jobs could not be started within given time-out " +\
              "interval of {0:d} seconds"
        raise TimeoutError(err.format(timeout))
    
    # Highlight how to connect to dask performance monitor
    print(successMsg.format(name=funcName, dash=cluster.dashboard_link))

    # If client was requested, return that instead of the created cluster
    if start_client:
        return Client(cluster)
    return cluster
예제 #23
0
def _layout_subplot_panels(npanels,
                           nrow=None,
                           ncol=None,
                           ndefault=5,
                           maxpanels=50):
    """
    Create space-optimal subplot grid given required number of panels
    
    Parameters
    ----------
    npanels : int
        Number of required subplot panels in figure
    nrow : int or None
        Required number of panel rows. Note, if both `nrow` and `ncol` are not `None`,
        then ``nrow * ncol >= npanels`` has to be satisfied, otherwise a 
        :class:`~syncopy.shared.errors.SPYValueError` is raised. 
    ncol : int or None
        Required number of panel columns. Note, if both `nrow` and `ncol` are not `None`,
        then ``nrow * ncol >= npanels`` has to be satisfied, otherwise a 
        :class:`~syncopy.shared.errors.SPYValueError` is raised. 
    ndefault: int
        Default number of panel columns for grid construction (only relevant if 
        both `nrow` and `ncol` are `None`). 
    maxpanels : int
        Maximally allowed number of subplot panels for which a grid is constructed 
        
    Returns
    -------
    nrow : int
        Number of rows of constructed subplot panel grid
    nrow : int
        Number of columns of constructed subplot panel grid
        
    Notes
    -----
    If both `nrow` and `ncol` are `None`, the constructed grid will have the 
    dimension `N` x `ndefault`, where `N` is chosen "optimally", i.e., the smallest
    integer that satisfies ``ndefault * N >= npanels``. 
    Note further, that this is an auxiliary method that is intended purely for 
    internal use. Thus, error-checking is only performed on potentially user-provided 
    inputs (`nrow` and `ncol`). 
    
    See also
    --------
    :func:`._setup_figure` : create and prepare figures for Syncopy visualizations
    
    Examples
    --------
    Create grid of default dimensions to hold eight panels
    
    >>> _layout_subplot_panels(8, ndefault=5)
    (2, 5)
    
    Create a grid that must have 4 rows
    
    >>> _layout_subplot_panels(8, nrow=4)
    (4, 2)
    
    Create a grid that must have 8 columns
    
    >>> _layout_subplot_panels(8, ncol=8)
    (1, 8)
    """

    # Abort if requested panel count is less than one or exceeds provided maximum
    try:
        scalar_parser(npanels,
                      varname="npanels",
                      ntype="int_like",
                      lims=[1, np.inf])
    except Exception as exc:
        raise exc
    if npanels > maxpanels:
        lgl = "a maximum of {} panels in total".format(maxpanels)
        raise SPYValueError(legal=lgl, actual=str(npanels), varname="npanels")

    # Row specifcation was provided, cols may or may not
    if nrow is not None:
        try:
            scalar_parser(nrow,
                          varname="nrow",
                          ntype="int_like",
                          lims=[1, np.inf])
        except Exception as exc:
            raise exc
        if ncol is None:
            ncol = np.ceil(npanels / nrow).astype(np.intp)

    # Column specifcation was provided, rows may or may not
    if ncol is not None:
        try:
            scalar_parser(ncol,
                          varname="ncol",
                          ntype="int_like",
                          lims=[1, np.inf])
        except Exception as exc:
            raise exc
        if nrow is None:
            nrow = np.ceil(npanels / ncol).astype(np.intp)

    # After the preparations above, this condition is *only* satisfied if both
    # `nrow` = `ncol` = `None` -> then use generic grid-layout
    if nrow is None:
        ncol = ndefault
        nrow = np.ceil(npanels / ncol).astype(np.intp)
        ncol = min(ncol, npanels)

    # Complain appropriately if requested no. of panels does not fit inside grid
    if nrow * ncol < npanels:
        lgl = "row- and column-specification of grid to fit all panels"
        act = "grid with {0} rows and {1} columns but {2} panels"
        raise SPYValueError(legal=lgl,
                            actual=act.format(nrow, ncol, npanels),
                            varname="nrow/ncol")

    # In case a grid was provided too big for the requested no. of panels (e.g.,
    # 8 panels in an 4 x 3 grid -> would fit in 3 x 3), just warn, don't crash
    if nrow * ncol - npanels >= ncol:
        msg = "Grid dimension ({0} rows x {1} columns) larger than necessary " +\
            "for {2} panels. "
        SPYWarning(msg.format(nrow, ncol, npanels))

    return nrow, ncol
예제 #24
0
    def parallel_client_detector(*args, **kwargs):

        # Extract `parallel` keyword: if `parallel` is `False`, nothing happens
        parallel = kwargs.get("parallel")

        # Determine if multiple or a single object was supplied for processing
        objList = []
        argList = list(args)
        nTrials = 0
        for arg in args:
            if hasattr(arg, "trials"):
                objList.append(arg)
                nTrials = max(nTrials, len(arg.trials))
                argList.remove(arg)
        nObs = len(objList)

        # If parallel processing was requested but ACME is not installed, show
        # warning but keep going
        if parallel is True and not spy.__acme__:
            wrng = "ACME seems not to be installed on this system. " +\
                   "Parallel processing capabilities cannot be used. "
            SPYWarning(wrng)
            parallel = False

        # If ACME is available but `parallel` was not set *and* no active
        # dask client is found, set `parallel` to`False` (i.e., sequential
        # processing as default)
        clientRunning = False
        if parallel is None and spy.__acme__:
            try:
                dd.get_client()
                parallel = True
                clientRunning = True
            except ValueError:
                parallel = False

        # If only one object was supplied, a dask client is set up in `ComputationalRoutine`;
        # for multiple objects, count the total number of trials and start a "global"
        # computing client with `n_jobs = nTrials` in here (unless a client is already running)
        cleanup = False
        if parallel and not clientRunning and nObs > 1:
            msg = "Syncopy <{fname:s}> Launching parallel computing client " +\
                    "to process {no:d} objects..."
            print(msg.format(fname=func.__name__, no=nObs))
            client = esi_cluster_setup(n_jobs=nTrials, interactive=False)
            cleanup = True

        # Add/update `parallel` to/in keyword args
        kwargs["parallel"] = parallel

        # Process provided object(s)
        if nObs <= 1:
            results = func(*args, **kwargs)
        else:
            results = []
            for obj in objList:
                results.append(func(obj, *argList, **kwargs))

        # Kill "global" cluster started in here
        if cleanup:
            cluster_cleanup(client=client)

        return results
예제 #25
0
    def compute(self, data, out, parallel=False, parallel_store=None,
                method=None, mem_thresh=0.5, log_dict=None, parallel_debug=False):
        """
        Central management and processing method

        Parameters
        ----------
        data : syncopy data object
           Syncopy data object to be processed (has to be the same object
           that was used by :meth:`initialize` in the pre-calculation
           dry-run).
        out : syncopy data object
           Empty object for holding results
        parallel : bool
           If `True`, processing is performed in parallel (i.e.,
           :meth:`computeFunction` is executed concurrently across trials).
           If `parallel` is `False`, :meth:`computeFunction` is executed
           consecutively trial after trial (i.e., the calculation realized
           in :meth:`computeFunction` is performed sequentially).
        parallel_store : None or bool
           Flag controlling saving mechanism. If `None`,
           ``parallel_store = parallel``, i.e., the compute-paradigm
           dictates the employed writing method. Thus, in case of parallel
           processing, results are written in a fully concurrent
           manner (each worker saves its own local result segment on disk as
           soon as it is done with its part of the computation). If `parallel_store`
           is `False` and `parallel` is `True` the processing result is saved
           sequentially using a mutex. If both `parallel` and `parallel_store`
           are `False` standard single-process HDF5 writing is employed for
           saving the result of the (sequential) computation.
        method : None or str
           If `None` the predefined methods :meth:`compute_parallel` or
           :meth:`compute_sequential` are used to control the actual computation
           (specifically, calling :meth:`computeFunction`) depending on whether
           `parallel` is `True` or `False`, respectively. If `method` is a
           string, it has to specify the name of an alternative (provided)
           class method (starting with the word `"compute_"`).
        mem_thresh : float
           Fraction of available memory required to perform computation. By
           default, the largest single trial result must not occupy more than
           50% (``mem_thresh = 0.5``) of available single-machine or worker
           memory (if `parallel` is `False` or `True`, respectively).
        log_dict : None or dict
           If `None`, the `log` properties of `out` is populated with the employed
           keyword arguments used in :meth:`computeFunction`.
           Otherwise, `out`'s `log` properties are filled  with items taken
           from `log_dict`.
        parallel_debug : bool
           If `True`, concurrent processing is performed using a single-threaded
           scheduler, i.e., all parallel computing task are run in the current
           Python thread permitting usage of tools like `pdb`/`ipdb`, `cProfile`
           and the like in :meth:`computeFunction`.
           Note that enabling parallel debugging effectively runs the given computation
           on the calling machine locally thereby requiring sufficient memory and
           CPU capacity.

        Returns
        -------
        Nothing : None
           The result of the computation is available in `out` once
           :meth:`compute` terminated successfully.

        Notes
        -----
        This routine calls several other class methods to perform all necessary
        pre- and post-processing steps in a fully automatic manner without
        requiring any user-input. Specifically, the following class methods
        are invoked consecutively (in the given order):

        1. :meth:`preallocate_output` allocates a (virtual) HDF5 dataset
           of appropriate dimension for storing the result
        2. :meth:`compute_parallel` (or :meth:`compute_sequential`) performs
           the actual computation via concurrently (or sequentially) calling
           :meth:`computeFunction`
        3. :meth:`process_metadata` attaches all relevant meta-information to
           the result `out` after successful termination of the calculation
        4. :meth:`write_log` stores employed input arguments in `out.cfg`
           and `out.log` to reproduce all relevant computational steps that
           generated `out`.

        Parallel processing setup and input sanitization is performed by
        `ACME <https://github.com/esi-neuroscience/acme>`_. Specifically, all
        necessary information for parallel execution and storage of results is
        propagated to the :class:`~acme.ParallelMap` context manager.

        See also
        --------
        initialize : pre-calculation preparations
        preallocate_output : storage provisioning
        compute_parallel : concurrent computation using :meth:`computeFunction`
        compute_sequential : sequential computation using :meth:`computeFunction`
        process_metadata : management of meta-information
        write_log : log-entry organization
        acme.ParallelMap : concurrent execution of Python callables
        """

        # By default, use VDS storage for parallel computing
        if parallel_store is None:
            parallel_store = parallel

        # Do not spill trials on disk if they're supposed to be removed anyway
        if parallel_store and not self.keeptrials:
            msg = "trial-averaging only supports sequential writing!"
            SPYWarning(msg)
            parallel_store = False

        # Create HDF5 dataset of appropriate dimension
        self.preallocate_output(out, parallel_store=parallel_store)

        # Concurrent processing requires some additional prep-work...
        if parallel:

            # Construct list of dicts that will be passed on to workers: in the
            # parallel case, `trl_dat` is a dictionary!
            workerDicts = [{"hdr": self.hdr,
                            "keeptrials": self.keeptrials,
                            "infile": data.filename,
                            "indset": data.data.name,
                            "ingrid": self.sourceLayout[chk],
                            "sigrid": self.sourceSelectors[chk],
                            "fancy": self.useFancyIdx,
                            "vdsdir": self.virtualDatasetDir,
                            "outfile": self.outFileName.format(chk),
                            "outdset": self.tmpDsetName,
                            "outgrid": self.targetLayout[chk],
                            "outshape": self.targetShapes[chk],
                            "dtype": self.dtype} for chk in range(self.numCalls)]

            # If channel-block parallelization has been set up, positional args of
            # `computeFunction` need to be massaged: any list whose elements represent
            # trial-specific args, needs to be expanded (so that each channel-block
            # per trial receives the correct number of pos. args)
            ArgV = list(self.argv)
            if self.numBlocksPerTrial > 1:
                for ak, arg in enumerate(self.argv):
                    if isinstance(arg, (list, tuple)):
                        if len(arg) == self.numTrials:
                            unrolled = chain.from_iterable([[ag] * self.numBlocksPerTrial for ag in arg])
                            if isinstance(arg, list):
                                ArgV[ak] = list(unrolled)
                            else:
                                ArgV[ak] = tuple(unrolled)
                    elif isinstance(arg, np.ndarray):
                        if len(arg.squeeze().shape) == 1 and arg.squeeze().size == self.numTrials:
                            ArgV[ak] = np.array(chain.from_iterable([[ag] * self.numBlocksPerTrial for ag in arg]))

            # Positional args for computeFunctions consist of `trl_dat` + others
            # (stored in `ArgV`). Account for this when seeting up `ParallelMap`
            if len(ArgV) == 0:
                inargs = (workerDicts, )
            else:
                inargs = (workerDicts, *ArgV)

            # Let ACME take care of argument distribution and memory checks: note
            # that `cfg` is trial-independent, i.e., we can simply throw it in here!
            self.pmap = ParallelMap(self.computeFunction,
                                    *inargs,
                                    n_inputs=self.numCalls,
                                    write_worker_results=False,
                                    write_pickle=False,
                                    partition="auto",
                                    n_jobs="auto",
                                    mem_per_job="auto",
                                    setup_timeout=60,
                                    setup_interactive=False,
                                    stop_client="auto",
                                    verbose=None,
                                    logfile=None,
                                    **self.cfg)

            # Edge-case correction: if by chance, any array-like element `x` of `cfg`
            # satisfies `len(x) = numCalls`, `ParallelMap` attempts to tear open `x` and
            # distribute its elements across workers. Prevent this!
            if self.numCalls > 1:
                for key, value in self.pmap.kwargv.items():
                    if isinstance(value, (list, tuple)):
                        if len(value) == self.numCalls:
                            self.pmap.kwargv[key] = [value]
                    elif isinstance(value, np.ndarray):
                        if len(value.squeeze().shape) == 1 and value.squeeze().size == self.numCalls:
                            self.pmap.kwargv[key] = [value]

            # Check if trials actually fit into memory before we start computation
            client = self.pmap.daemon.client
            if isinstance(client.cluster, (dd.LocalCluster, dj.SLURMCluster)):
                workerMem = [w["memory_limit"] for w in client.cluster.scheduler_info["workers"].values()]
                if len(workerMem) == 0:
                    raise SPYParallelError("no online workers found", client=client)
                workerMemMax = max(workerMem)
                if self.chunkMem >= mem_thresh * workerMemMax:
                    self.chunkMem /= 1024**3
                    workerMemMax /= 1000**3
                    msg = "Single-trial result sizes ({0:2.2f} GB) larger than available " +\
                        "worker memory ({1:2.2f} GB) currently not supported"
                    raise NotImplementedError(msg.format(self.chunkMem, workerMemMax))
            else:
                msg = "`ComputationalRoutine` only supports `LocalCluster` and " +\
                    "`SLURMCluster` dask cluster objects. Proceed with caution. "
                SPYWarning(msg)

            # Store provided debugging state
            self.parallelDebug = parallel_debug

        # For sequential processing, just ensure enough memory is available
        else:
            memSize = psutil.virtual_memory().available
            if self.chunkMem >= mem_thresh * memSize:
                self.chunkMem /= 1024**3
                memSize /= 1024**3
                msg = "Single-trial result sizes ({0:2.2f} GB) larger than available " +\
                      "memory ({1:2.2f} GB) currently not supported"
                raise NotImplementedError(msg.format(self.chunkMem, memSize))

        # The `method` keyword can be used to override the `parallel` flag
        if method is None:
            if parallel:
                computeMethod = self.compute_parallel
            else:
                computeMethod = self.compute_sequential
        else:
            computeMethod = getattr(self, "compute_" + method, None)

        # Ensure `data` is openend read-only to permit (potentially concurrent)
        # reading access to backing device on disk
        data.mode = "r"

        # Take care of `VirtualData` objects
        self.hdr = getattr(data, "hdr", None)

        # Perform actual computation
        computeMethod(data, out)

        # Reset data access mode
        data.mode = self.dataMode

        # Attach computed results to output object
        out.data = h5py.File(out.filename, mode="r+")[self.outDatasetName]

        # Store meta-data, write log and get outta here
        self.process_metadata(data, out)
        self.write_log(data, out, log_dict)
예제 #26
0
def singlepanelplot(self,
                    trials="all",
                    channels="all",
                    tapers="all",
                    toilim=None,
                    foilim=None,
                    avg_channels=True,
                    avg_tapers=True,
                    interp="spline36",
                    cmap="plasma",
                    vmin=None,
                    vmax=None,
                    title=None,
                    grid=None,
                    fig=None,
                    **kwargs):
    """
    Plot contents of :class:`~syncopy.SpectralData` objects using single-panel figure(s)
    
    Please refer to :func:`syncopy.singlepanelplot` for detailed usage information. 
    
    Examples
    --------
    Show frequency range 30-80 Hz of channel `'ecog_mua2'` averaged across 
    trials 2, 4, and 6:
    
    >>> fig = spy.singlepanelplot(freqData, trials=[2, 4, 6], channels=["ecog_mua2"],
                                  foilim=[30, 80])
                                  
    Overlay channel `'ecog_mua3'` with same settings:
    
    >>> fig2 = spy.singlepanelplot(freqData, trials=[2, 4, 6], channels=['ecog_mua3'],
                                   foilim=[30, 80], fig=fig)
                                   
    Plot time-frequency contents of channel `'ecog_mua1'` present in both objects 
    `tfData1` and `tfData2` using the 'viridis' colormap, a plot grid, manually 
    defined lower and upper color value limits and no interpolation
    
    >>> fig1, fig2 = spy.singlepanelplot(tfData1, tfData2, channels=['ecog_mua1'],
                                         cmap="viridis", vmin=0.25, vmax=0.95, 
                                         interp=None, grid=True, overlay=False)
    
    Note that overlay plotting is **not** supported for time-frequency objects. 
 
    See also
    --------
    syncopy.singlepanelplot : visualize Syncopy data objects using single-panel plots
    """

    # Collect input arguments in dict `inputArgs` and process them
    inputArgs = locals()
    inputArgs.pop("self")
    (dimArrs, dimCounts, isTimeFrequency, complexConversion, pltDtype,
     dataLbl) = _prep_spectral_plots(self, "singlepanelplot", **inputArgs)
    (nTrials, nChan, nFreq, nTap) = dimCounts
    (trList, chArr, freqArr, tpArr) = dimArrs

    # If we're overlaying, ensure data and plot type match up
    if hasattr(fig, "objCount"):
        if isTimeFrequency:
            msg = "Overlay plotting not supported for time-frequency data"
            raise SPYError(msg)
        if not hasattr(fig, "spectralPlot"):
            lgl = "figure visualizing data from a Syncopy `SpectralData` object"
            act = "visualization of other Syncopy data"
            raise SPYValueError(legal=lgl, varname="fig", actual=act)
        if hasattr(fig, "multipanelplot"):
            lgl = "single-panel figure generated by `singleplot`"
            act = "multi-panel figure generated by `multipanelplot`"
            raise SPYValueError(legal=lgl, varname="fig", actual=act)

    # No time-frequency shenanigans: this is a simple power-spectrum (line-plot)
    if not isTimeFrequency:

        # Generic titles for figures
        overlayTitle = "Overlay of {} datasets"

        # Either create new figure or fetch existing
        if fig is None:
            fig, ax = _setup_figure(1,
                                    xLabel="Frequency [Hz]",
                                    yLabel=dataLbl,
                                    grid=grid)
            fig.spectralPlot = True
        else:
            ax, = fig.get_axes()

        # Average across channels, tapers or both using local helper func
        nTime = 1
        if not avg_channels and not avg_tapers and nTap > 1:
            msg = "Either channels or trials need to be averaged for single-panel plot"
            SPYWarning(msg)
            return
        if avg_channels and not avg_tapers:
            panelTitle = "{} tapers averaged across {} channels and {} trials".format(
                nTap, nChan, nTrials)
            pltArr = _compute_pltArr(self,
                                     nFreq,
                                     nTap,
                                     nTime,
                                     complexConversion,
                                     pltDtype,
                                     avg1="channel")
        if avg_tapers and not avg_channels:
            panelTitle = "{} channels averaged across {} tapers and {} trials".format(
                nChan, nTap, nTrials)
            pltArr = _compute_pltArr(self,
                                     nFreq,
                                     nChan,
                                     nTime,
                                     complexConversion,
                                     pltDtype,
                                     avg1="taper")
        if avg_tapers and avg_channels:
            panelTitle = "Average of {} channels, {} tapers and {} trials".format(
                nChan, nTap, nTrials)
            pltArr = _compute_pltArr(self,
                                     nFreq,
                                     1,
                                     nTime,
                                     complexConversion,
                                     pltDtype,
                                     avg1="taper",
                                     avg2="channel")

        # Perform the actual plotting
        ax.plot(freqArr,
                np.log10(pltArr),
                label=os.path.basename(self.filename))
        ax.set_xlim([freqArr[0], freqArr[-1]])

        # Set plot title depending on dataset overlay
        if fig.objCount == 0:
            if title is None:
                title = panelTitle
            ax.set_title(title, size=pltConfig["singleTitleSize"])
        else:
            handles, labels = ax.get_legend_handles_labels()
            ax.legend(handles, labels)
            if title is None:
                title = overlayTitle.format(len(handles))
            ax.set_title(title, size=pltConfig["singleTitleSize"])

    else:

        # For a single-panel TF visualization, we need to average across both tapers + channels
        if not avg_channels and (not avg_tapers and nTap > 1):
            msg = "Single-panel time-frequency visualization requires averaging " +\
                "across both tapers and channels"
            SPYWarning(msg)
            return

        # Compute (and verify) length of selected time intervals and assemble array for plotting
        panelTitle = "Average of {} channels, {} tapers and {} trials".format(
            nChan, nTap, nTrials)
        tLengths = _prep_toilim_avg(self)
        nTime = tLengths[0]
        pltArr = _compute_pltArr(self,
                                 nFreq,
                                 1,
                                 nTime,
                                 complexConversion,
                                 pltDtype,
                                 avg1="taper",
                                 avg2="channel")

        # Prepare figure
        fig, ax, cax = _setup_figure(1,
                                     xLabel="Time [s]",
                                     yLabel="Frequency [Hz]",
                                     include_colorbar=True,
                                     grid=grid)
        fig.spectralPlot = True

        # Use `imshow` to render array as image
        time = self.time[trList[0]][self._selection.time[0]]
        ax.imshow(pltArr,
                  origin="lower",
                  interpolation=interp,
                  cmap=cmap,
                  vmin=vmin,
                  vmax=vmax,
                  extent=(time[0], time[-1], freqArr[0], freqArr[-1]),
                  aspect="auto")
        cbar = _setup_colorbar(fig,
                               ax,
                               cax,
                               label=dataLbl.replace(" [dB]", ""))
        if title is None:
            title = panelTitle
        ax.set_title(title, size=pltConfig["singleTitleSize"])

    # Increment overlay-counter and draw figure
    fig.objCount += 1
    plt.draw()
    self._selection = None
    return fig