Ejemplo n.º 1
0
def plot_svgp_and_start_end(gp, solver, traj_opt=None):
    params = {
        "text.usetex": True,
        "text.latex.preamble": [
            "\\usepackage{amssymb}",
            "\\usepackage{amsmath}",
        ],
    }
    plt.rcParams.update(params)

    Xnew, xx, yy = create_grid(gp.X, N=961)
    mu, var = gp_predict(
        Xnew,
        gp.Z,
        kernels=gp.kernel,
        mean_funcs=gp.mean_func,
        f=gp.q_mu,
        q_sqrt=gp.q_sqrt,
        full_cov=False,
    )
    print("mu var")
    # print(mu.shape)
    # print(var.shape)
    # mu = mu[0:1, :, :]
    # var = var[0:1, :]
    # mu = mu[1:2, :, :]
    # var = var[1:2, :]
    print(mu.shape)
    print(var.shape)
    fig, axs = plot_mean_and_var(
        xx,
        yy,
        mu,
        var,
        llabel="$\mathbb{E}[h^{(1)}(\mathbf{x})]$",
        rlabel="$\mathbb{V}[h^{(1)}(\mathbf{x})]$",
    )

    for ax in axs:
        fig, ax = plot_start_and_end_pos(fig, ax, solver)
        # plot_omitted_data(fig, ax, color="k")
        # ax.scatter(gp.X[:, 0], gp.X[:, 1])
        plot_traj(
            fig,
            ax,
            solver.state_guesses,
            color=color_init,
            label="Initial trajectory",
        )
        if traj_opt is not None:
            plot_traj(
                fig,
                ax,
                traj_opt,
                color=color_opt,
                label="Optimised trajectory",
            )
    axs[0].legend()
    return fig, axs
Ejemplo n.º 2
0
def plot_gp_and_start_end(gp, solver):
    Xnew, xx, yy = create_grid(gp.X, N=961)
    mu, var = gp_predict(
        Xnew,
        gp.X,
        kernels=gp.kernel,
        mean_funcs=gp.mean_func,
        f=gp.Y,
        q_sqrt=gp.q_sqrt,
        full_cov=False,
    )
    fig, axs = plot_mean_and_var(xx, yy, mu, var)

    for ax in axs:
        fig, ax = plot_start_and_end_pos(fig, ax, solver)
    return fig, axs
def plot_ode_svgp_quad():
    from ProbGeo.visualisation.gp import plot_mean_and_var
    from ProbGeo.visualisation.utils import create_grid
    from ProbGeo.gp import gp_predict
    gp = FakeSVGPQuad
    Xnew, xx, yy = create_grid(gp.X, N=961)
    mu, var = gp_predict(Xnew,
                         gp.Z,
                         kernel=gp.kernel,
                         mean_func=gp.mean_func,
                         f=gp.q_mu,
                         q_sqrt=gp.q_sqrt,
                         full_cov=True)
    # full_cov=False)
    var = np.diag(var)
    fig, axs = plot_mean_and_var(xx, yy, mu, var)

    from ProbGeo.metric_tensor import gp_metric_tensor
    ode = FakeCollocationSVGPQuad()
    metric = FakeSVGPMetricQuad()
    metric_fn = gp_metric_tensor

    def ode_fn(state_init):
        print('inside ode')
        print(state_init.shape)
        print(ode.vel_init_guess.shape)
        state_init = np.concatenate([state_init, ode.vel_init_guess])
        print('after concat')
        print(state_init.shape)
        state_prime = geodesic_ode(ode.times, state_init, metric_fn,
                                   metric.metric_fn_kwargs)
        return state_prime

    state_primes = jax.vmap(ode_fn)(Xnew)
    print('state primes')
    print(state_primes.shape)
    print(state_primes)

    for ax in axs:
        # ax.quiver(Xnew[:, 0], Xnew[:, 1], state_primes[:, 0], state_primes[:,
        #                                                                    1])
        ax.quiver(Xnew[:, 0], Xnew[:, 1], state_primes[:, 2], state_primes[:,
                                                                           3])
    #     fig, ax = plot_start_and_end_pos(fig, ax, solver)
    # return fig, axs
    plt.show()
