def plot_hovmoller(xx, tseq=None): """Plot Hovmöller diagram. Parameters ---------- xx: ndarray Plotted array tseq: `dapper.tools.chronos.Chronology`, optional object with property dko. Defaults: None """ fig, ax = place.freshfig("Hovmoller", figsize=(4, 3.5)) if tseq is not None: mask = tseq.tt <= tseq.Tplot*2 kk = tseq.kk[mask] tt = tseq.tt[mask] ax.set_ylabel('Time (t)') else: K = estimate_good_plot_length(xx, mult=20) kk = arange(K) tt = kk ax.set_ylabel('Time indices (k)') plt.contourf(arange(xx.shape[1]), tt, xx[kk], 25) plt.colorbar() ax.get_xaxis().set_major_locator(MaxNLocator(integer=True)) ax.set_title("Hovmoller diagram (of 'Truth')") ax.set_xlabel('Dimension index (i)') plt.pause(0.1) plt.tight_layout()
def amplitude_animation(EE, dt=None, interval=1, periodicity=None, blit=True, fignum=None, repeat=False): """Animation of line chart. Using an ensemble of the shape (time, ensemble size, state vector length). Parameters ---------- EE: ndarray Ensemble arry of the shape (K, N, Nx). K is the length of time, N is the ensemble size, and Nx is the length of state vector. dt: float Time interval of each frame. interval: float, optional Delay between frames in milliseconds. Defaults to 200. periodicity: bool, optional The mode of the wrapping. "+1": the first element is appended after the last. "+/-05": adding the midpoint of the first and last elements. Default: "+1" blit: bool, optional Controls whether blitting is used to optimize drawing. Default: True fignum: int, optional Figure index. Default: None repeat: bool, optional If True, repeat the animation. Default: False """ fig, ax = place.freshfig(fignum or "Amplitude animation") ax.set_xlabel('State index') ax.set_ylabel('Amplitue') ax.set_ylim(*stretch(*xtrema(EE), 1.1)) if EE.ndim == 2: EE = np.expand_dims(EE, 1) K, N, Nx = EE.shape ii, wrap = setup_wrapping(Nx, periodicity) lines = ax.plot(ii, wrap(EE[0]).T) ax.set_xlim(*xtrema(ii)) if dt is not None: times = 'time = %.1f' lines += [ax.text(0.05, 0.9, '', transform=ax.transAxes)] def anim(k): Ek = wrap(EE[k]) for n in range(N): lines[n].set_ydata(Ek[n]) if len(lines) > N: lines[-1].set_text(times % (dt*k)) return lines return FuncAnimation(fig, anim, range(K), interval=interval, blit=blit, repeat=repeat)
def spectrum(ydata, title="", figsize=(1.6, .7), semilogy=False, **kwargs): """Plotter specialized for spectra.""" title = dash("Spectrum", title) fig, ax = place.freshfig(title, figsize=figsize, rel=True) if semilogy: h = ax.semilogy(ydata) else: h = ax.loglog(ydata) ax.grid(True, "both", axis="both") ax.set(xlabel="eigenvalue index", ylabel="variance") fig.tight_layout() return h
def fields(ZZ, style=None, title="", figsize=(1.7, 1), label_color="k", colorbar=True, **kwargs): """Do `field(Z)` for each `Z` in `ZZ`.""" kw = lambda k: pop_style_with_fallback(k, style, kwargs) # Create figure using freshfig title = dash("Fields", kw("title"), title) fig, axs = place.freshfig(title, figsize=figsize, rel=True) # Store suptitle (exists if mpl is inline) coz gets cleared below try: suptitle = fig._suptitle.get_text() except AttributeError: suptitle = "" # Create axes using AxesGrid fig.clear() from mpl_toolkits.axes_grid1 import AxesGrid axs = AxesGrid(fig, 111, nrows_ncols=nRowCol(min(12, len(ZZ))).values(), cbar_mode='single', cbar_location='right', share_all=True, axes_pad=0.2, cbar_pad=0.1) # Turn off redundant axes for ax in axs[len(ZZ):]: ax.set_visible(False) # Convert (potential) list-like ZZ into dict if not isinstance(ZZ, dict): ZZ = {i: Z for (i, Z) in enumerate(ZZ)} hh = [] for ax, label in zip(axs, ZZ): label_ax(ax, label, c=label_color) hh.append(field(ax, ZZ[label], style, **kwargs)) # Suptitle if len(ZZ) > len(axs): suptitle = dash(suptitle, f"First {len(axs)} instances") # Re-set suptitle (since it got cleared above) if suptitle: fig.suptitle(suptitle) if colorbar: fig.colorbar(hh[0], cax=axs.cbar_axes[0], ticks=kw("ticks")) return fig, axs, hh
def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs): fig, ax = place.freshfig(fignum, figsize=(6, 3)) ax.set_xlabel('Sing. value index') ax.set_yscale('log') self.init_incomplete = True self.ax = ax self.plot_u = plot_u try: self.msft = stats.umisf self.sprd = stats.svals except AttributeError: self.is_active = False not_available_text(ax, "Spectral stats not being computed")
def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs): if not hasattr(stats, 'w'): self.is_active = False return fig, ax = place.freshfig(fignum, figsize=(7, 3), gridspec_kw={'bottom': .15}) ax.set_xscale('log') ax.set_xlabel('Weigth') ax.set_ylabel('Count') self.stats = stats self.ax = ax self.hist = [] self.bins = np.exp(np.linspace(np.log(1e-10), np.log(1), 31))
def plot_rank_histogram(stats): """Plot rank histogram of ensemble. Parameters ---------- stats: `dapper.stats.Stats` """ tseq = stats.HMM.tseq has_been_computed = \ hasattr(stats, 'rh') and \ not all(stats.rh.a[-1] == array(np.nan).astype(int)) fig, ax = place.freshfig("Rank histogram", figsize=(6, 3)) ax.set_title('(Mean of marginal) rank histogram (_a)') ax.set_ylabel('Freq. of occurence\n (of truth in interval n)') ax.set_xlabel('ensemble member index (n)') if has_been_computed: ranks = stats.rh.a[tseq.masko] Nx = ranks.shape[1] N = stats.xp.N if not hasattr(stats, 'w'): # Ensemble rank histogram integer_hist(ranks.ravel(), N) else: # Experimental: weighted rank histogram. # Weight ranks by inverse of particle weight. Why? Coz, with correct # importance weights, the "expected value" histogram is then flat. # Potential improvement: interpolate weights between particles. w = stats.w.a[tseq.masko] K = len(w) w = np.hstack([w, np.ones((K, 1))/N]) # define weights for rank N+1 w = array([w[arange(K), ranks[arange(K), i]] for i in range(Nx)]) w = w.T.ravel() # Artificial cap. Reduces variance, but introduces bias. w = np.maximum(w, 1/N/100) w = 1/w integer_hist(ranks.ravel(), N, weights=w) else: not_available_text(ax) plt.pause(0.1) plt.tight_layout()
def fig_ax(num): """Create fig, axs. Deserving of particular attention, so factored out.""" # Figure creation # Of course, for *interactive* mpl backends, this should only be run once. # But running it from inside f (with appropriate checks for single execution) # causes blank figure => Run outside of f(). # However, using `ipywidgets.Output` to capture output requires that it runs # inside f. In this case it actually seems to work though (no blank figures). if is_inline(): # Rm previous (static) image. Necssary when using `ipywidgets.Output` # Use `wait=True` because to avoid flickering, ref ipywidgets/issues/1582 clear_output(wait=True) else: # Check for existance, otherwise the first time it is run # (no error is thrown but) duplicate figures are created # (no longer seems to be an issue, but the check doesn't hurt) if plt.fignum_exists(num): # Fix issue: figure doesn't display **when cell is re-run**. # I think it's related to being in an ipython widget, but can also # be fixed by changing num (so that freshfig creates a new one). plt.close(num) fig, axs = place.freshfig(num, **kwargs) return fig, axs
def __init__(self, fignum, stats, key0, plot_u, E, P, **kwargs): GS = {'height_ratios': [4, 1], 'hspace': 0.09, 'top': 0.95} fig, (ax, ax2) = place.freshfig(fignum, figsize=(5, 6), nrows=2, gridspec_kw=GS) if E is None and np.isnan( P.diag if isinstance(P, CovMat) else P).all(): not_available_text(ax, ('Not available in replays' '\ncoz full Ens/Cov not stored.')) self.is_active = False return Nx = len(stats.mu[key0]) if Nx <= 1003: C = np.eye(Nx) # Mask half mask = np.zeros_like(C, dtype=np.bool) mask[np.tril_indices_from(mask)] = True # Make colormap. Log-transform cmap, # but not internally in matplotlib, # so as to avoid transforming the colorbar too. cmap = plt.get_cmap('RdBu_r') trfm = mpl.colors.SymLogNorm(linthresh=0.2, linscale=0.2, base=np.e, vmin=-1, vmax=1) cmap = cmap(trfm(np.linspace(-0.6, 0.6, cmap.N))) cmap = mpl.colors.ListedColormap(cmap) # VM = 1.0 # abs(np.percentile(C,[1,99])).max() im = ax.imshow(C, cmap=cmap, vmin=-VM, vmax=VM) # Colorbar _ = ax.figure.colorbar(im, ax=ax, shrink=0.8) # Tune plot plt.box(False) ax.set_facecolor('w') ax.grid(False) ax.set_title("State correlation matrix:", y=1.07) ax.xaxis.tick_top() # ax2 = inset_axes(ax,width="30%",height="60%",loc=3) line_AC, = ax2.plot(arange(Nx), ones(Nx), label='Correlation') line_AA, = ax2.plot(arange(Nx), ones(Nx), label='Abs. corr.') _ = ax2.hlines(0, 0, Nx - 1, 'k', 'dotted', lw=1) # Align ax2 with ax bb_AC = ax2.get_position() bb_C = ax.get_position() ax2.set_position([bb_C.x0, bb_AC.y0, bb_C.width, bb_AC.height]) # Tune plot ax2.set_title("Auto-correlation:") ax2.set_ylabel("Mean value") ax2.set_xlabel("Distance (in state indices)") ax2.set_xticklabels([]) ax2.set_yticks([0, 1] + list(ax2.get_yticks()[[0, -1]])) ax2.set_ylim(top=1) ax2.legend(frameon=True, facecolor='w', bbox_to_anchor=(1, 1), loc='upper left', borderaxespad=0.02) self.ax = ax self.ax2 = ax2 self.im = im self.line_AC = line_AC self.line_AA = line_AA self.mask = mask if hasattr(stats, 'w'): self.w = stats.w else: not_available_text(ax)
def __init__(self, fignum, stats, key0, plot_u, E, P, Tplot=None, **kwargs): # STYLE TABLES - Defines which/how diagnostics get plotted styles = {} def lin(a, b): return (lambda x: a + b * x) divN = 1 / getattr(stats.xp, 'N', 99) # Columns: transf, shape, plt kwargs styles['RMS'] = { 'err.rms': [None, None, dict(c='k', label='Error')], 'spread.rms': [None, None, dict(c='b', label='Spread', alpha=0.6)], } styles['Values'] = { 'skew': [None, None, dict(c='g', label=star + r'Skew/$\sigma^3$')], 'kurt': [None, None, dict(c='r', label=star + r'Kurt$/\sigma^4{-}3$')], 'trHK': [None, None, dict(c='k', label=star + 'HK')], 'infl': [lin(-10, 10), 'step', dict(c='c', label='10(infl-1)')], 'N_eff': [lin(0, divN), 'dirac', dict(c='y', label='N_eff/N', lw=3)], 'iters': [lin(0, .1), 'dirac', dict(c='m', label='iters/10')], 'resmpl': [None, 'dirac', dict(c='k', label='resampled?')], } nAx = len(styles) GS = {'left': 0.125, 'right': 0.76} fig, axs = place.freshfig(fignum, figsize=(5, 1 + nAx), nrows=nAx, sharex=True, gridspec_kw=GS) axs[0].set_title("Diagnostics") for style, ax in zip(styles, axs): ax.set_ylabel(style) ax.set_xlabel('Time (t)') place_ax.adjust_position(ax, y0=0.03) self.T_lag, K_lag, a_lag = validate_lag(Tplot, stats.HMM.tseq) def init_ax(ax, style_table): lines = {} for name in style_table: # SKIP -- if stats[name] is not in existence # Note: The nan check/deletion comes after the first ko. try: stat = deep_getattr(stats, name) except AttributeError: continue # try: val0 = stat[key0[0]] # except KeyError: continue # PS: recall (from series.py) that even if store_u is false, stat[k] is # still present if liveplots=True via the k_tmp functionality. # Unpack style ln = {} ln['transf'] = style_table[name][0] or (lambda x: x) ln['shape'] = style_table[name][1] ln['plt'] = style_table[name][2] # Create series if isinstance(stat, FAUSt): ln['plot_u'] = plot_u K_plot = comp_K_plot(K_lag, a_lag, ln['plot_u']) else: ln['plot_u'] = False K_plot = a_lag ln['data'] = RollingArray(K_plot) ln['tt'] = RollingArray(K_plot) # Plot (init) ln['handle'], = ax.plot(ln['tt'], ln['data'], **ln['plt']) # Plotting only nans yield ugly limits. Revert to defaults. ax.set_xlim(0, 1) ax.set_ylim(0, 1) lines[name] = ln return lines # Plot self.d = [init_ax(ax, styles[style]) for style, ax in zip(styles, axs)] # Horizontal line at y=0 self.baseline0, = ax.plot(ax.get_xlim(), [0, 0], c=0.5 * ones(3), lw=0.7, label='_nolegend_') # Store self.axs = axs self.stats = stats self.init_incomplete = True
def init(fignum, stats, key0, plot_u, E, P, **kwargs): GS = {'left': 0.125 - 0.04, 'right': 0.9 - 0.04} fig, axs = place.freshfig(fignum, figsize=(6, 6), nrows=2, ncols=2, sharex=True, sharey=True, gridspec_kw=GS) for ax in axs.flatten(): ax.set_aspect('equal', 'box') ((ax_11, ax_12), (ax_21, ax_22)) = axs ax_11.grid(color='w', linewidth=0.2) ax_12.grid(color='w', linewidth=0.2) ax_21.grid(color='k', linewidth=0.1) ax_22.grid(color='k', linewidth=0.1) # Upper colorbar -- position relative to ax_12 bb = ax_12.get_position() dy = 0.1 * bb.height ax_13 = fig.add_axes( [bb.x1 + 0.03, bb.y0 + dy, 0.04, bb.height - 2 * dy]) # Lower colorbar -- position relative to ax_22 bb = ax_22.get_position() dy = 0.1 * bb.height ax_23 = fig.add_axes( [bb.x1 + 0.03, bb.y0 + dy, 0.04, bb.height - 2 * dy]) # Extract data arrays xx, _, mu, spread, err = stats.xx, stats.yy, stats.mu, stats.spread, stats.err k = key0[0] tt = stats.HMM.tseq.tt # Plot # - origin='lower' might get overturned by set_ylim() below. im_11 = ax_11.imshow(square(mu[key0]), cmap=cm) im_12 = ax_12.imshow(square(xx[k]), cmap=cm) # hot is better, but needs +1 colorbar im_21 = ax_21.imshow(square(spread[key0]), cmap=plt.cm.bwr) im_22 = ax_22.imshow(square(err[key0]), cmap=plt.cm.bwr) ims = (im_11, im_12, im_21, im_22) # Obs init -- a list where item 0 is the handle of something invisible. lh = list(ax_12.plot(0, 0)[0:1]) sx = '$\\psi$' ax_11.set_title('mean ' + sx) ax_12.set_title('true ' + sx) ax_21.set_title('spread. ' + sx) ax_22.set_title('err. ' + sx) # TODO 7 # for ax in axs.flatten(): # Crop boundries (which should be 0, i.e. yield harsh q gradients): # lims = (1, nx-2) # step = (nx - 1)/8 # ticks = arange(step,nx-1,step) # ax.set_xlim (lims) # ax.set_ylim (lims[::-1]) # ax.set_xticks(ticks) # ax.set_yticks(ticks) for im, clim in zip(ims, clims): im.set_clim(clim) fig.colorbar(im_12, cax=ax_13) fig.colorbar(im_22, cax=ax_23) for ax in [ax_13, ax_23]: ax.yaxis.set_tick_params('major', length=2, width=0.5, direction='in', left=True, right=True) ax.set_axisbelow('line') # make ticks appear over colorbar patch # Title title = "Streamfunction (" + sx + ")" fig.suptitle(title) # Time info text_t = ax_12.text(1, 1.1, format_time(None, None, None), transform=ax_12.transAxes, family='monospace', ha='left') def update(key, E, P): k, ko, faus = key t = tt[k] im_11.set_data(square(mu[key])) im_12.set_data(square(xx[k])) im_21.set_data(square(spread[key])) im_22.set_data(square(err[key])) # Remove previous obs try: lh[0].remove() except ValueError: pass # Plot current obs. # - plot() automatically adjusts to direction of y-axis in use. # - ind2sub returns (iy,ix), while plot takes (ix,iy) => reverse. if ko is not None and obs_inds is not None: lh[0] = ax_12.plot(*ind2sub(obs_inds(t))[::-1], 'k.', ms=1, zorder=5)[0] text_t.set_text(format_time(k, ko, t)) return return update
def init(fignum, stats, key0, plot_u, E, P, **kwargs): xx, yy, mu = stats.xx, stats.yy, stats.mu # Set parameters (kwargs takes precedence over params_orig) p = DotDict( **{kw: kwargs.get(kw, val) for kw, val in params_orig.items()}) if not p.dims: M = xx.shape[-1] p.dims = arange(M) else: M = len(p.dims) # Make periodic wrapper ii, wrap = viz.setup_wrapping(M, p.periodicity) # Set up figure, axes fig, ax = place.freshfig(fignum, figsize=(8, 5)) fig.suptitle("1d amplitude plot") # Nans nan1 = wrap(nan * ones(M)) if E is None and p.conf_mult is None: p.conf_mult = 2 # Init plots if p.conf_mult: lines_s = ax.plot(ii, nan1, "b-", lw=1, label=(str(p.conf_mult) + r'$\sigma$ conf')) lines_s += ax.plot(ii, nan1, "b-", lw=1) line_mu, = ax.plot(ii, nan1, 'b-', lw=2, label='DA mean') else: nanE = nan * ones((stats.xp.N, M)) lines_E = ax.plot(ii, wrap(nanE[0]), **p.ens_props, lw=1, label='Ensemble') lines_E += ax.plot(ii, wrap(nanE[1:]).T, **p.ens_props, lw=1) # Truth, Obs (line_x, ) = ax.plot(ii, nan1, 'k-', lw=3, label='Truth') if p.obs_inds is not None: p.obs_inds = np.asarray(p.obs_inds) (line_y, ) = ax.plot(p.obs_inds, nan * p.obs_inds, 'g*', ms=5, label='Obs') # Tune plot ax.set_ylim(*viz.xtrema(xx)) ax.set_xlim(viz.stretch(ii[0], ii[-1], 1)) # Xticks xt = ax.get_xticks() xt = xt[abs(xt % 1) < 0.01].astype(int) # Keep only the integer ticks xt = xt[xt >= 0] xt = xt[xt < len(p.dims)] ax.set_xticks(xt) ax.set_xticklabels(p.dims[xt]) ax.set_xlabel('State index') ax.set_ylabel('Value') ax.legend(loc='upper right') text_t = ax.text(0.01, 0.01, format_time(None, None, None), transform=ax.transAxes, family='monospace', ha='left') # Init visibility (must come after legend): if p.obs_inds is not None: line_y.set_visible(False) def update(key, E, P): k, ko, faus = key if p.conf_mult: sigma = mu[key] + p.conf_mult * stats.spread[key] * [[1], [-1]] lines_s[0].set_ydata(wrap(sigma[0, p.dims])) lines_s[1].set_ydata(wrap(sigma[1, p.dims])) line_mu.set_ydata(wrap(mu[key][p.dims])) else: for n, line in enumerate(lines_E): line.set_ydata(wrap(E[n, p.dims])) update_alpha(key, stats, lines_E) line_x.set_ydata(wrap(xx[k, p.dims])) text_t.set_text(format_time(k, ko, stats.HMM.tseq.tt[k])) if 'f' in faus: if p.obs_inds is not None: line_y.set_ydata(yy[ko]) line_y.set_zorder(5) line_y.set_visible(True) if 'u' in faus: if p.obs_inds is not None: line_y.set_visible(False) return return update
- Gaussian distributions. """ dists = dist_euclid(vectorize(*pts)) Cov = 1 - variogram_gauss(dists, r) C12 = sla.sqrtm(Cov).real fields = randn(N, len(dists)) @ C12.T return fields if __name__ == "__main__": from simulator import plotting as plots from simulator.grid import Grid2D np.random.seed(3000) plt.ion() N = 15 # ensemble size ## 1D xx = np.linspace(0, 1, 201) fields = gaussian_fields((xx, ), N) fig, ax = freshfig(1) ax.plot(xx, fields.T, lw=2) ## 2D grid = Grid2D(Lx=1, Ly=1, Nx=20, Ny=20) plots.model = grid fields = gaussian_fields(grid.mesh(), N) fields = 0.5 + .2 * fields # fields = truncate_01(fields) plots.fields(plots.field, fields)
def plot_err_components(stats): """Plot components of the error. Parameters ---------- stats: `dapper.stats.Stats` .. note:: it was chosen to `plot(ii, mean_in_time(abs(err_i)))`, and thus the corresponding spread measure is MAD. If one chose instead: `plot(ii, std_spread_in_time(err_i))`, then the corresponding measure of spread would have been `spread`. This choice was made in part because (wrt. subplot 2) the singular values (`svals`) correspond to rotated MADs, and because `rms(umisf)` seems too convoluted for interpretation. """ fig, (ax0, ax1, ax2) = place.freshfig("Error components", figsize=(6, 6), nrows=3) tseq = stats.HMM.tseq Nx = stats.xx.shape[1] en_mean = lambda x: np.mean(x, axis=0) # noqa err = en_mean(abs(stats.err.a)) sprd = en_mean(stats.spread.a) umsft = en_mean(abs(stats.umisf.a)) usprd = en_mean(stats.svals.a) ax0.plot(arange(Nx), err, 'k', lw=2, label='Error') if Nx < 10**3: ax0.fill_between(arange(Nx), [0]*len(sprd), sprd, alpha=0.7, label='Spread') else: ax0.plot(arange(Nx), sprd, alpha=0.7, label='Spread') # ax0.set_yscale('log') ax0.set_title('Element-wise error comparison') ax0.set_xlabel('Dimension index (i)') ax0.set_ylabel('Time-average (_a) magnitude') ax0.set_xlim(0, Nx-1) ax0.get_xaxis().set_major_locator(MaxNLocator(integer=True)) ax0.legend(loc='upper right') ax1.set_xlim(0, Nx-1) ax1.set_xlabel('Principal component index') ax1.set_ylabel('Time-average (_a) magnitude') ax1.set_title('Spectral error comparison') has_been_computed = np.any(np.isfinite(umsft)) if has_been_computed: L = len(umsft) ax1.plot(arange(L), umsft, 'k', lw=2, label='Error') ax1.fill_between(arange(L), [0]*L, usprd, alpha=0.7, label='Spread') ax1.set_yscale('log') ax1.get_xaxis().set_major_locator(MaxNLocator(integer=True)) else: not_available_text(ax1) rmse = stats.err.rms.a[tseq.masko] ax2.hist(rmse, bins=30, density=False) ax2.set_ylabel('Num. of occurence (_a)') ax2.set_xlabel('RMSE') ax2.set_title('Histogram of RMSE values') ax2.set_xlim(left=0) plt.pause(0.1) plt.tight_layout()
def init(fignum, stats, key0, plot_u, E, P, **kwargs): xx, yy, mu, spread, tseq = \ stats.xx, stats.yy, stats.mu, stats.spread, stats.HMM.tseq # Set parameters (kwargs takes precedence over params_orig) p = DotDict( **{kw: kwargs.get(kw, val) for kw, val in params_orig.items()}) # Lag settings: T_lag, K_lag, a_lag = validate_lag(p.Tplot, tseq) K_plot = comp_K_plot(K_lag, a_lag, plot_u) # Extend K_plot forther for adding blanks in resampling (PartFilt): has_w = hasattr(stats, 'w') if has_w: K_plot += a_lag # Chose marginal dims to plot if not p.dims: Nx = min(10, xx.shape[-1]) DimsX = linspace_int(xx.shape[-1], Nx) else: Nx = len(p.dims) DimsX = p.dims # Pre-process obs dimensions # Rm inds of obs if not in DimsX iiY = [i for i, m in enumerate(p.obs_inds) if m in DimsX] # Rm obs_inds if not in DimsX DimsY = [m for i, m in enumerate(p.obs_inds) if m in DimsX] # Get dim (within y) of each x DimsY = [DimsY.index(m) if m in DimsY else None for m in DimsX] Ny = len(iiY) # Set up figure, axes fig, axs = place.freshfig(fignum, figsize=(5, 7), nrows=Nx, sharex=True) if Nx == 1: axs = [axs] # Tune plots axs[0].set_title("Marginal time series") for ix, (m, ax) in enumerate(zip(DimsX, axs)): # ax.set_ylim(*viz.stretch(*viz.xtrema(xx[:, m]), 1/p.zoomy)) if not p.labels: ax.set_ylabel("$x_{%d}$" % m) else: ax.set_ylabel(p.labels[ix]) axs[-1].set_xlabel('Time (t)') plot_pause(0.05) plt.tight_layout() # Allocate d = DotDict() # data arrays h = DotDict() # plot handles # Why "if True" ? Just to indent the rest of the line... if True: d.t = RollingArray((K_plot, )) if True: d.x = RollingArray((K_plot, Nx)) h.x = [] if True: d.y = RollingArray((K_plot, Ny)) h.y = [] if E is not None: d.E = RollingArray((K_plot, len(E), Nx)) h.E = [] if P is not None: d.mu = RollingArray((K_plot, Nx)) h.mu = [] if P is not None: d.s = RollingArray((K_plot, 2, Nx)) h.s = [] # Plot (invisible coz everything here is nan, for the moment). for ix, (_m, iy, ax) in enumerate(zip(DimsX, DimsY, axs)): if True: h.x += ax.plot(d.t, d.x[:, ix], 'k') if iy != None: h.y += ax.plot(d.t, d.y[:, iy], 'g*', ms=10) if 'E' in d: h.E += [ax.plot(d.t, d.E[:, :, ix], **p.ens_props)] if 'mu' in d: h.mu += ax.plot(d.t, d.mu[:, ix], 'b') if 's' in d: h.s += [ax.plot(d.t, d.s[:, :, ix], 'b--', lw=1)] def update(key, E, P): k, ko, faus = key EE = duplicate_with_blanks_for_resampled(E, DimsX, key, has_w) # Roll data array ind = k if plot_u else ko for Ens in EE: # If E is duplicated, so must the others be. if 'E' in d: d.E.insert(ind, Ens) if 'mu' in d: d.mu.insert(ind, mu[key][DimsX]) if 's' in d: d.s.insert( ind, mu[key][DimsX] + [[1], [-1]] * spread[key][DimsX]) if True: d.t.insert(ind, tseq.tt[k]) if True: d.y.insert( ind, yy[ko, iiY] if ko is not None else nan * ones(Ny)) if True: d.x.insert(ind, xx[k, DimsX]) # Update graphs for ix, (_m, iy, ax) in enumerate(zip(DimsX, DimsY, axs)): sliding_xlim(ax, d.t, T_lag, True) if True: h.x[ix].set_data(d.t, d.x[:, ix]) if iy != None: h.y[iy].set_data(d.t, d.y[:, iy]) if 'mu' in d: h.mu[ix].set_data(d.t, d.mu[:, ix]) if 's' in d: [h.s[ix][b].set_data(d.t, d.s[:, b, ix]) for b in [0, 1]] if 'E' in d: [ h.E[ix][n].set_data(d.t, d.E[:, n, ix]) for n in range(len(E)) ] if 'E' in d: update_alpha(key, stats, h.E[ix]) # TODO 3: fixup. This might be slow? # In any case, it is very far from tested. # Also, relim'iting all of the time is distracting. # Use d_ylim? if 'E' in d: lims = d.E elif 'mu' in d: lims = d.mu lims = np.array(viz.xtrema(lims[..., ix])) if lims[0] == lims[1]: lims += [-.5, +.5] ax.set_ylim(*viz.stretch(*lims, 1 / p.zoomy)) return return update
# ensured by the `config_wells` function used below. grid1 = [.1, .9] grid2 = np.dstack(np.meshgrid(grid1, grid1)).reshape((-1, 2)) rates = np.ones((len(grid2), 1)) # ==> all wells use the same (constant) rate model.config_wells( # Each row in `inj` and `prod` should be a tuple: (x, y, rate), # where x, y ∈ (0, 1) and rate > 0. inj=[[0.50, 0.50, 1.00]], prod=np.hstack((grid2, rates)), ) # #### Plot # Let's take a moment to visualize the (true) model permeability field, and the well locations. fig, ax = freshfig("True perm. field", figsize=(1.5, 1), rel=1) # plots.field(ax, perm.Truth, "pperm") plots.field(ax, perm_transf(perm.Truth), locator=LogLocator(), wells=True, colorbar=True) fig.tight_layout() # #### Observation operator # The data will consist in the water saturation of at the well locations, i.e. of the # production. I.e. there is no well model. It should be pointed out, however, that # ensemble methods technically support observation models of any complexity, though your # accuracy mileage may vary (again, depending on the incurred nonlinearity and # non-Gaussianity). Furthermore, it is also no problem to include time-dependence in the # observation model.
def init(fignum, stats, key0, plot_u, E, P, **kwargs): xx, yy, mu, _, tseq = \ stats.xx, stats.yy, stats.mu, stats.spread, stats.HMM.tseq # Set parameters (kwargs takes precedence over params_orig) p = DotDict( **{kw: kwargs.get(kw, val) for kw, val in params_orig.items()}) # Lag settings: has_w = hasattr(stats, 'w') if p.Tplot == 0: K_plot = 1 else: T_lag, K_lag, a_lag = validate_lag(p.Tplot, tseq) K_plot = comp_K_plot(K_lag, a_lag, plot_u) # Extend K_plot forther for adding blanks in resampling (PartFilt): if has_w: K_plot += a_lag # Dimension settings if not p.dims: p.dims = arange(M) if not p.labels: p.labels = ["$x_%d$" % d for d in p.dims] assert len(p.dims) == M # Set up figure, axes fig, _ = place.freshfig(fignum, figsize=(5, 5)) ax = plt.subplot(111, projection='3d' if is_3d else None) ax.set_facecolor('w') ax.set_title("Phase space trajectories") # Tune plot for ind, (s, i, t) in enumerate(zip(p.labels, p.dims, "xyz")): viz.set_ilim(ax, ind, *viz.stretch(*viz.xtrema(xx[:, i]), 1 / p.zoom)) eval("ax.set_%slabel('%s')" % (t, s)) # Allocate d = DotDict() # data arrays h = DotDict() # plot handles s = DotDict() # scatter handles if E is not None: d.E = RollingArray((K_plot, len(E), M)) h.E = [] if P is not None: d.mu = RollingArray((K_plot, M)) if True: d.x = RollingArray((K_plot, M)) if list(p.obs_inds) == list(p.dims): d.y = RollingArray((K_plot, M)) # Plot tails (invisible coz everything here is nan, for the moment). if 'E' in d: h.E += [ ax.plot(*xn, **p.ens_props)[0] for xn in np.transpose(d.E, [1, 2, 0]) ] if 'mu' in d: h.mu = ax.plot(*d.mu.T, 'b', lw=2)[0] if True: h.x = ax.plot(*d.x.T, 'k', lw=3)[0] if 'y' in d: h.y = ax.plot(*d.y.T, 'g*', ms=14)[0] # Scatter. NB: don't init with nan's coz it's buggy # (wrt. get_color() and _offsets3d) since mpl 3.1. if 'E' in d: s.E = ax.scatter(*E.T[p.dims], s=3**2, c=[hn.get_color() for hn in h.E]) if 'mu' in d: s.mu = ax.scatter(*ones(M), s=8**2, c=[h.mu.get_color()]) if True: s.x = ax.scatter(*ones(M), s=14**2, c=[h.x.get_color()], marker=(5, 1), zorder=99) def update(key, E, P): k, ko, faus = key show_y = 'y' in d and ko is not None def update_tail(handle, newdata): handle.set_data(newdata[:, 0], newdata[:, 1]) if is_3d: handle.set_3d_properties(newdata[:, 2]) def update_sctr(handle, newdata): if is_3d: handle._offsets3d = juggle_axes(*newdata.T, 'z') else: handle.set_offsets(newdata) EE = duplicate_with_blanks_for_resampled(E, p.dims, key, has_w) # Roll data array ind = k if plot_u else ko for Ens in EE: # If E is duplicated, so must the others be. if 'E' in d: d.E.insert(ind, Ens) if True: d.x.insert(ind, xx[k, p.dims]) if 'y' in d: d.y.insert(ind, yy[ko, :] if show_y else nan * ones(M)) if 'mu' in d: d.mu.insert(ind, mu[key][p.dims]) # Update graph update_sctr(s.x, d.x[[-1]]) update_tail(h.x, d.x) if 'y' in d: update_tail(h.y, d.y) if 'mu' in d: update_sctr(s.mu, d.mu[[-1]]) update_tail(h.mu, d.mu) else: update_sctr(s.E, d.E[-1]) for n in range(len(E)): update_tail(h.E[n], d.E[:, n, :]) update_alpha(key, stats, h.E, s.E) return return update
######################## # Reference trajectory ######################## # NB: Arbitrary, coz models are autonom. But dont use nan coz QG doesn't like it. t0 = 0.0 K = int(round(T / dt)) # Num of time steps. tt = np.linspace(dt, T, K) # Time seq. x = with_recursion(step, prog="BurnIn")(x0, int(10 / dt), t0, dt)[-1] xx = with_recursion(step, prog="Reference")(x, K, t0, dt) ######################## # ACF ######################## # NB: Won't work with QG (too big, and BCs==0). fig, ax = place.freshfig("ACF") if "ii" not in locals(): ii = np.arange(min(100, Nx)) if "nlags" not in locals(): nlags = min(100, K - 1) ax.plot( tt[:nlags], np.nanmean(series.auto_cov(xx[:nlags, ii], nlags=nlags - 1, corr=1), axis=1)) ax.set_xlabel('Time (t)') ax.set_ylabel('Auto-corr') viz.plot_pause(0.1) ######################## # "Linearized" forecasting ########################