Exemplo n.º 1
0
    def add_coefficients(self, data, **kwargs):
        """
            Add an individual CoefficientResult. Note that it is not possible
            to add the same data twice, instead it will be redrawn with
            the new arguments/style options provided.

            Parameters
            ----------
            data : CoefficientResult
                Added to the list of plotted elements.

            kwargs
                Keyword arguments passed to
                :obj:`matplotlib.axes.Axes.plot`. Use to customise the
                plots. If a `label` is set via `kwargs`, it will be used to
                overwrite the description of `data` in the meta file.
                If an alpha value is or linestyle is set, the shaded error
                region will be omitted.

            Example
            -------
            .. code-block:: python

                rk = mre.coefficients(mre.simulate_branching())

                mout = mre.OutputHandler()
                mout.add_coefficients(rk, color='C1', label='test')
            ..
        """
        if not isinstance(data, CoefficientResult):
            log.exception("'data' needs to be of type CoefficientResult")
            raise ValueError
        if not (self.type is None or self.type == 'correlation'):
            log.exception(
                "It is not possible to 'add_coefficients()' to " +
                "an OutputHandler containing a time series\n" +
                "\tHave you previously called 'add_ts()' on this handler?")
            raise ValueError
        self.type = 'correlation'

        # description for columns of meta data
        desc = str(data.desc)

        # plot legend label
        if 'label' in kwargs:
            label = kwargs.get('label')
            if label == '':
                label = None
            if label is None:
                labelerr = None
            else:
                # user wants custom label not intended to hide the legend
                label = str(label)
                labelerr = str(label) + ' Errors'
                # apply to meta data, too
                desc = str(label)
        else:
            # user has not set anything, copy from desc if set
            label = 'Data'
            labelerr = 'Errors'
            if desc != '':
                label = desc
                labelerr = desc + ' Errors'

        if desc != '':
            desc += ' '

        # dont put errors in the legend. this should become a user choice
        labelerr = ''

        # no previous coefficients present
        if len(self.rks) == 0:
            self.dt = data.dt
            self.dtunit = data.dtunit
            self.xlabel = \
                'steps[{}{}]'.format(ut._printeger(data.dt, 5), data.dtunit)
            self.ax.set_xlabel('k [{}{}]'.format(ut._printeger(data.dt, 5),
                                                 data.dtunit))
            self.ax.set_ylabel('$r_{k}$')
            self.ax.set_title('Correlation')

        # we dont support adding duplicates
        oldcurves = []
        if data in self.rks:
            indrk = self.rks.index(data)
            log.warning(
                'Coefficients ({}/{}) '.format(self.rklabels[indrk][0], label)
                + 'have already been added\n\tOverwriting with new style')
            del self.rks[indrk]
            del self.rklabels[indrk]
            oldcurves = self.rkcurves[indrk]
            del self.rkcurves[indrk]
            del self.rkkwargs[indrk]

        # add to meta data
        else:
            inds = self.set_xdata(data.steps, dt=data.dt, dtunit=data.dtunit)
            ydata = np.full(self.xdata.size, np.nan)
            ydata[inds] = data.coefficients
            self.ydata.append(ydata)
            self.ylabels.append(desc + 'coefficients')

            if data.stderrs is not None:
                ydata = np.full(self.xdata.size, np.nan)
                ydata[inds] = data.stderrs
                self.ydata.append(ydata)
                self.ylabels.append(desc + 'stderrs')

        self.rks.append(data)
        self.rklabels.append([label, labelerr])
        self.rkcurves.append(oldcurves)
        self.rkkwargs.append(kwargs)

        # refresh coefficients
        for r in self.rks:
            self._render_coefficients(r)

        # refresh fits
        for f in self.fits:
            self._render_fit(f)
