예제 #1
0
def plot_hovmoller(xx, tseq=None):
    """Plot Hovmöller diagram.

    Parameters
    ----------
    xx: ndarray
        Plotted array
    tseq: `dapper.tools.chronos.Chronology`, optional
        object with property dko. Defaults: None
    """
    fig, ax = place.freshfig("Hovmoller", figsize=(4, 3.5))

    if tseq is not None:
        mask = tseq.tt <= tseq.Tplot*2
        kk   = tseq.kk[mask]
        tt   = tseq.tt[mask]
        ax.set_ylabel('Time (t)')
    else:
        K    = estimate_good_plot_length(xx, mult=20)
        kk   = arange(K)
        tt   = kk
        ax.set_ylabel('Time indices (k)')

    plt.contourf(arange(xx.shape[1]), tt, xx[kk], 25)
    plt.colorbar()
    ax.get_xaxis().set_major_locator(MaxNLocator(integer=True))
    ax.set_title("Hovmoller diagram (of 'Truth')")
    ax.set_xlabel('Dimension index (i)')

    plt.pause(0.1)
    plt.tight_layout()
예제 #2
0
def amplitude_animation(EE, dt=None, interval=1,
                        periodicity=None, blit=True,
                        fignum=None, repeat=False):
    """Animation of line chart.

    Using an ensemble of
    the shape (time, ensemble size, state vector length).

    Parameters
    ----------
    EE: ndarray
        Ensemble arry of the shape (K, N, Nx).
        K is the length of time, N is the ensemble size, and
        Nx is the length of state vector.
    dt: float
        Time interval of each frame.
    interval: float, optional
        Delay between frames in milliseconds. Defaults to 200.
    periodicity: bool, optional
        The mode of the wrapping.
        "+1": the first element is appended after the last.
        "+/-05": adding the midpoint of the first and last elements.
        Default: "+1"
    blit: bool, optional
        Controls whether blitting is used to optimize drawing. Default: True
    fignum: int, optional
        Figure index. Default: None
    repeat: bool, optional
        If True, repeat the animation. Default: False
    """
    fig, ax = place.freshfig(fignum or "Amplitude animation")
    ax.set_xlabel('State index')
    ax.set_ylabel('Amplitue')
    ax.set_ylim(*stretch(*xtrema(EE), 1.1))

    if EE.ndim == 2:
        EE = np.expand_dims(EE, 1)
    K, N, Nx = EE.shape

    ii, wrap = setup_wrapping(Nx, periodicity)

    lines = ax.plot(ii, wrap(EE[0]).T)
    ax.set_xlim(*xtrema(ii))

    if dt is not None:
        times = 'time = %.1f'
        lines += [ax.text(0.05, 0.9, '', transform=ax.transAxes)]

    def anim(k):
        Ek = wrap(EE[k])
        for n in range(N):
            lines[n].set_ydata(Ek[n])
        if len(lines) > N:
            lines[-1].set_text(times % (dt*k))
        return lines

    return FuncAnimation(fig, anim, range(K),
                         interval=interval, blit=blit,
                         repeat=repeat)
예제 #3
0
def spectrum(ydata, title="", figsize=(1.6, .7), semilogy=False, **kwargs):
    """Plotter specialized for spectra."""
    title = dash("Spectrum", title)
    fig, ax = place.freshfig(title, figsize=figsize, rel=True)
    if semilogy:
        h = ax.semilogy(ydata)
    else:
        h = ax.loglog(ydata)
    ax.grid(True, "both", axis="both")
    ax.set(xlabel="eigenvalue index", ylabel="variance")
    fig.tight_layout()
    return h
예제 #4
0
def fields(ZZ,
           style=None,
           title="",
           figsize=(1.7, 1),
           label_color="k",
           colorbar=True,
           **kwargs):
    """Do `field(Z)` for each `Z` in `ZZ`."""
    kw = lambda k: pop_style_with_fallback(k, style, kwargs)

    # Create figure using freshfig
    title = dash("Fields", kw("title"), title)
    fig, axs = place.freshfig(title, figsize=figsize, rel=True)
    # Store suptitle (exists if mpl is inline) coz gets cleared below
    try:
        suptitle = fig._suptitle.get_text()
    except AttributeError:
        suptitle = ""
    # Create axes using AxesGrid
    fig.clear()
    from mpl_toolkits.axes_grid1 import AxesGrid
    axs = AxesGrid(fig,
                   111,
                   nrows_ncols=nRowCol(min(12, len(ZZ))).values(),
                   cbar_mode='single',
                   cbar_location='right',
                   share_all=True,
                   axes_pad=0.2,
                   cbar_pad=0.1)
    # Turn off redundant axes
    for ax in axs[len(ZZ):]:
        ax.set_visible(False)

    # Convert (potential) list-like ZZ into dict
    if not isinstance(ZZ, dict):
        ZZ = {i: Z for (i, Z) in enumerate(ZZ)}

    hh = []
    for ax, label in zip(axs, ZZ):
        label_ax(ax, label, c=label_color)
        hh.append(field(ax, ZZ[label], style, **kwargs))

    # Suptitle
    if len(ZZ) > len(axs):
        suptitle = dash(suptitle, f"First {len(axs)} instances")
    # Re-set suptitle (since it got cleared above)
    if suptitle:
        fig.suptitle(suptitle)

    if colorbar:
        fig.colorbar(hh[0], cax=axs.cbar_axes[0], ticks=kw("ticks"))

    return fig, axs, hh
