def update(key, E, P): k, ko, faus = key t = tt[k] im_11.set_data(square(mu[key])) im_12.set_data(square(xx[k])) im_21.set_data(square(spread[key])) im_22.set_data(square(err[key])) # Remove previous obs try: lh[0].remove() except ValueError: pass # Plot current obs. # - plot() automatically adjusts to direction of y-axis in use. # - ind2sub returns (iy,ix), while plot takes (ix,iy) => reverse. if ko is not None and obs_inds is not None: lh[0] = ax_12.plot(*ind2sub(obs_inds(t))[::-1], 'k.', ms=1, zorder=5)[0] text_t.set_text(format_time(k, ko, t)) return
def update(key, E, P): k, ko, faus = key if p.conf_mult: sigma = mu[key] + p.conf_mult * stats.spread[key] * [[1], [-1]] lines_s[0].set_ydata(wrap(sigma[0, p.dims])) lines_s[1].set_ydata(wrap(sigma[1, p.dims])) line_mu.set_ydata(wrap(mu[key][p.dims])) else: for n, line in enumerate(lines_E): line.set_ydata(wrap(E[n, p.dims])) update_alpha(key, stats, lines_E) line_x.set_ydata(wrap(xx[k, p.dims])) text_t.set_text(format_time(k, ko, stats.HMM.tseq.tt[k])) if 'f' in faus: if p.obs_inds is not None: line_y.set_ydata(yy[ko]) line_y.set_zorder(5) line_y.set_visible(True) if 'u' in faus: if p.obs_inds is not None: line_y.set_visible(False) return
def init(fignum, stats, key0, plot_u, E, P, **kwargs): GS = {'left': 0.125 - 0.04, 'right': 0.9 - 0.04} fig, axs = place.freshfig(fignum, figsize=(6, 6), nrows=2, ncols=2, sharex=True, sharey=True, gridspec_kw=GS) for ax in axs.flatten(): ax.set_aspect('equal', 'box') ((ax_11, ax_12), (ax_21, ax_22)) = axs ax_11.grid(color='w', linewidth=0.2) ax_12.grid(color='w', linewidth=0.2) ax_21.grid(color='k', linewidth=0.1) ax_22.grid(color='k', linewidth=0.1) # Upper colorbar -- position relative to ax_12 bb = ax_12.get_position() dy = 0.1 * bb.height ax_13 = fig.add_axes( [bb.x1 + 0.03, bb.y0 + dy, 0.04, bb.height - 2 * dy]) # Lower colorbar -- position relative to ax_22 bb = ax_22.get_position() dy = 0.1 * bb.height ax_23 = fig.add_axes( [bb.x1 + 0.03, bb.y0 + dy, 0.04, bb.height - 2 * dy]) # Extract data arrays xx, _, mu, spread, err = stats.xx, stats.yy, stats.mu, stats.spread, stats.err k = key0[0] tt = stats.HMM.tseq.tt # Plot # - origin='lower' might get overturned by set_ylim() below. im_11 = ax_11.imshow(square(mu[key0]), cmap=cm) im_12 = ax_12.imshow(square(xx[k]), cmap=cm) # hot is better, but needs +1 colorbar im_21 = ax_21.imshow(square(spread[key0]), cmap=plt.cm.bwr) im_22 = ax_22.imshow(square(err[key0]), cmap=plt.cm.bwr) ims = (im_11, im_12, im_21, im_22) # Obs init -- a list where item 0 is the handle of something invisible. lh = list(ax_12.plot(0, 0)[0:1]) sx = '$\\psi$' ax_11.set_title('mean ' + sx) ax_12.set_title('true ' + sx) ax_21.set_title('spread. ' + sx) ax_22.set_title('err. ' + sx) # TODO 7 # for ax in axs.flatten(): # Crop boundries (which should be 0, i.e. yield harsh q gradients): # lims = (1, nx-2) # step = (nx - 1)/8 # ticks = arange(step,nx-1,step) # ax.set_xlim (lims) # ax.set_ylim (lims[::-1]) # ax.set_xticks(ticks) # ax.set_yticks(ticks) for im, clim in zip(ims, clims): im.set_clim(clim) fig.colorbar(im_12, cax=ax_13) fig.colorbar(im_22, cax=ax_23) for ax in [ax_13, ax_23]: ax.yaxis.set_tick_params('major', length=2, width=0.5, direction='in', left=True, right=True) ax.set_axisbelow('line') # make ticks appear over colorbar patch # Title title = "Streamfunction (" + sx + ")" fig.suptitle(title) # Time info text_t = ax_12.text(1, 1.1, format_time(None, None, None), transform=ax_12.transAxes, family='monospace', ha='left') def update(key, E, P): k, ko, faus = key t = tt[k] im_11.set_data(square(mu[key])) im_12.set_data(square(xx[k])) im_21.set_data(square(spread[key])) im_22.set_data(square(err[key])) # Remove previous obs try: lh[0].remove() except ValueError: pass # Plot current obs. # - plot() automatically adjusts to direction of y-axis in use. # - ind2sub returns (iy,ix), while plot takes (ix,iy) => reverse. if ko is not None and obs_inds is not None: lh[0] = ax_12.plot(*ind2sub(obs_inds(t))[::-1], 'k.', ms=1, zorder=5)[0] text_t.set_text(format_time(k, ko, t)) return return update
def init(fignum, stats, key0, plot_u, E, P, **kwargs): xx, yy, mu = stats.xx, stats.yy, stats.mu # Set parameters (kwargs takes precedence over params_orig) p = DotDict( **{kw: kwargs.get(kw, val) for kw, val in params_orig.items()}) if not p.dims: M = xx.shape[-1] p.dims = arange(M) else: M = len(p.dims) # Make periodic wrapper ii, wrap = viz.setup_wrapping(M, p.periodicity) # Set up figure, axes fig, ax = place.freshfig(fignum, figsize=(8, 5)) fig.suptitle("1d amplitude plot") # Nans nan1 = wrap(nan * ones(M)) if E is None and p.conf_mult is None: p.conf_mult = 2 # Init plots if p.conf_mult: lines_s = ax.plot(ii, nan1, "b-", lw=1, label=(str(p.conf_mult) + r'$\sigma$ conf')) lines_s += ax.plot(ii, nan1, "b-", lw=1) line_mu, = ax.plot(ii, nan1, 'b-', lw=2, label='DA mean') else: nanE = nan * ones((stats.xp.N, M)) lines_E = ax.plot(ii, wrap(nanE[0]), **p.ens_props, lw=1, label='Ensemble') lines_E += ax.plot(ii, wrap(nanE[1:]).T, **p.ens_props, lw=1) # Truth, Obs (line_x, ) = ax.plot(ii, nan1, 'k-', lw=3, label='Truth') if p.obs_inds is not None: p.obs_inds = np.asarray(p.obs_inds) (line_y, ) = ax.plot(p.obs_inds, nan * p.obs_inds, 'g*', ms=5, label='Obs') # Tune plot ax.set_ylim(*viz.xtrema(xx)) ax.set_xlim(viz.stretch(ii[0], ii[-1], 1)) # Xticks xt = ax.get_xticks() xt = xt[abs(xt % 1) < 0.01].astype(int) # Keep only the integer ticks xt = xt[xt >= 0] xt = xt[xt < len(p.dims)] ax.set_xticks(xt) ax.set_xticklabels(p.dims[xt]) ax.set_xlabel('State index') ax.set_ylabel('Value') ax.legend(loc='upper right') text_t = ax.text(0.01, 0.01, format_time(None, None, None), transform=ax.transAxes, family='monospace', ha='left') # Init visibility (must come after legend): if p.obs_inds is not None: line_y.set_visible(False) def update(key, E, P): k, ko, faus = key if p.conf_mult: sigma = mu[key] + p.conf_mult * stats.spread[key] * [[1], [-1]] lines_s[0].set_ydata(wrap(sigma[0, p.dims])) lines_s[1].set_ydata(wrap(sigma[1, p.dims])) line_mu.set_ydata(wrap(mu[key][p.dims])) else: for n, line in enumerate(lines_E): line.set_ydata(wrap(E[n, p.dims])) update_alpha(key, stats, lines_E) line_x.set_ydata(wrap(xx[k, p.dims])) text_t.set_text(format_time(k, ko, stats.HMM.tseq.tt[k])) if 'f' in faus: if p.obs_inds is not None: line_y.set_ydata(yy[ko]) line_y.set_zorder(5) line_y.set_visible(True) if 'u' in faus: if p.obs_inds is not None: line_y.set_visible(False) return return update