def make_figure(rslds, zs_rslds, x_rslds): fig = plt.figure(figsize=(10, 10)) gs = gridspec.GridSpec(2, 2) fp = FontProperties() fp.set_weight("bold") ax3 = fig.add_subplot(gs[0, 0]) plot_most_likely_dynamics(rslds.trans_distn, rslds.dynamics_distns, colors, xlim=(-4, 8), ylim=(-.5, .5), ax=ax3) # Overlay a partial trajectory rplt.plot_trajectory(zs_rslds[-1], x_rslds, ax=ax3, ls="-") ax3.set_title("Inferred Dynamics (rSLDS)") # Plot samples of discrete state sequence ax4 = fig.add_subplot(gs[1, 0]) rplt.plot_z_samples(rslds.num_states, zs_rslds, plt_slice=(0, x_rslds.shape[0]), ax=ax4) ax4.set_title("Discrete State Samples") ax5 = fig.add_subplot(gs[0, 1]) plot_input_dynamics(rslds.trans_distn, rslds.dynamics_distns, xlim=(-4, 8), ylim=(-4, 10), ax=ax5) ax6 = fig.add_subplot(gs[1, 1]) plot_other_compartments_dynamics(rslds.trans_distn, rslds.dynamics_distns, xlim=(-4, 8), ylim=(-4, 10), ax=ax6) plt.tight_layout()
def make_figure(true_model, z_true, x_true, y, rslds, zs_rslds, x_rslds, z_rslds_gen, x_rslds_gen, z_slds_gen, x_slds_gen): fig = plt.figure(figsize=(6.5,3.5)) gs = gridspec.GridSpec(2,3) fp = FontProperties() fp.set_weight("bold") # True dynamics ax1 = fig.add_subplot(gs[0,0]) rplt.plot_most_likely_dynamics(true_model.trans_distn, true_model.dynamics_distns, xlim=(-3,3), ylim=(-2,2), ax=ax1) # Overlay a partial trajectory rplt.plot_trajectory(z_true[1:1000], x_true[1:1000], ax=ax1, ls="-") ax1.set_title("True Latent Dynamics") plt.figtext(.025, 1-.075, '(a)', fontproperties=fp) # Plot a few output dimensions ax2 = fig.add_subplot(gs[1, 0]) for n in range(args.D_obs): rplt.plot_data(z_true[1:1000], y[1:1000, n], ax=ax2, ls="-") ax2.set_xlabel("Time") ax2.set_ylabel("$y$") ax2.set_title("Observed Data") plt.figtext(.025, .5 - .075, '(b)', fontproperties=fp) # Plot the inferred dynamics under the rSLDS ax3 = fig.add_subplot(gs[0, 1]) rplt.plot_most_likely_dynamics(rslds.trans_distn, rslds.dynamics_distns, xlim=(-3, 3), ylim=(-2, 2), ax=ax3) # Overlay a partial trajectory rplt.plot_trajectory(zs_rslds[-1][1:1000], x_rslds[1:1000], ax=ax3, ls="-") ax3.set_title("Inferred Dynamics (rSLDS)") plt.figtext(.33 + .025, 1. - .075, '(c)', fontproperties=fp) # Plot samples of discrete state sequence ax4 = fig.add_subplot(gs[1,1]) rplt.plot_z_samples(args.K, zs_rslds, zref=z_true, plt_slice=(0,1000), ax=ax4) ax4.set_title("Discrete State Samples") plt.figtext(.33 + .025, .5 - .075, '(d)', fontproperties=fp) # Plot simulated SLDS data ax5 = fig.add_subplot(gs[0, 2]) rplt.plot_trajectory(z_slds_gen[-1000:], x_slds_gen[-1000:], ax=ax5, ls="-") plt.grid(True) ax5.set_title("Generated States (SLDS)") plt.figtext(.66 + .025, 1. - .075, '(e)', fontproperties=fp) # Plot simulated rSLDS data ax6 = fig.add_subplot(gs[1, 2]) rplt.plot_trajectory(z_rslds_gen[-1000:], x_rslds_gen[-1000:], ax=ax6, ls="-") ax6.set_title("Generated States (rSLDS)") plt.grid(True) plt.figtext(.66 + .025, .5 - .075, '(f)', fontproperties=fp) plt.tight_layout() plt.savefig(os.path.join(args.output_dir, "nascar.png"), dpi=200) plt.savefig(os.path.join(args.output_dir, "nascar.pdf")) plt.show()
def make_figure(true_model, z_true, x_true, y, rslds, zs_rslds, x_rslds, z_rslds_gen, x_rslds_gen, y_rslds_gen, slds, zs_slds, x_slds, z_slds_gen, x_slds_gen, y_slds_gen): """ Show the following: - True latent dynamics (for most likely state) - Segment of trajectory in latent space - A few examples of observations in 10D space - ARHMM segmentation of factors - rSLDS segmentation of factors - ARHMM synthesis - rSLDS synthesis """ # fig = plt.figure(figsize=(6.5,3.5)) fig = plt.figure(figsize=(13, 7)) gs = gridspec.GridSpec(2, 3) fp = FontProperties() fp.set_weight("bold") # True dynamics ax1 = fig.add_subplot(gs[0, 0]) plot_most_likely_dynamics(true_model.trans_distn, true_model.dynamics_distns, xlim=(-3, 3), ylim=(-2, 2), ax=ax1) # Overlay a partial trajectory plot_trajectory(z_true[1:1000], x_true[1:1000], ax=ax1, ls="-") ax1.set_title("True Latent Dynamics") plt.figtext(.025, 1 - .075, '(a)', fontproperties=fp) # Plot a few output dimensions ax2 = fig.add_subplot(gs[1, 0]) for n in range(D_obs): plot_data(z_true[1:1000], y[1:1000, n], ax=ax2, ls="-") ax2.set_xlabel("Time") ax2.set_ylabel("$y$") ax2.set_title("Observed Data") plt.figtext(.025, .5 - .075, '(b)', fontproperties=fp) # Plot the inferred dynamics under the rSLDS ax3 = fig.add_subplot(gs[0, 1]) ax3_lim = 1.05 * abs(x_rslds[1:1000]).max(axis=0) plot_most_likely_dynamics(rslds.trans_distn, rslds.dynamics_distns, xlim=(-ax3_lim[0], ax3_lim[0]), ylim=(-ax3_lim[1], ax3_lim[1]), ax=ax3) # Overlay a partial trajectory plot_trajectory(zs_rslds[-1][1:1000], x_rslds[1:1000], ax=ax3, ls="-") ax3.set_title("Inferred Dynamics (rSLDS)") plt.figtext(.33 + .025, 1. - .075, '(c)', fontproperties=fp) # Plot something... z samples? ax4 = fig.add_subplot(gs[1, 1]) plot_z_samples(K, zs_rslds, zref=z_true, plt_slice=(0, 1000), ax=ax4) ax4.set_title("Discrete State Samples") plt.figtext(.33 + .025, .5 - .075, '(d)', fontproperties=fp) # Plot simulated SLDS data ax5 = fig.add_subplot(gs[0, 2]) # for n, ls in enumerate(["-", ":", "-."]): # plot_data(z_slds_gen[-1000:], y_slds_gen[-1000:, n], ax=ax5, ls=ls) plot_trajectory(z_slds_gen[-1000:], x_slds_gen[-1000:], ax=ax5, ls="-") # ax5.set_xlabel("Time") # ax5.set_ylabel("$y$") plt.grid(True) ax5.set_title("Generated States (SLDS)") plt.figtext(.66 + .025, 1. - .075, '(e)', fontproperties=fp) # Plot simulated rSLDS data ax6 = fig.add_subplot(gs[1, 2]) # for n, ls in enumerate(["-", ":", "-."]): # plot_data(z_rslds_gen[-1000:], y_rslds_gen[-1000:, n], ax=ax6, ls=ls) # ax6.set_xlabel("Time") # ax6.set_ylabel("$y$") plot_trajectory(z_rslds_gen[-1000:], x_rslds_gen[-1000:], ax=ax6, ls="-") ax6.set_title("Generated States (rSLDS)") plt.grid(True) plt.figtext(.66 + .025, .5 - .075, '(f)', fontproperties=fp) plt.tight_layout() plt.savefig(os.path.join(RESULTS_DIR, "nascar.png"), dpi=200) plt.savefig(os.path.join(RESULTS_DIR, "nascar.pdf")) plt.show()
plt.xlabel("Iteration") plt.ylabel("ELBO") # rslds, rslds_lps, rslds_z_smpls, rslds_x = \ # fit_rslds_vbem(inputs, y, mask, # initialization="none", # N_iters=100) plot_trajectory_and_probs(rslds_z_smpls[-1][1:], rslds_x[1:], trans_distn=rslds.trans_distn, title="Recurrent SLDS") plot_all_dynamics(rslds.dynamics_distns) plot_z_samples(K, rslds_z_smpls, plt_slice=(0, 1000)) ## Generate from the model rslds_y_gen, rslds_x_gen, rslds_z_gen = rslds.generate(T=T_gen, inputs=np.ones( (T_gen, 1)), with_noise=True) slds_y_gen, slds_x_gen, slds_z_gen = slds.generate(T=T_gen, inputs=np.ones( (T_gen, 1))) make_figure( true_model, z_true, x_true, y,