Exemplo n.º 1
0
def correlation_fields(self,
                       fignum,
                       field_ensembles,
                       xy_coord,
                       title="",
                       **kwargs):
    field_ensembles = {k: v for k, v in field_ensembles.items() if v.ndim == 2}

    ncols = 2
    nAx = len(field_ensembles)
    nrows = int(np.ceil(nAx / ncols))
    fig, axs = fig_layout.freshfig(fignum,
                                   figsize=(8, 4 * nrows),
                                   ncols=ncols,
                                   nrows=nrows,
                                   sharex=True,
                                   sharey=True)

    fig.subplots_adjust(hspace=.3)
    fig.suptitle(title)
    for i, ax in enumerate(axs.ravel()):

        if i >= nAx:
            ax.set_visible(False)
        else:
            label = list(field_ensembles)[i]
            field = field_ensembles[label]
            handle = corr_field_vs(self, ax, field, xy_coord, label, **kwargs)

    fig_colorbar(fig, handle, ticks=[-1, -0.4, 0, 0.4, 1])  # type: ignore
Exemplo n.º 2
0
def plot_hovmoller(xx, chrono=None, **kwargs):
    """Plot Hovmöller diagram.

    Parameters
    ----------
    xx: ndarray
        Plotted array
    chrono: `dapper.tools.chronos.Chronology`, optional
        object with property dkObS. Defaults: None
    """
    fig, ax = freshfig(26, figsize=(4, 3.5))

    if chrono is not None:
        mask = chrono.tt <= chrono.Tplot * 2
        kk = chrono.kk[mask]
        tt = chrono.tt[mask]
        ax.set_ylabel('Time (t)')
    else:
        K = estimate_good_plot_length(xx, mult=20)
        kk = arange(K)
        tt = kk
        ax.set_ylabel('Time indices (k)')

    plt.contourf(arange(xx.shape[1]), tt, xx[kk], 25)
    plt.colorbar()
    ax.get_xaxis().set_major_locator(MaxNLocator(integer=True))
    ax.set_title("Hovmoller diagram (of 'Truth')")
    ax.set_xlabel('Dimension index (i)')

    plot_pause(0.1)
    plt.tight_layout()
Exemplo n.º 3
0
def productions(dct, fignum, figsize=None, title="", nProd=None, legend=True):

    if nProd is None:
        nProd = get0(dct).shape[1]
        nProd = min(23, nProd)
    fig, axs = fig_layout.freshfig(fignum,
                                   figsize=figsize,
                                   **nRowCol(nProd),
                                   sharex=True,
                                   sharey=True)
    # fig.suptitle("Oil productions " + title)

    # Turn off redundant axes
    for ax in axs.ravel()[nProd:]:
        ax.set_visible(False)

    handles = []

    # For each well
    for i in range(nProd):
        ax = axs.ravel()[i]
        ax.text(1,
                1,
                f"Well {i}" if i == 0 else i,
                c="k",
                ha="right",
                va="top",
                transform=ax.transAxes)

        for label, series in dct.items():

            # Get style props
            some_ensemble = list(dct.values())[-1]
            props = style(label, N=len(some_ensemble))

            # Plot
            ll = ax.plot(1 - series.T[i], **props)

            # Rm duplicate labels
            plt.setp(ll[1:], label="_nolegend_")

            # Store 1 handle of series
            if i == 0:
                handles.append(ll[0])

        # Legend
        if legend:
            leg = ax.legend(loc="upper left", bbox_to_anchor=(1, 1))
            for ln in leg.get_lines():
                ln.set(alpha=1, linewidth=max(1, ln.get_linewidth()))

    return handles
Exemplo n.º 4
0
    def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs):
        if not hasattr(stats, 'w'):
            self.is_active = False
            return
        fig, ax = freshfig(fignum, figsize=(7, 3), gridspec_kw={'bottom': .15})

        ax.set_xscale('log')
        ax.set_xlabel('Weigth')
        ax.set_ylabel('Count')
        self.stats = stats
        self.ax = ax
        self.hist = []
        self.bins = np.exp(np.linspace(np.log(1e-10), np.log(1), 31))
Exemplo n.º 5
0
def fields(self,
           fignum,
           plotter,
           ZZ,
           figsize=None,
           title="",
           txt_color="k",
           colorbar=True,
           **kwargs):

    fig, axs = fig_layout.freshfig(fignum,
                                   figsize=figsize,
                                   **nRowCol(min(12, len(ZZ))),
                                   sharex=True,
                                   sharey=True)

    # Turn off redundant axes
    for ax in axs[len(ZZ):]:
        ax.set_visible(False)

    # Convert list-like ZZ into dict
    if not isinstance(ZZ, dict):
        ZZ = {i: Z for (i, Z) in enumerate(ZZ)}

    # Get min/max across all fields
    flat = np.array(list(ZZ.values())).ravel()
    vmin = flat.min()
    vmax = flat.max()

    hh = []
    for ax, label in zip(axs.ravel(), ZZ):

        ax.text(0,
                1,
                label,
                ha="left",
                va="top",
                c=txt_color,
                transform=ax.transAxes)

        # Call plotter
        hh.append(plotter(self, ax, ZZ[label], vmin=vmin, vmax=vmax, **kwargs))

    if colorbar:
        fig_colorbar(fig, hh[0])

    if title:
        fig.suptitle(title)

    return fig, axs, hh
