예제 #1
0
def plot_svgp_and_start_end(gp, solver, traj_opts=None, labels=["", ""]):
    params = {
        "text.usetex":
        True,
        "text.latex.preamble": [
            "\\usepackage{amssymb}",
            "\\usepackage{amsmath}",
        ],
    }
    plt.rcParams.update(params)

    # Xnew, xx, yy = create_grid(gp.X, N=961)
    Xnew, xx, yy = create_grid(gp.Z, N=961)
    mu, var = gp_predict(
        Xnew,
        gp.Z,
        kernels=gp.kernel,
        mean_funcs=gp.mean_func,
        f=gp.q_mu,
        q_sqrt=gp.q_sqrt,
        full_cov=False,
    )
    print("mu var")
    # mu = mu[0:1, :, :]
    # var = var[0:1, :]
    # mu = mu[1:2, :, :]
    # var = var[1:2, :]
    print(mu.shape)
    print(var.shape)
    fig, axs = plot_mean_and_var(
        xx,
        yy,
        mu,
        var,
        llabel="$\mathbb{E}[h^{(1)}(\mathbf{x})]$",
        rlabel="$\mathbb{V}[h^{(1)}(\mathbf{x})]$",
    )

    for ax in axs:
        fig, ax = plot_start_and_end_pos(fig, ax, solver)
        # plot_omitted_data(fig, ax, color="k")
        ax.set_xlabel("$x$")
        ax.set_ylabel("$y$")
        # ax.scatter(gp.X[:, 0], gp.X[:, 1])
        plot_traj(fig,
                  ax,
                  solver.state_guesses,
                  color=color_init,
                  label="Init traj")
        if traj_opts is not None:
            if isinstance(traj_opts, list):
                for traj, label, color_opt in zip(traj_opts, labels,
                                                  color_opts):
                    plot_traj(fig, ax, traj, color=color_opt, label=label)
            else:
                plot_traj(fig, ax, traj_opts)
    axs[0].legend(loc="lower left")
    return fig, axs
예제 #2
0
def plot_svgp_jacobian_var(gp, solver, traj_opt=None):
    params = {
        "text.usetex":
        True,
        "text.latex.preamble": [
            "\\usepackage{amssymb}",
            "\\usepackage{amsmath}",
        ],
    }
    plt.rcParams.update(params)
    input_dim = gp.X.shape[1]

    Xnew, xx, yy = create_grid(gp.X, N=961)

    def gp_jacobian_all(x):
        if len(x.shape) == 1:
            x = x.reshape(1, -1)
        return gp_jacobian(
            x,
            gp.Z,
            gp.kernel,
            gp.mean_func,
            f=gp.q_mu,
            q_sqrt=gp.q_sqrt,
            full_cov=False,
        )

    mu_j, var_j = jax.vmap(gp_jacobian_all, in_axes=(0))(Xnew)
    print("gp jacobain mu var")
    print(mu_j.shape)
    print(var_j.shape)
    # mu = np.prod(mu, 1)
    # var = np.diagonal(var, axis1=-2, axis2=-1)
    # var = np.prod(var, 1)

    fig, axs = plt.subplots(input_dim, input_dim)
    for i in range(input_dim):
        for j in range(input_dim):
            plot_contourf(
                fig,
                axs[i, j],
                xx,
                yy,
                var_j[:, i, j],
                label="$\Sigma_J(\mathbf{x})$",
            )
            axs[i, j].set_xlabel("$x$")
            axs[i, j].set_ylabel("$y$")

    return fig, axs
예제 #3
0
def plot_mixing_prob_all_trajs(gp, solver, traj_opts=None, labels=None):

    # plot original GP
    Xnew, xx, yy = create_grid(gp.X, N=961)

    mixing_probs = jax.vmap(
        single_mogpe_mixing_probability,
        (0, None, None, None, None, None, None),
    )(Xnew, gp.Z, gp.kernel, gp.mean_func, gp.q_mu, False, gp.q_sqrt)
    fig, ax = plt.subplots(1, 1)
    plot_contourf(
        fig,
        ax,
        xx,
        yy,
        # mixing_probs[:, 1:2],
        mixing_probs[:, 0:1],
        label="$\Pr(\\alpha=1 | \mathbf{x})$",
    )
    ax.set_xlabel("$x$")
    ax.set_ylabel("$y$")

    # plot_omitted_data(fig, ax, color='k')
    plot_start_and_end_pos(fig, ax, solver)

    plot_traj(fig,
              ax,
              solver.state_guesses,
              color=color_init,
              label="Init traj")

    save_name = dir_name + "mixing_prob_2d_traj_init.pdf"
    plt.savefig(save_name, transparent=True, bbox_inches="tight", pad_inches=0)

    if traj_opts is not None:
        i = 0
        for traj, label, color in zip(traj_opts, labels, color_opts):
            plot_traj(fig, ax, traj, color=color, label=label)
            save_name = dir_name + "mixing_prob_2d_traj_" + str(i) + ".pdf"
            plt.savefig(save_name,
                        transparent=True,
                        bbox_inches="tight",
                        pad_inches=0)
            i += 1

    ax.legend(loc=3)

    return fig, ax
