Пример #1
0
def make_figure(rslds, zs_rslds, x_rslds):

    fig = plt.figure(figsize=(10, 10))
    gs = gridspec.GridSpec(2, 2)

    fp = FontProperties()
    fp.set_weight("bold")

    ax3 = fig.add_subplot(gs[0, 0])
    plot_most_likely_dynamics(rslds.trans_distn,
                              rslds.dynamics_distns,
                              colors,
                              xlim=(-4, 8),
                              ylim=(-.5, .5),
                              ax=ax3)

    # Overlay a partial trajectory
    rplt.plot_trajectory(zs_rslds[-1], x_rslds, ax=ax3, ls="-")
    ax3.set_title("Inferred Dynamics (rSLDS)")

    # Plot samples of discrete state sequence
    ax4 = fig.add_subplot(gs[1, 0])
    rplt.plot_z_samples(rslds.num_states,
                        zs_rslds,
                        plt_slice=(0, x_rslds.shape[0]),
                        ax=ax4)
    ax4.set_title("Discrete State Samples")

    ax5 = fig.add_subplot(gs[0, 1])
    plot_input_dynamics(rslds.trans_distn,
                        rslds.dynamics_distns,
                        xlim=(-4, 8),
                        ylim=(-4, 10),
                        ax=ax5)

    ax6 = fig.add_subplot(gs[1, 1])
    plot_other_compartments_dynamics(rslds.trans_distn,
                                     rslds.dynamics_distns,
                                     xlim=(-4, 8),
                                     ylim=(-4, 10),
                                     ax=ax6)

    plt.tight_layout()
Пример #2
0
def make_figure(true_model, z_true, x_true, y,
                rslds, zs_rslds, x_rslds,
                z_rslds_gen, x_rslds_gen,
                z_slds_gen, x_slds_gen):
    fig = plt.figure(figsize=(6.5,3.5))
    gs = gridspec.GridSpec(2,3)

    fp = FontProperties()
    fp.set_weight("bold")

    # True dynamics
    ax1 = fig.add_subplot(gs[0,0])
    rplt.plot_most_likely_dynamics(true_model.trans_distn,
                                   true_model.dynamics_distns,
                                   xlim=(-3,3), ylim=(-2,2),
                                   ax=ax1)

    # Overlay a partial trajectory
    rplt.plot_trajectory(z_true[1:1000], x_true[1:1000], ax=ax1, ls="-")
    ax1.set_title("True Latent Dynamics")
    plt.figtext(.025, 1-.075, '(a)', fontproperties=fp)

    # Plot a few output dimensions
    ax2 = fig.add_subplot(gs[1, 0])
    for n in range(args.D_obs):
        rplt.plot_data(z_true[1:1000], y[1:1000, n], ax=ax2, ls="-")
    ax2.set_xlabel("Time")
    ax2.set_ylabel("$y$")
    ax2.set_title("Observed Data")
    plt.figtext(.025, .5 - .075, '(b)', fontproperties=fp)

    # Plot the inferred dynamics under the rSLDS
    ax3 = fig.add_subplot(gs[0, 1])
    rplt.plot_most_likely_dynamics(rslds.trans_distn,
                                   rslds.dynamics_distns,
                                   xlim=(-3, 3), ylim=(-2, 2),
                                   ax=ax3)

    # Overlay a partial trajectory
    rplt.plot_trajectory(zs_rslds[-1][1:1000], x_rslds[1:1000], ax=ax3, ls="-")
    ax3.set_title("Inferred Dynamics (rSLDS)")
    plt.figtext(.33 + .025, 1. - .075, '(c)', fontproperties=fp)

    # Plot samples of discrete state sequence
    ax4 = fig.add_subplot(gs[1,1])
    rplt.plot_z_samples(args.K, zs_rslds, zref=z_true, plt_slice=(0,1000), ax=ax4)
    ax4.set_title("Discrete State Samples")
    plt.figtext(.33 + .025, .5 - .075, '(d)', fontproperties=fp)

    # Plot simulated SLDS data
    ax5 = fig.add_subplot(gs[0, 2])
    rplt.plot_trajectory(z_slds_gen[-1000:], x_slds_gen[-1000:], ax=ax5, ls="-")
    plt.grid(True)
    ax5.set_title("Generated States (SLDS)")
    plt.figtext(.66 + .025, 1. - .075, '(e)', fontproperties=fp)

    # Plot simulated rSLDS data
    ax6 = fig.add_subplot(gs[1, 2])
    rplt.plot_trajectory(z_rslds_gen[-1000:], x_rslds_gen[-1000:], ax=ax6, ls="-")
    ax6.set_title("Generated States (rSLDS)")
    plt.grid(True)
    plt.figtext(.66 + .025, .5 - .075, '(f)', fontproperties=fp)

    plt.tight_layout()
    plt.savefig(os.path.join(args.output_dir, "nascar.png"), dpi=200)
    plt.savefig(os.path.join(args.output_dir, "nascar.pdf"))
    plt.show()
