Ejemplo n.º 1
0
def plot_gmm_3d_trace(trace_hist,
                      π,
                      μ,
                      σ,
                      title,
                      xmin,
                      xmax,
                      ax,
                      h=1,
                      n_eval=500):
    x_eval = np.linspace(xmin, xmax, n_eval)
    kde_eval = pml.kdeg(x_eval[:, None], trace_hist[:, None], h)
    px = norm(μ, σ).pdf(x_eval[:, None]) * π
    px = px.sum(axis=-1)

    ax.plot(np.arange(n_iterations), trace_hist)
    ax.plot(np.zeros(n_eval), x_eval, px, c="tab:red", linewidth=2)
    ax.plot(np.zeros(n_eval), x_eval, kde_eval, c="tab:blue")

    ax.set_zlim(0, kde_eval.max() * 1.1)
    ax.set_xlabel("Iterations")
    ax.set_ylabel("Samples")

    ax.view_init(25, -30)
    ax.set_title(title)
Ejemplo n.º 2
0
def plot_3d_belief_state(mu_hist,
                         dim,
                         ax,
                         skip=3,
                         npoints=2000,
                         azimuth=-30,
                         elevation=30):
    nsteps = len(mu_hist)
    xmin, xmax = mu_hist[..., dim].min(), mu_hist[..., dim].max()
    xrange = jnp.linspace(xmin, xmax, npoints).reshape(-1, 1)
    res = np.apply_along_axis(lambda X: pml.kdeg(xrange, X[..., None], 0.5), 1,
                              mu_hist)
    densities = res[..., dim]
    for t in range(0, nsteps, skip):
        tloc = t * np.ones(npoints)
        px = densities[t]
        ax.plot(tloc, xrange, px, c="tab:blue", linewidth=1)
    ax.set_zlim(0, 1)
    pml.style3d(ax, 1.8, 1.2, 0.7, 0.8)
    ax.view_init(elevation, azimuth)
    ax.set_xlabel(r"$t$", fontsize=13)
    ax.set_ylabel(r"$x_{" f"d={dim}" ",t}$", fontsize=13)
    ax.set_zlabel(r"$p(x_{d, t} \vert y_{1:t})$", fontsize=13)