Exemplo n.º 6
0
    def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs):
        fig, ax = freshfig(fignum, figsize=(6, 3))
        ax.set_xlabel('Sing. value index')
        ax.set_yscale('log')
        self.init_incomplete = True
        self.ax = ax
        self.plot_u = plot_u

        try:
            self.msft = stats.umisf
            self.sprd = stats.svals
        except AttributeError:
            self.is_active = False
            not_available_text(ax, "Spectral stats not being computed")
Exemplo n.º 7
0
def plot_rank_histogram(stats):
    """Plot rank histogram of ensemble.

    Parameters
    ----------
    stats: `dapper.stats.Stats`
    """
    chrono = stats.HMM.t

    has_been_computed = \
        hasattr(stats, 'rh') and \
        not all(stats.rh.a[-1] == array(np.nan).astype(int))

    fig, ax = freshfig(24, (6, 3), loc="3313")
    ax.set_title('(Mean of marginal) rank histogram (_a)')
    ax.set_ylabel('Freq. of occurence\n (of truth in interval n)')
    ax.set_xlabel('ensemble member index (n)')

    if has_been_computed:
        ranks = stats.rh.a[chrono.maskObs_BI]
        Nx = ranks.shape[1]
        N = stats.xp.N
        if not hasattr(stats, 'w'):
            # Ensemble rank histogram
            integer_hist(ranks.ravel(), N)
        else:
            # Experimental: weighted rank histogram.
            # Weight ranks by inverse of particle weight. Why? Coz, with correct
            # importance weights, the "expected value" histogram is then flat.
            # Potential improvement: interpolate weights between particles.
            w = stats.w.a[chrono.maskObs_BI]
            K = len(w)
            w = np.hstack([w, np.ones(
                (K, 1)) / N])  # define weights for rank N+1
            w = array([w[arange(K), ranks[arange(K), i]] for i in range(Nx)])
            w = w.T.ravel()
            # Artificial cap. Reduces variance, but introduces bias.
            w = np.maximum(w, 1 / N / 100)
            w = 1 / w
            integer_hist(ranks.ravel(), N, weights=w)
    else:
        not_available_text(ax)

    plot_pause(0.1)
    plt.tight_layout()
Exemplo n.º 8
0
def oilfield_means(self, fignum, water_sat_fields, title="", **kwargs):
    ncols = 2
    nAx = len(water_sat_fields)
    nrows = int(np.ceil(nAx / ncols))
    fig, axs = fig_layout.freshfig(fignum,
                                   figsize=(8, 4 * nrows),
                                   ncols=ncols,
                                   nrows=nrows,
                                   sharex=True,
                                   sharey=True)

    fig.subplots_adjust(hspace=.3)
    fig.suptitle(f"Oil saturation (mean fields) - {title}")
    for ax, label in zip(axs.ravel(), water_sat_fields):

        field = water_sat_fields[label]
        if field.ndim == 2:
            field = field.mean(axis=0)

        handle = oilfield(self, ax, field, title=label, **kwargs)

    fig_colorbar(fig, handle)  # type: ignore
Exemplo n.º 9
0
    def init(fignum, stats, key0, plot_u, E, P, **kwargs):
        xx, yy, mu, _, chrono = \
            stats.xx, stats.yy, stats.mu, stats.std, stats.HMM.t

        # Set parameters (kwargs takes precedence over params_orig)
        p = DotDict(
            **{kw: kwargs.get(kw, val)
               for kw, val in params_orig.items()})

        # Lag settings:
        has_w = hasattr(stats, 'w')
        if p.Tplot == 0:
            K_plot = 1
        else:
            T_lag, K_lag, a_lag = validate_lag(p.Tplot, chrono)
            K_plot = comp_K_plot(K_lag, a_lag, plot_u)
            # Extend K_plot forther for adding blanks in resampling (PartFilt):
            if has_w:
                K_plot += a_lag

        # Dimension settings
        if not p.dims:
            p.dims = arange(M)
        if not p.labels:
            p.labels = ["$x_%d$" % d for d in p.dims]
        assert len(p.dims) == M

        # Set up figure, axes
        fig, _ = freshfig(fignum, figsize=(5, 5))
        ax = plt.subplot(111, projection='3d' if is_3d else None)
        ax.set_facecolor('w')
        ax.set_title("Phase space trajectories")
        # Tune plot
        for ind, (s, i, t) in enumerate(zip(p.labels, p.dims, "xyz")):
            viz.set_ilim(ax, ind,
                         *viz.stretch(*viz.xtrema(xx[:, i]), 1 / p.zoom))
            eval("ax.set_%slabel('%s')" % (t, s))

        # Allocate
        d = DotDict()  # data arrays
        h = DotDict()  # plot handles
        s = DotDict()  # scatter handles
        if E is not None:
            d.E = RollingArray((K_plot, len(E), M))
            h.E = []
        if P is not None:
            d.mu = RollingArray((K_plot, M))
        if True:
            d.x = RollingArray((K_plot, M))
        if list(p.obs_inds) == list(p.dims):
            d.y = RollingArray((K_plot, M))

        # Plot tails (invisible coz everything here is nan, for the moment).
        if 'E' in d:
            h.E += [
                ax.plot(*xn, **p.ens_props)[0]
                for xn in np.transpose(d.E, [1, 2, 0])
            ]
        if 'mu' in d:
            h.mu = ax.plot(*d.mu.T, 'b', lw=2)[0]
        if True:
            h.x = ax.plot(*d.x.T, 'k', lw=3)[0]
        if 'y' in d:
            h.y = ax.plot(*d.y.T, 'g*', ms=14)[0]

        # Scatter. NB: don't init with nan's coz it's buggy
        # (wrt. get_color() and _offsets3d) since mpl 3.1.
        if 'E' in d:
            s.E = ax.scatter(*E.T[p.dims],
                             s=3**2,
                             c=[hn.get_color() for hn in h.E])
        if 'mu' in d:
            s.mu = ax.scatter(*ones(M), s=8**2, c=[h.mu.get_color()])
        if True:
            s.x = ax.scatter(*ones(M),
                             s=14**2,
                             c=[h.x.get_color()],
                             marker=(5, 1),
                             zorder=99)

        def update(key, E, P):
            k, kObs, faus = key
            show_y = 'y' in d and kObs is not None

            def update_tail(handle, newdata):
                handle.set_data(newdata[:, 0], newdata[:, 1])
                if is_3d:
                    handle.set_3d_properties(newdata[:, 2])

            def update_sctr(handle, newdata):
                if is_3d:
                    handle._offsets3d = juggle_axes(*newdata.T, 'z')
                else:
                    handle.set_offsets(newdata)

            EE = duplicate_with_blanks_for_resampled(E, p.dims, key, has_w)

            # Roll data array
            ind = k if plot_u else kObs
            for Ens in EE:  # If E is duplicated, so must the others be.
                if 'E' in d:
                    d.E.insert(ind, Ens)
                if True:
                    d.x.insert(ind, xx[k, p.dims])
                if 'y' in d:
                    d.y.insert(ind, yy[kObs, :] if show_y else nan * ones(M))
                if 'mu' in d:
                    d.mu.insert(ind, mu[key][p.dims])

            # Update graph
            update_sctr(s.x, d.x[[-1]])
            update_tail(h.x, d.x)
            if 'y' in d:
                update_tail(h.y, d.y)
            if 'mu' in d:
                update_sctr(s.mu, d.mu[[-1]])
                update_tail(h.mu, d.mu)
            else:
                update_sctr(s.E, d.E[-1])
                for n in range(len(E)):
                    update_tail(h.E[n], d.E[:, n, :])
                update_alpha(key, stats, h.E, s.E)

            return

        return update