예제 #5
0
    def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs):
        fig, ax = place.freshfig(fignum, figsize=(6, 3))
        ax.set_xlabel('Sing. value index')
        ax.set_yscale('log')
        self.init_incomplete = True
        self.ax = ax
        self.plot_u = plot_u

        try:
            self.msft = stats.umisf
            self.sprd = stats.svals
        except AttributeError:
            self.is_active = False
            not_available_text(ax, "Spectral stats not being computed")
예제 #6
0
    def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs):
        if not hasattr(stats, 'w'):
            self.is_active = False
            return
        fig, ax = place.freshfig(fignum,
                                 figsize=(7, 3),
                                 gridspec_kw={'bottom': .15})

        ax.set_xscale('log')
        ax.set_xlabel('Weigth')
        ax.set_ylabel('Count')
        self.stats = stats
        self.ax = ax
        self.hist = []
        self.bins = np.exp(np.linspace(np.log(1e-10), np.log(1), 31))
예제 #7
0
def plot_rank_histogram(stats):
    """Plot rank histogram of ensemble.

    Parameters
    ----------
    stats: `dapper.stats.Stats`
    """
    tseq = stats.HMM.tseq

    has_been_computed = \
        hasattr(stats, 'rh') and \
        not all(stats.rh.a[-1] == array(np.nan).astype(int))

    fig, ax = place.freshfig("Rank histogram", figsize=(6, 3))
    ax.set_title('(Mean of marginal) rank histogram (_a)')
    ax.set_ylabel('Freq. of occurence\n (of truth in interval n)')
    ax.set_xlabel('ensemble member index (n)')

    if has_been_computed:
        ranks = stats.rh.a[tseq.masko]
        Nx    = ranks.shape[1]
        N     = stats.xp.N
        if not hasattr(stats, 'w'):
            # Ensemble rank histogram
            integer_hist(ranks.ravel(), N)
        else:
            # Experimental: weighted rank histogram.
            # Weight ranks by inverse of particle weight. Why? Coz, with correct
            # importance weights, the "expected value" histogram is then flat.
            # Potential improvement: interpolate weights between particles.
            w  = stats.w.a[tseq.masko]
            K  = len(w)
            w  = np.hstack([w, np.ones((K, 1))/N])  # define weights for rank N+1
            w  = array([w[arange(K), ranks[arange(K), i]] for i in range(Nx)])
            w  = w.T.ravel()
            # Artificial cap. Reduces variance, but introduces bias.
            w  = np.maximum(w, 1/N/100)
            w  = 1/w
            integer_hist(ranks.ravel(), N, weights=w)
    else:
        not_available_text(ax)

    plt.pause(0.1)
    plt.tight_layout()
예제 #8
0
 def fig_ax(num):
     """Create fig, axs. Deserving of particular attention, so factored out."""
     # Figure creation
     # Of course, for *interactive* mpl backends, this should only be run once.
     # But running it from inside f (with appropriate checks for single execution)
     # causes blank figure => Run outside of f().
     # However, using `ipywidgets.Output` to capture output requires that it runs
     # inside f. In this case it actually seems to work though (no blank figures).
     if is_inline():
         # Rm previous (static) image. Necssary when using `ipywidgets.Output`
         # Use `wait=True` because to avoid flickering, ref ipywidgets/issues/1582
         clear_output(wait=True)
     else:
         # Check for existance, otherwise the first time it is run
         # (no error is thrown but) duplicate figures are created
         # (no longer seems to be an issue, but the check doesn't hurt)
         if plt.fignum_exists(num):
             # Fix issue: figure doesn't display **when cell is re-run**.
             # I think it's related to being in an ipython widget, but can also
             # be fixed by changing num (so that freshfig creates a new one).
             plt.close(num)
     fig, axs = place.freshfig(num, **kwargs)
     return fig, axs
