def plot_svgp_and_start_end(gp, solver, traj_opt=None): params = { "text.usetex": True, "text.latex.preamble": [ "\\usepackage{amssymb}", "\\usepackage{amsmath}", ], } plt.rcParams.update(params) Xnew, xx, yy = create_grid(gp.X, N=961) mu, var = gp_predict( Xnew, gp.Z, kernels=gp.kernel, mean_funcs=gp.mean_func, f=gp.q_mu, q_sqrt=gp.q_sqrt, full_cov=False, ) print("mu var") # print(mu.shape) # print(var.shape) # mu = mu[0:1, :, :] # var = var[0:1, :] # mu = mu[1:2, :, :] # var = var[1:2, :] print(mu.shape) print(var.shape) fig, axs = plot_mean_and_var( xx, yy, mu, var, llabel="$\mathbb{E}[h^{(1)}(\mathbf{x})]$", rlabel="$\mathbb{V}[h^{(1)}(\mathbf{x})]$", ) for ax in axs: fig, ax = plot_start_and_end_pos(fig, ax, solver) # plot_omitted_data(fig, ax, color="k") # ax.scatter(gp.X[:, 0], gp.X[:, 1]) plot_traj( fig, ax, solver.state_guesses, color=color_init, label="Initial trajectory", ) if traj_opt is not None: plot_traj( fig, ax, traj_opt, color=color_opt, label="Optimised trajectory", ) axs[0].legend() return fig, axs
def plot_gp_and_start_end(gp, solver): Xnew, xx, yy = create_grid(gp.X, N=961) mu, var = gp_predict( Xnew, gp.X, kernels=gp.kernel, mean_funcs=gp.mean_func, f=gp.Y, q_sqrt=gp.q_sqrt, full_cov=False, ) fig, axs = plot_mean_and_var(xx, yy, mu, var) for ax in axs: fig, ax = plot_start_and_end_pos(fig, ax, solver) return fig, axs
def plot_ode_svgp_quad(): from ProbGeo.visualisation.gp import plot_mean_and_var from ProbGeo.visualisation.utils import create_grid from ProbGeo.gp import gp_predict gp = FakeSVGPQuad Xnew, xx, yy = create_grid(gp.X, N=961) mu, var = gp_predict(Xnew, gp.Z, kernel=gp.kernel, mean_func=gp.mean_func, f=gp.q_mu, q_sqrt=gp.q_sqrt, full_cov=True) # full_cov=False) var = np.diag(var) fig, axs = plot_mean_and_var(xx, yy, mu, var) from ProbGeo.metric_tensor import gp_metric_tensor ode = FakeCollocationSVGPQuad() metric = FakeSVGPMetricQuad() metric_fn = gp_metric_tensor def ode_fn(state_init): print('inside ode') print(state_init.shape) print(ode.vel_init_guess.shape) state_init = np.concatenate([state_init, ode.vel_init_guess]) print('after concat') print(state_init.shape) state_prime = geodesic_ode(ode.times, state_init, metric_fn, metric.metric_fn_kwargs) return state_prime state_primes = jax.vmap(ode_fn)(Xnew) print('state primes') print(state_primes.shape) print(state_primes) for ax in axs: # ax.quiver(Xnew[:, 0], Xnew[:, 1], state_primes[:, 0], state_primes[:, # 1]) ax.quiver(Xnew[:, 0], Xnew[:, 1], state_primes[:, 2], state_primes[:, 3]) # fig, ax = plot_start_and_end_pos(fig, ax, solver) # return fig, axs plt.show()
def plot_svgp_and_start_end(gp, solver): from ProbGeo.gp import gp_predict from ProbGeo.visualisation.gp import plot_mean_and_var from ProbGeo.visualisation.utils import create_grid Xnew, xx, yy = create_grid(gp.X, N=961) mu, var = gp_predict( Xnew, gp.Z, kernel=gp.kernel, mean_func=gp.mean_func, f=gp.q_mu, q_sqrt=gp.q_sqrt, full_cov=False, ) fig, axs = plot_mean_and_var(xx, yy, mu, var) for ax in axs: fig, ax = plot_start_and_end_pos(fig, ax, solver) return fig, axs
@tf.function def tf_optimization_step(): optimizer.minimize(training_loss, m.trainable_variables) for epoch in range(epochs): for _ in range(num_batches_per_epoch): tf_optimization_step() # tf_optimization_step(model, training_loss, optimizer) epoch_id = epoch + 1 if epoch_id % logging_epoch_freq == 0: tf.print(f"Epoch {epoch_id}: ELBO (train) {training_loss()}") gpf.utilities.print_summary(m) mu, var = m.predict_y(Xnew) fig, axs = plot_mean_and_var(xx, yy, mu.numpy(), var.numpy()) mu, var = m.predict_f(Xnew) fig, axs = plot_mean_and_var(xx, yy, mu.numpy(), var.numpy()) plt.show() lengthscales = m.kernel.lengthscales.numpy() variance = m.kernel.variance.numpy() q_mu = m.q_mu.numpy() q_sqrt = m.q_sqrt.numpy() z = m.inducing_variable.Z.numpy() mean_func = m.mean_function.c.numpy() np.savez( save_params_filename, l=lengthscales,
def plot_svgp_jacobian_mean(gp, solver, traj_opt=None): params = { "text.usetex": True, "text.latex.preamble": [ "\\usepackage{amssymb}", "\\usepackage{amsmath}", ], } plt.rcParams.update(params) Xnew, xx, yy = create_grid(gp.X, N=961) mu, var = gp_predict( Xnew, gp.Z, kernels=gp.kernel, mean_funcs=gp.mean_func, f=gp.q_mu, q_sqrt=gp.q_sqrt, full_cov=False, ) def gp_jacobian_all(x): if len(x.shape) == 1: x = x.reshape(1, -1) return gp_jacobian( x, gp.Z, gp.kernel, gp.mean_func, f=gp.q_mu, q_sqrt=gp.q_sqrt, full_cov=False, ) mu_j, var_j = jax.vmap(gp_jacobian_all, in_axes=(0))(Xnew) print("gp jacobain mu var") print(mu_j.shape) print(var_j.shape) # mu = np.prod(mu, 1) # var = np.diagonal(var, axis1=-2, axis2=-1) # var = np.prod(var, 1) fig, axs = plot_mean_and_var( xx, yy, mu, var, # mu, # var, llabel="$\mathbb{E}[h^{(1)}]$", rlabel="$\mathbb{V}[h^{(1)}]$", ) for ax in axs: ax.quiver(Xnew[:, 0], Xnew[:, 1], mu_j[:, 0], mu_j[:, 1], color="k") fig, ax = plot_start_and_end_pos(fig, ax, solver) plot_omitted_data(fig, ax, color="k") # ax.scatter(gp.X[:, 0], gp.X[:, 1]) plot_traj( fig, ax, solver.state_guesses, color=color_init, label="Initial trajectory", ) if traj_opt is not None: plot_traj( fig, ax, traj_opt, color=color_opt, label="Optimised trajectory", ) axs[0].legend() return fig, axs