Exemplo n.º 2
0
def simulate_branching(m,
                       a=None,
                       h=None,
                       length=10000,
                       numtrials=1,
                       subp=1,
                       seed='random'):
    """
        Simulates a branching process with Poisson input. Returns data
        in the trial structure.

        Per default, the function discards the first
        few time steps to produce stationary activity. If a
        `drive` is passed as ``h=0``, the recording starts instantly
        (and produces exponentially decaying activity).

        Parameters
        ----------
        m : float
            Branching parameter.

        a : float
            Stationarity activity of the process.
            Only considered if no drive `h` is specified.

        h : ~numpy.array, optional
            Specify a custom drive (possibly changing) for every time step.
            If `h` is given, its length takes priority over the `length`
            parameter. If the first or only value of `h` is zero, the recording
            starts instantly with set activity `a` and the resulting timeseries
            will not be stationary in the beginning.

        length : int, optional
            Number of steps for the process, thereby sets the total length of
            the generated time series. Overwritten if drive `h` is set as an
            array.

        numtrials : int, optional
            Generate 'numtrials' trials. Default is 1.

        seed : int, optional
            Initialise the random number generator with a seed. Per default,
            ``seed='random'`` and the generator is seeded randomly (hence
            each call to `simulate_branching()` returns different results).
            ``seed=None`` skips (re)seeding.

        subp : float, optional
            Subsample the activity with the probability `subp` (calls
            `simulate_subsampling()` before returning).

        Returns
        -------
        : :class:`~numpy.ndarray`
            with `numtrials` time series, each containging
            `length` entries of activity.
            Per default, one trial is created with
            10000 measurements.
    """

    length = int(length)
    numtrials = int(numtrials)
    if h is None:
        if a is None:
            log.exception("Missing argument, either provide " +
                          "the activity 'a' or the drive 'h'")
            raise TypeError
        else:
            h = np.full((length), a * (1 - m))
    else:
        if a is None:
            a = 0
        h = np.asarray(h)
        if h.size == 1:
            h = np.full((length), h)
        elif len(h.shape) != 1:
            log.exception("Argument drive 'h' needs to be a float or 1d array")
            raise ValueError
        else:
            length = h.size

    log.debug('simulate_branching() seeding to {}'.format(seed))
    if seed is None:
        pass
    elif seed == 'random':
        np.random.seed(None)
    else:
        np.random.seed(seed)

    if h[0] == 0 and a != 0:
        log.debug('Skipping thermalization since initial h=0')
    if h[0] == 0 and a == 0:
        log.warning('activity a=0 and initial h=0')

    log.info('Generating branching process with m={}'.format(ut._printeger(m)))
    log.debug(
        '{:d} trials with {:d} time steps each\n'.format(numtrials, length) +
        'branchign ratio m={}\n'.format(m) +
        '(initial) activity a={}\n'.format(a) +
        '(initial) drive rate h={}'.format(h[0]))

    A_t = np.zeros(shape=(numtrials, length), dtype=int)
    a = np.ones_like(A_t[:, 0]) * a

    # if drive is zero, user would expect exp-decay of set activity
    # for m>1 we want exp-increase, else
    # avoid nonstationarity by discarding some steps
    if (h[0] != 0 and h[0] and m < 1):
        therm = np.fmax(100, int(length * 0.05))
        log.info('Setting up stationarity, {:d} steps'.format(therm))
        for idx in range(0, therm):
            a = np.random.poisson(lam=m * a + h[0])

    A_t[:, 0] = np.random.poisson(lam=m * a + h[0])
    for idx in range(1, length):
        try:
            # if m >= 1 activity may explode until this throws an error
            A_t[:, idx] = np.random.poisson(lam=m * A_t[:, idx - 1] + h[idx])
        except ValueError as e:
            log.debug('Exception passed for bp generation', exc_info=True)
            # A_t.resize((numtrials, idx))
            A_t = A_t[:, 0:idx]
            log.info(
                'Activity is exceeding numeric limits, canceling ' +
                'and resizing output from length={} to {}'.format(length, idx))
            break

    if subp != 1 and subp is not None:
        try:
            # do not change rng seed when calling this as nested, otherwise
            # bp with subs. is not reproducible even with given seed
            return simulate_subsampling(A_t, prob=subp, seed=None)
        except ValueError:
            log.debug('Exception passed', exc_info=True)
    return A_t