예제 #9
0
    def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs):

        GS = {'height_ratios': [4, 1], 'hspace': 0.09, 'top': 0.95}
        fig, (ax, ax2) = place.freshfig(fignum,
                                        figsize=(5, 6),
                                        nrows=2,
                                        gridspec_kw=GS)

        if E is None and np.isnan(
                P.diag if isinstance(P, CovMat) else P).all():
            not_available_text(ax, ('Not available in replays'
                                    '\ncoz full Ens/Cov not stored.'))
            self.is_active = False
            return

        Nx = len(stats.mu[key0])
        if Nx <= 1003:
            C = np.eye(Nx)
            # Mask half
            mask = np.zeros_like(C, dtype=np.bool)
            mask[np.tril_indices_from(mask)] = True
            # Make colormap. Log-transform cmap,
            # but not internally in matplotlib,
            # so as to avoid transforming the colorbar too.
            cmap = plt.get_cmap('RdBu_r')
            trfm = mpl.colors.SymLogNorm(linthresh=0.2,
                                         linscale=0.2,
                                         base=np.e,
                                         vmin=-1,
                                         vmax=1)
            cmap = cmap(trfm(np.linspace(-0.6, 0.6, cmap.N)))
            cmap = mpl.colors.ListedColormap(cmap)
            #
            VM = 1.0  # abs(np.percentile(C,[1,99])).max()
            im = ax.imshow(C, cmap=cmap, vmin=-VM, vmax=VM)
            # Colorbar
            _ = ax.figure.colorbar(im, ax=ax, shrink=0.8)
            # Tune plot
            plt.box(False)
            ax.set_facecolor('w')
            ax.grid(False)
            ax.set_title("State correlation matrix:", y=1.07)
            ax.xaxis.tick_top()

            # ax2 = inset_axes(ax,width="30%",height="60%",loc=3)
            line_AC, = ax2.plot(arange(Nx), ones(Nx), label='Correlation')
            line_AA, = ax2.plot(arange(Nx), ones(Nx), label='Abs. corr.')
            _ = ax2.hlines(0, 0, Nx - 1, 'k', 'dotted', lw=1)
            # Align ax2 with ax
            bb_AC = ax2.get_position()
            bb_C = ax.get_position()
            ax2.set_position([bb_C.x0, bb_AC.y0, bb_C.width, bb_AC.height])
            # Tune plot
            ax2.set_title("Auto-correlation:")
            ax2.set_ylabel("Mean value")
            ax2.set_xlabel("Distance (in state indices)")
            ax2.set_xticklabels([])
            ax2.set_yticks([0, 1] + list(ax2.get_yticks()[[0, -1]]))
            ax2.set_ylim(top=1)
            ax2.legend(frameon=True,
                       facecolor='w',
                       bbox_to_anchor=(1, 1),
                       loc='upper left',
                       borderaxespad=0.02)

            self.ax = ax
            self.ax2 = ax2
            self.im = im
            self.line_AC = line_AC
            self.line_AA = line_AA
            self.mask = mask
            if hasattr(stats, 'w'):
                self.w = stats.w
        else:
            not_available_text(ax)
예제 #10
0
    def __init__(self,
                 fignum,
                 stats,
                 key0,
                 plot_u,
                 E,
                 P,
                 Tplot=None,
                 **kwargs):

        # STYLE TABLES - Defines which/how diagnostics get plotted
        styles = {}

        def lin(a, b):
            return (lambda x: a + b * x)

        divN = 1 / getattr(stats.xp, 'N', 99)
        # Columns: transf, shape, plt kwargs
        styles['RMS'] = {
            'err.rms': [None, None, dict(c='k', label='Error')],
            'spread.rms': [None, None,
                           dict(c='b', label='Spread', alpha=0.6)],
        }
        styles['Values'] = {
            'skew': [None, None,
                     dict(c='g', label=star + r'Skew/$\sigma^3$')],
            'kurt':
            [None, None,
             dict(c='r', label=star + r'Kurt$/\sigma^4{-}3$')],
            'trHK': [None, None, dict(c='k', label=star + 'HK')],
            'infl': [lin(-10, 10), 'step',
                     dict(c='c', label='10(infl-1)')],
            'N_eff':
            [lin(0, divN), 'dirac',
             dict(c='y', label='N_eff/N', lw=3)],
            'iters': [lin(0, .1), 'dirac',
                      dict(c='m', label='iters/10')],
            'resmpl': [None, 'dirac',
                       dict(c='k', label='resampled?')],
        }

        nAx = len(styles)
        GS = {'left': 0.125, 'right': 0.76}
        fig, axs = place.freshfig(fignum,
                                  figsize=(5, 1 + nAx),
                                  nrows=nAx,
                                  sharex=True,
                                  gridspec_kw=GS)

        axs[0].set_title("Diagnostics")
        for style, ax in zip(styles, axs):
            ax.set_ylabel(style)
        ax.set_xlabel('Time (t)')
        place_ax.adjust_position(ax, y0=0.03)

        self.T_lag, K_lag, a_lag = validate_lag(Tplot, stats.HMM.tseq)

        def init_ax(ax, style_table):
            lines = {}
            for name in style_table:

                # SKIP -- if stats[name] is not in existence
                # Note: The nan check/deletion comes after the first ko.
                try:
                    stat = deep_getattr(stats, name)
                except AttributeError:
                    continue
                # try: val0 = stat[key0[0]]
                # except KeyError: continue
                # PS: recall (from series.py) that even if store_u is false, stat[k] is
                # still present if liveplots=True via the k_tmp functionality.

                # Unpack style
                ln = {}
                ln['transf'] = style_table[name][0] or (lambda x: x)
                ln['shape'] = style_table[name][1]
                ln['plt'] = style_table[name][2]

                # Create series
                if isinstance(stat, FAUSt):
                    ln['plot_u'] = plot_u
                    K_plot = comp_K_plot(K_lag, a_lag, ln['plot_u'])
                else:
                    ln['plot_u'] = False
                    K_plot = a_lag
                ln['data'] = RollingArray(K_plot)
                ln['tt'] = RollingArray(K_plot)

                # Plot (init)
                ln['handle'], = ax.plot(ln['tt'], ln['data'], **ln['plt'])

                # Plotting only nans yield ugly limits. Revert to defaults.
                ax.set_xlim(0, 1)
                ax.set_ylim(0, 1)

                lines[name] = ln
            return lines

        # Plot
        self.d = [init_ax(ax, styles[style]) for style, ax in zip(styles, axs)]

        # Horizontal line at y=0
        self.baseline0, = ax.plot(ax.get_xlim(), [0, 0],
                                  c=0.5 * ones(3),
                                  lw=0.7,
                                  label='_nolegend_')

        # Store
        self.axs = axs
        self.stats = stats
        self.init_incomplete = True