Exemplo n.º 10
0
def plot_err_components(stats):
    """Plot components of the error.

    Parameters
    ----------
    stats: `dapper.stats.Stats`

    .. note::
      it was chosen to plot(ii, mean_in_time(abs(err_i))),
      and thus the corresponding spread measure is MAD.
      If one chose instead: plot(ii, std_in_time(err_i)),
      then the corresponding measure of spread would have been std.
      This choice was made in part because (wrt. subplot 2)
      the singular values (svals) correspond to rotated MADs,
      and because rms(umisf) seems to convoluted for interpretation.
    """
    fig, (ax0, ax1, ax2) = freshfig(25, figsize=(6, 6), nrows=3)

    chrono = stats.HMM.t
    Nx = stats.xx.shape[1]

    err = np.mean(np.abs(stats.err.a), 0)
    sprd = np.mean(stats.mad.a, 0)
    umsft = np.mean(np.abs(stats.umisf.a), 0)
    usprd = np.mean(stats.svals.a, 0)

    ax0.plot(arange(Nx), err, 'k', lw=2, label='Error')
    if Nx < 10**3:
        ax0.fill_between(arange(Nx), [0] * len(sprd),
                         sprd,
                         alpha=0.7,
                         label='Spread')
    else:
        ax0.plot(arange(Nx), sprd, alpha=0.7, label='Spread')
    # ax0.set_yscale('log')
    ax0.set_title('Element-wise error comparison')
    ax0.set_xlabel('Dimension index (i)')
    ax0.set_ylabel('Time-average (_a) magnitude')
    ax0.set_xlim(0, Nx - 1)
    ax0.get_xaxis().set_major_locator(MaxNLocator(integer=True))
    ax0.legend(loc='upper right')

    ax1.set_xlim(0, Nx - 1)
    ax1.set_xlabel('Principal component index')
    ax1.set_ylabel('Time-average (_a) magnitude')
    ax1.set_title('Spectral error comparison')
    has_been_computed = np.any(np.isfinite(umsft))
    if has_been_computed:
        L = len(umsft)
        ax1.plot(arange(L), umsft, 'k', lw=2, label='Error')
        ax1.fill_between(arange(L), [0] * L, usprd, alpha=0.7, label='Spread')
        ax1.set_yscale('log')
        ax1.get_xaxis().set_major_locator(MaxNLocator(integer=True))
    else:
        not_available_text(ax1)

    rmse = stats.err_rms.a[chrono.maskObs_BI]
    ax2.hist(rmse, bins=30, density=False)
    ax2.set_ylabel('Num. of occurence (_a)')
    ax2.set_xlabel('RMSE')
    ax2.set_title('Histogram of RMSE values')
    ax2.set_xlim(left=0)

    plot_pause(0.1)
    plt.tight_layout()
