Esempio n. 1
0
        def update(key, E, P):
            k, kObs, faus = key

            EE = duplicate_with_blanks_for_resampled(E, DimsX, key, has_w)

            # Roll data array
            ind = k if plot_u else kObs
            for Ens in EE:  # If E is duplicated, so must the others be.
                if 'E' in d:
                    d.E.insert(ind, Ens)
                if 'mu' in d:
                    d.mu.insert(ind, mu[key][DimsX])
                if 's' in d:
                    d.s.insert(ind,
                               mu[key][DimsX] + [[1], [-1]] * std[key][DimsX])
                if True:
                    d.t.insert(ind, chrono.tt[k])
                if True:
                    d.y.insert(
                        ind, yy[kObs,
                                iiY] if kObs is not None else nan * ones(Ny))
                if True:
                    d.x.insert(ind, xx[k, DimsX])

            # Update graphs
            for ix, (_m, iy, ax) in enumerate(zip(DimsX, DimsY, axs)):
                sliding_xlim(ax, d.t, T_lag, True)
                if True:
                    h.x[ix].set_data(d.t, d.x[:, ix])
                if iy != None:
                    h.y[iy].set_data(d.t, d.y[:, iy])
                if 'mu' in d:
                    h.mu[ix].set_data(d.t, d.mu[:, ix])
                if 's' in d:
                    [h.s[ix][b].set_data(d.t, d.s[:, b, ix]) for b in [0, 1]]
                if 'E' in d:
                    [
                        h.E[ix][n].set_data(d.t, d.E[:, n, ix])
                        for n in range(len(E))
                    ]
                if 'E' in d:
                    update_alpha(key, stats, h.E[ix])

                # TODO 3: fixup. This might be slow?
                # In any case, it is very far from tested.
                # Also, relim'iting all of the time is distracting.
                # Use d_ylim?
                if 'E' in d:
                    lims = d.E
                elif 'mu' in d:
                    lims = d.mu
                lims = np.array(viz.xtrema(lims[..., ix]))
                if lims[0] == lims[1]:
                    lims += [-.5, +.5]
                ax.set_ylim(*viz.stretch(*lims, 1 / p.zoomy))

            return