예제 #4
0
def plot_svgp_metric_and_start_end(metric, solver, traj_opt=None):
    params = {
        "text.usetex":
        True,
        "text.latex.preamble": [
            "\\usepackage{amssymb}",
            "\\usepackage{amsmath}",
        ],
    }
    plt.rcParams.update(params)

    input_dim = metric.gp.X.shape[1]
    # plot original GP
    Xnew, xx, yy = create_grid(metric.gp.X, N=961)
    metric_tensor, mu_j, cov_j = gp_metric_tensor(
        Xnew,
        metric.gp.Z,
        metric.gp.kernel,
        mean_func=metric.gp.mean_func,
        f=metric.gp.q_mu,
        full_cov=True,
        q_sqrt=metric.gp.q_sqrt,
        cov_weight=metric.cov_weight,
    )
    print("metric yo yo")
    print(metric_tensor.shape)

    fig, axs = plt.subplots(input_dim, input_dim)
    for i in range(input_dim):
        for j in range(input_dim):
            plot_contourf(
                fig,
                axs[i, j],
                xx,
                yy,
                metric_tensor[:, i, j],
                label="$G(\mathbf{x})$",
            )
            axs[i, j].set_xlabel("$x$")
            axs[i, j].set_ylabel("$y$")

    return fig, axs
예제 #5
0
def plot_mixing_prob_and_start_end(gp, solver, traj_opts=None, labels=None):

    # plot original GP
    Xnew, xx, yy = create_grid(gp.X, N=961)

    mixing_probs = jax.vmap(
        single_mogpe_mixing_probability,
        (0, None, None, None, None, None, None),
    )(Xnew, gp.Z, gp.kernel, gp.mean_func, gp.q_mu, False, gp.q_sqrt)
    # print('mixing probs yo')
    # print(mixing_probs.shape)
    # mixing_probs = mixing_probs[:, 0, :] * mixing_probs[:, 1, :]
    # output_dim = mixing_probs.shape[0]
    fig, ax = plt.subplots(1, 1)
    plot_contourf(
        fig,
        ax,
        xx,
        yy,
        # mixing_probs[:, 1:2],
        mixing_probs[:, 0:1],
        label="$\Pr(\\alpha=1 | \mathbf{x})$",
    )
    ax.set_xlabel("$x$")
    ax.set_ylabel("$y$")

    plot_omitted_data(fig, ax, color="k")
    plot_start_and_end_pos(fig, ax, solver)

    plot_traj(fig,
              ax,
              solver.state_guesses,
              color=color_init,
              label="Init traj")
    if traj_opts is not None:
        for traj, label, color in zip(traj_opts, labels, color_opts):
            plot_traj(fig, ax, traj, color=color, label=label)
    ax.legend(loc=3)

    return fig, ax
예제 #6
0
def plot_svgp_metric_trace_and_start_end(metrics, solver, traj_opts, labels,
                                         linestyles):

    # plot original GP
    Xnew, xx, yy = create_grid(metrics[0].gp.X, N=961)

    fig, ax = plt.subplots(1, 1, figsize=(6.4, 2.8))

    ax.set_xlabel("Time $t$")
    ax.set_ylabel("Tr$(\mathbf{G}(\mathbf{x}(t)))$")

    for traj_opt, metric, color, label, linetyle in zip(
            traj_opts, metrics, color_opts, labels, linestyles):
        metric_tensor_init, _, _ = gp_metric_tensor(
            solver.state_guesses[:, 0:2],
            metric.gp.Z,
            metric.gp.kernel,
            mean_func=metric.gp.mean_func,
            f=metric.gp.q_mu,
            full_cov=True,
            q_sqrt=metric.gp.q_sqrt,
            cov_weight=metric.cov_weight,
        )
        metric_trace_init = np.trace(metric_tensor_init, axis1=1, axis2=2)

        metric_tensor_opt, _, _ = gp_metric_tensor(
            traj_opt[:, 0:2],
            metric.gp.Z,
            metric.gp.kernel,
            mean_func=metric.gp.mean_func,
            f=metric.gp.q_mu,
            full_cov=True,
            q_sqrt=metric.gp.q_sqrt,
            cov_weight=metric.cov_weight,
        )

        metric_trace_opt = np.trace(metric_tensor_opt, axis1=1, axis2=2)

        traces = np.stack([metric_trace_opt, metric_trace_init])
        max_trace = np.max(traces)
        min_trace = np.min(traces)
        print("max min")
        print(max_trace)
        print(min_trace)
        ax.plot(
            solver.times,
            (metric_trace_init - traces.min()) / (traces.max() - traces.min()),
            color=color_init,
            linestyle=linetyle,
            label="Init traj $\lambda=$" + label,
        )
        ax.plot(
            solver.times,
            (metric_trace_opt - traces.min()) / (traces.max() - traces.min()),
            color=color,
            linestyle=linetyle,
            label="Opt traj $\lambda=$" + label,
        )

    ax.legend()
    return fig, ax