Exemplo n.º 11
0
    def init(fignum, stats, key0, plot_u, E, P, **kwargs):
        xx, yy, mu = stats.xx, stats.yy, stats.mu

        # Set parameters (kwargs takes precedence over params_orig)
        p = DotDict(
            **{kw: kwargs.get(kw, val)
               for kw, val in params_orig.items()})

        if not p.dims:
            M = xx.shape[-1]
            p.dims = arange(M)
        else:
            M = len(p.dims)

        # Make periodic wrapper
        ii, wrap = viz.setup_wrapping(M, p.periodicity)

        # Set up figure, axes
        fig, ax = freshfig(fignum, figsize=(8, 5))
        fig.suptitle("1d amplitude plot")

        # Nans
        nan1 = wrap(nan * ones(M))

        if E is None and p.conf_mult is None:
            p.conf_mult = 2

        # Init plots
        if p.conf_mult:
            lines_s = ax.plot(ii,
                              nan1,
                              "b-",
                              lw=1,
                              label=(str(p.conf_mult) + r'$\sigma$ conf'))
            lines_s += ax.plot(ii, nan1, "b-", lw=1)
            line_mu, = ax.plot(ii, nan1, 'b-', lw=2, label='DA mean')
        else:
            nanE = nan * ones((stats.xp.N, M))
            lines_E = ax.plot(ii,
                              wrap(nanE[0]),
                              **p.ens_props,
                              lw=1,
                              label='Ensemble')
            lines_E += ax.plot(ii, wrap(nanE[1:]).T, **p.ens_props, lw=1)
        # Truth, Obs
        (line_x, ) = ax.plot(ii, nan1, 'k-', lw=3, label='Truth')
        if p.obs_inds is not None:
            (line_y, ) = ax.plot(p.obs_inds,
                                 nan * p.obs_inds,
                                 'g*',
                                 ms=5,
                                 label='Obs')

        # Tune plot
        ax.set_ylim(*viz.xtrema(xx))
        ax.set_xlim(viz.stretch(ii[0], ii[-1], 1))
        # Xticks
        xt = ax.get_xticks()
        xt = xt[abs(xt % 1) < 0.01].astype(int)  # Keep only the integer ticks
        xt = xt[xt >= 0]
        xt = xt[xt < len(p.dims)]
        ax.set_xticks(xt)
        ax.set_xticklabels(p.dims[xt])

        ax.set_xlabel('State index')
        ax.set_ylabel('Value')
        ax.legend(loc='upper right')

        text_t = ax.text(0.01,
                         0.01,
                         format_time(None, None, None),
                         transform=ax.transAxes,
                         family='monospace',
                         ha='left')

        # Init visibility (must come after legend):
        if p.obs_inds is not None:
            line_y.set_visible(False)

        def update(key, E, P):
            k, kObs, faus = key

            if p.conf_mult:
                sigma = mu[key] + p.conf_mult * stats.std[key] * [[1], [-1]]
                lines_s[0].set_ydata(wrap(sigma[0, p.dims]))
                lines_s[1].set_ydata(wrap(sigma[1, p.dims]))
                line_mu.set_ydata(wrap(mu[key][p.dims]))
            else:
                for n, line in enumerate(lines_E):
                    line.set_ydata(wrap(E[n, p.dims]))
                update_alpha(key, stats, lines_E)

            line_x.set_ydata(wrap(xx[k, p.dims]))

            text_t.set_text(format_time(k, kObs, stats.HMM.t.tt[k]))

            if 'f' in faus:
                if p.obs_inds is not None:
                    line_y.set_ydata(yy[kObs])
                    line_y.set_zorder(5)
                    line_y.set_visible(True)

            if 'u' in faus:
                if p.obs_inds is not None:
                    line_y.set_visible(False)

            return

        return update