예제 #11
0
    def init(fignum, stats, key0, plot_u, E, P, **kwargs):

        GS = {'left': 0.125 - 0.04, 'right': 0.9 - 0.04}
        fig, axs = place.freshfig(fignum,
                                  figsize=(6, 6),
                                  nrows=2,
                                  ncols=2,
                                  sharex=True,
                                  sharey=True,
                                  gridspec_kw=GS)

        for ax in axs.flatten():
            ax.set_aspect('equal', 'box')

        ((ax_11, ax_12), (ax_21, ax_22)) = axs

        ax_11.grid(color='w', linewidth=0.2)
        ax_12.grid(color='w', linewidth=0.2)
        ax_21.grid(color='k', linewidth=0.1)
        ax_22.grid(color='k', linewidth=0.1)

        # Upper colorbar -- position relative to ax_12
        bb = ax_12.get_position()
        dy = 0.1 * bb.height
        ax_13 = fig.add_axes(
            [bb.x1 + 0.03, bb.y0 + dy, 0.04, bb.height - 2 * dy])
        # Lower colorbar -- position relative to ax_22
        bb = ax_22.get_position()
        dy = 0.1 * bb.height
        ax_23 = fig.add_axes(
            [bb.x1 + 0.03, bb.y0 + dy, 0.04, bb.height - 2 * dy])

        # Extract data arrays
        xx, _, mu, spread, err = stats.xx, stats.yy, stats.mu, stats.spread, stats.err
        k = key0[0]
        tt = stats.HMM.tseq.tt

        # Plot
        # - origin='lower' might get overturned by set_ylim() below.
        im_11 = ax_11.imshow(square(mu[key0]), cmap=cm)
        im_12 = ax_12.imshow(square(xx[k]), cmap=cm)
        # hot is better, but needs +1 colorbar
        im_21 = ax_21.imshow(square(spread[key0]), cmap=plt.cm.bwr)
        im_22 = ax_22.imshow(square(err[key0]), cmap=plt.cm.bwr)
        ims = (im_11, im_12, im_21, im_22)
        # Obs init -- a list where item 0 is the handle of something invisible.
        lh = list(ax_12.plot(0, 0)[0:1])

        sx = '$\\psi$'
        ax_11.set_title('mean ' + sx)
        ax_12.set_title('true ' + sx)
        ax_21.set_title('spread. ' + sx)
        ax_22.set_title('err. ' + sx)

        # TODO 7
        # for ax in axs.flatten():
        # Crop boundries (which should be 0, i.e. yield harsh q gradients):
        # lims = (1, nx-2)
        # step = (nx - 1)/8
        # ticks = arange(step,nx-1,step)
        # ax.set_xlim  (lims)
        # ax.set_ylim  (lims[::-1])
        # ax.set_xticks(ticks)
        # ax.set_yticks(ticks)

        for im, clim in zip(ims, clims):
            im.set_clim(clim)

        fig.colorbar(im_12, cax=ax_13)
        fig.colorbar(im_22, cax=ax_23)
        for ax in [ax_13, ax_23]:
            ax.yaxis.set_tick_params('major',
                                     length=2,
                                     width=0.5,
                                     direction='in',
                                     left=True,
                                     right=True)
            ax.set_axisbelow('line')  # make ticks appear over colorbar patch

        # Title
        title = "Streamfunction (" + sx + ")"
        fig.suptitle(title)
        # Time info
        text_t = ax_12.text(1,
                            1.1,
                            format_time(None, None, None),
                            transform=ax_12.transAxes,
                            family='monospace',
                            ha='left')

        def update(key, E, P):
            k, ko, faus = key
            t = tt[k]

            im_11.set_data(square(mu[key]))
            im_12.set_data(square(xx[k]))
            im_21.set_data(square(spread[key]))
            im_22.set_data(square(err[key]))

            # Remove previous obs
            try:
                lh[0].remove()
            except ValueError:
                pass
            # Plot current obs.
            #  - plot() automatically adjusts to direction of y-axis in use.
            #  - ind2sub returns (iy,ix), while plot takes (ix,iy) => reverse.
            if ko is not None and obs_inds is not None:
                lh[0] = ax_12.plot(*ind2sub(obs_inds(t))[::-1],
                                   'k.',
                                   ms=1,
                                   zorder=5)[0]

            text_t.set_text(format_time(k, ko, t))

            return

        return update