Ejemplo n.º 4
0
def plot_svgp_and_start_end(gp, solver):
    from ProbGeo.gp import gp_predict
    from ProbGeo.visualisation.gp import plot_mean_and_var
    from ProbGeo.visualisation.utils import create_grid

    Xnew, xx, yy = create_grid(gp.X, N=961)
    mu, var = gp_predict(
        Xnew,
        gp.Z,
        kernel=gp.kernel,
        mean_func=gp.mean_func,
        f=gp.q_mu,
        q_sqrt=gp.q_sqrt,
        full_cov=False,
    )
    fig, axs = plot_mean_and_var(xx, yy, mu, var)

    for ax in axs:
        fig, ax = plot_start_and_end_pos(fig, ax, solver)
    return fig, axs
Ejemplo n.º 5
0
@tf.function
def tf_optimization_step():
    optimizer.minimize(training_loss, m.trainable_variables)


for epoch in range(epochs):
    for _ in range(num_batches_per_epoch):
        tf_optimization_step()
        # tf_optimization_step(model, training_loss, optimizer)
    epoch_id = epoch + 1
    if epoch_id % logging_epoch_freq == 0:
        tf.print(f"Epoch {epoch_id}: ELBO (train) {training_loss()}")

gpf.utilities.print_summary(m)
mu, var = m.predict_y(Xnew)
fig, axs = plot_mean_and_var(xx, yy, mu.numpy(), var.numpy())
mu, var = m.predict_f(Xnew)
fig, axs = plot_mean_and_var(xx, yy, mu.numpy(), var.numpy())
plt.show()

lengthscales = m.kernel.lengthscales.numpy()
variance = m.kernel.variance.numpy()

q_mu = m.q_mu.numpy()
q_sqrt = m.q_sqrt.numpy()
z = m.inducing_variable.Z.numpy()
mean_func = m.mean_function.c.numpy()

np.savez(
    save_params_filename,
    l=lengthscales,
Ejemplo n.º 6
0
def plot_svgp_jacobian_mean(gp, solver, traj_opt=None):
    params = {
        "text.usetex": True,
        "text.latex.preamble": [
            "\\usepackage{amssymb}",
            "\\usepackage{amsmath}",
        ],
    }
    plt.rcParams.update(params)

    Xnew, xx, yy = create_grid(gp.X, N=961)
    mu, var = gp_predict(
        Xnew,
        gp.Z,
        kernels=gp.kernel,
        mean_funcs=gp.mean_func,
        f=gp.q_mu,
        q_sqrt=gp.q_sqrt,
        full_cov=False,
    )

    def gp_jacobian_all(x):
        if len(x.shape) == 1:
            x = x.reshape(1, -1)
        return gp_jacobian(
            x,
            gp.Z,
            gp.kernel,
            gp.mean_func,
            f=gp.q_mu,
            q_sqrt=gp.q_sqrt,
            full_cov=False,
        )

    mu_j, var_j = jax.vmap(gp_jacobian_all, in_axes=(0))(Xnew)
    print("gp jacobain mu var")
    print(mu_j.shape)
    print(var_j.shape)
    # mu = np.prod(mu, 1)
    # var = np.diagonal(var, axis1=-2, axis2=-1)
    # var = np.prod(var, 1)
    fig, axs = plot_mean_and_var(
        xx,
        yy,
        mu,
        var,
        # mu,
        # var,
        llabel="$\mathbb{E}[h^{(1)}]$",
        rlabel="$\mathbb{V}[h^{(1)}]$",
    )

    for ax in axs:
        ax.quiver(Xnew[:, 0], Xnew[:, 1], mu_j[:, 0], mu_j[:, 1], color="k")

        fig, ax = plot_start_and_end_pos(fig, ax, solver)
        plot_omitted_data(fig, ax, color="k")
        # ax.scatter(gp.X[:, 0], gp.X[:, 1])
        plot_traj(
            fig,
            ax,
            solver.state_guesses,
            color=color_init,
            label="Initial trajectory",
        )
        if traj_opt is not None:
            plot_traj(
                fig,
                ax,
                traj_opt,
                color=color_opt,
                label="Optimised trajectory",
            )
    axs[0].legend()
    return fig, axs