Ejemplo n.º 1
0
import numpy as np
from matplotlib import pyplot as plt

import dapper as dpr
from dapper.mods.LorenzUV.lorenz96 import LUV
from dapper.tools.utils import progbar
from dapper.tools.viz import setup_wrapping

nU, J = LUV.nU, LUV.J

dt = 0.005
K  = int(4/dt)

step_1 = dpr.with_rk4(LUV.dxdt, autonom=True)
step_K = dpr.with_recursion(step_1, prog='Simulating')

xx = step_K(LUV.x0, K, np.nan, dt)

# Grab parts of state vector
ii, wrapU = setup_wrapping(nU)
jj, wrapV = setup_wrapping(nU*J)

# Animate linear
plt.figure()
lhU = plt.plot(ii,   wrapU(xx[-1, :nU]), 'b', lw=3)[0]
lhV = plt.plot(jj/J, wrapV(xx[-1, nU:]), 'g', lw=2)[0]
for k in progbar(range(K), 'Plotting'):
    lhU.set_ydata(wrapU(xx[k, :nU]))
    lhV.set_ydata(wrapV(xx[k, nU:]))
    plt.pause(0.001)
Ejemplo n.º 2
0
    def init(fignum, stats, key0, plot_u, E, P, **kwargs):
        xx, yy, mu = stats.xx, stats.yy, stats.mu

        # Set parameters (kwargs takes precedence over params_orig)
        p = DotDict(
            **{kw: kwargs.get(kw, val)
               for kw, val in params_orig.items()})

        if not p.dims:
            M = xx.shape[-1]
            p.dims = arange(M)
        else:
            M = len(p.dims)

        # Make periodic wrapper
        ii, wrap = viz.setup_wrapping(M, p.periodicity)

        # Set up figure, axes
        fig, ax = place.freshfig(fignum, figsize=(8, 5))
        fig.suptitle("1d amplitude plot")

        # Nans
        nan1 = wrap(nan * ones(M))

        if E is None and p.conf_mult is None:
            p.conf_mult = 2

        # Init plots
        if p.conf_mult:
            lines_s = ax.plot(ii,
                              nan1,
                              "b-",
                              lw=1,
                              label=(str(p.conf_mult) + r'$\sigma$ conf'))
            lines_s += ax.plot(ii, nan1, "b-", lw=1)
            line_mu, = ax.plot(ii, nan1, 'b-', lw=2, label='DA mean')
        else:
            nanE = nan * ones((stats.xp.N, M))
            lines_E = ax.plot(ii,
                              wrap(nanE[0]),
                              **p.ens_props,
                              lw=1,
                              label='Ensemble')
            lines_E += ax.plot(ii, wrap(nanE[1:]).T, **p.ens_props, lw=1)
        # Truth, Obs
        (line_x, ) = ax.plot(ii, nan1, 'k-', lw=3, label='Truth')
        if p.obs_inds is not None:
            p.obs_inds = np.asarray(p.obs_inds)
            (line_y, ) = ax.plot(p.obs_inds,
                                 nan * p.obs_inds,
                                 'g*',
                                 ms=5,
                                 label='Obs')

        # Tune plot
        ax.set_ylim(*viz.xtrema(xx))
        ax.set_xlim(viz.stretch(ii[0], ii[-1], 1))
        # Xticks
        xt = ax.get_xticks()
        xt = xt[abs(xt % 1) < 0.01].astype(int)  # Keep only the integer ticks
        xt = xt[xt >= 0]
        xt = xt[xt < len(p.dims)]
        ax.set_xticks(xt)
        ax.set_xticklabels(p.dims[xt])

        ax.set_xlabel('State index')
        ax.set_ylabel('Value')
        ax.legend(loc='upper right')

        text_t = ax.text(0.01,
                         0.01,
                         format_time(None, None, None),
                         transform=ax.transAxes,
                         family='monospace',
                         ha='left')

        # Init visibility (must come after legend):
        if p.obs_inds is not None:
            line_y.set_visible(False)

        def update(key, E, P):
            k, ko, faus = key

            if p.conf_mult:
                sigma = mu[key] + p.conf_mult * stats.spread[key] * [[1], [-1]]
                lines_s[0].set_ydata(wrap(sigma[0, p.dims]))
                lines_s[1].set_ydata(wrap(sigma[1, p.dims]))
                line_mu.set_ydata(wrap(mu[key][p.dims]))
            else:
                for n, line in enumerate(lines_E):
                    line.set_ydata(wrap(E[n, p.dims]))
                update_alpha(key, stats, lines_E)

            line_x.set_ydata(wrap(xx[k, p.dims]))

            text_t.set_text(format_time(k, ko, stats.HMM.tseq.tt[k]))

            if 'f' in faus:
                if p.obs_inds is not None:
                    line_y.set_ydata(yy[ko])
                    line_y.set_zorder(5)
                    line_y.set_visible(True)

            if 'u' in faus:
                if p.obs_inds is not None:
                    line_y.set_visible(False)

            return

        return update