Esempio n. 2
0
    def init(fignum, stats, key0, plot_u, E, P, **kwargs):
        xx, yy, mu, _, tseq = \
            stats.xx, stats.yy, stats.mu, stats.spread, stats.HMM.tseq

        # Set parameters (kwargs takes precedence over params_orig)
        p = DotDict(
            **{kw: kwargs.get(kw, val)
               for kw, val in params_orig.items()})

        # Lag settings:
        has_w = hasattr(stats, 'w')
        if p.Tplot == 0:
            K_plot = 1
        else:
            T_lag, K_lag, a_lag = validate_lag(p.Tplot, tseq)
            K_plot = comp_K_plot(K_lag, a_lag, plot_u)
            # Extend K_plot forther for adding blanks in resampling (PartFilt):
            if has_w:
                K_plot += a_lag

        # Dimension settings
        if not p.dims:
            p.dims = arange(M)
        if not p.labels:
            p.labels = ["$x_%d$" % d for d in p.dims]
        assert len(p.dims) == M

        # Set up figure, axes
        fig, _ = place.freshfig(fignum, figsize=(5, 5))
        ax = plt.subplot(111, projection='3d' if is_3d else None)
        ax.set_facecolor('w')
        ax.set_title("Phase space trajectories")
        # Tune plot
        for ind, (s, i, t) in enumerate(zip(p.labels, p.dims, "xyz")):
            viz.set_ilim(ax, ind,
                         *viz.stretch(*viz.xtrema(xx[:, i]), 1 / p.zoom))
            eval("ax.set_%slabel('%s')" % (t, s))

        # Allocate
        d = DotDict()  # data arrays
        h = DotDict()  # plot handles
        s = DotDict()  # scatter handles
        if E is not None:
            d.E = RollingArray((K_plot, len(E), M))
            h.E = []
        if P is not None:
            d.mu = RollingArray((K_plot, M))
        if True:
            d.x = RollingArray((K_plot, M))
        if list(p.obs_inds) == list(p.dims):
            d.y = RollingArray((K_plot, M))

        # Plot tails (invisible coz everything here is nan, for the moment).
        if 'E' in d:
            h.E += [
                ax.plot(*xn, **p.ens_props)[0]
                for xn in np.transpose(d.E, [1, 2, 0])
            ]
        if 'mu' in d:
            h.mu = ax.plot(*d.mu.T, 'b', lw=2)[0]
        if True:
            h.x = ax.plot(*d.x.T, 'k', lw=3)[0]
        if 'y' in d:
            h.y = ax.plot(*d.y.T, 'g*', ms=14)[0]

        # Scatter. NB: don't init with nan's coz it's buggy
        # (wrt. get_color() and _offsets3d) since mpl 3.1.
        if 'E' in d:
            s.E = ax.scatter(*E.T[p.dims],
                             s=3**2,
                             c=[hn.get_color() for hn in h.E])
        if 'mu' in d:
            s.mu = ax.scatter(*ones(M), s=8**2, c=[h.mu.get_color()])
        if True:
            s.x = ax.scatter(*ones(M),
                             s=14**2,
                             c=[h.x.get_color()],
                             marker=(5, 1),
                             zorder=99)

        def update(key, E, P):
            k, ko, faus = key
            show_y = 'y' in d and ko is not None

            def update_tail(handle, newdata):
                handle.set_data(newdata[:, 0], newdata[:, 1])
                if is_3d:
                    handle.set_3d_properties(newdata[:, 2])

            def update_sctr(handle, newdata):
                if is_3d:
                    handle._offsets3d = juggle_axes(*newdata.T, 'z')
                else:
                    handle.set_offsets(newdata)

            EE = duplicate_with_blanks_for_resampled(E, p.dims, key, has_w)

            # Roll data array
            ind = k if plot_u else ko
            for Ens in EE:  # If E is duplicated, so must the others be.
                if 'E' in d:
                    d.E.insert(ind, Ens)
                if True:
                    d.x.insert(ind, xx[k, p.dims])
                if 'y' in d:
                    d.y.insert(ind, yy[ko, :] if show_y else nan * ones(M))
                if 'mu' in d:
                    d.mu.insert(ind, mu[key][p.dims])

            # Update graph
            update_sctr(s.x, d.x[[-1]])
            update_tail(h.x, d.x)
            if 'y' in d:
                update_tail(h.y, d.y)
            if 'mu' in d:
                update_sctr(s.mu, d.mu[[-1]])
                update_tail(h.mu, d.mu)
            else:
                update_sctr(s.E, d.E[-1])
                for n in range(len(E)):
                    update_tail(h.E[n], d.E[:, n, :])
                update_alpha(key, stats, h.E, s.E)

            return

        return update