Exemplo n.º 12
0
    def init(fignum, stats, key0, plot_u, E, P, **kwargs):

        GS = {'left': 0.125 - 0.04, 'right': 0.9 - 0.04}
        fig, axs = freshfig(fignum,
                            figsize=(6, 6),
                            nrows=2,
                            ncols=2,
                            sharex=True,
                            sharey=True,
                            gridspec_kw=GS)

        for ax in axs.flatten():
            ax.set_aspect('equal', viz.adjustable_box_or_forced())

        ((ax_11, ax_12), (ax_21, ax_22)) = axs

        ax_11.grid(color='w', linewidth=0.2)
        ax_12.grid(color='w', linewidth=0.2)
        ax_21.grid(color='k', linewidth=0.1)
        ax_22.grid(color='k', linewidth=0.1)

        # Upper colorbar -- position relative to ax_12
        bb = ax_12.get_position()
        dy = 0.1 * bb.height
        ax_13 = fig.add_axes(
            [bb.x1 + 0.03, bb.y0 + dy, 0.04, bb.height - 2 * dy])
        # Lower colorbar -- position relative to ax_22
        bb = ax_22.get_position()
        dy = 0.1 * bb.height
        ax_23 = fig.add_axes(
            [bb.x1 + 0.03, bb.y0 + dy, 0.04, bb.height - 2 * dy])

        # Extract data arrays
        xx, _, mu, std, err = stats.xx, stats.yy, stats.mu, stats.std, stats.err
        k = key0[0]
        tt = stats.HMM.t.tt

        # Plot
        # - origin='lower' might get overturned by set_ylim() below.
        im_11 = ax_11.imshow(square(mu[key0]), cmap=cm)
        im_12 = ax_12.imshow(square(xx[k]), cmap=cm)
        # hot is better, but needs +1 colorbar
        im_21 = ax_21.imshow(square(std[key0]), cmap=plt.cm.bwr)
        im_22 = ax_22.imshow(square(err[key0]), cmap=plt.cm.bwr)
        ims = (im_11, im_12, im_21, im_22)
        # Obs init -- a list where item 0 is the handle of something invisible.
        lh = list(ax_12.plot(0, 0)[0:1])

        sx = '$\\psi$'
        ax_11.set_title('mean ' + sx)
        ax_12.set_title('true ' + sx)
        ax_21.set_title('std. ' + sx)
        ax_22.set_title('err. ' + sx)

        # TODO 7
        # for ax in axs.flatten():
        # Crop boundries (which should be 0, i.e. yield harsh q gradients):
        # lims = (1, nx-2)
        # step = (nx - 1)/8
        # ticks = arange(step,nx-1,step)
        # ax.set_xlim  (lims)
        # ax.set_ylim  (lims[::-1])
        # ax.set_xticks(ticks)
        # ax.set_yticks(ticks)

        for im, clim in zip(ims, clims):
            im.set_clim(clim)

        fig.colorbar(im_12, cax=ax_13)
        fig.colorbar(im_22, cax=ax_23)
        for ax in [ax_13, ax_23]:
            ax.yaxis.set_tick_params('major',
                                     length=2,
                                     width=0.5,
                                     direction='in',
                                     left=True,
                                     right=True)
            ax.set_axisbelow('line')  # make ticks appear over colorbar patch

        # Title
        title = "Streamfunction (" + sx + ")"
        fig.suptitle(title)
        # Time info
        text_t = ax_12.text(1,
                            1.1,
                            format_time(None, None, None),
                            transform=ax_12.transAxes,
                            family='monospace',
                            ha='left')

        def update(key, E, P):
            k, kObs, faus = key
            t = tt[k]

            im_11.set_data(square(mu[key]))
            im_12.set_data(square(xx[k]))
            im_21.set_data(square(std[key]))
            im_22.set_data(square(err[key]))

            # Remove previous obs
            try:
                lh[0].remove()
            except ValueError:
                pass
            # Plot current obs.
            #  - plot() automatically adjusts to direction of y-axis in use.
            #  - ind2sub returns (iy,ix), while plot takes (ix,iy) => reverse.
            if kObs is not None and obs_inds is not None:
                lh[0] = ax_12.plot(*ind2sub(obs_inds(t))[::-1],
                                   'k.',
                                   ms=1,
                                   zorder=5)[0]

            text_t.set_text(format_time(k, kObs, t))

            return

        return update
Exemplo n.º 13
0
def dashboard(self,
              saturation,
              production,
              pause=200,
              animate=True,
              title="",
              **kwargs):
    fig, axs = fig_layout.freshfig(231, ncols=2, nrows=2, figsize=(12, 10))
    if is_notebook_or_qt:
        plt.close()  # ttps://stackoverflow.com/q/47138023

    tt = np.arange(len(saturation))

    axs[0, 0].set_title("Initial")
    axs[0, 0].cc = oilfield(self, axs[0, 0], saturation[0], **kwargs)
    axs[0, 0].set_ylabel(f"y ({COORD_TYPE})")

    axs[0, 1].set_title("Evolution")
    axs[0, 1].cc = oilfield(self, axs[0, 1], saturation[-1], **kwargs)
    well_scatter(self, axs[0, 1], self.injectors)
    well_scatter(self,
                 axs[0, 1],
                 self.producers,
                 False,
                 color=[f"C{i}" for i in range(len(self.producers))])

    axs[1, 0].set_title("Production")
    prod_handles = production1(axs[1, 0], production)

    axs[1, 1].set_visible(False)

    # fig.tight_layout()
    fig_colorbar(fig, axs[0, 0].cc)

    if title:
        fig.suptitle(f"Oil saturation -- {title}")

    if animate:
        from matplotlib import animation

        def update_fig(iT):
            # Update field
            for c in axs[0, 1].cc.collections:
                try:
                    axs[0, 1].collections.remove(c)
                except ValueError:
                    pass  # occurs when re-running script
            axs[0, 1].cc = oilfield(self, axs[0, 1], saturation[iT], **kwargs)

            # Update production lines
            if iT >= 1:
                for h, p in zip(prod_handles, 1 - production.T):
                    h.set_data(tt[:iT - 1], p[:iT - 1])

        ani = animation.FuncAnimation(fig,
                                      update_fig,
                                      len(tt),
                                      blit=False,
                                      interval=pause)

        return ani