Exemplo n.º 3
0
    def set_xdata(self, data=None, dt=1, dtunit=None):
        """
            Adjust xdata of the plot, matching the input value.
            Returns an array of indices matching the incoming indices to
            already present ones. Automatically called when adding content.

            If you want to customize the plot range, add all the content
            and use matplotlibs
            :obj:`~matplotlib.axes.Axes.set_xlim` function once at the end.
            (`set_xdata()` also manages meta data and can only *increase* the
            plot range)

            Parameters
            ----------
            data : ~numpy.array
                x-values to plot the fits for. `data` does not need to be
                spaced equally but is assumed to be sorted.

            dt : float
                check if existing data can be mapped to the new, provided `dt`
                or the other way around. `set_xdata()` pads
                undefined areas with `nan`.

            dtunit : str
                check if the new `dtunit` matches the one set previously. Any
                padding to match `dt` is only done if `dtunits` are the same,
                otherwise the plot falls back to using generic integer steps.

            Returns
            -------
            : :class:`~numpy.array`
                containing the indices where the `data` given to this function
                coincides with (possibly) already existing data that was
                added/plotted before.

            Example
            -------
            .. code-block:: python

                out = mre.OutputHandler()

                # 100 intervals of 2ms
                out.set_xdata(np.arange(0,100), dt=2, dtunit='ms')

                # increase resolution to 1ms for the first 50ms
                # this changes the existing structure in the meta data. also
                # the axis of `out` is not equally spaced anymore
                fiftyms = np.arange(0,50)
                out.set_xdata(fiftyms, dt=1, dtunit='ms')

                # data with larger intervals is less dense, the returned list
                # tells you which index in `out` belongs to every index
                # in `xdat`
                xdat = np.arange(0,50)
                ydat = np.random_sample(50)
                inds = out.set_xdata(xdat, dt=4, dtunit='ms')

                # to pad `ydat` to match the axis of `out`:
                temp = np.full(out.xdata.size, np.nan)
                temp[inds] = ydat

            ..
        """
        log.debug('OutputHandler.set_xdata()')
        # make sure data is not altered
        xdata = np.copy(data.astype('float64'))
        # xdata = data

        # nothing set so far, no arugment provided, return some default
        if self.xdata is None and xdata is None:
            self.xdata = np.arange(0, 1501)
            self.dtunit = dtunit
            self.dt = dt
            return np.arange(0, 1501)

        # set x for the first time, copying input
        if self.xdata is None:
            self.xdata = np.array(xdata)
            self.dtunit = dtunit
            self.dt = dt
            return np.arange(0, self.xdata.size)

        # no new data provided, no need to call this
        elif xdata is None:
            log.debug("set_xdata() called without argument when " +
                      "xdata is already set. Nothing to adjust")
            return np.arange(0, self.xdata.size)

        # compare dtunits
        elif dtunit != self.dtunit and dtunit is not None:
            log.warning("'dtunit' does not match across added elements, " +
                        "adjusting axis label to '[different units]'")
            regex = r'\[.*?\]'
            oldlabel = self.ax.get_xlabel()
            self.ax.set_xlabel(re.sub(regex, '[different units]', oldlabel))

        # set dtunit to new value if not assigned yet
        elif self.dtunit is None and dtunit is not None:
            self.dtunit = dtunit

        # new data matches old data, nothing to adjust
        if np.array_equal(self.xdata, xdata) and self.dt == dt:
            return np.arange(0, self.xdata.size)

        # compare timescales dt
        elif self.dt < dt:
            log.debug('dt does not match,')
            scd = dt / self.dt
            if float(scd).is_integer():
                log.debug(
                    'Changing axis values of new data (dt={})'.format(dt) +
                    'to match higher resolution of ' +
                    'old xaxis (dt={})'.format(self.dt))
                scd = dt / self.dt
                xdata *= scd
            else:
                log.warning(
                    "New 'dt={}' is not an integer multiple of ".format(dt) +
                    "the previous 'dt={}\n".format(self.dt) +
                    "\tPlotting with '[different units]'\n" +
                    "\tAs a workaround, try adding the data with the " +
                    "smallest 'dt' first")
                try:
                    regex = r'\[.*?\]'
                    oldlabel = self.ax.get_xlabel()
                    self.ax.set_xlabel(
                        re.sub(regex, '[different units]', oldlabel))
                    self.xlabel = re.sub(regex, '[different units]',
                                         self.xlabel)
                except TypeError:
                    log.debug('Exception passed', exc_info=True)

        elif self.dt > dt:
            scd = self.dt / dt
            if float(scd).is_integer():
                log.debug(
                    "Changing 'dt' to new value 'dt={}'\n".format(dt) +
                    "\tAdjusting existing axis values (dt={})".format(self.dt))
                self.xdata *= scd
                self.dt = dt
                try:
                    regex = r'\[.*?\]'
                    oldlabel = self.ax.get_xlabel()
                    newlabel = str('[{}{}]'.format(ut._printeger(self.dt),
                                                   self.dtunit))
                    self.ax.set_xlabel(re.sub(regex, newlabel, oldlabel))
                    self.xlabel = re.sub(regex, newlabel, self.xlabel)
                except TypeError:
                    pass
            else:
                log.warning(
                    "old 'dt={}' is not an integer multiple ".format(self.dt) +
                    "of the new value 'dt={}'\n".format(self.dt) +
                    "\tPlotting with '[different units]'\n")
                try:
                    regex = r'\[.*?\]'
                    oldlabel = self.ax.get_xlabel()
                    self.ax.set_xlabel(
                        re.sub(regex, '[different units]', oldlabel))
                    self.xlabel = re.sub(regex, '[different units]',
                                         self.xlabel)
                except TypeError:
                    pass

        # check if new is subset of old
        temp = np.union1d(self.xdata, xdata)
        if not np.array_equal(self.xdata, temp):
            log.debug('Rearranging present data')
            _, indtemp = ut._intersecting_index(self.xdata, temp)
            self.xdata = temp
            for ydx, col in enumerate(self.ydata):
                coln = np.full(self.xdata.size, np.nan)
                coln[indtemp] = col
                self.ydata[ydx] = coln

        # return list of indices where to place new ydata in the existing
        # (higher-resolution) notation
        indold, indnew = ut._intersecting_index(self.xdata, xdata)
        assert (len(indold) == len(xdata))

        return indold
