def init(fignum, stats, key0, plot_u, E, P, **kwargs): xx, yy, mu, _, tseq = \ stats.xx, stats.yy, stats.mu, stats.spread, stats.HMM.tseq # Set parameters (kwargs takes precedence over params_orig) p = DotDict( **{kw: kwargs.get(kw, val) for kw, val in params_orig.items()}) # Lag settings: has_w = hasattr(stats, 'w') if p.Tplot == 0: K_plot = 1 else: T_lag, K_lag, a_lag = validate_lag(p.Tplot, tseq) K_plot = comp_K_plot(K_lag, a_lag, plot_u) # Extend K_plot forther for adding blanks in resampling (PartFilt): if has_w: K_plot += a_lag # Dimension settings if not p.dims: p.dims = arange(M) if not p.labels: p.labels = ["$x_%d$" % d for d in p.dims] assert len(p.dims) == M # Set up figure, axes fig, _ = place.freshfig(fignum, figsize=(5, 5)) ax = plt.subplot(111, projection='3d' if is_3d else None) ax.set_facecolor('w') ax.set_title("Phase space trajectories") # Tune plot for ind, (s, i, t) in enumerate(zip(p.labels, p.dims, "xyz")): viz.set_ilim(ax, ind, *viz.stretch(*viz.xtrema(xx[:, i]), 1 / p.zoom)) eval("ax.set_%slabel('%s')" % (t, s)) # Allocate d = DotDict() # data arrays h = DotDict() # plot handles s = DotDict() # scatter handles if E is not None: d.E = RollingArray((K_plot, len(E), M)) h.E = [] if P is not None: d.mu = RollingArray((K_plot, M)) if True: d.x = RollingArray((K_plot, M)) if list(p.obs_inds) == list(p.dims): d.y = RollingArray((K_plot, M)) # Plot tails (invisible coz everything here is nan, for the moment). if 'E' in d: h.E += [ ax.plot(*xn, **p.ens_props)[0] for xn in np.transpose(d.E, [1, 2, 0]) ] if 'mu' in d: h.mu = ax.plot(*d.mu.T, 'b', lw=2)[0] if True: h.x = ax.plot(*d.x.T, 'k', lw=3)[0] if 'y' in d: h.y = ax.plot(*d.y.T, 'g*', ms=14)[0] # Scatter. NB: don't init with nan's coz it's buggy # (wrt. get_color() and _offsets3d) since mpl 3.1. if 'E' in d: s.E = ax.scatter(*E.T[p.dims], s=3**2, c=[hn.get_color() for hn in h.E]) if 'mu' in d: s.mu = ax.scatter(*ones(M), s=8**2, c=[h.mu.get_color()]) if True: s.x = ax.scatter(*ones(M), s=14**2, c=[h.x.get_color()], marker=(5, 1), zorder=99) def update(key, E, P): k, ko, faus = key show_y = 'y' in d and ko is not None def update_tail(handle, newdata): handle.set_data(newdata[:, 0], newdata[:, 1]) if is_3d: handle.set_3d_properties(newdata[:, 2]) def update_sctr(handle, newdata): if is_3d: handle._offsets3d = juggle_axes(*newdata.T, 'z') else: handle.set_offsets(newdata) EE = duplicate_with_blanks_for_resampled(E, p.dims, key, has_w) # Roll data array ind = k if plot_u else ko for Ens in EE: # If E is duplicated, so must the others be. if 'E' in d: d.E.insert(ind, Ens) if True: d.x.insert(ind, xx[k, p.dims]) if 'y' in d: d.y.insert(ind, yy[ko, :] if show_y else nan * ones(M)) if 'mu' in d: d.mu.insert(ind, mu[key][p.dims]) # Update graph update_sctr(s.x, d.x[[-1]]) update_tail(h.x, d.x) if 'y' in d: update_tail(h.y, d.y) if 'mu' in d: update_sctr(s.mu, d.mu[[-1]]) update_tail(h.mu, d.mu) else: update_sctr(s.E, d.E[-1]) for n in range(len(E)): update_tail(h.E[n], d.E[:, n, :]) update_alpha(key, stats, h.E, s.E) return return update
def init(fignum, stats, key0, plot_u, E, P, **kwargs): xx, yy, mu, spread, tseq = \ stats.xx, stats.yy, stats.mu, stats.spread, stats.HMM.tseq # Set parameters (kwargs takes precedence over params_orig) p = DotDict( **{kw: kwargs.get(kw, val) for kw, val in params_orig.items()}) # Lag settings: T_lag, K_lag, a_lag = validate_lag(p.Tplot, tseq) K_plot = comp_K_plot(K_lag, a_lag, plot_u) # Extend K_plot forther for adding blanks in resampling (PartFilt): has_w = hasattr(stats, 'w') if has_w: K_plot += a_lag # Chose marginal dims to plot if not p.dims: Nx = min(10, xx.shape[-1]) DimsX = linspace_int(xx.shape[-1], Nx) else: Nx = len(p.dims) DimsX = p.dims # Pre-process obs dimensions # Rm inds of obs if not in DimsX iiY = [i for i, m in enumerate(p.obs_inds) if m in DimsX] # Rm obs_inds if not in DimsX DimsY = [m for i, m in enumerate(p.obs_inds) if m in DimsX] # Get dim (within y) of each x DimsY = [DimsY.index(m) if m in DimsY else None for m in DimsX] Ny = len(iiY) # Set up figure, axes fig, axs = place.freshfig(fignum, figsize=(5, 7), nrows=Nx, sharex=True) if Nx == 1: axs = [axs] # Tune plots axs[0].set_title("Marginal time series") for ix, (m, ax) in enumerate(zip(DimsX, axs)): # ax.set_ylim(*viz.stretch(*viz.xtrema(xx[:, m]), 1/p.zoomy)) if not p.labels: ax.set_ylabel("$x_{%d}$" % m) else: ax.set_ylabel(p.labels[ix]) axs[-1].set_xlabel('Time (t)') plot_pause(0.05) plt.tight_layout() # Allocate d = DotDict() # data arrays h = DotDict() # plot handles # Why "if True" ? Just to indent the rest of the line... if True: d.t = RollingArray((K_plot, )) if True: d.x = RollingArray((K_plot, Nx)) h.x = [] if True: d.y = RollingArray((K_plot, Ny)) h.y = [] if E is not None: d.E = RollingArray((K_plot, len(E), Nx)) h.E = [] if P is not None: d.mu = RollingArray((K_plot, Nx)) h.mu = [] if P is not None: d.s = RollingArray((K_plot, 2, Nx)) h.s = [] # Plot (invisible coz everything here is nan, for the moment). for ix, (_m, iy, ax) in enumerate(zip(DimsX, DimsY, axs)): if True: h.x += ax.plot(d.t, d.x[:, ix], 'k') if iy != None: h.y += ax.plot(d.t, d.y[:, iy], 'g*', ms=10) if 'E' in d: h.E += [ax.plot(d.t, d.E[:, :, ix], **p.ens_props)] if 'mu' in d: h.mu += ax.plot(d.t, d.mu[:, ix], 'b') if 's' in d: h.s += [ax.plot(d.t, d.s[:, :, ix], 'b--', lw=1)] def update(key, E, P): k, ko, faus = key EE = duplicate_with_blanks_for_resampled(E, DimsX, key, has_w) # Roll data array ind = k if plot_u else ko for Ens in EE: # If E is duplicated, so must the others be. if 'E' in d: d.E.insert(ind, Ens) if 'mu' in d: d.mu.insert(ind, mu[key][DimsX]) if 's' in d: d.s.insert( ind, mu[key][DimsX] + [[1], [-1]] * spread[key][DimsX]) if True: d.t.insert(ind, tseq.tt[k]) if True: d.y.insert( ind, yy[ko, iiY] if ko is not None else nan * ones(Ny)) if True: d.x.insert(ind, xx[k, DimsX]) # Update graphs for ix, (_m, iy, ax) in enumerate(zip(DimsX, DimsY, axs)): sliding_xlim(ax, d.t, T_lag, True) if True: h.x[ix].set_data(d.t, d.x[:, ix]) if iy != None: h.y[iy].set_data(d.t, d.y[:, iy]) if 'mu' in d: h.mu[ix].set_data(d.t, d.mu[:, ix]) if 's' in d: [h.s[ix][b].set_data(d.t, d.s[:, b, ix]) for b in [0, 1]] if 'E' in d: [ h.E[ix][n].set_data(d.t, d.E[:, n, ix]) for n in range(len(E)) ] if 'E' in d: update_alpha(key, stats, h.E[ix]) # TODO 3: fixup. This might be slow? # In any case, it is very far from tested. # Also, relim'iting all of the time is distracting. # Use d_ylim? if 'E' in d: lims = d.E elif 'mu' in d: lims = d.mu lims = np.array(viz.xtrema(lims[..., ix])) if lims[0] == lims[1]: lims += [-.5, +.5] ax.set_ylim(*viz.stretch(*lims, 1 / p.zoomy)) return return update