def make_figure(true_model, z_true, x_true, y, rslds, zs_rslds, x_rslds,
                z_rslds_gen, x_rslds_gen, y_rslds_gen, slds, zs_slds, x_slds,
                z_slds_gen, x_slds_gen, y_slds_gen):
    """
    Show the following:
     - True latent dynamics (for most likely state)
     - Segment of trajectory in latent space
     - A few examples of observations in 10D space
     - ARHMM segmentation of factors
     - rSLDS segmentation of factors
     - ARHMM synthesis
     - rSLDS synthesis
    """
    # fig = plt.figure(figsize=(6.5,3.5))
    fig = plt.figure(figsize=(13, 7))
    gs = gridspec.GridSpec(2, 3)

    fp = FontProperties()
    fp.set_weight("bold")

    # True dynamics
    ax1 = fig.add_subplot(gs[0, 0])
    plot_most_likely_dynamics(true_model.trans_distn,
                              true_model.dynamics_distns,
                              xlim=(-3, 3),
                              ylim=(-2, 2),
                              ax=ax1)

    # Overlay a partial trajectory
    plot_trajectory(z_true[1:1000], x_true[1:1000], ax=ax1, ls="-")
    ax1.set_title("True Latent Dynamics")
    plt.figtext(.025, 1 - .075, '(a)', fontproperties=fp)

    # Plot a few output dimensions
    ax2 = fig.add_subplot(gs[1, 0])
    for n in range(D_obs):
        plot_data(z_true[1:1000], y[1:1000, n], ax=ax2, ls="-")
    ax2.set_xlabel("Time")
    ax2.set_ylabel("$y$")
    ax2.set_title("Observed Data")
    plt.figtext(.025, .5 - .075, '(b)', fontproperties=fp)

    # Plot the inferred dynamics under the rSLDS
    ax3 = fig.add_subplot(gs[0, 1])
    ax3_lim = 1.05 * abs(x_rslds[1:1000]).max(axis=0)
    plot_most_likely_dynamics(rslds.trans_distn,
                              rslds.dynamics_distns,
                              xlim=(-ax3_lim[0], ax3_lim[0]),
                              ylim=(-ax3_lim[1], ax3_lim[1]),
                              ax=ax3)

    # Overlay a partial trajectory
    plot_trajectory(zs_rslds[-1][1:1000], x_rslds[1:1000], ax=ax3, ls="-")
    ax3.set_title("Inferred Dynamics (rSLDS)")
    plt.figtext(.33 + .025, 1. - .075, '(c)', fontproperties=fp)

    # Plot something... z samples?
    ax4 = fig.add_subplot(gs[1, 1])
    plot_z_samples(K, zs_rslds, zref=z_true, plt_slice=(0, 1000), ax=ax4)
    ax4.set_title("Discrete State Samples")
    plt.figtext(.33 + .025, .5 - .075, '(d)', fontproperties=fp)

    # Plot simulated SLDS data
    ax5 = fig.add_subplot(gs[0, 2])
    # for n, ls in enumerate(["-", ":", "-."]):
    #     plot_data(z_slds_gen[-1000:], y_slds_gen[-1000:, n], ax=ax5, ls=ls)
    plot_trajectory(z_slds_gen[-1000:], x_slds_gen[-1000:], ax=ax5, ls="-")
    # ax5.set_xlabel("Time")
    # ax5.set_ylabel("$y$")
    plt.grid(True)
    ax5.set_title("Generated States (SLDS)")
    plt.figtext(.66 + .025, 1. - .075, '(e)', fontproperties=fp)

    # Plot simulated rSLDS data
    ax6 = fig.add_subplot(gs[1, 2])
    # for n, ls in enumerate(["-", ":", "-."]):
    #     plot_data(z_rslds_gen[-1000:], y_rslds_gen[-1000:, n], ax=ax6, ls=ls)
    # ax6.set_xlabel("Time")
    # ax6.set_ylabel("$y$")
    plot_trajectory(z_rslds_gen[-1000:], x_rslds_gen[-1000:], ax=ax6, ls="-")
    ax6.set_title("Generated States (rSLDS)")
    plt.grid(True)
    plt.figtext(.66 + .025, .5 - .075, '(f)', fontproperties=fp)

    plt.tight_layout()
    plt.savefig(os.path.join(RESULTS_DIR, "nascar.png"), dpi=200)
    plt.savefig(os.path.join(RESULTS_DIR, "nascar.pdf"))
    plt.show()
    plt.xlabel("Iteration")
    plt.ylabel("ELBO")

    # rslds, rslds_lps, rslds_z_smpls, rslds_x = \
    #     fit_rslds_vbem(inputs, y, mask,
    #                    initialization="none",
    #                    N_iters=100)

    plot_trajectory_and_probs(rslds_z_smpls[-1][1:],
                              rslds_x[1:],
                              trans_distn=rslds.trans_distn,
                              title="Recurrent SLDS")

    plot_all_dynamics(rslds.dynamics_distns)

    plot_z_samples(K, rslds_z_smpls, plt_slice=(0, 1000))

    ## Generate from the model
    rslds_y_gen, rslds_x_gen, rslds_z_gen = rslds.generate(T=T_gen,
                                                           inputs=np.ones(
                                                               (T_gen, 1)),
                                                           with_noise=True)
    slds_y_gen, slds_x_gen, slds_z_gen = slds.generate(T=T_gen,
                                                       inputs=np.ones(
                                                           (T_gen, 1)))

    make_figure(
        true_model,
        z_true,
        x_true,
        y,