Exemplo n.º 4
0
def overview(src, rks, fits, **kwargs):
    """
        creates an A4 overview panel and returns the matplotlib figure element.
        No Argument checks are done
    """

    # ratios = np.ones(4)*.75
    # ratios[3] = 0.25
    ratios = None
    # A4 in inches, should check rc params in the future
    # matplotlib changes the figure size when modifying subplots
    topshift = 0.925
    fig, axes = plt.subplots(nrows=4,
                             figsize=(8.27, 11.69 * topshift),
                             gridspec_kw={"height_ratios": ratios})

    # avoid huge file size for many trials due to separate layers.
    # everything below 0 gets rastered to the same layer.
    axes[0].set_rasterization_zorder(0)

    # ------------------------------------------------------------------ #
    # Time Series
    # ------------------------------------------------------------------ #

    tsout = OutputHandler(ax=axes[0])
    tsout.add_ts(src, label='Trials')
    if (src.shape[0] > 1):
        try:
            prevclr = plt.rcParams["axes.prop_cycle"].by_key()["color"][0]
        except Exception:
            prevclr = 'navy'
            log.debug('Exception getting color cycle', exc_info=True)
        tsout.add_ts(np.mean(src, axis=0), color=prevclr, label='Average')
    else:
        tsout.ax.legend().set_visible(False)

    tsout.ax.set_title('Time Series (Input Data)')
    tsout.ax.set_xlabel('t [{}{}]'.format(ut._printeger(rks[0].dt),
                                          rks[0].dtunit))

    # ------------------------------------------------------------------ #
    # Mean Trial Activity
    # ------------------------------------------------------------------ #

    if (src.shape[0] > 1):
        # average trial activites as function of trial number
        taout = OutputHandler(rks[0].trialactivities, ax=axes[1])
        try:
            err1 = rks[0].trialactivities - np.sqrt(rks[0].trialvariances)
            err2 = rks[0].trialactivities + np.sqrt(rks[0].trialvariances)
            prevclr = plt.rcParams["axes.prop_cycle"].by_key()["color"][0]
            taout.ax.fill_between(np.arange(1, rks[0].numtrials + 1),
                                  err1,
                                  err2,
                                  color=prevclr,
                                  alpha=0.2)
        except Exception as e:
            log.debug('Exception adding std deviation to plot', exc_info=True)
        taout.ax.set_title('Mean Trial Activity and Std. Deviation')
        taout.ax.set_xlabel('Trial i')
        taout.ax.set_ylabel('$\\bar{A}_i$')
    else:
        # running average over the one trial to see if stays stationary
        numsegs = kwargs.get(numsegs) if 'numsegs' in kwargs else 50
        ravg = np.zeros(numsegs)
        err1 = np.zeros(numsegs)
        err2 = np.zeros(numsegs)
        seglen = int(src.shape[1] / numsegs)
        for s in range(numsegs):
            temp = np.mean(src[0][s * seglen:(s + 1) * seglen])
            ravg[s] = temp
            stddev = np.sqrt(np.var(src[0][s * seglen:(s + 1) * seglen]))
            err1[s] = temp - stddev
            err2[s] = temp + stddev

        taout = OutputHandler(ravg, ax=axes[1])
        try:
            prevclr = plt.rcParams["axes.prop_cycle"].by_key()["color"][0]
            taout.ax.fill_between(np.arange(1, numsegs + 1),
                                  err1,
                                  err2,
                                  color=prevclr,
                                  alpha=0.2)
        except Exception as e:
            log.debug('Exception adding std deviation to plot', exc_info=True)
        taout.ax.set_title(
            'Average Activity and Stddev for {} Intervals'.format(numsegs))
        taout.ax.set_xlabel('Interval i')
        taout.ax.set_ylabel('$\\bar{A}_i$')

    # ------------------------------------------------------------------ #
    # Coefficients and Fit results
    # ------------------------------------------------------------------ #

    cout = OutputHandler(rks + fits, ax=axes[2])

    fitcurves = []
    fitlabels = []
    for i, f in enumerate(cout.fits):
        fitcurves.append(cout.fitcurves[i][0])
        label = ut.math_from_doc(f.fitfunc, 5)
        label += '\n\n$\\tau={:.2f}${}\n'.format(f.tau, f.dtunit)
        if f.tauquantiles is not None:
            label += '$[{:.2f}:{:.2f}]$\n\n' \
                .format(f.tauquantiles[0], f.tauquantiles[-1])
        else:
            label += '\n\n'
        label += '$m={:.5f}$\n'.format(f.mre)
        if f.mrequantiles is not None:
            label +='$[{:.5f}:{:.5f}]$' \
                .format(f.mrequantiles[0], f.mrequantiles[-1])
        else:
            label += '\n'
        fitlabels.append(label)

    tempkwargs = {
        # 'title': 'Fitresults',
        'ncol': len(fitlabels),
        'loc': 'upper center',
        'mode': 'expand',
        'frameon': True,
        'markerfirst': True,
        'fancybox': False,
        # 'framealpha': 1,
        'borderaxespad': 0,
        'edgecolor': 'black',
        # hide handles
        'handlelength': 0,
        'handletextpad': 0,
    }
    try:
        axes[3].legend(fitcurves, fitlabels, **tempkwargs)
    except Exception:
        log.debug('Exception passed', exc_info=True)
        del tempkwargs['edgecolor']
        axes[3].legend(fitcurves, fitlabels, **tempkwargs)

    # hide handles
    for handle in axes[3].get_legend().legendHandles:
        handle.set_visible(False)

    # center text
    for t in axes[3].get_legend().texts:
        t.set_multialignment('center')

    # apply stile and fill legend
    axes[3].get_legend().get_frame().set_linewidth(0.5)
    axes[3].axis('off')
    axes[3].set_title('Fitresults\n[$12.5\\%$:$87.5\\%$]')
    for a in axes:
        a.xaxis.set_tick_params(width=0.5)
        a.yaxis.set_tick_params(width=0.5)
        for s in a.spines:
            a.spines[s].set_linewidth(0.5)

    fig.tight_layout(h_pad=2.0)
    plt.subplots_adjust(top=topshift)
    title = kwargs.get('title') if 'title' in kwargs else None
    if (title is not None and title != ''):
        fig.suptitle(title + '\n', fontsize=14)

    if 'warning' in kwargs and kwargs.get('warning') is not None:
        s = u'\u26A0 {}'.format(kwargs.get('warning'))
        fig.text(.5,
                 .01,
                 s,
                 fontsize=13,
                 horizontalalignment='center',
                 color='red')

    fig.text(.995,
             .005,
             'v{}'.format(__version__),
             fontsize=8,
             horizontalalignment='right',
             color='silver')
    return fig