예제 #7
0
def plot_svgp_jacobian_mean(gp, solver, traj_opt=None):
    params = {
        "text.usetex":
        True,
        "text.latex.preamble": [
            "\\usepackage{amssymb}",
            "\\usepackage{amsmath}",
        ],
    }
    plt.rcParams.update(params)

    Xnew, xx, yy = create_grid(gp.X, N=961)
    mu, var = gp_predict(
        Xnew,
        gp.Z,
        kernels=gp.kernel,
        mean_funcs=gp.mean_func,
        f=gp.q_mu,
        q_sqrt=gp.q_sqrt,
        full_cov=False,
    )

    def gp_jacobian_all(x):
        if len(x.shape) == 1:
            x = x.reshape(1, -1)
        return gp_jacobian(
            x,
            gp.Z,
            gp.kernel,
            gp.mean_func,
            f=gp.q_mu,
            q_sqrt=gp.q_sqrt,
            full_cov=False,
        )

    mu_j, var_j = jax.vmap(gp_jacobian_all, in_axes=(0))(Xnew)
    print("gp jacobain mu var")
    print(mu_j.shape)
    print(var_j.shape)
    # mu = np.prod(mu, 1)
    # var = np.diagonal(var, axis1=-2, axis2=-1)
    # var = np.prod(var, 1)
    fig, axs = plot_mean_and_var(
        xx,
        yy,
        mu,
        var,
        # mu,
        # var,
        llabel="$\mathbb{E}[h^{(1)}]$",
        rlabel="$\mathbb{V}[h^{(1)}]$",
    )

    for ax in axs:
        ax.quiver(Xnew[:, 0], Xnew[:, 1], mu_j[:, 0], mu_j[:, 1], color="k")

        fig, ax = plot_start_and_end_pos(fig, ax, solver)
        plot_omitted_data(fig, ax, color="k")
        # ax.scatter(gp.X[:, 0], gp.X[:, 1])
        plot_traj(fig,
                  ax,
                  solver.state_guesses,
                  color=color_init,
                  label="Init traj")
        if traj_opt is not None:
            plot_traj(fig, ax, traj_opt, color=color_opt, label="Opt traj")
    axs[0].legend()
    return fig, axs
예제 #8
0
def plot_svgp_and_all_trajs(gp, solver, traj_opts=None, labels=None):
    params = {
        "text.usetex":
        True,
        "text.latex.preamble": [
            "\\usepackage{amssymb}",
            "\\usepackage{amsmath}",
        ],
    }
    plt.rcParams.update(params)

    Xnew, xx, yy = create_grid(gp.X, N=961)
    mu, var = gp_predict(
        Xnew,
        gp.Z,
        kernels=gp.kernel,
        mean_funcs=gp.mean_func,
        f=gp.q_mu,
        q_sqrt=gp.q_sqrt,
        full_cov=False,
    )
    print("mu var")
    # print(mu.shape)
    # print(var.shape)
    # mu = mu[0:1, :, :]
    # var = var[0:1, :]
    # mu = mu[1:2, :, :]
    # var = var[1:2, :]
    print(mu.shape)
    print(var.shape)
    fig, axs = plot_mean_and_var(
        xx,
        yy,
        mu,
        var,
        llabel="$\mathbb{E}[h^{(1)}]$",
        rlabel="$\mathbb{V}[h^{(1)}]$",
    )

    for ax in axs:
        fig, ax = plot_start_and_end_pos(fig, ax, solver)
        plot_omitted_data(fig, ax, color="k")
        ax.set_xlabel("$x$")
        ax.set_ylabel("$y$")
        # ax.scatter(gp.X[:, 0], gp.X[:, 1])
        plot_traj(fig,
                  ax,
                  solver.state_guesses,
                  color=color_init,
                  label="Init traj")
        save_name = dir_name + "svgp_2d_traj_init.pdf"
        plt.savefig(save_name,
                    transparent=True,
                    bbox_inches="tight",
                    pad_inches=0)

    if traj_opts is not None:
        i = 0
        for traj, label, color_opt in zip(traj_opts, labels, color_opts):
            for ax in axs:
                plot_traj(fig, ax, traj, color=color_opt, label=label)
            save_name = dir_name + "svgp_2d_traj_" + str(i) + ".pdf"
            plt.savefig(save_name,
                        transparent=True,
                        bbox_inches="tight",
                        pad_inches=0)
            i += 1

    # axs[0].legend(loc='lower left')
    return fig, axs