Exemplo n.º 14
0
    def __init__(self,
                 fignum,
                 stats,
                 key0,
                 plot_u,
                 E,
                 P,
                 Tplot=None,
                 **kwargs):

        # STYLE TABLES - Defines which/how diagnostics get plotted
        styles = {}

        def lin(a, b):
            return (lambda x: a + b * x)

        divN = 1 / getattr(stats.xp, 'N', 99)
        # Columns: transf, shape, plt kwargs
        styles['RMS'] = {
            'err.rms': [None, None, dict(c='k', label='Error')],
            'std.rms': [None, None,
                        dict(c='b', label='Spread', alpha=0.6)],
        }
        styles['Values'] = {
            'skew': [None, None,
                     dict(c='g', label=star + r'Skew/$\sigma^3$')],
            'kurt':
            [None, None,
             dict(c='r', label=star + r'Kurt$/\sigma^4{-}3$')],
            'trHK': [None, None, dict(c='k', label=star + 'HK')],
            'infl': [lin(-10, 10), 'step',
                     dict(c='c', label='10(infl-1)')],
            'N_eff':
            [lin(0, divN), 'dirac',
             dict(c='y', label='N_eff/N', lw=3)],
            'iters': [lin(0, .1), 'dirac',
                      dict(c='m', label='iters/10')],
            'resmpl': [None, 'dirac',
                       dict(c='k', label='resampled?')],
        }

        nAx = len(styles)
        GS = {'left': 0.125, 'right': 0.76}
        fig, axs = freshfig(fignum,
                            figsize=(5, 1 + nAx),
                            nrows=nAx,
                            sharex=True,
                            gridspec_kw=GS)

        axs[0].set_title("Diagnostics")
        for style, ax in zip(styles, axs):
            ax.set_ylabel(style)
        ax.set_xlabel('Time (t)')
        viz.adjust_position(ax, y0=0.03)

        self.T_lag, K_lag, a_lag = validate_lag(Tplot, stats.HMM.t)

        def init_ax(ax, style_table):
            lines = {}
            for name in style_table:

                # SKIP -- if stats[name] is not in existence
                # Note: The nan check/deletion comes after the first kObs.
                try:
                    stat = deep_getattr(stats, name)
                except AttributeError:
                    continue
                # try: val0 = stat[key0[0]]
                # except KeyError: continue
                # PS: recall (from series.py) that even if store_u is false, stat[k] is
                # still present if liveplots=True via the k_tmp functionality.

                # Unpack style
                ln = {}
                ln['transf'] = style_table[name][0] or (lambda x: x)
                ln['shape'] = style_table[name][1]
                ln['plt'] = style_table[name][2]

                # Create series
                if isinstance(stat, FAUSt):
                    ln['plot_u'] = plot_u
                    K_plot = comp_K_plot(K_lag, a_lag, ln['plot_u'])
                else:
                    ln['plot_u'] = False
                    K_plot = a_lag
                ln['data'] = RollingArray(K_plot)
                ln['tt'] = RollingArray(K_plot)

                # Plot (init)
                ln['handle'], = ax.plot(ln['tt'], ln['data'], **ln['plt'])

                # Plotting only nans yield ugly limits. Revert to defaults.
                ax.set_xlim(0, 1)
                ax.set_ylim(0, 1)

                lines[name] = ln
            return lines

        # Plot
        self.d = [init_ax(ax, styles[style]) for style, ax in zip(styles, axs)]

        # Horizontal line at y=0
        self.baseline0, = ax.plot(ax.get_xlim(), [0, 0],
                                  c=0.5 * ones(3),
                                  lw=0.7,
                                  label='_nolegend_')

        # Store
        self.axs = axs
        self.stats = stats
        self.init_incomplete = True
Exemplo n.º 15
0
    """Random field generation.

    Uses:
    - Gaussian variogram.
    - Gaussian distributions.
    """
    dists = dist_euclid(vectorize(*pts))
    Cov = 1 - variogram_gauss(dists, r)
    C12 = sla.sqrtm(Cov).real
    fields = randn(N, len(dists)) @ C12.T
    return fields


if __name__ == "__main__":
    np.random.seed(3000)
    plt.ion()
    N = 15  # ensemble size

    ## 1D
    xx = np.linspace(0, 1, 201)
    fields = gaussian_fields((xx, ), N)
    fig, ax = freshfig(1)
    ax.plot(xx, fields.T, lw=2)

    ## 2D
    grid = Grid2D(Lx=1, Ly=1, Nx=20, Ny=20)
    fields = gaussian_fields(grid.mesh(), N)
    fields = 0.5 + .2 * fields
    # fields = truncate_01(fields)
    plots.oilfields(grid, 2, fields)
Exemplo n.º 16
0
def amplitude_animation(EE,
                        dt=None,
                        interval=1,
                        periodicity=None,
                        blit=True,
                        fignum=None,
                        repeat=False):
    """Animation of line chart.

    Using an ensemble of
    the shape (time, ensemble size, state vector length).

    Parameters
    ----------
    EE: ndarray
        Ensemble arry of the shape (K, N, Nx).
        K is the length of time, N is the ensemble size, and
        Nx is the length of state vector.
    dt: float
        Time interval of each frame.
    interval: float, optional
        Delay between frames in milliseconds. Defaults to 200.
    periodicity: bool, optional
        The mode of the wrapping.
        "+1": the first element is appended after the last.
        "+/-05": adding the midpoint of the first and last elements.
        Default: "+1"
    blit: bool, optional
        Controls whether blitting is used to optimize drawing. Default: True
    fignum: int, optional
        Figure index. Default: None
    repeat: bool, optional
        If True, repeat the animation. Default: False
    """
    fig, ax = freshfig(fignum)
    ax.set_xlabel('State index')
    ax.set_ylabel('Amplitue')
    ax.set_ylim(*stretch(*xtrema(EE), 1.1))

    if EE.ndim == 2:
        EE = np.expand_dims(EE, 1)
    K, N, Nx = EE.shape

    ii, wrap = setup_wrapping(Nx, periodicity)

    lines = ax.plot(ii, wrap(EE[0]).T)
    ax.set_xlim(*xtrema(ii))

    if dt is not None:
        times = 'time = %.1f'
        lines += [ax.text(0.05, 0.9, '', transform=ax.transAxes)]

    def anim(k):
        Ek = wrap(EE[k])
        for n in range(N):
            lines[n].set_ydata(Ek[n])
        if len(lines) > N:
            lines[-1].set_text(times % (dt * k))
        return lines

    return FuncAnimation(fig,
                         anim,
                         range(K),
                         interval=interval,
                         blit=blit,
                         repeat=repeat)
