def correlation_fields(self, fignum, field_ensembles, xy_coord, title="", **kwargs): field_ensembles = {k: v for k, v in field_ensembles.items() if v.ndim == 2} ncols = 2 nAx = len(field_ensembles) nrows = int(np.ceil(nAx / ncols)) fig, axs = fig_layout.freshfig(fignum, figsize=(8, 4 * nrows), ncols=ncols, nrows=nrows, sharex=True, sharey=True) fig.subplots_adjust(hspace=.3) fig.suptitle(title) for i, ax in enumerate(axs.ravel()): if i >= nAx: ax.set_visible(False) else: label = list(field_ensembles)[i] field = field_ensembles[label] handle = corr_field_vs(self, ax, field, xy_coord, label, **kwargs) fig_colorbar(fig, handle, ticks=[-1, -0.4, 0, 0.4, 1]) # type: ignore
def plot_hovmoller(xx, chrono=None, **kwargs): """Plot Hovmöller diagram. Parameters ---------- xx: ndarray Plotted array chrono: `dapper.tools.chronos.Chronology`, optional object with property dkObS. Defaults: None """ fig, ax = freshfig(26, figsize=(4, 3.5)) if chrono is not None: mask = chrono.tt <= chrono.Tplot * 2 kk = chrono.kk[mask] tt = chrono.tt[mask] ax.set_ylabel('Time (t)') else: K = estimate_good_plot_length(xx, mult=20) kk = arange(K) tt = kk ax.set_ylabel('Time indices (k)') plt.contourf(arange(xx.shape[1]), tt, xx[kk], 25) plt.colorbar() ax.get_xaxis().set_major_locator(MaxNLocator(integer=True)) ax.set_title("Hovmoller diagram (of 'Truth')") ax.set_xlabel('Dimension index (i)') plot_pause(0.1) plt.tight_layout()
def productions(dct, fignum, figsize=None, title="", nProd=None, legend=True): if nProd is None: nProd = get0(dct).shape[1] nProd = min(23, nProd) fig, axs = fig_layout.freshfig(fignum, figsize=figsize, **nRowCol(nProd), sharex=True, sharey=True) # fig.suptitle("Oil productions " + title) # Turn off redundant axes for ax in axs.ravel()[nProd:]: ax.set_visible(False) handles = [] # For each well for i in range(nProd): ax = axs.ravel()[i] ax.text(1, 1, f"Well {i}" if i == 0 else i, c="k", ha="right", va="top", transform=ax.transAxes) for label, series in dct.items(): # Get style props some_ensemble = list(dct.values())[-1] props = style(label, N=len(some_ensemble)) # Plot ll = ax.plot(1 - series.T[i], **props) # Rm duplicate labels plt.setp(ll[1:], label="_nolegend_") # Store 1 handle of series if i == 0: handles.append(ll[0]) # Legend if legend: leg = ax.legend(loc="upper left", bbox_to_anchor=(1, 1)) for ln in leg.get_lines(): ln.set(alpha=1, linewidth=max(1, ln.get_linewidth())) return handles
def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs): if not hasattr(stats, 'w'): self.is_active = False return fig, ax = freshfig(fignum, figsize=(7, 3), gridspec_kw={'bottom': .15}) ax.set_xscale('log') ax.set_xlabel('Weigth') ax.set_ylabel('Count') self.stats = stats self.ax = ax self.hist = [] self.bins = np.exp(np.linspace(np.log(1e-10), np.log(1), 31))
def fields(self, fignum, plotter, ZZ, figsize=None, title="", txt_color="k", colorbar=True, **kwargs): fig, axs = fig_layout.freshfig(fignum, figsize=figsize, **nRowCol(min(12, len(ZZ))), sharex=True, sharey=True) # Turn off redundant axes for ax in axs[len(ZZ):]: ax.set_visible(False) # Convert list-like ZZ into dict if not isinstance(ZZ, dict): ZZ = {i: Z for (i, Z) in enumerate(ZZ)} # Get min/max across all fields flat = np.array(list(ZZ.values())).ravel() vmin = flat.min() vmax = flat.max() hh = [] for ax, label in zip(axs.ravel(), ZZ): ax.text(0, 1, label, ha="left", va="top", c=txt_color, transform=ax.transAxes) # Call plotter hh.append(plotter(self, ax, ZZ[label], vmin=vmin, vmax=vmax, **kwargs)) if colorbar: fig_colorbar(fig, hh[0]) if title: fig.suptitle(title) return fig, axs, hh
def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs): fig, ax = freshfig(fignum, figsize=(6, 3)) ax.set_xlabel('Sing. value index') ax.set_yscale('log') self.init_incomplete = True self.ax = ax self.plot_u = plot_u try: self.msft = stats.umisf self.sprd = stats.svals except AttributeError: self.is_active = False not_available_text(ax, "Spectral stats not being computed")
def plot_rank_histogram(stats): """Plot rank histogram of ensemble. Parameters ---------- stats: `dapper.stats.Stats` """ chrono = stats.HMM.t has_been_computed = \ hasattr(stats, 'rh') and \ not all(stats.rh.a[-1] == array(np.nan).astype(int)) fig, ax = freshfig(24, (6, 3), loc="3313") ax.set_title('(Mean of marginal) rank histogram (_a)') ax.set_ylabel('Freq. of occurence\n (of truth in interval n)') ax.set_xlabel('ensemble member index (n)') if has_been_computed: ranks = stats.rh.a[chrono.maskObs_BI] Nx = ranks.shape[1] N = stats.xp.N if not hasattr(stats, 'w'): # Ensemble rank histogram integer_hist(ranks.ravel(), N) else: # Experimental: weighted rank histogram. # Weight ranks by inverse of particle weight. Why? Coz, with correct # importance weights, the "expected value" histogram is then flat. # Potential improvement: interpolate weights between particles. w = stats.w.a[chrono.maskObs_BI] K = len(w) w = np.hstack([w, np.ones( (K, 1)) / N]) # define weights for rank N+1 w = array([w[arange(K), ranks[arange(K), i]] for i in range(Nx)]) w = w.T.ravel() # Artificial cap. Reduces variance, but introduces bias. w = np.maximum(w, 1 / N / 100) w = 1 / w integer_hist(ranks.ravel(), N, weights=w) else: not_available_text(ax) plot_pause(0.1) plt.tight_layout()
def oilfield_means(self, fignum, water_sat_fields, title="", **kwargs): ncols = 2 nAx = len(water_sat_fields) nrows = int(np.ceil(nAx / ncols)) fig, axs = fig_layout.freshfig(fignum, figsize=(8, 4 * nrows), ncols=ncols, nrows=nrows, sharex=True, sharey=True) fig.subplots_adjust(hspace=.3) fig.suptitle(f"Oil saturation (mean fields) - {title}") for ax, label in zip(axs.ravel(), water_sat_fields): field = water_sat_fields[label] if field.ndim == 2: field = field.mean(axis=0) handle = oilfield(self, ax, field, title=label, **kwargs) fig_colorbar(fig, handle) # type: ignore
def init(fignum, stats, key0, plot_u, E, P, **kwargs): xx, yy, mu, _, chrono = \ stats.xx, stats.yy, stats.mu, stats.std, stats.HMM.t # Set parameters (kwargs takes precedence over params_orig) p = DotDict( **{kw: kwargs.get(kw, val) for kw, val in params_orig.items()}) # Lag settings: has_w = hasattr(stats, 'w') if p.Tplot == 0: K_plot = 1 else: T_lag, K_lag, a_lag = validate_lag(p.Tplot, chrono) K_plot = comp_K_plot(K_lag, a_lag, plot_u) # Extend K_plot forther for adding blanks in resampling (PartFilt): if has_w: K_plot += a_lag # Dimension settings if not p.dims: p.dims = arange(M) if not p.labels: p.labels = ["$x_%d$" % d for d in p.dims] assert len(p.dims) == M # Set up figure, axes fig, _ = freshfig(fignum, figsize=(5, 5)) ax = plt.subplot(111, projection='3d' if is_3d else None) ax.set_facecolor('w') ax.set_title("Phase space trajectories") # Tune plot for ind, (s, i, t) in enumerate(zip(p.labels, p.dims, "xyz")): viz.set_ilim(ax, ind, *viz.stretch(*viz.xtrema(xx[:, i]), 1 / p.zoom)) eval("ax.set_%slabel('%s')" % (t, s)) # Allocate d = DotDict() # data arrays h = DotDict() # plot handles s = DotDict() # scatter handles if E is not None: d.E = RollingArray((K_plot, len(E), M)) h.E = [] if P is not None: d.mu = RollingArray((K_plot, M)) if True: d.x = RollingArray((K_plot, M)) if list(p.obs_inds) == list(p.dims): d.y = RollingArray((K_plot, M)) # Plot tails (invisible coz everything here is nan, for the moment). if 'E' in d: h.E += [ ax.plot(*xn, **p.ens_props)[0] for xn in np.transpose(d.E, [1, 2, 0]) ] if 'mu' in d: h.mu = ax.plot(*d.mu.T, 'b', lw=2)[0] if True: h.x = ax.plot(*d.x.T, 'k', lw=3)[0] if 'y' in d: h.y = ax.plot(*d.y.T, 'g*', ms=14)[0] # Scatter. NB: don't init with nan's coz it's buggy # (wrt. get_color() and _offsets3d) since mpl 3.1. if 'E' in d: s.E = ax.scatter(*E.T[p.dims], s=3**2, c=[hn.get_color() for hn in h.E]) if 'mu' in d: s.mu = ax.scatter(*ones(M), s=8**2, c=[h.mu.get_color()]) if True: s.x = ax.scatter(*ones(M), s=14**2, c=[h.x.get_color()], marker=(5, 1), zorder=99) def update(key, E, P): k, kObs, faus = key show_y = 'y' in d and kObs is not None def update_tail(handle, newdata): handle.set_data(newdata[:, 0], newdata[:, 1]) if is_3d: handle.set_3d_properties(newdata[:, 2]) def update_sctr(handle, newdata): if is_3d: handle._offsets3d = juggle_axes(*newdata.T, 'z') else: handle.set_offsets(newdata) EE = duplicate_with_blanks_for_resampled(E, p.dims, key, has_w) # Roll data array ind = k if plot_u else kObs for Ens in EE: # If E is duplicated, so must the others be. if 'E' in d: d.E.insert(ind, Ens) if True: d.x.insert(ind, xx[k, p.dims]) if 'y' in d: d.y.insert(ind, yy[kObs, :] if show_y else nan * ones(M)) if 'mu' in d: d.mu.insert(ind, mu[key][p.dims]) # Update graph update_sctr(s.x, d.x[[-1]]) update_tail(h.x, d.x) if 'y' in d: update_tail(h.y, d.y) if 'mu' in d: update_sctr(s.mu, d.mu[[-1]]) update_tail(h.mu, d.mu) else: update_sctr(s.E, d.E[-1]) for n in range(len(E)): update_tail(h.E[n], d.E[:, n, :]) update_alpha(key, stats, h.E, s.E) return return update
def plot_err_components(stats): """Plot components of the error. Parameters ---------- stats: `dapper.stats.Stats` .. note:: it was chosen to plot(ii, mean_in_time(abs(err_i))), and thus the corresponding spread measure is MAD. If one chose instead: plot(ii, std_in_time(err_i)), then the corresponding measure of spread would have been std. This choice was made in part because (wrt. subplot 2) the singular values (svals) correspond to rotated MADs, and because rms(umisf) seems to convoluted for interpretation. """ fig, (ax0, ax1, ax2) = freshfig(25, figsize=(6, 6), nrows=3) chrono = stats.HMM.t Nx = stats.xx.shape[1] err = np.mean(np.abs(stats.err.a), 0) sprd = np.mean(stats.mad.a, 0) umsft = np.mean(np.abs(stats.umisf.a), 0) usprd = np.mean(stats.svals.a, 0) ax0.plot(arange(Nx), err, 'k', lw=2, label='Error') if Nx < 10**3: ax0.fill_between(arange(Nx), [0] * len(sprd), sprd, alpha=0.7, label='Spread') else: ax0.plot(arange(Nx), sprd, alpha=0.7, label='Spread') # ax0.set_yscale('log') ax0.set_title('Element-wise error comparison') ax0.set_xlabel('Dimension index (i)') ax0.set_ylabel('Time-average (_a) magnitude') ax0.set_xlim(0, Nx - 1) ax0.get_xaxis().set_major_locator(MaxNLocator(integer=True)) ax0.legend(loc='upper right') ax1.set_xlim(0, Nx - 1) ax1.set_xlabel('Principal component index') ax1.set_ylabel('Time-average (_a) magnitude') ax1.set_title('Spectral error comparison') has_been_computed = np.any(np.isfinite(umsft)) if has_been_computed: L = len(umsft) ax1.plot(arange(L), umsft, 'k', lw=2, label='Error') ax1.fill_between(arange(L), [0] * L, usprd, alpha=0.7, label='Spread') ax1.set_yscale('log') ax1.get_xaxis().set_major_locator(MaxNLocator(integer=True)) else: not_available_text(ax1) rmse = stats.err_rms.a[chrono.maskObs_BI] ax2.hist(rmse, bins=30, density=False) ax2.set_ylabel('Num. of occurence (_a)') ax2.set_xlabel('RMSE') ax2.set_title('Histogram of RMSE values') ax2.set_xlim(left=0) plot_pause(0.1) plt.tight_layout()
def init(fignum, stats, key0, plot_u, E, P, **kwargs): xx, yy, mu = stats.xx, stats.yy, stats.mu # Set parameters (kwargs takes precedence over params_orig) p = DotDict( **{kw: kwargs.get(kw, val) for kw, val in params_orig.items()}) if not p.dims: M = xx.shape[-1] p.dims = arange(M) else: M = len(p.dims) # Make periodic wrapper ii, wrap = viz.setup_wrapping(M, p.periodicity) # Set up figure, axes fig, ax = freshfig(fignum, figsize=(8, 5)) fig.suptitle("1d amplitude plot") # Nans nan1 = wrap(nan * ones(M)) if E is None and p.conf_mult is None: p.conf_mult = 2 # Init plots if p.conf_mult: lines_s = ax.plot(ii, nan1, "b-", lw=1, label=(str(p.conf_mult) + r'$\sigma$ conf')) lines_s += ax.plot(ii, nan1, "b-", lw=1) line_mu, = ax.plot(ii, nan1, 'b-', lw=2, label='DA mean') else: nanE = nan * ones((stats.xp.N, M)) lines_E = ax.plot(ii, wrap(nanE[0]), **p.ens_props, lw=1, label='Ensemble') lines_E += ax.plot(ii, wrap(nanE[1:]).T, **p.ens_props, lw=1) # Truth, Obs (line_x, ) = ax.plot(ii, nan1, 'k-', lw=3, label='Truth') if p.obs_inds is not None: (line_y, ) = ax.plot(p.obs_inds, nan * p.obs_inds, 'g*', ms=5, label='Obs') # Tune plot ax.set_ylim(*viz.xtrema(xx)) ax.set_xlim(viz.stretch(ii[0], ii[-1], 1)) # Xticks xt = ax.get_xticks() xt = xt[abs(xt % 1) < 0.01].astype(int) # Keep only the integer ticks xt = xt[xt >= 0] xt = xt[xt < len(p.dims)] ax.set_xticks(xt) ax.set_xticklabels(p.dims[xt]) ax.set_xlabel('State index') ax.set_ylabel('Value') ax.legend(loc='upper right') text_t = ax.text(0.01, 0.01, format_time(None, None, None), transform=ax.transAxes, family='monospace', ha='left') # Init visibility (must come after legend): if p.obs_inds is not None: line_y.set_visible(False) def update(key, E, P): k, kObs, faus = key if p.conf_mult: sigma = mu[key] + p.conf_mult * stats.std[key] * [[1], [-1]] lines_s[0].set_ydata(wrap(sigma[0, p.dims])) lines_s[1].set_ydata(wrap(sigma[1, p.dims])) line_mu.set_ydata(wrap(mu[key][p.dims])) else: for n, line in enumerate(lines_E): line.set_ydata(wrap(E[n, p.dims])) update_alpha(key, stats, lines_E) line_x.set_ydata(wrap(xx[k, p.dims])) text_t.set_text(format_time(k, kObs, stats.HMM.t.tt[k])) if 'f' in faus: if p.obs_inds is not None: line_y.set_ydata(yy[kObs]) line_y.set_zorder(5) line_y.set_visible(True) if 'u' in faus: if p.obs_inds is not None: line_y.set_visible(False) return return update
def init(fignum, stats, key0, plot_u, E, P, **kwargs): GS = {'left': 0.125 - 0.04, 'right': 0.9 - 0.04} fig, axs = freshfig(fignum, figsize=(6, 6), nrows=2, ncols=2, sharex=True, sharey=True, gridspec_kw=GS) for ax in axs.flatten(): ax.set_aspect('equal', viz.adjustable_box_or_forced()) ((ax_11, ax_12), (ax_21, ax_22)) = axs ax_11.grid(color='w', linewidth=0.2) ax_12.grid(color='w', linewidth=0.2) ax_21.grid(color='k', linewidth=0.1) ax_22.grid(color='k', linewidth=0.1) # Upper colorbar -- position relative to ax_12 bb = ax_12.get_position() dy = 0.1 * bb.height ax_13 = fig.add_axes( [bb.x1 + 0.03, bb.y0 + dy, 0.04, bb.height - 2 * dy]) # Lower colorbar -- position relative to ax_22 bb = ax_22.get_position() dy = 0.1 * bb.height ax_23 = fig.add_axes( [bb.x1 + 0.03, bb.y0 + dy, 0.04, bb.height - 2 * dy]) # Extract data arrays xx, _, mu, std, err = stats.xx, stats.yy, stats.mu, stats.std, stats.err k = key0[0] tt = stats.HMM.t.tt # Plot # - origin='lower' might get overturned by set_ylim() below. im_11 = ax_11.imshow(square(mu[key0]), cmap=cm) im_12 = ax_12.imshow(square(xx[k]), cmap=cm) # hot is better, but needs +1 colorbar im_21 = ax_21.imshow(square(std[key0]), cmap=plt.cm.bwr) im_22 = ax_22.imshow(square(err[key0]), cmap=plt.cm.bwr) ims = (im_11, im_12, im_21, im_22) # Obs init -- a list where item 0 is the handle of something invisible. lh = list(ax_12.plot(0, 0)[0:1]) sx = '$\\psi$' ax_11.set_title('mean ' + sx) ax_12.set_title('true ' + sx) ax_21.set_title('std. ' + sx) ax_22.set_title('err. ' + sx) # TODO 7 # for ax in axs.flatten(): # Crop boundries (which should be 0, i.e. yield harsh q gradients): # lims = (1, nx-2) # step = (nx - 1)/8 # ticks = arange(step,nx-1,step) # ax.set_xlim (lims) # ax.set_ylim (lims[::-1]) # ax.set_xticks(ticks) # ax.set_yticks(ticks) for im, clim in zip(ims, clims): im.set_clim(clim) fig.colorbar(im_12, cax=ax_13) fig.colorbar(im_22, cax=ax_23) for ax in [ax_13, ax_23]: ax.yaxis.set_tick_params('major', length=2, width=0.5, direction='in', left=True, right=True) ax.set_axisbelow('line') # make ticks appear over colorbar patch # Title title = "Streamfunction (" + sx + ")" fig.suptitle(title) # Time info text_t = ax_12.text(1, 1.1, format_time(None, None, None), transform=ax_12.transAxes, family='monospace', ha='left') def update(key, E, P): k, kObs, faus = key t = tt[k] im_11.set_data(square(mu[key])) im_12.set_data(square(xx[k])) im_21.set_data(square(std[key])) im_22.set_data(square(err[key])) # Remove previous obs try: lh[0].remove() except ValueError: pass # Plot current obs. # - plot() automatically adjusts to direction of y-axis in use. # - ind2sub returns (iy,ix), while plot takes (ix,iy) => reverse. if kObs is not None and obs_inds is not None: lh[0] = ax_12.plot(*ind2sub(obs_inds(t))[::-1], 'k.', ms=1, zorder=5)[0] text_t.set_text(format_time(k, kObs, t)) return return update
def dashboard(self, saturation, production, pause=200, animate=True, title="", **kwargs): fig, axs = fig_layout.freshfig(231, ncols=2, nrows=2, figsize=(12, 10)) if is_notebook_or_qt: plt.close() # ttps://stackoverflow.com/q/47138023 tt = np.arange(len(saturation)) axs[0, 0].set_title("Initial") axs[0, 0].cc = oilfield(self, axs[0, 0], saturation[0], **kwargs) axs[0, 0].set_ylabel(f"y ({COORD_TYPE})") axs[0, 1].set_title("Evolution") axs[0, 1].cc = oilfield(self, axs[0, 1], saturation[-1], **kwargs) well_scatter(self, axs[0, 1], self.injectors) well_scatter(self, axs[0, 1], self.producers, False, color=[f"C{i}" for i in range(len(self.producers))]) axs[1, 0].set_title("Production") prod_handles = production1(axs[1, 0], production) axs[1, 1].set_visible(False) # fig.tight_layout() fig_colorbar(fig, axs[0, 0].cc) if title: fig.suptitle(f"Oil saturation -- {title}") if animate: from matplotlib import animation def update_fig(iT): # Update field for c in axs[0, 1].cc.collections: try: axs[0, 1].collections.remove(c) except ValueError: pass # occurs when re-running script axs[0, 1].cc = oilfield(self, axs[0, 1], saturation[iT], **kwargs) # Update production lines if iT >= 1: for h, p in zip(prod_handles, 1 - production.T): h.set_data(tt[:iT - 1], p[:iT - 1]) ani = animation.FuncAnimation(fig, update_fig, len(tt), blit=False, interval=pause) return ani
def __init__(self, fignum, stats, key0, plot_u, E, P, Tplot=None, **kwargs): # STYLE TABLES - Defines which/how diagnostics get plotted styles = {} def lin(a, b): return (lambda x: a + b * x) divN = 1 / getattr(stats.xp, 'N', 99) # Columns: transf, shape, plt kwargs styles['RMS'] = { 'err.rms': [None, None, dict(c='k', label='Error')], 'std.rms': [None, None, dict(c='b', label='Spread', alpha=0.6)], } styles['Values'] = { 'skew': [None, None, dict(c='g', label=star + r'Skew/$\sigma^3$')], 'kurt': [None, None, dict(c='r', label=star + r'Kurt$/\sigma^4{-}3$')], 'trHK': [None, None, dict(c='k', label=star + 'HK')], 'infl': [lin(-10, 10), 'step', dict(c='c', label='10(infl-1)')], 'N_eff': [lin(0, divN), 'dirac', dict(c='y', label='N_eff/N', lw=3)], 'iters': [lin(0, .1), 'dirac', dict(c='m', label='iters/10')], 'resmpl': [None, 'dirac', dict(c='k', label='resampled?')], } nAx = len(styles) GS = {'left': 0.125, 'right': 0.76} fig, axs = freshfig(fignum, figsize=(5, 1 + nAx), nrows=nAx, sharex=True, gridspec_kw=GS) axs[0].set_title("Diagnostics") for style, ax in zip(styles, axs): ax.set_ylabel(style) ax.set_xlabel('Time (t)') viz.adjust_position(ax, y0=0.03) self.T_lag, K_lag, a_lag = validate_lag(Tplot, stats.HMM.t) def init_ax(ax, style_table): lines = {} for name in style_table: # SKIP -- if stats[name] is not in existence # Note: The nan check/deletion comes after the first kObs. try: stat = deep_getattr(stats, name) except AttributeError: continue # try: val0 = stat[key0[0]] # except KeyError: continue # PS: recall (from series.py) that even if store_u is false, stat[k] is # still present if liveplots=True via the k_tmp functionality. # Unpack style ln = {} ln['transf'] = style_table[name][0] or (lambda x: x) ln['shape'] = style_table[name][1] ln['plt'] = style_table[name][2] # Create series if isinstance(stat, FAUSt): ln['plot_u'] = plot_u K_plot = comp_K_plot(K_lag, a_lag, ln['plot_u']) else: ln['plot_u'] = False K_plot = a_lag ln['data'] = RollingArray(K_plot) ln['tt'] = RollingArray(K_plot) # Plot (init) ln['handle'], = ax.plot(ln['tt'], ln['data'], **ln['plt']) # Plotting only nans yield ugly limits. Revert to defaults. ax.set_xlim(0, 1) ax.set_ylim(0, 1) lines[name] = ln return lines # Plot self.d = [init_ax(ax, styles[style]) for style, ax in zip(styles, axs)] # Horizontal line at y=0 self.baseline0, = ax.plot(ax.get_xlim(), [0, 0], c=0.5 * ones(3), lw=0.7, label='_nolegend_') # Store self.axs = axs self.stats = stats self.init_incomplete = True
"""Random field generation. Uses: - Gaussian variogram. - Gaussian distributions. """ dists = dist_euclid(vectorize(*pts)) Cov = 1 - variogram_gauss(dists, r) C12 = sla.sqrtm(Cov).real fields = randn(N, len(dists)) @ C12.T return fields if __name__ == "__main__": np.random.seed(3000) plt.ion() N = 15 # ensemble size ## 1D xx = np.linspace(0, 1, 201) fields = gaussian_fields((xx, ), N) fig, ax = freshfig(1) ax.plot(xx, fields.T, lw=2) ## 2D grid = Grid2D(Lx=1, Ly=1, Nx=20, Ny=20) fields = gaussian_fields(grid.mesh(), N) fields = 0.5 + .2 * fields # fields = truncate_01(fields) plots.oilfields(grid, 2, fields)
def amplitude_animation(EE, dt=None, interval=1, periodicity=None, blit=True, fignum=None, repeat=False): """Animation of line chart. Using an ensemble of the shape (time, ensemble size, state vector length). Parameters ---------- EE: ndarray Ensemble arry of the shape (K, N, Nx). K is the length of time, N is the ensemble size, and Nx is the length of state vector. dt: float Time interval of each frame. interval: float, optional Delay between frames in milliseconds. Defaults to 200. periodicity: bool, optional The mode of the wrapping. "+1": the first element is appended after the last. "+/-05": adding the midpoint of the first and last elements. Default: "+1" blit: bool, optional Controls whether blitting is used to optimize drawing. Default: True fignum: int, optional Figure index. Default: None repeat: bool, optional If True, repeat the animation. Default: False """ fig, ax = freshfig(fignum) ax.set_xlabel('State index') ax.set_ylabel('Amplitue') ax.set_ylim(*stretch(*xtrema(EE), 1.1)) if EE.ndim == 2: EE = np.expand_dims(EE, 1) K, N, Nx = EE.shape ii, wrap = setup_wrapping(Nx, periodicity) lines = ax.plot(ii, wrap(EE[0]).T) ax.set_xlim(*xtrema(ii)) if dt is not None: times = 'time = %.1f' lines += [ax.text(0.05, 0.9, '', transform=ax.transAxes)] def anim(k): Ek = wrap(EE[k]) for n in range(N): lines[n].set_ydata(Ek[n]) if len(lines) > N: lines[-1].set_text(times % (dt * k)) return lines return FuncAnimation(fig, anim, range(K), interval=interval, blit=blit, repeat=repeat)
def init(fignum, stats, key0, plot_u, E, P, **kwargs): xx, yy, mu, std, chrono = \ stats.xx, stats.yy, stats.mu, stats.std, stats.HMM.t # Set parameters (kwargs takes precedence over params_orig) p = DotDict( **{kw: kwargs.get(kw, val) for kw, val in params_orig.items()}) # Lag settings: T_lag, K_lag, a_lag = validate_lag(p.Tplot, chrono) K_plot = comp_K_plot(K_lag, a_lag, plot_u) # Extend K_plot forther for adding blanks in resampling (PartFilt): has_w = hasattr(stats, 'w') if has_w: K_plot += a_lag # Chose marginal dims to plot if not p.dims: Nx = min(10, xx.shape[-1]) DimsX = linspace_int(xx.shape[-1], Nx) else: Nx = len(p.dims) DimsX = p.dims # Pre-process obs dimensions # Rm inds of obs if not in DimsX iiY = [i for i, m in enumerate(p.obs_inds) if m in DimsX] # Rm obs_inds if not in DimsX DimsY = [m for i, m in enumerate(p.obs_inds) if m in DimsX] # Get dim (within y) of each x DimsY = [DimsY.index(m) if m in DimsY else None for m in DimsX] Ny = len(iiY) # Set up figure, axes fig, axs = freshfig(fignum, figsize=(5, 7), nrows=Nx, sharex=True) if Nx == 1: axs = [axs] # Tune plots axs[0].set_title("Marginal time series") for ix, (m, ax) in enumerate(zip(DimsX, axs)): # ax.set_ylim(*viz.stretch(*viz.xtrema(xx[:, m]), 1/p.zoomy)) if not p.labels: ax.set_ylabel("$x_{%d}$" % m) else: ax.set_ylabel(p.labels[ix]) axs[-1].set_xlabel('Time (t)') plot_pause(0.05) plt.tight_layout() # Allocate d = DotDict() # data arrays h = DotDict() # plot handles # Why "if True" ? Just to indent the rest of the line... if True: d.t = RollingArray((K_plot, )) if True: d.x = RollingArray((K_plot, Nx)) h.x = [] if True: d.y = RollingArray((K_plot, Ny)) h.y = [] if E is not None: d.E = RollingArray((K_plot, len(E), Nx)) h.E = [] if P is not None: d.mu = RollingArray((K_plot, Nx)) h.mu = [] if P is not None: d.s = RollingArray((K_plot, 2, Nx)) h.s = [] # Plot (invisible coz everything here is nan, for the moment). for ix, (_m, iy, ax) in enumerate(zip(DimsX, DimsY, axs)): if True: h.x += ax.plot(d.t, d.x[:, ix], 'k') if iy != None: h.y += ax.plot(d.t, d.y[:, iy], 'g*', ms=10) if 'E' in d: h.E += [ax.plot(d.t, d.E[:, :, ix], **p.ens_props)] if 'mu' in d: h.mu += ax.plot(d.t, d.mu[:, ix], 'b') if 's' in d: h.s += [ax.plot(d.t, d.s[:, :, ix], 'b--', lw=1)] def update(key, E, P): k, kObs, faus = key EE = duplicate_with_blanks_for_resampled(E, DimsX, key, has_w) # Roll data array ind = k if plot_u else kObs for Ens in EE: # If E is duplicated, so must the others be. if 'E' in d: d.E.insert(ind, Ens) if 'mu' in d: d.mu.insert(ind, mu[key][DimsX]) if 's' in d: d.s.insert(ind, mu[key][DimsX] + [[1], [-1]] * std[key][DimsX]) if True: d.t.insert(ind, chrono.tt[k]) if True: d.y.insert( ind, yy[kObs, iiY] if kObs is not None else nan * ones(Ny)) if True: d.x.insert(ind, xx[k, DimsX]) # Update graphs for ix, (_m, iy, ax) in enumerate(zip(DimsX, DimsY, axs)): sliding_xlim(ax, d.t, T_lag, True) if True: h.x[ix].set_data(d.t, d.x[:, ix]) if iy != None: h.y[iy].set_data(d.t, d.y[:, iy]) if 'mu' in d: h.mu[ix].set_data(d.t, d.mu[:, ix]) if 's' in d: [h.s[ix][b].set_data(d.t, d.s[:, b, ix]) for b in [0, 1]] if 'E' in d: [ h.E[ix][n].set_data(d.t, d.E[:, n, ix]) for n in range(len(E)) ] if 'E' in d: update_alpha(key, stats, h.E[ix]) # TODO 3: fixup. This might be slow? # In any case, it is very far from tested. # Also, relim'iting all of the time is distracting. # Use d_ylim? if 'E' in d: lims = d.E elif 'mu' in d: lims = d.mu lims = np.array(viz.xtrema(lims[..., ix])) if lims[0] == lims[1]: lims += [-.5, +.5] ax.set_ylim(*viz.stretch(*lims, 1 / p.zoomy)) return return update
def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs): GS = {'height_ratios': [4, 1], 'hspace': 0.09, 'top': 0.95} fig, (ax, ax2) = freshfig(fignum, figsize=(5, 6), nrows=2, gridspec_kw=GS) if E is None and np.isnan( P.diag if isinstance(P, CovMat) else P).all(): not_available_text(ax, ('Not available in replays' '\ncoz full Ens/Cov not stored.')) self.is_active = False return Nx = len(stats.mu[key0]) if Nx <= 1003: C = np.eye(Nx) # Mask half mask = np.zeros_like(C, dtype=np.bool) mask[np.tril_indices_from(mask)] = True # Make colormap. Log-transform cmap, # but not internally in matplotlib, # so as to avoid transforming the colorbar too. cmap = plt.get_cmap('RdBu_r') trfm = mpl.colors.SymLogNorm(linthresh=0.2, linscale=0.2, base=np.e, vmin=-1, vmax=1) cmap = cmap(trfm(np.linspace(-0.6, 0.6, cmap.N))) cmap = mpl.colors.ListedColormap(cmap) # VM = 1.0 # abs(np.percentile(C,[1,99])).max() im = ax.imshow(C, cmap=cmap, vmin=-VM, vmax=VM) # Colorbar _ = ax.figure.colorbar(im, ax=ax, shrink=0.8) # Tune plot plt.box(False) ax.set_facecolor('w') ax.grid(False) ax.set_title("State correlation matrix:", y=1.07) ax.xaxis.tick_top() # ax2 = inset_axes(ax,width="30%",height="60%",loc=3) line_AC, = ax2.plot(arange(Nx), ones(Nx), label='Correlation') line_AA, = ax2.plot(arange(Nx), ones(Nx), label='Abs. corr.') _ = ax2.hlines(0, 0, Nx - 1, 'k', 'dotted', lw=1) # Align ax2 with ax bb_AC = ax2.get_position() bb_C = ax.get_position() ax2.set_position([bb_C.x0, bb_AC.y0, bb_C.width, bb_AC.height]) # Tune plot ax2.set_title("Auto-correlation:") ax2.set_ylabel("Mean value") ax2.set_xlabel("Distance (in state indices)") ax2.set_xticklabels([]) ax2.set_yticks([0, 1] + list(ax2.get_yticks()[[0, -1]])) ax2.set_ylim(top=1) ax2.legend(frameon=True, facecolor='w', bbox_to_anchor=(1, 1), loc='upper left', borderaxespad=0.02) self.ax = ax self.ax2 = ax2 self.im = im self.line_AC = line_AC self.line_AA = line_AA self.mask = mask if hasattr(stats, 'w'): self.w = stats.w else: not_available_text(ax)
######################## # Reference trajectory ######################## # NB: Arbitrary, coz models are autonom. But dont use nan coz QG doesn't like it. t0 = 0.0 K = int(round(T/dt)) # Num of time steps. tt = np.linspace(dt, T, K) # Time seq. x = with_recursion(step, prog="BurnIn")(x0, int(10/dt), t0, dt)[-1] xx = with_recursion(step, prog="Reference")(x, K, t0, dt) ######################## # ACF ######################## # NB: Won't work with QG (too big, and BCs==0). fig, ax = freshfig(4) if "ii" not in locals(): ii = np.arange(min(100, Nx)) if "nlags" not in locals(): nlags = min(100, K-1) ax.plot(tt[:nlags], np.nanmean( series.auto_cov(xx[:nlags, ii], nlags=nlags-1, corr=1), axis=1)) ax.set_xlabel('Time (t)') ax.set_ylabel('Auto-corr') viz.plot_pause(0.1) ######################## # "Linearized" forecasting ########################
well_grid = well_grid.T.reshape((-1, 3)) model.init_Q( inj =[[0.50, 0.50, 1.00]], prod=well_grid, ); # # Random well configuration # model.init_Q( # inj =rand(1, 3), # prod=rand(8, 3) # ); # - # #### Plot true field fig, ax = freshfig(110) # cs = plots.field(model, ax, perm.Truth) cs = plots.field(model, ax, f_perm(perm.Truth), locator=ticker.LogLocator()) plots.well_scatter(model, ax, model.producers, inj=False) plots.well_scatter(model, ax, model.injectors, inj=True) fig.colorbar(cs) fig.suptitle("True field"); plt.pause(.1) # #### Define obs operator # There is no well model. The data consists purely of the water cut at the location of the wells. obs_inds = [model.xy2ind(x, y) for (x, y, _) in model.producers] def obs(water_sat): return [water_sat[i] for i in obs_inds]