Esempio n. 3
0
    def init(fignum, stats, key0, plot_u, E, P, **kwargs):
        xx, yy, mu = stats.xx, stats.yy, stats.mu

        # Set parameters (kwargs takes precedence over params_orig)
        p = DotDict(
            **{kw: kwargs.get(kw, val)
               for kw, val in params_orig.items()})

        if not p.dims:
            M = xx.shape[-1]
            p.dims = arange(M)
        else:
            M = len(p.dims)

        # Make periodic wrapper
        ii, wrap = viz.setup_wrapping(M, p.periodicity)

        # Set up figure, axes
        fig, ax = place.freshfig(fignum, figsize=(8, 5))
        fig.suptitle("1d amplitude plot")

        # Nans
        nan1 = wrap(nan * ones(M))

        if E is None and p.conf_mult is None:
            p.conf_mult = 2

        # Init plots
        if p.conf_mult:
            lines_s = ax.plot(ii,
                              nan1,
                              "b-",
                              lw=1,
                              label=(str(p.conf_mult) + r'$\sigma$ conf'))
            lines_s += ax.plot(ii, nan1, "b-", lw=1)
            line_mu, = ax.plot(ii, nan1, 'b-', lw=2, label='DA mean')
        else:
            nanE = nan * ones((stats.xp.N, M))
            lines_E = ax.plot(ii,
                              wrap(nanE[0]),
                              **p.ens_props,
                              lw=1,
                              label='Ensemble')
            lines_E += ax.plot(ii, wrap(nanE[1:]).T, **p.ens_props, lw=1)
        # Truth, Obs
        (line_x, ) = ax.plot(ii, nan1, 'k-', lw=3, label='Truth')
        if p.obs_inds is not None:
            p.obs_inds = np.asarray(p.obs_inds)
            (line_y, ) = ax.plot(p.obs_inds,
                                 nan * p.obs_inds,
                                 'g*',
                                 ms=5,
                                 label='Obs')

        # Tune plot
        ax.set_ylim(*viz.xtrema(xx))
        ax.set_xlim(viz.stretch(ii[0], ii[-1], 1))
        # Xticks
        xt = ax.get_xticks()
        xt = xt[abs(xt % 1) < 0.01].astype(int)  # Keep only the integer ticks
        xt = xt[xt >= 0]
        xt = xt[xt < len(p.dims)]
        ax.set_xticks(xt)
        ax.set_xticklabels(p.dims[xt])

        ax.set_xlabel('State index')
        ax.set_ylabel('Value')
        ax.legend(loc='upper right')

        text_t = ax.text(0.01,
                         0.01,
                         format_time(None, None, None),
                         transform=ax.transAxes,
                         family='monospace',
                         ha='left')

        # Init visibility (must come after legend):
        if p.obs_inds is not None:
            line_y.set_visible(False)

        def update(key, E, P):
            k, ko, faus = key

            if p.conf_mult:
                sigma = mu[key] + p.conf_mult * stats.spread[key] * [[1], [-1]]
                lines_s[0].set_ydata(wrap(sigma[0, p.dims]))
                lines_s[1].set_ydata(wrap(sigma[1, p.dims]))
                line_mu.set_ydata(wrap(mu[key][p.dims]))
            else:
                for n, line in enumerate(lines_E):
                    line.set_ydata(wrap(E[n, p.dims]))
                update_alpha(key, stats, lines_E)

            line_x.set_ydata(wrap(xx[k, p.dims]))

            text_t.set_text(format_time(k, ko, stats.HMM.tseq.tt[k]))

            if 'f' in faus:
                if p.obs_inds is not None:
                    line_y.set_ydata(yy[ko])
                    line_y.set_zorder(5)
                    line_y.set_visible(True)

            if 'u' in faus:
                if p.obs_inds is not None:
                    line_y.set_visible(False)

            return

        return update
Esempio n. 4
0
# Define operator
Obs = {
    'M': len(H),
    'model': lambda E, t: E @ H.T,
    'linear': lambda E, t: H,
    'noise': 1,
    'localizer': localization_setup(lambda t: y2x_dists, batches),
}

HMM = modelling.HiddenMarkovModel(
    Dyn,
    Obs,
    tseq,
    X0,
    LP=LPs(),
    sectors={'land': np.arange(*xtrema(obs_sites)).astype(int)})

####################
# Suggested tuning
####################