Exemplo n.º 17
0
    def init(fignum, stats, key0, plot_u, E, P, **kwargs):
        xx, yy, mu, std, chrono = \
            stats.xx, stats.yy, stats.mu, stats.std, stats.HMM.t

        # Set parameters (kwargs takes precedence over params_orig)
        p = DotDict(
            **{kw: kwargs.get(kw, val)
               for kw, val in params_orig.items()})

        # Lag settings:
        T_lag, K_lag, a_lag = validate_lag(p.Tplot, chrono)
        K_plot = comp_K_plot(K_lag, a_lag, plot_u)
        # Extend K_plot forther for adding blanks in resampling (PartFilt):
        has_w = hasattr(stats, 'w')
        if has_w:
            K_plot += a_lag

        # Chose marginal dims to plot
        if not p.dims:
            Nx = min(10, xx.shape[-1])
            DimsX = linspace_int(xx.shape[-1], Nx)
        else:
            Nx = len(p.dims)
            DimsX = p.dims
        # Pre-process obs dimensions
        # Rm inds of obs if not in DimsX
        iiY = [i for i, m in enumerate(p.obs_inds) if m in DimsX]
        # Rm obs_inds    if not in DimsX
        DimsY = [m for i, m in enumerate(p.obs_inds) if m in DimsX]
        # Get dim (within y) of each x
        DimsY = [DimsY.index(m) if m in DimsY else None for m in DimsX]
        Ny = len(iiY)

        # Set up figure, axes
        fig, axs = freshfig(fignum, figsize=(5, 7), nrows=Nx, sharex=True)
        if Nx == 1:
            axs = [axs]

        # Tune plots
        axs[0].set_title("Marginal time series")
        for ix, (m, ax) in enumerate(zip(DimsX, axs)):
            # ax.set_ylim(*viz.stretch(*viz.xtrema(xx[:, m]), 1/p.zoomy))
            if not p.labels:
                ax.set_ylabel("$x_{%d}$" % m)
            else:
                ax.set_ylabel(p.labels[ix])
        axs[-1].set_xlabel('Time (t)')

        plot_pause(0.05)
        plt.tight_layout()

        # Allocate
        d = DotDict()  # data arrays
        h = DotDict()  # plot handles
        # Why "if True" ? Just to indent the rest of the line...
        if True:
            d.t = RollingArray((K_plot, ))
        if True:
            d.x = RollingArray((K_plot, Nx))
            h.x = []
        if True:
            d.y = RollingArray((K_plot, Ny))
            h.y = []
        if E is not None:
            d.E = RollingArray((K_plot, len(E), Nx))
            h.E = []
        if P is not None:
            d.mu = RollingArray((K_plot, Nx))
            h.mu = []
        if P is not None:
            d.s = RollingArray((K_plot, 2, Nx))
            h.s = []

        # Plot (invisible coz everything here is nan, for the moment).
        for ix, (_m, iy, ax) in enumerate(zip(DimsX, DimsY, axs)):
            if True:
                h.x += ax.plot(d.t, d.x[:, ix], 'k')
            if iy != None:
                h.y += ax.plot(d.t, d.y[:, iy], 'g*', ms=10)
            if 'E' in d:
                h.E += [ax.plot(d.t, d.E[:, :, ix], **p.ens_props)]
            if 'mu' in d:
                h.mu += ax.plot(d.t, d.mu[:, ix], 'b')
            if 's' in d:
                h.s += [ax.plot(d.t, d.s[:, :, ix], 'b--', lw=1)]

        def update(key, E, P):
            k, kObs, faus = key

            EE = duplicate_with_blanks_for_resampled(E, DimsX, key, has_w)

            # Roll data array
            ind = k if plot_u else kObs
            for Ens in EE:  # If E is duplicated, so must the others be.
                if 'E' in d:
                    d.E.insert(ind, Ens)
                if 'mu' in d:
                    d.mu.insert(ind, mu[key][DimsX])
                if 's' in d:
                    d.s.insert(ind,
                               mu[key][DimsX] + [[1], [-1]] * std[key][DimsX])
                if True:
                    d.t.insert(ind, chrono.tt[k])
                if True:
                    d.y.insert(
                        ind, yy[kObs,
                                iiY] if kObs is not None else nan * ones(Ny))
                if True:
                    d.x.insert(ind, xx[k, DimsX])

            # Update graphs
            for ix, (_m, iy, ax) in enumerate(zip(DimsX, DimsY, axs)):
                sliding_xlim(ax, d.t, T_lag, True)
                if True:
                    h.x[ix].set_data(d.t, d.x[:, ix])
                if iy != None:
                    h.y[iy].set_data(d.t, d.y[:, iy])
                if 'mu' in d:
                    h.mu[ix].set_data(d.t, d.mu[:, ix])
                if 's' in d:
                    [h.s[ix][b].set_data(d.t, d.s[:, b, ix]) for b in [0, 1]]
                if 'E' in d:
                    [
                        h.E[ix][n].set_data(d.t, d.E[:, n, ix])
                        for n in range(len(E))
                    ]
                if 'E' in d:
                    update_alpha(key, stats, h.E[ix])

                # TODO 3: fixup. This might be slow?
                # In any case, it is very far from tested.
                # Also, relim'iting all of the time is distracting.
                # Use d_ylim?
                if 'E' in d:
                    lims = d.E
                elif 'mu' in d:
                    lims = d.mu
                lims = np.array(viz.xtrema(lims[..., ix]))
                if lims[0] == lims[1]:
                    lims += [-.5, +.5]
                ax.set_ylim(*viz.stretch(*lims, 1 / p.zoomy))

            return

        return update