Exemplo n.º 5
0
    def save_meta(self, fname=''):
        """
            Saves only the details/source used to create the plot. It is
            recommended to call this manually, if you decide to save
            the plots yourself or when you want only the fit results.

            Parameters
            ----------
            fname : str, optional
                Path where to save, without file extension. Defaults to "./mre"
        """
        if not isinstance(fname, str): fname = str(fname)
        if fname == '': fname = './mre'

        # try creating enclosing dir if not existing
        tempdir = os.path.abspath(os.path.expanduser(fname + "/../"))
        os.makedirs(tempdir, exist_ok=True)

        fname = os.path.expanduser(fname)

        log.info('Saving meta to {}.tsv'.format(fname))
        # fits
        hdr = 'Mr. Estimator v{}\n'.format(__version__)
        try:
            for fdx, fit in enumerate(self.fits):
                hdr += '{}\n'.format('-' * 72)
                hdr += 'legendlabel: ' + str(self.fitlabels[fdx]) + '\n'
                hdr += '{}\n'.format('-' * 72)
                if fit.desc != '':
                    hdr += 'description: ' + str(fit.desc) + '\n'
                hdr += 'm = {}\ntau = {} [{}]\n' \
                    .format(fit.mre, fit.tau, fit.dtunit)
                if fit.quantiles is not None:
                    hdr += 'quantiles | tau [{}] | m:\n'.format(fit.dtunit)
                    for i, q in enumerate(fit.quantiles):
                        hdr += '{:6.3f} | '.format(fit.quantiles[i])
                        hdr += '{:8.3f} | '.format(fit.tauquantiles[i])
                        hdr += '{:8.8f}\n'.format(fit.mrequantiles[i])
                    hdr += '\n'
                hdr += 'fitrange: {} <= k <= {} [{}{}]\n'.format(
                    fit.steps[0], fit.steps[-1], ut._printeger(fit.dt),
                    fit.dtunit)
                hdr += 'function: ' + ut.math_from_doc(fit.fitfunc) + '\n'
                # hdr += '\twith parameters:\n'
                parname = list(inspect.signature(fit.fitfunc).parameters)[1:]
                parlen = len(max(parname, key=len))
                for pdx, par in enumerate(self.fits[fdx].popt):
                    unit = ''
                    if parname[pdx] == 'nu':
                        unit += '[1/{}]'.format(fit.dtunit)
                    elif parname[pdx].find('tau') != -1:
                        unit += '[{}]'.format(fit.dtunit)
                    hdr += '\t{: <{width}}'.format(parname[pdx] + ' ' + unit,
                                                   width=parlen + 5 +
                                                   len(fit.dtunit))
                    hdr += ' = {}\n'.format(par)
                hdr += '\n'
        except Exception as e:
            log.debug('Exception passed', exc_info=True)

        # rks / ts
        labels = ''
        dat = []
        if self.ydata is not None and len(self.ydata) != 0:
            hdr += '{}\n'.format('-' * 72)
            hdr += 'Data\n'
            hdr += '{}\n'.format('-' * 72)
            labels += '1_' + self.xlabel
            for ldx, label in enumerate(self.ylabels):
                labels += '\t' + str(ldx + 2) + '_' + label
            labels = labels.replace(' ', '_')
            dat = np.vstack((self.xdata, np.asarray(self.ydata)))
        np.savetxt(fname + '.tsv',
                   np.transpose(dat),
                   delimiter='\t',
                   header=hdr + labels)
