def plot_svgp_mixing_prob_and_start_end(gp, solver, traj_opt=None): # plot original GP Xnew, xx, yy = create_grid(gp.X, N=961) mixing_probs = mogpe_mixing_probability( Xnew, gp.Z, gp.kernel, mean_func=gp.mean_func, f=gp.q_mu, q_sqrt=gp.q_sqrt, full_cov=False, ) fig, ax = plt.subplots(1, 1) plot_contourf(fig, ax, xx, yy, mixing_probs[:, 0:1]) plot_start_and_end_pos(fig, ax, solver) plot_traj( fig, ax, solver.state_guesses, color=color_init, label="Initial trajectory", ) if traj_opt is not None: plot_traj( fig, ax, traj_opt, color=color_opt, label="Optimised trajectory" ) ax.legend() return fig, ax
def plot_mixing_prob_and_start_end(gp, solver): from ProbGeo.mogpe import mogpe_mixing_probability from ProbGeo.visualisation.gp import plot_contourf from ProbGeo.visualisation.utils import create_grid # plot original GP Xnew, xx, yy = create_grid(gp.X, N=961) mixing_probs = mogpe_mixing_probability( Xnew, gp.X, gp.kernel, mean_func=gp.mean_func, f=gp.Y, q_sqrt=gp.q_sqrt, full_cov=False, ) fig, ax = plt.subplots(1, 1) plot_contourf(fig, ax, xx, yy, mixing_probs[:, 0:1]) ax.scatter(solver.pos_init[0], solver.pos_init[1], marker="o", color="r") ax.scatter(solver.pos_end_targ[0], solver.pos_end_targ[1], color="r", marker="o") ax.annotate("start", (solver.pos_init[0], solver.pos_init[1])) ax.annotate("end", (solver.pos_end_targ[0], solver.pos_end_targ[1])) return fig, ax
def plot_svgp_and_start_end(gp, solver, traj_opt=None): params = { "text.usetex": True, "text.latex.preamble": [ "\\usepackage{amssymb}", "\\usepackage{amsmath}", ], } plt.rcParams.update(params) Xnew, xx, yy = create_grid(gp.X, N=961) mu, var = gp_predict( Xnew, gp.Z, kernels=gp.kernel, mean_funcs=gp.mean_func, f=gp.q_mu, q_sqrt=gp.q_sqrt, full_cov=False, ) print("mu var") # print(mu.shape) # print(var.shape) # mu = mu[0:1, :, :] # var = var[0:1, :] # mu = mu[1:2, :, :] # var = var[1:2, :] print(mu.shape) print(var.shape) fig, axs = plot_mean_and_var( xx, yy, mu, var, llabel="$\mathbb{E}[h^{(1)}(\mathbf{x})]$", rlabel="$\mathbb{V}[h^{(1)}(\mathbf{x})]$", ) for ax in axs: fig, ax = plot_start_and_end_pos(fig, ax, solver) # plot_omitted_data(fig, ax, color="k") # ax.scatter(gp.X[:, 0], gp.X[:, 1]) plot_traj( fig, ax, solver.state_guesses, color=color_init, label="Initial trajectory", ) if traj_opt is not None: plot_traj( fig, ax, traj_opt, color=color_opt, label="Optimised trajectory", ) axs[0].legend() return fig, axs
def plot_domain_and_start_end(gp, solver, traj_opts=None, labels=None): from ProbGeo.visualisation.utils import create_grid from ProbGeo.mogpe import mogpe_mixing_probability, single_mogpe_mixing_probability # plot original GP Xnew, xx, yy = create_grid(gp.X, N=961) mixing_probs = jax.vmap(single_mogpe_mixing_probability, (0, None, None, None, None, None, None))( Xnew, gp.Z, gp.kernel, gp.mean_func, gp.q_mu, False, gp.q_sqrt) fig, ax = plt.subplots(1, 1) contf = ax.contourf(xx, yy, mixing_probs[:, 0:1].reshape(xx.shape), cmap=cm.coolwarm, levels=[0., 0.5, 1.0], linewidth=0, antialiased=False) # cbar = fig.colorbar(contf, shrink=0.5, aspect=5, ax=ax) # cbar.set_label('$Pr(\\alpha=1 | \mathbf{x})$') ax.set_xlabel('$x$') ax.set_ylabel('$y$') plot_omitted_data(fig, ax, color='k') plot_start_and_end_pos(fig, ax, solver) ax.annotate( "Mode 1", (-2.2, -1.5), horizontalalignment='left', verticalalignment='top', ) ax.annotate( "Mode 2", (0.1, -0.2), horizontalalignment='left', verticalalignment='top', ) # plot_traj(fig, # ax, # solver.state_guesses, # color=color_init, # label='Initial trajectory') # if traj_opts is not None: # for traj, label, color in zip(traj_opts, labels, color_opts): # plot_traj(fig, ax, traj, color=color, label=label) ax.legend(loc=3) return fig, ax
def plot_3d_traj_mean_and_var(fig, axs, gp, traj): Xnew, xx, yy = create_grid(gp.X, N=961) traj_mu, traj_var = gp_predict( traj[:, 0:2], gp.Z, kernels=gp.kernel, mean_funcs=gp.mean_func, f=gp.q_mu, q_sqrt=gp.q_sqrt, full_cov=False, ) plot_3d_traj(fig, axs[0], traj, zs=traj_mu) plot_3d_traj(fig, axs[1], traj, zs=traj_var) return fig, axs
def plot_svgp_jacobian_var(gp, solver, traj_opt=None): params = { "text.usetex": True, "text.latex.preamble": [ "\\usepackage{amssymb}", "\\usepackage{amsmath}", ], } plt.rcParams.update(params) input_dim = gp.X.shape[1] Xnew, xx, yy = create_grid(gp.X, N=961) def gp_jacobian_all(x): if len(x.shape) == 1: x = x.reshape(1, -1) return gp_jacobian( x, gp.Z, gp.kernel, gp.mean_func, f=gp.q_mu, q_sqrt=gp.q_sqrt, full_cov=False, ) mu_j, var_j = jax.vmap(gp_jacobian_all, in_axes=(0))(Xnew) print("gp jacobain mu var") print(mu_j.shape) print(var_j.shape) # mu = np.prod(mu, 1) # var = np.diagonal(var, axis1=-2, axis2=-1) # var = np.prod(var, 1) fig, axs = plt.subplots(input_dim, input_dim) for i in range(input_dim): for j in range(input_dim): plot_contourf( fig, axs[i, j], xx, yy, var_j[:, i, j], label="$\Sigma_J(\mathbf{x})$", ) axs[i, j].set_xlabel("$x$") axs[i, j].set_ylabel("$y$") return fig, axs
def plot_gp_and_start_end(gp, solver): Xnew, xx, yy = create_grid(gp.X, N=961) mu, var = gp_predict( Xnew, gp.X, kernels=gp.kernel, mean_funcs=gp.mean_func, f=gp.Y, q_sqrt=gp.q_sqrt, full_cov=False, ) fig, axs = plot_mean_and_var(xx, yy, mu, var) for ax in axs: fig, ax = plot_start_and_end_pos(fig, ax, solver) return fig, axs
def plot_3d_mixing_prob(gp, solver): # plot original GP Xnew, xx, yy = create_grid(gp.X, N=961) mixing_probs = jax.vmap( single_mogpe_mixing_probability, (0, None, None, None, None, None, None), )(Xnew, gp.Z, gp.kernel, gp.mean_func, gp.q_mu, False, gp.q_sqrt) # fig = plt.figure(figsize=plt.figaspect(0.5)) fig = plt.figure() ax = fig.add_subplot(1, 1, 1, projection="3d") surf = plot_3d_surf(fig, ax, xx, yy, mixing_probs[:, 0:1]) ax.set_zlabel("$\Pr(\\alpha_*=0 | \mathbf{x}_*)$") return fig, ax
def plot_ode_svgp_quad(): from ProbGeo.visualisation.gp import plot_mean_and_var from ProbGeo.visualisation.utils import create_grid from ProbGeo.gp import gp_predict gp = FakeSVGPQuad Xnew, xx, yy = create_grid(gp.X, N=961) mu, var = gp_predict(Xnew, gp.Z, kernel=gp.kernel, mean_func=gp.mean_func, f=gp.q_mu, q_sqrt=gp.q_sqrt, full_cov=True) # full_cov=False) var = np.diag(var) fig, axs = plot_mean_and_var(xx, yy, mu, var) from ProbGeo.metric_tensor import gp_metric_tensor ode = FakeCollocationSVGPQuad() metric = FakeSVGPMetricQuad() metric_fn = gp_metric_tensor def ode_fn(state_init): print('inside ode') print(state_init.shape) print(ode.vel_init_guess.shape) state_init = np.concatenate([state_init, ode.vel_init_guess]) print('after concat') print(state_init.shape) state_prime = geodesic_ode(ode.times, state_init, metric_fn, metric.metric_fn_kwargs) return state_prime state_primes = jax.vmap(ode_fn)(Xnew) print('state primes') print(state_primes.shape) print(state_primes) for ax in axs: # ax.quiver(Xnew[:, 0], Xnew[:, 1], state_primes[:, 0], state_primes[:, # 1]) ax.quiver(Xnew[:, 0], Xnew[:, 1], state_primes[:, 2], state_primes[:, 3]) # fig, ax = plot_start_and_end_pos(fig, ax, solver) # return fig, axs plt.show()
def plot_svgp_and_start_end(gp, solver): from ProbGeo.gp import gp_predict from ProbGeo.visualisation.gp import plot_mean_and_var from ProbGeo.visualisation.utils import create_grid Xnew, xx, yy = create_grid(gp.X, N=961) mu, var = gp_predict( Xnew, gp.Z, kernel=gp.kernel, mean_func=gp.mean_func, f=gp.q_mu, q_sqrt=gp.q_sqrt, full_cov=False, ) fig, axs = plot_mean_and_var(xx, yy, mu, var) for ax in axs: fig, ax = plot_start_and_end_pos(fig, ax, solver) return fig, axs
def test_and_plot_shooting_geodesic_solver_with_prob_svgp(): import matplotlib.pyplot as plt from ProbGeo.mogpe import mogpe_mixing_probability from ProbGeo.visualisation.gp import plot_contourf from ProbGeo.visualisation.utils import create_grid # Plot manifold with start and end points gp = FakeSVGP() solver = FakeODESolverProbSVGP() # plot original GP Xnew, xx, yy = create_grid(gp.X, N=961) mixing_probs = mogpe_mixing_probability( Xnew, gp.Z, gp.kernel, mean_func=gp.mean_func, f=gp.q_mu, q_sqrt=gp.q_sqrt, full_cov=False, ) fig, ax = plt.subplots(1, 1) plot_contourf(fig, ax, xx, yy, mixing_probs[:, 0:1]) ax.scatter(solver.pos_init[0], solver.pos_init[1], marker="o", color="r") ax.scatter(solver.pos_end_targ[0], solver.pos_end_targ[1], color="r", marker="o") ax.annotate("start", (solver.pos_init[0], solver.pos_init[1])) ax.annotate("end", (solver.pos_end_targ[0], solver.pos_end_targ[1])) # plt.show() ( opt_vel_init, geodesic_traj, ) = test_shooting_geodesic_solver_with_prob_svgp() ax.scatter(geodesic_traj[0, :], geodesic_traj[1, :], marker="x", color="k") ax.plot(geodesic_traj[0, :], geodesic_traj[1, :], marker="x", color="k") plt.show()
def plot_svgp_metric_and_start_end(metric, solver, traj_opt=None): params = { "text.usetex": True, "text.latex.preamble": [ "\\usepackage{amssymb}", "\\usepackage{amsmath}", ], } plt.rcParams.update(params) input_dim = metric.gp.X.shape[1] # plot original GP Xnew, xx, yy = create_grid(metric.gp.X, N=961) metric_tensor, mu_j, cov_j = gp_metric_tensor( Xnew, metric.gp.Z, metric.gp.kernel, mean_func=metric.gp.mean_func, f=metric.gp.q_mu, full_cov=True, q_sqrt=metric.gp.q_sqrt, cov_weight=metric.cov_weight, ) print("metric yo yo") print(metric_tensor.shape) fig, axs = plt.subplots(input_dim, input_dim) for i in range(input_dim): for j in range(input_dim): plot_contourf( fig, axs[i, j], xx, yy, metric_tensor[:, i, j], label="$G(\mathbf{x})$", ) axs[i, j].set_xlabel("$x$") axs[i, j].set_ylabel("$y$") return fig, axs
def plot_3d_metric_trace(metric, solver): # plot original GP Xnew, xx, yy = create_grid(metric.gp.X, N=961) metric_tensor, mu_j, cov_j = gp_metric_tensor( Xnew, metric.gp.Z, metric.gp.kernel, mean_func=metric.gp.mean_func, f=metric.gp.q_mu, full_cov=True, q_sqrt=metric.gp.q_sqrt, cov_weight=metric.cov_weight, ) metric_trace = np.trace(metric_tensor, axis1=1, axis2=2) # fig = plt.figure(figsize=plt.figaspect(0.5)) fig = plt.figure() ax = fig.add_subplot(1, 1, 1, projection="3d") surf = plot_3d_surf(fig, ax, xx, yy, metric_trace) ax.set_zlabel("Tr$(G(\mathbf{x}_*))$") return fig, ax
def plot_svgp_metric_trace_and_start_end(metric, solver, traj_opt=None): # plot original GP Xnew, xx, yy = create_grid(metric.gp.X, N=961) metric_tensor, mu_j, cov_j = gp_metric_tensor( Xnew, metric.gp.Z, metric.gp.kernel, mean_func=metric.gp.mean_func, f=metric.gp.q_mu, full_cov=True, q_sqrt=metric.gp.q_sqrt, cov_weight=metric.cov_weight, ) metric_trace = np.trace(metric_tensor, axis1=1, axis2=2) fig, ax = plt.subplots(1, 1) # fig, ax = plt.subplots(1, 1, figsize=(12, 4)) plt.subplots_adjust(wspace=0, hspace=0) surf_traceG = plot_contourf( fig, ax, xx, yy, metric_trace, label="Tr$(G(\mathbf{x}))$" ) plot_start_and_end_pos(fig, ax, solver) ax.set_xlabel("$x$") ax.set_ylabel("$y$") plot_traj( fig, ax, solver.state_guesses, color=color_init, label="Initial trajectory", ) if traj_opt is not None: plot_traj( fig, ax, traj_opt, color=color_opt, label="Optimised trajectory" ) ax.legend() return fig, ax
def plot_3d_mean_and_var(gp, solver): Xnew, xx, yy = create_grid(gp.X, N=961) mu, var = gp_predict( Xnew, gp.Z, kernels=gp.kernel, mean_funcs=gp.mean_func, f=gp.q_mu, q_sqrt=gp.q_sqrt, full_cov=False, ) fig = plt.figure(figsize=plt.figaspect(0.5)) ax_mu = fig.add_subplot(1, 2, 1, projection="3d") surf_mu = plot_3d_surf(fig, ax_mu, xx, yy, mu) ax_var = fig.add_subplot(1, 2, 2, projection="3d") surf_var = plot_3d_surf(fig, ax_var, xx, yy, var) axs = [ax_mu, ax_var] ax_mu.set_zlabel("Mean") ax_var.set_zlabel("Variance") return fig, axs
inducing_variable = X[idx, ...].reshape(-1, input_dim) inducing_variable = gpf.inducing_variables.InducingPoints(inducing_variable) m = gpf.models.SVGP( kernel=kern, likelihood=lik, inducing_variable=inducing_variable, mean_function=mean_func, ) gpf.utilities.print_summary(m) gpf.set_trainable(m.likelihood.variance, False) gpf.set_trainable(m.inducing_variable, False) gpf.utilities.print_summary(m) Xnew, xx, yy = create_grid(X, N=961) # mu, var = m.predict_y(Xnew) # fig, axs = plot_mean_and_var(xx, yy, mu.numpy(), var.numpy()) # plt.show() optimizer = tf.optimizers.Adam() prefetch_size = tf.data.experimental.AUTOTUNE shuffle_buffer_size = num_data // 2 num_batches_per_epoch = num_data // batch_size train_dataset = tf.data.Dataset.from_tensor_slices(dataset) train_dataset = ( train_dataset.repeat() .prefetch(prefetch_size) .shuffle(buffer_size=shuffle_buffer_size) .batch(batch_size, drop_remainder=True)
def plot_mixing_prob_and_start_end(gp, solver, traj_opt=None): # plot original GP Xnew, xx, yy = create_grid(gp.X, N=961) # TODO need to change gp.X to gp.Z (and gp.q_mu) for sparse mixing_probs = jax.vmap( single_mogpe_mixing_probability, (0, None, None, None, None, None, None), )( Xnew, # gp.X, gp.Z, gp.kernel, gp.mean_func, gp.q_mu, False, gp.q_sqrt, ) # mixing_probs = mogpe_mixing_probability(Xnew, # gp.X, # gp.kernel, # mean_func=gp.mean_func, # f=gp.Y, # q_sqrt=gp.q_sqrt, # full_cov=False) # print('mixing probs yo') # print(mixing_probs.shape) # mixing_probs = mixing_probs[:, 0, :] * mixing_probs[:, 1, :] # output_dim = mixing_probs.shape[0] fig, ax = plt.subplots(1, 1) plot_contourf( fig, ax, xx, yy, # mixing_probs[:, 1:2], mixing_probs[:, 0:1], label="$\Pr(\\alpha=1 | \mathbf{x})$", ) ax.set_xlabel("$x$") ax.set_ylabel("$y$") plot_omitted_data(fig, ax, color="k") plot_start_and_end_pos(fig, ax, solver) plot_traj( fig, ax, solver.state_guesses, color=color_init, label="Initial trajectory", ) if traj_opt is not None: plot_traj( fig, ax, traj_opt, color=color_opt, label="Optimised trajectory" ) ax.legend() return fig, ax
def plot_svgp_jacobian_mean(gp, solver, traj_opt=None): params = { "text.usetex": True, "text.latex.preamble": [ "\\usepackage{amssymb}", "\\usepackage{amsmath}", ], } plt.rcParams.update(params) Xnew, xx, yy = create_grid(gp.X, N=961) mu, var = gp_predict( Xnew, gp.Z, kernels=gp.kernel, mean_funcs=gp.mean_func, f=gp.q_mu, q_sqrt=gp.q_sqrt, full_cov=False, ) def gp_jacobian_all(x): if len(x.shape) == 1: x = x.reshape(1, -1) return gp_jacobian( x, gp.Z, gp.kernel, gp.mean_func, f=gp.q_mu, q_sqrt=gp.q_sqrt, full_cov=False, ) mu_j, var_j = jax.vmap(gp_jacobian_all, in_axes=(0))(Xnew) print("gp jacobain mu var") print(mu_j.shape) print(var_j.shape) # mu = np.prod(mu, 1) # var = np.diagonal(var, axis1=-2, axis2=-1) # var = np.prod(var, 1) fig, axs = plot_mean_and_var( xx, yy, mu, var, # mu, # var, llabel="$\mathbb{E}[h^{(1)}]$", rlabel="$\mathbb{V}[h^{(1)}]$", ) for ax in axs: ax.quiver(Xnew[:, 0], Xnew[:, 1], mu_j[:, 0], mu_j[:, 1], color="k") fig, ax = plot_start_and_end_pos(fig, ax, solver) plot_omitted_data(fig, ax, color="k") # ax.scatter(gp.X[:, 0], gp.X[:, 1]) plot_traj( fig, ax, solver.state_guesses, color=color_init, label="Initial trajectory", ) if traj_opt is not None: plot_traj( fig, ax, traj_opt, color=color_opt, label="Optimised trajectory", ) axs[0].legend() return fig, axs