def init(fignum, stats, key0, plot_u, E, P, **kwargs): xx, yy, mu, _, tseq = \ stats.xx, stats.yy, stats.mu, stats.spread, stats.HMM.tseq # Set parameters (kwargs takes precedence over params_orig) p = DotDict( **{kw: kwargs.get(kw, val) for kw, val in params_orig.items()}) # Lag settings: has_w = hasattr(stats, 'w') if p.Tplot == 0: K_plot = 1 else: T_lag, K_lag, a_lag = validate_lag(p.Tplot, tseq) K_plot = comp_K_plot(K_lag, a_lag, plot_u) # Extend K_plot forther for adding blanks in resampling (PartFilt): if has_w: K_plot += a_lag # Dimension settings if not p.dims: p.dims = arange(M) if not p.labels: p.labels = ["$x_%d$" % d for d in p.dims] assert len(p.dims) == M # Set up figure, axes fig, _ = place.freshfig(fignum, figsize=(5, 5)) ax = plt.subplot(111, projection='3d' if is_3d else None) ax.set_facecolor('w') ax.set_title("Phase space trajectories") # Tune plot for ind, (s, i, t) in enumerate(zip(p.labels, p.dims, "xyz")): viz.set_ilim(ax, ind, *viz.stretch(*viz.xtrema(xx[:, i]), 1 / p.zoom)) eval("ax.set_%slabel('%s')" % (t, s)) # Allocate d = DotDict() # data arrays h = DotDict() # plot handles s = DotDict() # scatter handles if E is not None: d.E = RollingArray((K_plot, len(E), M)) h.E = [] if P is not None: d.mu = RollingArray((K_plot, M)) if True: d.x = RollingArray((K_plot, M)) if list(p.obs_inds) == list(p.dims): d.y = RollingArray((K_plot, M)) # Plot tails (invisible coz everything here is nan, for the moment). if 'E' in d: h.E += [ ax.plot(*xn, **p.ens_props)[0] for xn in np.transpose(d.E, [1, 2, 0]) ] if 'mu' in d: h.mu = ax.plot(*d.mu.T, 'b', lw=2)[0] if True: h.x = ax.plot(*d.x.T, 'k', lw=3)[0] if 'y' in d: h.y = ax.plot(*d.y.T, 'g*', ms=14)[0] # Scatter. NB: don't init with nan's coz it's buggy # (wrt. get_color() and _offsets3d) since mpl 3.1. if 'E' in d: s.E = ax.scatter(*E.T[p.dims], s=3**2, c=[hn.get_color() for hn in h.E]) if 'mu' in d: s.mu = ax.scatter(*ones(M), s=8**2, c=[h.mu.get_color()]) if True: s.x = ax.scatter(*ones(M), s=14**2, c=[h.x.get_color()], marker=(5, 1), zorder=99) def update(key, E, P): k, ko, faus = key show_y = 'y' in d and ko is not None def update_tail(handle, newdata): handle.set_data(newdata[:, 0], newdata[:, 1]) if is_3d: handle.set_3d_properties(newdata[:, 2]) def update_sctr(handle, newdata): if is_3d: handle._offsets3d = juggle_axes(*newdata.T, 'z') else: handle.set_offsets(newdata) EE = duplicate_with_blanks_for_resampled(E, p.dims, key, has_w) # Roll data array ind = k if plot_u else ko for Ens in EE: # If E is duplicated, so must the others be. if 'E' in d: d.E.insert(ind, Ens) if True: d.x.insert(ind, xx[k, p.dims]) if 'y' in d: d.y.insert(ind, yy[ko, :] if show_y else nan * ones(M)) if 'mu' in d: d.mu.insert(ind, mu[key][p.dims]) # Update graph update_sctr(s.x, d.x[[-1]]) update_tail(h.x, d.x) if 'y' in d: update_tail(h.y, d.y) if 'mu' in d: update_sctr(s.mu, d.mu[[-1]]) update_tail(h.mu, d.mu) else: update_sctr(s.E, d.E[-1]) for n in range(len(E)): update_tail(h.E[n], d.E[:, n, :]) update_alpha(key, stats, h.E, s.E) return return update
def init(fignum, stats, key0, plot_u, E, P, **kwargs): xx, yy, mu = stats.xx, stats.yy, stats.mu # Set parameters (kwargs takes precedence over params_orig) p = DotDict( **{kw: kwargs.get(kw, val) for kw, val in params_orig.items()}) if not p.dims: M = xx.shape[-1] p.dims = arange(M) else: M = len(p.dims) # Make periodic wrapper ii, wrap = viz.setup_wrapping(M, p.periodicity) # Set up figure, axes fig, ax = place.freshfig(fignum, figsize=(8, 5)) fig.suptitle("1d amplitude plot") # Nans nan1 = wrap(nan * ones(M)) if E is None and p.conf_mult is None: p.conf_mult = 2 # Init plots if p.conf_mult: lines_s = ax.plot(ii, nan1, "b-", lw=1, label=(str(p.conf_mult) + r'$\sigma$ conf')) lines_s += ax.plot(ii, nan1, "b-", lw=1) line_mu, = ax.plot(ii, nan1, 'b-', lw=2, label='DA mean') else: nanE = nan * ones((stats.xp.N, M)) lines_E = ax.plot(ii, wrap(nanE[0]), **p.ens_props, lw=1, label='Ensemble') lines_E += ax.plot(ii, wrap(nanE[1:]).T, **p.ens_props, lw=1) # Truth, Obs (line_x, ) = ax.plot(ii, nan1, 'k-', lw=3, label='Truth') if p.obs_inds is not None: p.obs_inds = np.asarray(p.obs_inds) (line_y, ) = ax.plot(p.obs_inds, nan * p.obs_inds, 'g*', ms=5, label='Obs') # Tune plot ax.set_ylim(*viz.xtrema(xx)) ax.set_xlim(viz.stretch(ii[0], ii[-1], 1)) # Xticks xt = ax.get_xticks() xt = xt[abs(xt % 1) < 0.01].astype(int) # Keep only the integer ticks xt = xt[xt >= 0] xt = xt[xt < len(p.dims)] ax.set_xticks(xt) ax.set_xticklabels(p.dims[xt]) ax.set_xlabel('State index') ax.set_ylabel('Value') ax.legend(loc='upper right') text_t = ax.text(0.01, 0.01, format_time(None, None, None), transform=ax.transAxes, family='monospace', ha='left') # Init visibility (must come after legend): if p.obs_inds is not None: line_y.set_visible(False) def update(key, E, P): k, ko, faus = key if p.conf_mult: sigma = mu[key] + p.conf_mult * stats.spread[key] * [[1], [-1]] lines_s[0].set_ydata(wrap(sigma[0, p.dims])) lines_s[1].set_ydata(wrap(sigma[1, p.dims])) line_mu.set_ydata(wrap(mu[key][p.dims])) else: for n, line in enumerate(lines_E): line.set_ydata(wrap(E[n, p.dims])) update_alpha(key, stats, lines_E) line_x.set_ydata(wrap(xx[k, p.dims])) text_t.set_text(format_time(k, ko, stats.HMM.tseq.tt[k])) if 'f' in faus: if p.obs_inds is not None: line_y.set_ydata(yy[ko]) line_y.set_zorder(5) line_y.set_visible(True) if 'u' in faus: if p.obs_inds is not None: line_y.set_visible(False) return return update