Exemplo n.º 6
0
def overview(src, rks, fits, **kwargs):
    """
        creates an A4 overview panel and returns the matplotlib figure element.
        No Argument checks are done
    """

    ratios = np.ones(5)
    ratios[4] = 0.0001
    # ratios=None
    # A5 in inches, should check rc params in the future
    # matplotlib changes the figure size when modifying subplots
    fig, axes = plt.subplots(nrows=5,
                             figsize=(5.8, 8.3),
                             gridspec_kw={"height_ratios": ratios})

    # avoid huge file size for many trials due to separate layers.
    # everything below 0 gets rastered to the same layer.
    axes[0].set_rasterization_zorder(0)

    # ------------------------------------------------------------------ #
    # Time Series
    # ------------------------------------------------------------------ #

    tsout = OutputHandler(ax=axes[0])
    tsout.add_ts(src, label="Trials")
    if src.shape[0] > 1:
        try:
            prevclr = plt.rcParams["axes.prop_cycle"].by_key()["color"][0]
        except Exception:
            prevclr = "navy"
            log.debug("Exception getting color cycle", exc_info=True)
        tsout.add_ts(np.mean(src, axis=0), color=prevclr, label="Average")
    else:
        tsout.ax.legend().set_visible(False)

    tsout.ax.set_title("Time Series", fontweight="bold", loc="center")
    tsout.ax.set_title("(Input Data)",
                       fontsize="medium",
                       color="#646464",
                       loc="right")
    tsout.ax.set_xlabel("t [{}{}]".format(
        ut._printeger(rks[0].dt) + " " if rks[0].dt != 1 else "",
        rks[0].dtunit))

    # ------------------------------------------------------------------ #
    # Mean Trial Activity
    # ------------------------------------------------------------------ #

    if src.shape[0] > 1:
        # average trial activites as function of trial number
        taout = OutputHandler(rks[0].trialactivities, ax=axes[1])
        try:
            err1 = rks[0].trialactivities - np.sqrt(rks[0].trialvariances)
            err2 = rks[0].trialactivities + np.sqrt(rks[0].trialvariances)
            prevclr = plt.rcParams["axes.prop_cycle"].by_key()["color"][0]
            taout.ax.fill_between(np.arange(1, rks[0].numtrials + 1),
                                  err1,
                                  err2,
                                  color=prevclr,
                                  alpha=0.2)
        except Exception as e:
            log.debug("Exception adding std deviation to plot", exc_info=True)
        taout.ax.set_title("Mean Trial Activity and Std. Deviation",
                           fontweight="bold")
        taout.ax.set_xlabel("Trial i")
        taout.ax.set_ylabel("$\\bar{A}_i$")
    else:
        # running average over the one trial to see if stays stationary
        numsegs = kwargs.get(numsegs) if "numsegs" in kwargs else 50
        ravg = np.zeros(numsegs)
        err1 = np.zeros(numsegs)
        err2 = np.zeros(numsegs)
        seglen = int(src.shape[1] / numsegs)
        for s in range(numsegs):
            temp = np.mean(src[0][s * seglen:(s + 1) * seglen])
            ravg[s] = temp
            stddev = np.sqrt(np.var(src[0][s * seglen:(s + 1) * seglen]))
            err1[s] = temp - stddev
            err2[s] = temp + stddev

        taout = OutputHandler(ravg, ax=axes[1])
        try:
            prevclr = plt.rcParams["axes.prop_cycle"].by_key()["color"][0]
            taout.ax.fill_between(np.arange(1, numsegs + 1),
                                  err1,
                                  err2,
                                  color=prevclr,
                                  alpha=0.2)
        except Exception as e:
            log.debug("Exception adding std deviation to plot", exc_info=True)
        taout.ax.set_title(
            "Average Activity and Stddev for {} Intervals".format(numsegs),
            fontweight="bold",
        )
        taout.ax.set_xlabel("Interval i")
        taout.ax.set_ylabel("$\\bar{A}_i$")

    # ------------------------------------------------------------------ #
    # Coefficients and Fit results
    # ------------------------------------------------------------------ #

    cout = OutputHandler(rks + fits, ax=axes[2])

    fitcurves = []
    fitlabels = []
    for i, f in enumerate(cout.fits):
        fitcurves.append(cout.fitcurves[i][0])
        label = ut.math_from_doc(f.fitfunc, 5)
        label += "\n\n$\\tau={:.2f}${}\n".format(f.tau, f.dtunit)
        if f.tauquantiles is not None:
            label += "$[{:.2f}:{:.2f}]$\n\n".format(f.tauquantiles[0],
                                                    f.tauquantiles[-1])
        else:
            label += "\n\n"
        label += "$m={:.5f}$\n".format(f.mre)
        if f.mrequantiles is not None:
            label += "$[{:.5f}:{:.5f}]$".format(f.mrequantiles[0],
                                                f.mrequantiles[-1])
        else:
            label += "\n"
        fitlabels.append(label)

    tempkwargs = {
        # 'title': 'Fitresults',
        "ncol": len(fitlabels),
        "loc": "upper center",
        "mode": "expand",
        "frameon": True,
        "markerfirst": True,
        "fancybox": False,
        # 'framealpha': 1,
        "borderaxespad": 0,
        "edgecolor": "black",
        # hide handles
        "handlelength": 0,
        "handletextpad": 0,
    }
    try:
        axes[3].legend(fitcurves, fitlabels, **tempkwargs)
    except Exception:
        log.debug("Exception passed", exc_info=True)
        del tempkwargs["edgecolor"]
        axes[3].legend(fitcurves, fitlabels, **tempkwargs)

    # hide handles
    for handle in axes[3].get_legend().legendHandles:
        handle.set_visible(False)

    # center text
    for t in axes[3].get_legend().texts:
        t.set_multialignment("center")

    # apply stile and fill legend
    axes[3].get_legend().get_frame().set_linewidth(0.5)
    axes[3].axis("off")
    axes[3].set_title(
        "Fitresults",
        fontweight="bold",
        loc="center",
    )
    axes[3].set_title(
        " (with CI: [$12.5\\%:87.5\\%$])",
        color="#646464",
        fontsize="medium",
        loc="right",
    )
    for a in axes:
        a.xaxis.set_tick_params(width=0.5)
        a.yaxis.set_tick_params(width=0.5)
        for s in a.spines:
            a.spines[s].set_linewidth(0.5)

    # dummy axes for version and warnings
    axes[4].axis("off")

    fig.tight_layout()
    plt.subplots_adjust(hspace=0.8, top=0.95, bottom=0.0, left=0.1, right=0.99)

    title = kwargs.get("title") if "title" in kwargs else None
    if title is not None and title != "":
        fig.suptitle(title, fontsize=14)
        plt.subplots_adjust(top=0.91)

    if "warning" in kwargs and kwargs.get("warning") is not None:
        s = "\u26A0 {}".format(kwargs.get("warning"))
        fig.text(0.5,
                 0.01,
                 s,
                 fontsize=13,
                 horizontalalignment="center",
                 color="red")

    fig.text(
        0.995,
        0.005,
        "v{}".format(__version__),
        fontsize="small",
        horizontalalignment="right",
        color="#646464",
    )

    return fig