예제 #12
0
    def init(fignum, stats, key0, plot_u, E, P, **kwargs):
        xx, yy, mu = stats.xx, stats.yy, stats.mu

        # Set parameters (kwargs takes precedence over params_orig)
        p = DotDict(
            **{kw: kwargs.get(kw, val)
               for kw, val in params_orig.items()})

        if not p.dims:
            M = xx.shape[-1]
            p.dims = arange(M)
        else:
            M = len(p.dims)

        # Make periodic wrapper
        ii, wrap = viz.setup_wrapping(M, p.periodicity)

        # Set up figure, axes
        fig, ax = place.freshfig(fignum, figsize=(8, 5))
        fig.suptitle("1d amplitude plot")

        # Nans
        nan1 = wrap(nan * ones(M))

        if E is None and p.conf_mult is None:
            p.conf_mult = 2

        # Init plots
        if p.conf_mult:
            lines_s = ax.plot(ii,
                              nan1,
                              "b-",
                              lw=1,
                              label=(str(p.conf_mult) + r'$\sigma$ conf'))
            lines_s += ax.plot(ii, nan1, "b-", lw=1)
            line_mu, = ax.plot(ii, nan1, 'b-', lw=2, label='DA mean')
        else:
            nanE = nan * ones((stats.xp.N, M))
            lines_E = ax.plot(ii,
                              wrap(nanE[0]),
                              **p.ens_props,
                              lw=1,
                              label='Ensemble')
            lines_E += ax.plot(ii, wrap(nanE[1:]).T, **p.ens_props, lw=1)
        # Truth, Obs
        (line_x, ) = ax.plot(ii, nan1, 'k-', lw=3, label='Truth')
        if p.obs_inds is not None:
            p.obs_inds = np.asarray(p.obs_inds)
            (line_y, ) = ax.plot(p.obs_inds,
                                 nan * p.obs_inds,
                                 'g*',
                                 ms=5,
                                 label='Obs')

        # Tune plot
        ax.set_ylim(*viz.xtrema(xx))
        ax.set_xlim(viz.stretch(ii[0], ii[-1], 1))
        # Xticks
        xt = ax.get_xticks()
        xt = xt[abs(xt % 1) < 0.01].astype(int)  # Keep only the integer ticks
        xt = xt[xt >= 0]
        xt = xt[xt < len(p.dims)]
        ax.set_xticks(xt)
        ax.set_xticklabels(p.dims[xt])

        ax.set_xlabel('State index')
        ax.set_ylabel('Value')
        ax.legend(loc='upper right')

        text_t = ax.text(0.01,
                         0.01,
                         format_time(None, None, None),
                         transform=ax.transAxes,
                         family='monospace',
                         ha='left')

        # Init visibility (must come after legend):
        if p.obs_inds is not None:
            line_y.set_visible(False)

        def update(key, E, P):
            k, ko, faus = key

            if p.conf_mult:
                sigma = mu[key] + p.conf_mult * stats.spread[key] * [[1], [-1]]
                lines_s[0].set_ydata(wrap(sigma[0, p.dims]))
                lines_s[1].set_ydata(wrap(sigma[1, p.dims]))
                line_mu.set_ydata(wrap(mu[key][p.dims]))
            else:
                for n, line in enumerate(lines_E):
                    line.set_ydata(wrap(E[n, p.dims]))
                update_alpha(key, stats, lines_E)

            line_x.set_ydata(wrap(xx[k, p.dims]))

            text_t.set_text(format_time(k, ko, stats.HMM.tseq.tt[k]))

            if 'f' in faus:
                if p.obs_inds is not None:
                    line_y.set_ydata(yy[ko])
                    line_y.set_zorder(5)
                    line_y.set_visible(True)

            if 'u' in faus:
                if p.obs_inds is not None:
                    line_y.set_visible(False)

            return

        return update
