def plot_svgp_and_start_end(gp, solver, traj_opts=None, labels=["", ""]): params = { "text.usetex": True, "text.latex.preamble": [ "\\usepackage{amssymb}", "\\usepackage{amsmath}", ], } plt.rcParams.update(params) # Xnew, xx, yy = create_grid(gp.X, N=961) Xnew, xx, yy = create_grid(gp.Z, N=961) mu, var = gp_predict( Xnew, gp.Z, kernels=gp.kernel, mean_funcs=gp.mean_func, f=gp.q_mu, q_sqrt=gp.q_sqrt, full_cov=False, ) print("mu var") # mu = mu[0:1, :, :] # var = var[0:1, :] # mu = mu[1:2, :, :] # var = var[1:2, :] print(mu.shape) print(var.shape) fig, axs = plot_mean_and_var( xx, yy, mu, var, llabel="$\mathbb{E}[h^{(1)}(\mathbf{x})]$", rlabel="$\mathbb{V}[h^{(1)}(\mathbf{x})]$", ) for ax in axs: fig, ax = plot_start_and_end_pos(fig, ax, solver) # plot_omitted_data(fig, ax, color="k") ax.set_xlabel("$x$") ax.set_ylabel("$y$") # ax.scatter(gp.X[:, 0], gp.X[:, 1]) plot_traj(fig, ax, solver.state_guesses, color=color_init, label="Init traj") if traj_opts is not None: if isinstance(traj_opts, list): for traj, label, color_opt in zip(traj_opts, labels, color_opts): plot_traj(fig, ax, traj, color=color_opt, label=label) else: plot_traj(fig, ax, traj_opts) axs[0].legend(loc="lower left") return fig, axs
def plot_svgp_jacobian_var(gp, solver, traj_opt=None): params = { "text.usetex": True, "text.latex.preamble": [ "\\usepackage{amssymb}", "\\usepackage{amsmath}", ], } plt.rcParams.update(params) input_dim = gp.X.shape[1] Xnew, xx, yy = create_grid(gp.X, N=961) def gp_jacobian_all(x): if len(x.shape) == 1: x = x.reshape(1, -1) return gp_jacobian( x, gp.Z, gp.kernel, gp.mean_func, f=gp.q_mu, q_sqrt=gp.q_sqrt, full_cov=False, ) mu_j, var_j = jax.vmap(gp_jacobian_all, in_axes=(0))(Xnew) print("gp jacobain mu var") print(mu_j.shape) print(var_j.shape) # mu = np.prod(mu, 1) # var = np.diagonal(var, axis1=-2, axis2=-1) # var = np.prod(var, 1) fig, axs = plt.subplots(input_dim, input_dim) for i in range(input_dim): for j in range(input_dim): plot_contourf( fig, axs[i, j], xx, yy, var_j[:, i, j], label="$\Sigma_J(\mathbf{x})$", ) axs[i, j].set_xlabel("$x$") axs[i, j].set_ylabel("$y$") return fig, axs
def plot_mixing_prob_all_trajs(gp, solver, traj_opts=None, labels=None): # plot original GP Xnew, xx, yy = create_grid(gp.X, N=961) mixing_probs = jax.vmap( single_mogpe_mixing_probability, (0, None, None, None, None, None, None), )(Xnew, gp.Z, gp.kernel, gp.mean_func, gp.q_mu, False, gp.q_sqrt) fig, ax = plt.subplots(1, 1) plot_contourf( fig, ax, xx, yy, # mixing_probs[:, 1:2], mixing_probs[:, 0:1], label="$\Pr(\\alpha=1 | \mathbf{x})$", ) ax.set_xlabel("$x$") ax.set_ylabel("$y$") # plot_omitted_data(fig, ax, color='k') plot_start_and_end_pos(fig, ax, solver) plot_traj(fig, ax, solver.state_guesses, color=color_init, label="Init traj") save_name = dir_name + "mixing_prob_2d_traj_init.pdf" plt.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0) if traj_opts is not None: i = 0 for traj, label, color in zip(traj_opts, labels, color_opts): plot_traj(fig, ax, traj, color=color, label=label) save_name = dir_name + "mixing_prob_2d_traj_" + str(i) + ".pdf" plt.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0) i += 1 ax.legend(loc=3) return fig, ax
def plot_svgp_metric_and_start_end(metric, solver, traj_opt=None): params = { "text.usetex": True, "text.latex.preamble": [ "\\usepackage{amssymb}", "\\usepackage{amsmath}", ], } plt.rcParams.update(params) input_dim = metric.gp.X.shape[1] # plot original GP Xnew, xx, yy = create_grid(metric.gp.X, N=961) metric_tensor, mu_j, cov_j = gp_metric_tensor( Xnew, metric.gp.Z, metric.gp.kernel, mean_func=metric.gp.mean_func, f=metric.gp.q_mu, full_cov=True, q_sqrt=metric.gp.q_sqrt, cov_weight=metric.cov_weight, ) print("metric yo yo") print(metric_tensor.shape) fig, axs = plt.subplots(input_dim, input_dim) for i in range(input_dim): for j in range(input_dim): plot_contourf( fig, axs[i, j], xx, yy, metric_tensor[:, i, j], label="$G(\mathbf{x})$", ) axs[i, j].set_xlabel("$x$") axs[i, j].set_ylabel("$y$") return fig, axs
def plot_mixing_prob_and_start_end(gp, solver, traj_opts=None, labels=None): # plot original GP Xnew, xx, yy = create_grid(gp.X, N=961) mixing_probs = jax.vmap( single_mogpe_mixing_probability, (0, None, None, None, None, None, None), )(Xnew, gp.Z, gp.kernel, gp.mean_func, gp.q_mu, False, gp.q_sqrt) # print('mixing probs yo') # print(mixing_probs.shape) # mixing_probs = mixing_probs[:, 0, :] * mixing_probs[:, 1, :] # output_dim = mixing_probs.shape[0] fig, ax = plt.subplots(1, 1) plot_contourf( fig, ax, xx, yy, # mixing_probs[:, 1:2], mixing_probs[:, 0:1], label="$\Pr(\\alpha=1 | \mathbf{x})$", ) ax.set_xlabel("$x$") ax.set_ylabel("$y$") plot_omitted_data(fig, ax, color="k") plot_start_and_end_pos(fig, ax, solver) plot_traj(fig, ax, solver.state_guesses, color=color_init, label="Init traj") if traj_opts is not None: for traj, label, color in zip(traj_opts, labels, color_opts): plot_traj(fig, ax, traj, color=color, label=label) ax.legend(loc=3) return fig, ax
def plot_svgp_metric_trace_and_start_end(metrics, solver, traj_opts, labels, linestyles): # plot original GP Xnew, xx, yy = create_grid(metrics[0].gp.X, N=961) fig, ax = plt.subplots(1, 1, figsize=(6.4, 2.8)) ax.set_xlabel("Time $t$") ax.set_ylabel("Tr$(\mathbf{G}(\mathbf{x}(t)))$") for traj_opt, metric, color, label, linetyle in zip( traj_opts, metrics, color_opts, labels, linestyles): metric_tensor_init, _, _ = gp_metric_tensor( solver.state_guesses[:, 0:2], metric.gp.Z, metric.gp.kernel, mean_func=metric.gp.mean_func, f=metric.gp.q_mu, full_cov=True, q_sqrt=metric.gp.q_sqrt, cov_weight=metric.cov_weight, ) metric_trace_init = np.trace(metric_tensor_init, axis1=1, axis2=2) metric_tensor_opt, _, _ = gp_metric_tensor( traj_opt[:, 0:2], metric.gp.Z, metric.gp.kernel, mean_func=metric.gp.mean_func, f=metric.gp.q_mu, full_cov=True, q_sqrt=metric.gp.q_sqrt, cov_weight=metric.cov_weight, ) metric_trace_opt = np.trace(metric_tensor_opt, axis1=1, axis2=2) traces = np.stack([metric_trace_opt, metric_trace_init]) max_trace = np.max(traces) min_trace = np.min(traces) print("max min") print(max_trace) print(min_trace) ax.plot( solver.times, (metric_trace_init - traces.min()) / (traces.max() - traces.min()), color=color_init, linestyle=linetyle, label="Init traj $\lambda=$" + label, ) ax.plot( solver.times, (metric_trace_opt - traces.min()) / (traces.max() - traces.min()), color=color, linestyle=linetyle, label="Opt traj $\lambda=$" + label, ) ax.legend() return fig, ax
def plot_svgp_jacobian_mean(gp, solver, traj_opt=None): params = { "text.usetex": True, "text.latex.preamble": [ "\\usepackage{amssymb}", "\\usepackage{amsmath}", ], } plt.rcParams.update(params) Xnew, xx, yy = create_grid(gp.X, N=961) mu, var = gp_predict( Xnew, gp.Z, kernels=gp.kernel, mean_funcs=gp.mean_func, f=gp.q_mu, q_sqrt=gp.q_sqrt, full_cov=False, ) def gp_jacobian_all(x): if len(x.shape) == 1: x = x.reshape(1, -1) return gp_jacobian( x, gp.Z, gp.kernel, gp.mean_func, f=gp.q_mu, q_sqrt=gp.q_sqrt, full_cov=False, ) mu_j, var_j = jax.vmap(gp_jacobian_all, in_axes=(0))(Xnew) print("gp jacobain mu var") print(mu_j.shape) print(var_j.shape) # mu = np.prod(mu, 1) # var = np.diagonal(var, axis1=-2, axis2=-1) # var = np.prod(var, 1) fig, axs = plot_mean_and_var( xx, yy, mu, var, # mu, # var, llabel="$\mathbb{E}[h^{(1)}]$", rlabel="$\mathbb{V}[h^{(1)}]$", ) for ax in axs: ax.quiver(Xnew[:, 0], Xnew[:, 1], mu_j[:, 0], mu_j[:, 1], color="k") fig, ax = plot_start_and_end_pos(fig, ax, solver) plot_omitted_data(fig, ax, color="k") # ax.scatter(gp.X[:, 0], gp.X[:, 1]) plot_traj(fig, ax, solver.state_guesses, color=color_init, label="Init traj") if traj_opt is not None: plot_traj(fig, ax, traj_opt, color=color_opt, label="Opt traj") axs[0].legend() return fig, axs
def plot_svgp_and_all_trajs(gp, solver, traj_opts=None, labels=None): params = { "text.usetex": True, "text.latex.preamble": [ "\\usepackage{amssymb}", "\\usepackage{amsmath}", ], } plt.rcParams.update(params) Xnew, xx, yy = create_grid(gp.X, N=961) mu, var = gp_predict( Xnew, gp.Z, kernels=gp.kernel, mean_funcs=gp.mean_func, f=gp.q_mu, q_sqrt=gp.q_sqrt, full_cov=False, ) print("mu var") # print(mu.shape) # print(var.shape) # mu = mu[0:1, :, :] # var = var[0:1, :] # mu = mu[1:2, :, :] # var = var[1:2, :] print(mu.shape) print(var.shape) fig, axs = plot_mean_and_var( xx, yy, mu, var, llabel="$\mathbb{E}[h^{(1)}]$", rlabel="$\mathbb{V}[h^{(1)}]$", ) for ax in axs: fig, ax = plot_start_and_end_pos(fig, ax, solver) plot_omitted_data(fig, ax, color="k") ax.set_xlabel("$x$") ax.set_ylabel("$y$") # ax.scatter(gp.X[:, 0], gp.X[:, 1]) plot_traj(fig, ax, solver.state_guesses, color=color_init, label="Init traj") save_name = dir_name + "svgp_2d_traj_init.pdf" plt.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0) if traj_opts is not None: i = 0 for traj, label, color_opt in zip(traj_opts, labels, color_opts): for ax in axs: plot_traj(fig, ax, traj, color=color_opt, label=label) save_name = dir_name + "svgp_2d_traj_" + str(i) + ".pdf" plt.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0) i += 1 # axs[0].legend(loc='lower left') return fig, axs