Exemplo n.º 7
0
    def save_meta(self, fname=""):
        """
            Saves only the details/source used to create the plot. It is
            recommended to call this manually, if you decide to save
            the plots yourself or when you want only the fit results.

            Parameters
            ----------
            fname : str, optional
                Path where to save, without file extension. Defaults to "./mre"
        """
        if not isinstance(fname, str):
            fname = str(fname)
        if fname == "":
            fname = "./mre"

        # try creating enclosing dir if not existing
        tempdir = os.path.abspath(os.path.expanduser(fname + "/../"))
        os.makedirs(tempdir, exist_ok=True)

        fname = os.path.expanduser(fname)

        log.info("Saving meta to {}.tsv".format(fname))
        # fits
        hdr = "Mr. Estimator v{}\n".format(__version__)
        try:
            for fdx, fit in enumerate(self.fits):
                hdr += "{}\n".format("-" * 72)
                hdr += "legendlabel: " + str(self.fitlabels[fdx]) + "\n"
                hdr += "{}\n".format("-" * 72)
                if fit.desc != "":
                    hdr += "description: " + str(fit.desc) + "\n"
                hdr += "m = {}\ntau = {} [{}]\n".format(
                    fit.mre, fit.tau, fit.dtunit)
                if fit.quantiles is not None:
                    hdr += "quantiles | tau [{}] | m:\n".format(fit.dtunit)
                    for i, q in enumerate(fit.quantiles):
                        hdr += "{:6.3f} | ".format(fit.quantiles[i])
                        hdr += "{:8.3f} | ".format(fit.tauquantiles[i])
                        hdr += "{:8.8f}\n".format(fit.mrequantiles[i])
                    hdr += "\n"
                hdr += "fitrange: {} <= k <= {} [{} {}]\n".format(
                    fit.steps[0], fit.steps[-1], ut._printeger(fit.dt),
                    fit.dtunit)
                hdr += "function: " + ut.math_from_doc(fit.fitfunc) + "\n"
                # hdr += '\twith parameters:\n'
                parname = list(inspect.signature(fit.fitfunc).parameters)[1:]
                parlen = len(max(parname, key=len))
                for pdx, par in enumerate(self.fits[fdx].popt):
                    unit = ""
                    if parname[pdx] == "nu":
                        unit += "[1/{}]".format(fit.dtunit)
                    elif parname[pdx].find("tau") != -1:
                        unit += "[{}]".format(fit.dtunit)
                    hdr += "\t{: <{width}}".format(parname[pdx] + " " + unit,
                                                   width=parlen + 5 +
                                                   len(fit.dtunit))
                    hdr += " = {}\n".format(par)
                hdr += "\n"
        except Exception as e:
            log.debug("Exception passed", exc_info=True)

        # rks / ts
        labels = ""
        dat = []
        if self.ydata is not None and len(self.ydata) != 0:
            hdr += "{}\n".format("-" * 72)
            hdr += "Data\n"
            hdr += "{}\n".format("-" * 72)
            labels += "1_" + self.xlabel
            for ldx, label in enumerate(self.ylabels):
                labels += "\t" + str(ldx + 2) + "_" + label
            labels = labels.replace(" ", "_")
            dat = np.vstack((self.xdata, np.asarray(self.ydata)))
        np.savetxt(fname + ".tsv",
                   np.transpose(dat),
                   delimiter="\t",
                   header=hdr + labels)