Exemplo n.º 18
0
    def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs):

        GS = {'height_ratios': [4, 1], 'hspace': 0.09, 'top': 0.95}
        fig, (ax, ax2) = freshfig(fignum,
                                  figsize=(5, 6),
                                  nrows=2,
                                  gridspec_kw=GS)

        if E is None and np.isnan(
                P.diag if isinstance(P, CovMat) else P).all():
            not_available_text(ax, ('Not available in replays'
                                    '\ncoz full Ens/Cov not stored.'))
            self.is_active = False
            return

        Nx = len(stats.mu[key0])
        if Nx <= 1003:
            C = np.eye(Nx)
            # Mask half
            mask = np.zeros_like(C, dtype=np.bool)
            mask[np.tril_indices_from(mask)] = True
            # Make colormap. Log-transform cmap,
            # but not internally in matplotlib,
            # so as to avoid transforming the colorbar too.
            cmap = plt.get_cmap('RdBu_r')
            trfm = mpl.colors.SymLogNorm(linthresh=0.2,
                                         linscale=0.2,
                                         base=np.e,
                                         vmin=-1,
                                         vmax=1)
            cmap = cmap(trfm(np.linspace(-0.6, 0.6, cmap.N)))
            cmap = mpl.colors.ListedColormap(cmap)
            #
            VM = 1.0  # abs(np.percentile(C,[1,99])).max()
            im = ax.imshow(C, cmap=cmap, vmin=-VM, vmax=VM)
            # Colorbar
            _ = ax.figure.colorbar(im, ax=ax, shrink=0.8)
            # Tune plot
            plt.box(False)
            ax.set_facecolor('w')
            ax.grid(False)
            ax.set_title("State correlation matrix:", y=1.07)
            ax.xaxis.tick_top()

            # ax2 = inset_axes(ax,width="30%",height="60%",loc=3)
            line_AC, = ax2.plot(arange(Nx), ones(Nx), label='Correlation')
            line_AA, = ax2.plot(arange(Nx), ones(Nx), label='Abs. corr.')
            _ = ax2.hlines(0, 0, Nx - 1, 'k', 'dotted', lw=1)
            # Align ax2 with ax
            bb_AC = ax2.get_position()
            bb_C = ax.get_position()
            ax2.set_position([bb_C.x0, bb_AC.y0, bb_C.width, bb_AC.height])
            # Tune plot
            ax2.set_title("Auto-correlation:")
            ax2.set_ylabel("Mean value")
            ax2.set_xlabel("Distance (in state indices)")
            ax2.set_xticklabels([])
            ax2.set_yticks([0, 1] + list(ax2.get_yticks()[[0, -1]]))
            ax2.set_ylim(top=1)
            ax2.legend(frameon=True,
                       facecolor='w',
                       bbox_to_anchor=(1, 1),
                       loc='upper left',
                       borderaxespad=0.02)

            self.ax = ax
            self.ax2 = ax2
            self.im = im
            self.line_AC = line_AC
            self.line_AA = line_AA
            self.mask = mask
            if hasattr(stats, 'w'):
                self.w = stats.w
        else:
            not_available_text(ax)
Exemplo n.º 19
0
########################
# Reference trajectory
########################
# NB: Arbitrary, coz models are autonom. But dont use nan coz QG doesn't like it.
t0 = 0.0
K  = int(round(T/dt))       # Num of time steps.
tt = np.linspace(dt, T, K)  # Time seq.
x  = with_recursion(step, prog="BurnIn")(x0, int(10/dt), t0, dt)[-1]
xx = with_recursion(step, prog="Reference")(x, K, t0, dt)


########################
# ACF
########################
# NB: Won't work with QG (too big, and BCs==0).
fig, ax = freshfig(4)
if "ii" not in locals():
    ii = np.arange(min(100, Nx))
if "nlags" not in locals():
    nlags = min(100, K-1)
ax.plot(tt[:nlags], np.nanmean(
    series.auto_cov(xx[:nlags, ii], nlags=nlags-1, corr=1),
    axis=1))
ax.set_xlabel('Time (t)')
ax.set_ylabel('Auto-corr')
viz.plot_pause(0.1)


########################
# "Linearized" forecasting
########################
Exemplo n.º 20
0
well_grid = well_grid.T.reshape((-1, 3))
model.init_Q(
    inj =[[0.50, 0.50, 1.00]],
    prod=well_grid,
);

# # Random well configuration
# model.init_Q(
#     inj =rand(1, 3),
#     prod=rand(8, 3)
# );
# -

# #### Plot true field

fig, ax = freshfig(110)
# cs = plots.field(model, ax, perm.Truth)
cs = plots.field(model, ax, f_perm(perm.Truth), locator=ticker.LogLocator())
plots.well_scatter(model, ax, model.producers, inj=False)
plots.well_scatter(model, ax, model.injectors, inj=True)
fig.colorbar(cs)
fig.suptitle("True field");
plt.pause(.1)


# #### Define obs operator
# There is no well model. The data consists purely of the water cut at the location of the wells.

obs_inds = [model.xy2ind(x, y) for (x, y, _) in model.producers]
def obs(water_sat):
    return [water_sat[i] for i in obs_inds]