예제 #13
0
    - Gaussian distributions.
    """
    dists = dist_euclid(vectorize(*pts))
    Cov = 1 - variogram_gauss(dists, r)
    C12 = sla.sqrtm(Cov).real
    fields = randn(N, len(dists)) @ C12.T
    return fields


if __name__ == "__main__":
    from simulator import plotting as plots
    from simulator.grid import Grid2D

    np.random.seed(3000)
    plt.ion()
    N = 15  # ensemble size

    ## 1D
    xx = np.linspace(0, 1, 201)
    fields = gaussian_fields((xx, ), N)
    fig, ax = freshfig(1)
    ax.plot(xx, fields.T, lw=2)

    ## 2D
    grid = Grid2D(Lx=1, Ly=1, Nx=20, Ny=20)
    plots.model = grid
    fields = gaussian_fields(grid.mesh(), N)
    fields = 0.5 + .2 * fields
    # fields = truncate_01(fields)
    plots.fields(plots.field, fields)
예제 #14
0
def plot_err_components(stats):
    """Plot components of the error.

    Parameters
    ----------
    stats: `dapper.stats.Stats`

    .. note::
      it was chosen to `plot(ii, mean_in_time(abs(err_i)))`,
      and thus the corresponding spread measure is MAD.
      If one chose instead: `plot(ii, std_spread_in_time(err_i))`,
      then the corresponding measure of spread would have been `spread`.
      This choice was made in part because (wrt. subplot 2)
      the singular values (`svals`) correspond to rotated MADs,
      and because `rms(umisf)` seems too convoluted for interpretation.
    """
    fig, (ax0, ax1, ax2) = place.freshfig("Error components", figsize=(6, 6), nrows=3)

    tseq = stats.HMM.tseq
    Nx     = stats.xx.shape[1]

    en_mean = lambda x: np.mean(x, axis=0)  # noqa
    err   = en_mean(abs(stats.err.a))
    sprd  = en_mean(stats.spread.a)
    umsft = en_mean(abs(stats.umisf.a))
    usprd = en_mean(stats.svals.a)

    ax0.plot(arange(Nx), err, 'k', lw=2, label='Error')
    if Nx < 10**3:
        ax0.fill_between(arange(Nx), [0]*len(sprd), sprd, alpha=0.7, label='Spread')
    else:
        ax0.plot(arange(Nx), sprd, alpha=0.7, label='Spread')
    # ax0.set_yscale('log')
    ax0.set_title('Element-wise error comparison')
    ax0.set_xlabel('Dimension index (i)')
    ax0.set_ylabel('Time-average (_a) magnitude')
    ax0.set_xlim(0, Nx-1)
    ax0.get_xaxis().set_major_locator(MaxNLocator(integer=True))
    ax0.legend(loc='upper right')

    ax1.set_xlim(0, Nx-1)
    ax1.set_xlabel('Principal component index')
    ax1.set_ylabel('Time-average (_a) magnitude')
    ax1.set_title('Spectral error comparison')
    has_been_computed = np.any(np.isfinite(umsft))
    if has_been_computed:
        L = len(umsft)
        ax1.plot(arange(L), umsft, 'k', lw=2, label='Error')
        ax1.fill_between(arange(L), [0]*L, usprd, alpha=0.7, label='Spread')
        ax1.set_yscale('log')
        ax1.get_xaxis().set_major_locator(MaxNLocator(integer=True))
    else:
        not_available_text(ax1)

    rmse = stats.err.rms.a[tseq.masko]
    ax2.hist(rmse, bins=30, density=False)
    ax2.set_ylabel('Num. of occurence (_a)')
    ax2.set_xlabel('RMSE')
    ax2.set_title('Histogram of RMSE values')
    ax2.set_xlim(left=0)

    plt.pause(0.1)
    plt.tight_layout()
예제 #15
0
    def init(fignum, stats, key0, plot_u, E, P, **kwargs):
        xx, yy, mu, spread, tseq = \
            stats.xx, stats.yy, stats.mu, stats.spread, stats.HMM.tseq

        # Set parameters (kwargs takes precedence over params_orig)
        p = DotDict(
            **{kw: kwargs.get(kw, val)
               for kw, val in params_orig.items()})

        # Lag settings:
        T_lag, K_lag, a_lag = validate_lag(p.Tplot, tseq)
        K_plot = comp_K_plot(K_lag, a_lag, plot_u)
        # Extend K_plot forther for adding blanks in resampling (PartFilt):
        has_w = hasattr(stats, 'w')
        if has_w:
            K_plot += a_lag

        # Chose marginal dims to plot
        if not p.dims:
            Nx = min(10, xx.shape[-1])
            DimsX = linspace_int(xx.shape[-1], Nx)
        else:
            Nx = len(p.dims)
            DimsX = p.dims
        # Pre-process obs dimensions
        # Rm inds of obs if not in DimsX
        iiY = [i for i, m in enumerate(p.obs_inds) if m in DimsX]
        # Rm obs_inds    if not in DimsX
        DimsY = [m for i, m in enumerate(p.obs_inds) if m in DimsX]
        # Get dim (within y) of each x
        DimsY = [DimsY.index(m) if m in DimsY else None for m in DimsX]
        Ny = len(iiY)

        # Set up figure, axes
        fig, axs = place.freshfig(fignum,
                                  figsize=(5, 7),
                                  nrows=Nx,
                                  sharex=True)
        if Nx == 1:
            axs = [axs]

        # Tune plots
        axs[0].set_title("Marginal time series")
        for ix, (m, ax) in enumerate(zip(DimsX, axs)):
            # ax.set_ylim(*viz.stretch(*viz.xtrema(xx[:, m]), 1/p.zoomy))
            if not p.labels:
                ax.set_ylabel("$x_{%d}$" % m)
            else:
                ax.set_ylabel(p.labels[ix])
        axs[-1].set_xlabel('Time (t)')

        plot_pause(0.05)
        plt.tight_layout()

        # Allocate
        d = DotDict()  # data arrays
        h = DotDict()  # plot handles
        # Why "if True" ? Just to indent the rest of the line...
        if True:
            d.t = RollingArray((K_plot, ))
        if True:
            d.x = RollingArray((K_plot, Nx))
            h.x = []
        if True:
            d.y = RollingArray((K_plot, Ny))
            h.y = []
        if E is not None:
            d.E = RollingArray((K_plot, len(E), Nx))
            h.E = []
        if P is not None:
            d.mu = RollingArray((K_plot, Nx))
            h.mu = []
        if P is not None:
            d.s = RollingArray((K_plot, 2, Nx))
            h.s = []

        # Plot (invisible coz everything here is nan, for the moment).
        for ix, (_m, iy, ax) in enumerate(zip(DimsX, DimsY, axs)):
            if True:
                h.x += ax.plot(d.t, d.x[:, ix], 'k')
            if iy != None:
                h.y += ax.plot(d.t, d.y[:, iy], 'g*', ms=10)
            if 'E' in d:
                h.E += [ax.plot(d.t, d.E[:, :, ix], **p.ens_props)]
            if 'mu' in d:
                h.mu += ax.plot(d.t, d.mu[:, ix], 'b')
            if 's' in d:
                h.s += [ax.plot(d.t, d.s[:, :, ix], 'b--', lw=1)]

        def update(key, E, P):
            k, ko, faus = key

            EE = duplicate_with_blanks_for_resampled(E, DimsX, key, has_w)

            # Roll data array
            ind = k if plot_u else ko
            for Ens in EE:  # If E is duplicated, so must the others be.
                if 'E' in d:
                    d.E.insert(ind, Ens)
                if 'mu' in d:
                    d.mu.insert(ind, mu[key][DimsX])
                if 's' in d:
                    d.s.insert(
                        ind, mu[key][DimsX] + [[1], [-1]] * spread[key][DimsX])
                if True:
                    d.t.insert(ind, tseq.tt[k])
                if True:
                    d.y.insert(
                        ind, yy[ko, iiY] if ko is not None else nan * ones(Ny))
                if True:
                    d.x.insert(ind, xx[k, DimsX])

            # Update graphs
            for ix, (_m, iy, ax) in enumerate(zip(DimsX, DimsY, axs)):
                sliding_xlim(ax, d.t, T_lag, True)
                if True:
                    h.x[ix].set_data(d.t, d.x[:, ix])
                if iy != None:
                    h.y[iy].set_data(d.t, d.y[:, iy])
                if 'mu' in d:
                    h.mu[ix].set_data(d.t, d.mu[:, ix])
                if 's' in d:
                    [h.s[ix][b].set_data(d.t, d.s[:, b, ix]) for b in [0, 1]]
                if 'E' in d:
                    [
                        h.E[ix][n].set_data(d.t, d.E[:, n, ix])
                        for n in range(len(E))
                    ]
                if 'E' in d:
                    update_alpha(key, stats, h.E[ix])

                # TODO 3: fixup. This might be slow?
                # In any case, it is very far from tested.
                # Also, relim'iting all of the time is distracting.
                # Use d_ylim?
                if 'E' in d:
                    lims = d.E
                elif 'mu' in d:
                    lims = d.mu
                lims = np.array(viz.xtrema(lims[..., ix]))
                if lims[0] == lims[1]:
                    lims += [-.5, +.5]
                ax.set_ylim(*viz.stretch(*lims, 1 / p.zoomy))

            return

        return update
예제 #16
0
# ensured by the `config_wells` function used below.

grid1 = [.1, .9]
grid2 = np.dstack(np.meshgrid(grid1, grid1)).reshape((-1, 2))
rates = np.ones((len(grid2), 1))  # ==> all wells use the same (constant) rate
model.config_wells(
    # Each row in `inj` and `prod` should be a tuple: (x, y, rate),
    # where x, y ∈ (0, 1) and rate > 0.
    inj=[[0.50, 0.50, 1.00]],
    prod=np.hstack((grid2, rates)),
)

# #### Plot
# Let's take a moment to visualize the (true) model permeability field, and the well locations.

fig, ax = freshfig("True perm. field", figsize=(1.5, 1), rel=1)
# plots.field(ax, perm.Truth, "pperm")
plots.field(ax,
            perm_transf(perm.Truth),
            locator=LogLocator(),
            wells=True,
            colorbar=True)
fig.tight_layout()

# #### Observation operator
# The data will consist in the water saturation of at the well locations, i.e. of the
# production. I.e. there is no well model. It should be pointed out, however, that
# ensemble methods technically support observation models of any complexity, though your
# accuracy mileage may vary (again, depending on the incurred nonlinearity and
# non-Gaussianity). Furthermore, it is also no problem to include time-dependence in the
# observation model.
예제 #17
0
    def init(fignum, stats, key0, plot_u, E, P, **kwargs):
        xx, yy, mu, _, tseq = \
            stats.xx, stats.yy, stats.mu, stats.spread, stats.HMM.tseq

        # Set parameters (kwargs takes precedence over params_orig)
        p = DotDict(
            **{kw: kwargs.get(kw, val)
               for kw, val in params_orig.items()})

        # Lag settings:
        has_w = hasattr(stats, 'w')
        if p.Tplot == 0:
            K_plot = 1
        else:
            T_lag, K_lag, a_lag = validate_lag(p.Tplot, tseq)
            K_plot = comp_K_plot(K_lag, a_lag, plot_u)
            # Extend K_plot forther for adding blanks in resampling (PartFilt):
            if has_w:
                K_plot += a_lag

        # Dimension settings
        if not p.dims:
            p.dims = arange(M)
        if not p.labels:
            p.labels = ["$x_%d$" % d for d in p.dims]
        assert len(p.dims) == M

        # Set up figure, axes
        fig, _ = place.freshfig(fignum, figsize=(5, 5))
        ax = plt.subplot(111, projection='3d' if is_3d else None)
        ax.set_facecolor('w')
        ax.set_title("Phase space trajectories")
        # Tune plot
        for ind, (s, i, t) in enumerate(zip(p.labels, p.dims, "xyz")):
            viz.set_ilim(ax, ind,
                         *viz.stretch(*viz.xtrema(xx[:, i]), 1 / p.zoom))
            eval("ax.set_%slabel('%s')" % (t, s))

        # Allocate
        d = DotDict()  # data arrays
        h = DotDict()  # plot handles
        s = DotDict()  # scatter handles
        if E is not None:
            d.E = RollingArray((K_plot, len(E), M))
            h.E = []
        if P is not None:
            d.mu = RollingArray((K_plot, M))
        if True:
            d.x = RollingArray((K_plot, M))
        if list(p.obs_inds) == list(p.dims):
            d.y = RollingArray((K_plot, M))

        # Plot tails (invisible coz everything here is nan, for the moment).
        if 'E' in d:
            h.E += [
                ax.plot(*xn, **p.ens_props)[0]
                for xn in np.transpose(d.E, [1, 2, 0])
            ]
        if 'mu' in d:
            h.mu = ax.plot(*d.mu.T, 'b', lw=2)[0]
        if True:
            h.x = ax.plot(*d.x.T, 'k', lw=3)[0]
        if 'y' in d:
            h.y = ax.plot(*d.y.T, 'g*', ms=14)[0]

        # Scatter. NB: don't init with nan's coz it's buggy
        # (wrt. get_color() and _offsets3d) since mpl 3.1.
        if 'E' in d:
            s.E = ax.scatter(*E.T[p.dims],
                             s=3**2,
                             c=[hn.get_color() for hn in h.E])
        if 'mu' in d:
            s.mu = ax.scatter(*ones(M), s=8**2, c=[h.mu.get_color()])
        if True:
            s.x = ax.scatter(*ones(M),
                             s=14**2,
                             c=[h.x.get_color()],
                             marker=(5, 1),
                             zorder=99)

        def update(key, E, P):
            k, ko, faus = key
            show_y = 'y' in d and ko is not None

            def update_tail(handle, newdata):
                handle.set_data(newdata[:, 0], newdata[:, 1])
                if is_3d:
                    handle.set_3d_properties(newdata[:, 2])

            def update_sctr(handle, newdata):
                if is_3d:
                    handle._offsets3d = juggle_axes(*newdata.T, 'z')
                else:
                    handle.set_offsets(newdata)

            EE = duplicate_with_blanks_for_resampled(E, p.dims, key, has_w)

            # Roll data array
            ind = k if plot_u else ko
            for Ens in EE:  # If E is duplicated, so must the others be.
                if 'E' in d:
                    d.E.insert(ind, Ens)
                if True:
                    d.x.insert(ind, xx[k, p.dims])
                if 'y' in d:
                    d.y.insert(ind, yy[ko, :] if show_y else nan * ones(M))
                if 'mu' in d:
                    d.mu.insert(ind, mu[key][p.dims])

            # Update graph
            update_sctr(s.x, d.x[[-1]])
            update_tail(h.x, d.x)
            if 'y' in d:
                update_tail(h.y, d.y)
            if 'mu' in d:
                update_sctr(s.mu, d.mu[[-1]])
                update_tail(h.mu, d.mu)
            else:
                update_sctr(s.E, d.E[-1])
                for n in range(len(E)):
                    update_tail(h.E[n], d.E[:, n, :])
                update_alpha(key, stats, h.E, s.E)

            return

        return update
예제 #18
0
    ########################
    # Reference trajectory
    ########################
    # NB: Arbitrary, coz models are autonom. But dont use nan coz QG doesn't like it.
    t0 = 0.0
    K = int(round(T / dt))  # Num of time steps.
    tt = np.linspace(dt, T, K)  # Time seq.
    x = with_recursion(step, prog="BurnIn")(x0, int(10 / dt), t0, dt)[-1]
    xx = with_recursion(step, prog="Reference")(x, K, t0, dt)

    ########################
    # ACF
    ########################
    # NB: Won't work with QG (too big, and BCs==0).
    fig, ax = place.freshfig("ACF")
    if "ii" not in locals():
        ii = np.arange(min(100, Nx))
    if "nlags" not in locals():
        nlags = min(100, K - 1)
    ax.plot(
        tt[:nlags],
        np.nanmean(series.auto_cov(xx[:nlags, ii], nlags=nlags - 1, corr=1),
                   axis=1))
    ax.set_xlabel('Time (t)')
    ax.set_ylabel('Auto-corr')
    viz.plot_pause(0.1)

    ########################
    # "Linearized" forecasting
    ########################