# Reproduce Anderson Figure 2
# -----------------------------------------------------------------------------------
# xp = SL_EAKF(N=6, infl=sqrt(1.1), loc_rad=0.2/1.82*40)
# for lbl in ['err', 'spread']:
#     stat = getattr(xp.stats,lbl).f[HMM.tseq.masko]
#     plt.plot(sqrt(np.mean(stat**2, axis=0)),label=lbl)
#
# Note: for this xp, one must to be lucky with the random seed to avoid
#       blow up in the ocean sector (which is not constrained by obs) due to infl.
#       Instead, I recommend lowering dt (as in Miyoshi 2011) to stabilize integration.
Esempio n. 5
0
    def init(fignum, stats, key0, plot_u, E, P, **kwargs):
        xx, yy, mu, std, chrono = \
            stats.xx, stats.yy, stats.mu, stats.std, stats.HMM.t

        # Set parameters (kwargs takes precedence over params_orig)
        p = DotDict(
            **{kw: kwargs.get(kw, val)
               for kw, val in params_orig.items()})

        # Lag settings:
        T_lag, K_lag, a_lag = validate_lag(p.Tplot, chrono)
        K_plot = comp_K_plot(K_lag, a_lag, plot_u)
        # Extend K_plot forther for adding blanks in resampling (PartFilt):
        has_w = hasattr(stats, 'w')
        if has_w:
            K_plot += a_lag

        # Chose marginal dims to plot
        if p.dims == []:
            Nx = min(10, xx.shape[-1])
            DimsX = dapper.tools.math.linspace_int(xx.shape[-1], Nx)
        else:
            Nx = len(p.dims)
            DimsX = p.dims
        # Pre-process obs dimensions
        # Rm inds of obs if not in DimsX
        iiY = [i for i, m in enumerate(p.obs_inds) if m in DimsX]
        # Rm obs_inds    if not in DimsX
        DimsY = [m for i, m in enumerate(p.obs_inds) if m in DimsX]
        # Get dim (within y) of each x
        DimsY = [DimsY.index(m) if m in DimsY else None for m in DimsX]
        Ny = len(iiY)

        # Set up figure, axes
        fig, axs = freshfig(fignum, (5, 7),
                            loc='231-22',
                            nrows=Nx,
                            sharex=True)
        if Nx == 1:
            axs = [axs]

        # Tune plots
        axs[0].set_title("Marginal time series")
        for ix, (m, ax) in enumerate(zip(DimsX, axs)):
            ax.set_ylim(*viz.stretch(*viz.xtrema(xx[:, m]), 1 / p.zoomy))
            if p.labels == []:
                ax.set_ylabel("$x_{%d}$" % m)
            else:
                ax.set_ylabel(p.labels[ix])
        axs[-1].set_xlabel('Time (t)')

        plot_pause(0.05)
        plt.tight_layout()

        # Allocate
        d = DotDict()  # data arrays
        h = DotDict()  # plot handles
        # Why "if True" ? Just to indent the rest of the line...
        if True:
            d.t = RollingArray((K_plot, ))
        if True:
            d.x = RollingArray((K_plot, Nx))
            h.x = []
        if True:
            d.y = RollingArray((K_plot, Ny))
            h.y = []
        if E is not None:
            d.E = RollingArray((K_plot, len(E), Nx))
            h.E = []
        if P is not None:
            d.mu = RollingArray((K_plot, Nx))
            h.mu = []
        if P is not None:
            d.s = RollingArray((K_plot, 2, Nx))
            h.s = []

        # Plot (invisible coz everything here is nan, for the moment).
        for ix, (m, iy, ax) in enumerate(zip(DimsX, DimsY, axs)):
            if True:
                h.x += ax.plot(d.t, d.x[:, ix], 'k')
            if iy != None:
                h.y += ax.plot(d.t, d.y[:, iy], 'g*', ms=10)
            if 'E' in d:
                h.E += [ax.plot(d.t, d.E[:, :, ix], **p.ens_props)]
            if 'mu' in d:
                h.mu += ax.plot(d.t, d.mu[:, ix], 'b')
            if 's' in d:
                h.s += [ax.plot(d.t, d.s[:, :, ix], 'b--', lw=1)]

        def update(key, E, P):
            k, kObs, faus = key

            EE = duplicate_with_blanks_for_resampled(E, DimsX, key, has_w)

            # Roll data array
            ind = k if plot_u else kObs
            for Ens in EE:  # If E is duplicated, so must the others be.
                if 'E' in d:
                    d.E.insert(ind, Ens)
                if 'mu' in d:
                    d.mu.insert(ind, mu[key][DimsX])
                if 's' in d:
                    d.s.insert(ind,
                               mu[key][DimsX] + [[1], [-1]] * std[key][DimsX])
                if True:
                    d.t.insert(ind, chrono.tt[k])
                if True:
                    d.y.insert(
                        ind, yy[kObs,
                                iiY] if kObs is not None else nan * ones(Ny))
                if True:
                    d.x.insert(ind, xx[k, DimsX])

            # Update graphs
            for ix, (m, iy, ax) in enumerate(zip(DimsX, DimsY, axs)):
                sliding_xlim(ax, d.t, T_lag, True)
                if True:
                    h.x[ix].set_data(d.t, d.x[:, ix])
                if iy != None:
                    h.y[iy].set_data(d.t, d.y[:, iy])
                if 'mu' in d:
                    h.mu[ix].set_data(d.t, d.mu[:, ix])
                if 's' in d:
                    [h.s[ix][b].set_data(d.t, d.s[:, b, ix]) for b in [0, 1]]
                if 'E' in d:
                    [
                        h.E[ix][n].set_data(d.t, d.E[:, n, ix])
                        for n in range(len(E))
                    ]
                if 'E' in d:
                    update_alpha(key, stats, h.E[ix])

            return

        return update