def contour_example(
    A=np.array([[1, 1]]),
    b=np.zeros([1, 1]),  # noqa: C901
    cov_11=0.5,
    cov_01=-0.25,
    initial_mean=np.array([0.25, 0.25]),
    alpha=1,
    omega=1,
    obs_std=1,
    show_full=True,
    show_data=True,
    show_est=False,
    param_ref=None,
    compare=False,
    fsize=42,
    figname="latest_figure.png",
    save=False,
):
    """
    alpha: float in [0, 1], weight of Tikhonov regularization
    omega: float in [0, 1], weight of Modified regularization
    """
    # mesh for plotting
    N, r = 250, 1
    X, Y, XX = make_2d_unit_mesh(N, r)
    inputs = XX

    std_of_data = [obs_std]
    obs_cov = np.diag(std_of_data)
    observed_data_mean = np.array([[1]])
    initial_cov = np.array([[1, cov_01], [cov_01, cov_11]])

    assert np.all(np.linalg.eigvals(initial_cov) > 0)

    z = full_functional(
        A, XX, b, initial_mean, initial_cov, observed_data_mean, observed_cov=obs_cov
    )
    zp = norm_predicted(A, XX, initial_mean, initial_cov)
    zi = norm_input(XX, initial_mean, initial_cov)
    zd = norm_data(A, XX, b, observed_data_mean, observed_cov=obs_cov)
    # sanity check that all arguments passed correctly:
    assert np.linalg.norm(z - zi - zd + zp) < 1e-8

    # plotting contours
    z = alpha * zi + zd - omega * zp
    mud_a = np.argmin(z)
    map_a = np.argmin(alpha * zi + zd)

    # get mud/map points from minimal values on mesh
    mud_pt = inputs[mud_a, :]
    map_pt = inputs[map_a, :]

    msize = 500
    ls = (np.linalg.pinv(A) @ observed_data_mean.T).ravel()
    if show_data:
        plt.contour(
            inputs[:, 0].reshape(N, N),
            inputs[:, 1].reshape(N, N),
            (zd).reshape(N, N),
            25,
            cmap=cm.viridis,
            alpha=0.5,
            vmin=0,
            vmax=4,
        )
        plt.axis("equal")

        s = np.linspace(-2 * r, 2 * r, 10)

        if A.shape[0] < A.shape[1]:
            # nullspace through least-squares
            null_line = null_space(A) * s + ls.reshape(-1, 1)
            plt.plot(
                null_line[0, :],
                null_line[1, :],
                label="Solution Contour",
                lw=2,
                color="xkcd:red",
            )
            if not show_full:
                plt.annotate(
                    "Solution Contour", (0.1, 0.9), fontsize=fsize, backgroundcolor="w"
                )

    if show_full:
        plt.contour(
            inputs[:, 0].reshape(N, N),
            inputs[:, 1].reshape(N, N),
            z.reshape(N, N),
            50,
            cmap=cm.viridis,
            alpha=1.0,
        )
    elif alpha + omega > 0:
        plt.contour(
            inputs[:, 0].reshape(N, N),
            inputs[:, 1].reshape(N, N),
            (alpha * zi - omega * zp).reshape(N, N),
            100,
            cmap=cm.viridis,
            alpha=0.25,
        )
    plt.axis("equal")

    if alpha + omega > 0:
        plt.scatter(
            initial_mean[0], initial_mean[1], label="Initial Mean", color="k", s=msize
        )
        if not show_full:
            plt.annotate(
                "Initial Mean",
                (initial_mean[0] + 0.001 * fsize, initial_mean[1] - 0.001 * fsize),
                fontsize=fsize,
                backgroundcolor="w",
            )
        else:
            if compare:
                plt.scatter(
                    param_ref[0],
                    param_ref[1],
                    label="$\\lambda^\\dagger$",
                    color="k",
                    s=msize,
                    marker="s",
                )

                plt.annotate(
                    "Truth",
                    (param_ref[0] + 0.00075 * fsize, param_ref[1] + 0.00075 * fsize),
                    fontsize=fsize,
                    backgroundcolor="w",
                )

        show_mud = omega > 0 or compare

        if show_full:
            # scatter and line from origin to least squares
            plt.scatter(
                ls[0],
                ls[1],
                label="Least Squares",
                color="xkcd:blue",
                marker="d",
                s=msize,
                zorder=10,
            )
            plt.plot(
                [0, ls[0]], [0, ls[1]], color="xkcd:blue", marker="d", lw=1, zorder=10
            )
            plt.annotate(
                "Least Squares",
                (ls[0] - 0.001 * fsize, ls[1] + 0.001 * fsize),
                fontsize=fsize,
                backgroundcolor="w",
            )

            if show_est:  # numerical solutions
                if omega > 0:
                    plt.scatter(
                        mud_pt[0],
                        mud_pt[1],
                        label="min: Tk - Un",
                        color="xkcd:sky blue",
                        marker="o",
                        s=3 * msize,
                        zorder=10,
                    )
                if alpha > 0 and omega != 1:
                    plt.scatter(
                        map_pt[0],
                        map_pt[1],
                        label="min: Tk",
                        color="xkcd:blue",
                        marker="o",
                        s=3 * msize,
                        zorder=10,
                    )

            if alpha > 0 and omega != 1:  # analytical MAP point
                map_pt_eq = map_sol(
                    A,
                    b,
                    observed_data_mean,
                    initial_mean,
                    initial_cov,
                    data_cov=obs_cov,
                    w=alpha,
                )
                plt.scatter(
                    map_pt_eq[0],
                    map_pt_eq[1],
                    label="MAP",
                    color="xkcd:orange",
                    marker="x",
                    s=msize,
                    lw=10,
                    zorder=10,
                )

                if compare:  # second map point has half the regularization strength
                    plt.annotate(
                        "MAP$_{\\alpha}$",
                        (map_pt_eq[0] - 0.004 * fsize, map_pt_eq[1] - 0.002 * fsize),
                        fontsize=fsize,
                        backgroundcolor="w",
                    )

                else:
                    plt.annotate(
                        "MAP$_{\\alpha}$",
                        (map_pt_eq[0] + 0.0001 * fsize, map_pt_eq[1] - 0.002 * fsize),
                        fontsize=fsize,
                        backgroundcolor="w",
                    )

            if show_mud:  # analytical MUD point
                mud_pt_eq = mud_sol(A, b, observed_data_mean, initial_mean, initial_cov)
                plt.scatter(
                    mud_pt_eq[0],
                    mud_pt_eq[1],
                    label="MUD",
                    color="xkcd:brown",
                    marker="*",
                    s=2 * msize,
                    lw=5,
                    zorder=10,
                )
                plt.annotate(
                    "MUD",
                    (mud_pt_eq[0] + 0.001 * fsize, mud_pt_eq[1] - 0.001 * fsize),
                    fontsize=fsize,
                    backgroundcolor="w",
                )

        if A.shape[0] < A.shape[1]:
            # want orthogonal nullspace, function gives one that is already normalized
            v = null_space(A @ initial_cov)
            v = v[
                ::-1
            ]  # in 2D, we can just swap entries and put a negative sign in front of one
            v[0] = -v[0]

            if show_full and show_mud:
                # grid search to find upper/lower bounds of line being drawn.
                # importance is the direction, image is nicer with a proper origin/termination
                s = np.linspace(-1, 1, 1000)
                new_line = (v.reshape(-1, 1) * s) + initial_mean.reshape(-1, 1)
                mx = np.argmin(
                    np.linalg.norm(new_line - initial_mean.reshape(-1, 1), axis=0)
                )
                mn = np.argmin(
                    np.linalg.norm(new_line - mud_pt_eq.reshape(-1, 1), axis=0)
                )
                plt.plot(
                    new_line[0, mn:mx],
                    new_line[1, mn:mx],
                    lw=1,
                    label="projection line",
                    c="k",
                )
        elif show_full:
            plt.plot(
                [initial_mean[0], ls[0]],
                [initial_mean[1], ls[1]],
                lw=1,
                label="Projection Line",
                c="k",
            )

    #     print(p)

    plt.axis("square")
    plt.axis([0, r, 0, r])
    #     plt.legend(fontsize=fsize)
    plt.xticks(fontsize=0.75 * fsize)
    plt.yticks(fontsize=0.75 * fsize)
    plt.tight_layout()
    if save:
        if "/" in figname:
            fdir = "/".join(figname.split("/")[:-1])
            check_dir(fdir)
        plt.savefig(figname, dpi=300)
        plt.close("all")
def main_meas(args):
    """
    Main entrypoint for High-Dim Linear Measurement Example
    """
    args = parse_args(args)
    setup_logging(args.loglevel)
    np.random.seed(args.seed)
    #     example       = args.example
    #     num_trials   = args.num_trials
    #     fsize        = args.fsize
    #     linewidth    = args.linewidth
    #     seed         = args.seed
    # dim_input     = args.input_dim
    # save         = args.save
    #     alt          = args.alt
    #     bayes        = args.bayes
    #     prefix       = args.prefix
    #     dist         = args.dist

    presentation = False
    save = True

    if not presentation:
        plt.rcParams["mathtext.fontset"] = "stix"
        plt.rcParams["font.family"] = "STIXGeneral"
    fdir = "figures/lin"
    check_dir(fdir)

    fsize = 42

    def numnonzero(x, tol=1e-4):
        return len(x[abs(x) < tol])

    # # Impact of Number of Measurements for Various Choices of $\\Sigma_\text{init}$

    # dim_output = dim_input
    dim_input, dim_output = 20, 5
    # seed = 12
    # np.random.seed(seed)

    initial_cov = np.diag(np.sort(np.random.rand(dim_input))[::-1] + 0.5)
    # initial_cov = np.eye(dim_input)  # will cause spectrum of updated covariance to contain repeated eigenvalues

    plt.figure(figsize=(10, 10))
    # initial_mean = np.zeros(dim_input).reshape(-1, 1)
    # initial_mean = np.random.randn(dim_input).reshape(-1,1)
    # num_obs_list = np.arange(1, 101).tolist()

    lam_ref = np.random.randn(dim_input).reshape(-1, 1)

    prefix = "lin-meas-cov"

    # Ns = [10, 50, 100, 500, 1000]

    sigma = 1e-1
    # np.random.seed(21)
    # Ns = np.arange(10, 2001, 50).tolist()
    Ns = [10, 100, 1000, 10000]

    # for _ in range(num_trials):

    operator_list, data_list, _ = models.createRandomLinearProblem(
        lam_ref,
        dim_output,
        [max(Ns)] * dim_output,  # want to iterate over increasing measurements
        [0] * dim_output,  # noiseless data bc we want to simulate multiple trials
        dist="norm",
        repeated=True,
    )

    MUD = np.zeros((dim_input, len(Ns)))
    UP = np.zeros((dim_input, len(Ns)))
    noise_draw = np.random.randn(dim_output, max(Ns)) * sigma

    for j, N in enumerate(Ns):
        A, b, _ = transform_measurements(operator_list, data_list, N, sigma, noise_draw)
        MUD[:, j] = mud_sol(A, b, cov=initial_cov)
        up_cov = updated_cov(A, initial_cov)
        up_sdvals = sp.linalg.svdvals(up_cov)
        # print(up_sdvals.shape, dim_input, up_cov.shape)
        UP[:, j] = up_sdvals

    # mud_var = MUD.var(axis=2)
    lines = ["solid", "dashed", "dashdot", "dotted"]

    for p in range(dim_input):
        plt.plot(
            Ns,
            UP[p, :],
            label=f"SV {p}",
            alpha=0.4,
            lw=5,
            ls=lines[p % len(lines)],
        )

    # plt.plot(Ns, mud_var, label='MUD', c='k', lw=10)
    # plt.title("Precision of MUD Estimates", fontsize=1.25 * fsize)
    plt.yscale("log")
    plt.xscale("log")
    plt.ylabel("Eigenvalues of $\\Sigma_{up}$", fontsize=fsize * 1.25)
    plt.xlabel("Number of Measurements", fontsize=fsize)
    # plt.legend()
    if save:
        plt.savefig(f"{fdir}/{prefix}-convergence.png", bbox_inches="tight")
        plt.close("all")
    else:
        plt.show()
        plt.close("all")

    fig, ax = plt.subplots(figsize=(15, 10))
    ax.set_yscale("log")
    index_values = np.arange(dim_input) + 1

    for i, N in enumerate(Ns):
        ax.scatter(
            index_values,
            UP[:, i],
            marker="o",
            s=200,
            facecolors="none",
            edgecolors="k",
        )

        ax.plot(
            index_values,
            UP[:, i],
            label=f"$N={N:1.0E}$",
            alpha=1,
            lw=3,
            ls=lines[i % len(lines)],
            c="k",
        )
    ax.set_xticks(index_values)
    ax.set_xticklabels(ax.get_xticks(), rotation=0)
    ax.set_xlabel("Index", fontsize=fsize)
    ax.set_ylabel("Eigenvalue", fontsize=fsize)

    ax.xaxis.set_major_formatter(FormatStrFormatter("%2d"))
    ax.legend(loc="lower left", fontsize=fsize * 0.75)

    if save:
        plt.savefig(f"{fdir}/{prefix}-sd-convergence.png", bbox_inches="tight")
        plt.close("all")
    else:
        plt.show()
        plt.close("all")
def main_meas_var(args):
    """
    Main entrypoint for High-Dim Linear Measurement Example
    """
    args = parse_args(args)
    setup_logging(args.loglevel)
    np.random.seed(args.seed)
    #     example       = args.example
    #     num_trials   = args.num_trials
    #     fsize        = args.fsize
    #     linewidth    = args.linewidth
    #     seed         = args.seed
    # dim_input     = args.input_dim
    # save         = args.save
    #     alt          = args.alt
    #     bayes        = args.bayes
    #     prefix       = args.prefix
    #     dist         = args.dist

    presentation = False
    save = True

    if not presentation:
        plt.rcParams["mathtext.fontset"] = "stix"
        plt.rcParams["font.family"] = "STIXGeneral"
    fdir = "figures/lin"
    check_dir(fdir)

    fsize = 42

    def numnonzero(x, tol=1e-4):
        return len(x[abs(x) < tol])

    # # Impact of Number of Measurements for Various Choices of $\\Sigma_\text{init}$

    # dim_output = dim_input
    dim_input, dim_output = 4, 2
    # seed = 12
    # np.random.seed(seed)

    # initial_cov = np.diag(np.sort(np.random.rand(dim_input))[::-1] + 0.5)
    initial_cov = np.eye(dim_input)

    plt.figure(figsize=(10, 10))
    # initial_mean = np.zeros(dim_input).reshape(-1, 1)
    # initial_mean = np.random.randn(dim_input).reshape(-1,1)
    # num_obs_list = np.arange(1, 101).tolist()

    lam_ref = np.random.randn(dim_input).reshape(-1, 1)

    prefix = "lin-meas-cov"

    # Ns = [10, 50, 100, 500, 1000]

    sigma = 1e-1
    # np.random.seed(21)
    Ns = np.arange(10, 2001, 50).tolist()
    # Ns = [10, 50, 100, 500, 1000, 5000, 10000]

    num_trials = 50
    # for _ in range(num_trials):

    def A_N(M, N, sigma):
        A = np.sqrt(N) / sigma * M
        return A

    def d_N(M, lam, n):
        d = M @ lam + n
        assert len(d.ravel()) == len(
            n.ravel()
        ), f"Shape mismatch noise={n.shape}, data={d.shape}"
        return d

    def b_N(N, d, sigma):
        b = -1 / np.sqrt(N) * np.sum(np.divide(d, sigma), axis=1)
        return b

    # M = np.random.normal(size=(dim_output, dim_input))
    operator_list, data_list, _ = models.createRandomLinearProblem(
        lam_ref,
        dim_output,
        [max(Ns)] * dim_output,  # want to iterate over increasing measurements
        [0] * dim_output,  # noiseless data bc we want to simulate multiple trials
        dist="norm",
        repeated=True,
    )

    # operator list has dim_output 1xdim_input matrices
    MUD = np.zeros((dim_input, len(Ns), num_trials))
    # M = np.array(operator_list).reshape(dim_output, dim_input)
    noise_draw = [
        np.random.randn(dim_output, max(Ns)) * sigma for _ in range(num_trials)
    ]
    for j, N in enumerate(Ns):
        # _A = A_N(M, N, sigma)
        for i in range(num_trials):
            # _b = b_N(N, d_N(M, lam_ref, noise_draw[i][:, 0:N]), sigma)
            # A, b = transform_linear_setup(operator_list, data_list, sigma)
            A, b, _ = transform_measurements(
                operator_list, data_list, N, sigma, noise_draw[i]
            )
            MUD[:, j, i] = mud_sol(A, b, cov=initial_cov)

    mud_var = MUD.var(axis=2).mean(axis=0)
    plt.plot(Ns, mud_var, label="MUD", c="k", lw=10)

    # plt.title("Precision of MUD Estimates", fontsize=1.25 * fsize)
    plt.yscale("log")
    plt.xscale("log")
    plt.ylabel("Mean Variance of MUD Estimates", fontsize=fsize * 1.25)
    plt.xlabel("Number of Measurements", fontsize=fsize)
    plt.legend()
    # plt.legend(['MUD', 'Least Squares'], fontsize=fsize)
    if save:
        plt.savefig(f"{fdir}/{prefix}-convergence.png", bbox_inches="tight")
        plt.close("all")
    else:
        plt.show()
        plt.close("all")
def main_contours(args):
    """
    Main entrypoint for 2D Linear Rank-Deficient Example (Contour Plots)
    """
    args = parse_args(args)
    setup_logging(args.loglevel)
    np.random.seed(args.seed)
    #     example       = args.example
    #     num_trials   = args.num_trials
    #     fsize        = args.fsize
    #     linewidth    = args.linewidth
    #     seed         = args.seed
    #     inputdim     = args.input_dim
    #     save         = args.save
    #     alt          = args.alt
    #     bayes        = args.bayes
    #     prefix       = args.prefix
    #     dist         = args.dist

    presentation = False
    save = True

    if not presentation:
        plt.rcParams["mathtext.fontset"] = "stix"
        plt.rcParams["font.family"] = "STIXGeneral"
    fdir = "figures/contours"
    check_dir(fdir)
    lam_true = np.array([0.7, 0.3])
    initial_mean = np.array([0.25, 0.25])
    A = np.array([[1, 1]])
    b = np.zeros((1, 1))

    experiments = {}

    # data mismatch
    experiments["data_mismatch"] = {}
    experiments["data_mismatch"]["out_file"] = f"{fdir}/data_mismatch_contour.png"
    experiments["data_mismatch"]["data_check"] = True
    experiments["data_mismatch"]["full_check"] = False
    experiments["data_mismatch"]["tk_reg"] = 0
    experiments["data_mismatch"]["pr_reg"] = 0

    # tikonov regularization
    experiments["tikonov"] = {}
    experiments["tikonov"]["out_file"] = f"{fdir}/tikonov_contour.png"
    experiments["tikonov"]["tk_reg"] = 1
    experiments["tikonov"]["pr_reg"] = 0
    experiments["tikonov"]["data_check"] = False
    experiments["tikonov"]["full_check"] = False

    # modified regularization
    experiments["modified"] = {}
    experiments["modified"]["out_file"] = f"{fdir}/consistent_contour.png"
    experiments["modified"]["tk_reg"] = 1
    experiments["modified"]["pr_reg"] = 1
    experiments["modified"]["data_check"] = False
    experiments["modified"]["full_check"] = False

    # map point
    experiments["classical"] = {}
    experiments["classical"]["out_file"] = f"{fdir}/classical_solution.png"
    experiments["classical"]["tk_reg"] = 1
    experiments["classical"]["pr_reg"] = 0
    experiments["classical"]["data_check"] = True
    experiments["classical"]["full_check"] = True

    # mud point
    experiments["consistent"] = {}
    experiments["consistent"]["out_file"] = f"{fdir}/consistent_solution.png"
    experiments["consistent"]["tk_reg"] = 1
    experiments["consistent"]["pr_reg"] = 1
    experiments["consistent"]["data_check"] = True
    experiments["consistent"]["full_check"] = True

    # comparison
    experiments["compare"] = {}
    experiments["compare"]["out_file"] = f"{fdir}/map_compare_contour.png"
    experiments["compare"]["data_check"] = True
    experiments["compare"]["full_check"] = True
    experiments["compare"]["tk_reg"] = 1
    experiments["compare"]["pr_reg"] = 0
    experiments["compare"]["comparison"] = True
    experiments["compare"]["cov_01"] = -0.5

    for ex in experiments:
        _logger.info(f"Running {ex}")
        config = experiments[ex]
        out_file = config.get("out_file", "latest_figure.png")
        tk_reg = config.get("tk_reg", 1)
        pr_reg = config.get("pr_reg", 1)
        cov_01 = config.get("cov_01", -0.25)
        cov_11 = config.get("cov_11", 0.5)
        obs_std = config.get("obs_std", 0.5)
        full_check = config.get("full_check", True)
        data_check = config.get("data_check", True)
        numr_check = config.get("numr_check", False)
        comparison = config.get("comparison", False)

        contour_example(
            A=A,
            b=b,
            save=save,
            param_ref=lam_true,
            compare=comparison,
            cov_01=cov_01,
            cov_11=cov_11,
            initial_mean=initial_mean,
            alpha=tk_reg,
            omega=pr_reg,
            show_full=full_check,
            show_data=data_check,
            show_est=numr_check,
            obs_std=obs_std,
            figname=out_file,
        )
def main_dim(args):
    """
    Main entrypoint for High-Dim Linear Dimension Example
    """
    args = parse_args(args)
    setup_logging(args.loglevel)
    np.random.seed(args.seed)
    #     example       = args.example
    #     num_trials   = args.num_trials
    #     fsize        = args.fsize
    #     linewidth    = args.linewidth
    #     seed         = args.seed
    # dim_input     = args.input_dim
    #     save         = args.save
    #     alt          = args.alt
    #     bayes        = args.bayes
    #     prefix       = args.prefix
    #     dist         = args.dist

    presentation = False
    save = True

    if not presentation:
        plt.rcParams["mathtext.fontset"] = "stix"
        plt.rcParams["font.family"] = "STIXGeneral"
    fdir = "figures/lin"
    check_dir(fdir)

    fsize = 42

    def numnonzero(x, tol=1e-4):
        return len(x[abs(x) < tol])

    # # Impact of Dimension for Various Choices of $\\Sigma_\text{init}$

    # dim_output = dim_input
    dim_input, dim_output = 100, 100
    seed = 12
    np.random.seed(seed)

    # from sklearn.datasets import make_spd_matrix as make_spd
    # from sklearn.datasets import make_sparse_spd_matrix as make_cov
    # cov = np.eye(dim_input)
    initial_cov = np.diag(np.sort(np.random.rand(dim_input))[::-1] + 0.5)

    plt.figure(figsize=(10, 10))
    initial_mean = np.zeros(dim_input).reshape(-1, 1)
    # initial_mean = np.random.randn(dim_input).reshape(-1,1)
    randA = models.randA_gauss  # choose which variety of generating map
    A, b = models.randP(dim_input, randA=randA)
    prefix = "lin-dim-cov"
    alpha_list = [10 ** (n) for n in np.linspace(-3, 4, 8)]

    # option to fix A and perturb lam_ref

    lam_ref = np.random.randn(dim_input).reshape(-1, 1)
    # d = A @ lam_ref + b

    # %%time
    sols = compare_linear_sols_dim(lam_ref, A, b, alpha_list, initial_mean, initial_cov)

    # c = np.linalg.cond(A)*np.linalg.norm(lam_ref)
    c = np.linalg.norm(lam_ref)
    # c = 1
    err_mud_list = [
        [np.linalg.norm(_m[0] - lam_ref) / c for _m in sols[alpha]]
        for alpha in alpha_list
    ]
    err_map_list = [
        [np.linalg.norm(_m[1] - lam_ref) / c for _m in sols[alpha]]
        for alpha in alpha_list
    ]
    err_pin_list = [
        [np.linalg.norm(_m[2] - lam_ref) / c for _m in sols[alpha]]
        for alpha in alpha_list
    ]

    # c = np.linalg.cond(A)
    c = np.linalg.norm(A)
    # err_Amud_list = [[np.linalg.norm(A @ (_m[0] - lam_ref)) / c for _m in sols[alpha]] for alpha in alpha_list]
    # err_Amap_list = [[np.linalg.norm(A @ (_m[1] - lam_ref)) / c for _m in sols[alpha]] for alpha in alpha_list]
    # err_Apin_list = [[np.linalg.norm(A @ (_m[2] - lam_ref)) / c for _m in sols[alpha]] for alpha in alpha_list]

    # measure # of components that agree
    # err_mud_list = [[numnonzero(_m[0] - lam_ref) for _m in sols[alpha]] for alpha in alpha_list]
    # err_map_list = [[numnonzero(_m[1] - lam_ref) for _m in sols[alpha]] for alpha in alpha_list]
    # err_pin_list = [[numnonzero(_m[2] - lam_ref) for _m in sols[alpha]] for alpha in alpha_list]

    x, y = np.arange(1, dim_output, 1), err_mud_list[0][0:-1]

    slope, intercept = (
        np.linalg.pinv(np.vander(x, 2)) @ np.array(y).reshape(-1, 1)
    ).ravel()
    regression = slope * x + intercept

    # ---

    # # Convergence Plot

    for idx, alpha in enumerate(alpha_list):
        if (1 + idx) % 2 and alpha <= 10:
            plt.annotate(
                f"$\\alpha$={alpha:1.2E}",
                (100, max(err_map_list[idx][-1], 0.01)),
                fontsize=24,
            )
        _err_mud = err_mud_list[idx]
        _err_map = err_map_list[idx]
        _err_pin = err_pin_list[idx]

        plt.plot(x, _err_mud[:-1], label="MUD", c="k", lw=10)
        plt.plot(x, _err_map[:-1], label="MAP", c="r", ls="--", lw=5)
        plt.plot(x, _err_pin[:-1], label="LSQ", c="xkcd:light blue", ls="-", lw=5)

    # plt.plot(x, regression, c='g', ls='-')
    # plt.xlim(0,dim_output)
    if "id" in prefix:
        plt.title(
            "Convergence for Various $\\Sigma_{init} = \\alpha I$",
            fontsize=1.25 * fsize,
        )
    else:
        plt.title(
            "Convergence for Various $\\Sigma_{init} = \\alpha \\Sigma$",
            fontsize=1.25 * fsize,
        )
        # plt.yscale('log')
    # plt.yscale('log')
    # plt.xscale('log')
    plt.ylim(0, 1.0)
    # plt.ylim(1E-4, 5E-2)
    # plt.ylabel("$\\frac{||\\lambda^\\dagger - \\lambda||}{||\\lambda^\\dagger||}$", fontsize=fsize*1.25)
    plt.ylabel("Relative Error", fontsize=fsize * 1.25)
    plt.xlabel("Dimension of Output Space", fontsize=fsize)
    plt.legend(["MUD", "MAP", "Least Squares"], fontsize=fsize)
    # plt.annotate(f'Slope={slope:1.4f}', (4,4/7), fontsize=32)
    plt.savefig(f"{fdir}/{prefix}-convergence.png", bbox_inches="tight")
    plt.close("all")
Exemple #6
0
    def plot(
        self,
        sols=None,
        num_measurements=20,
        example="mud",
        fsize=36,
        ftype="png",
        save=False,
    ):
        lam = self.lam
        qoi = self.qoi
        qoi_ref = self.qoi_ref
        # dist = self.dist
        g = self.g
        fname = self.fname.replace(".pkl", "")
        # fname = fname.replace('data/', '')
        fname = "figures/" + fname
        check_dir(fname)
        closest_fit_index_out = np.argmin(
            np.linalg.norm(qoi - np.array(qoi_ref), axis=1))
        g_projected = list(lam[closest_fit_index_out, :])
        plt.figure(figsize=(10, 10))

        g_mesh, g_plot = g
        intervals = list(np.linspace(0, 1, lam.shape[1] + 2)[1:-1])
        # fin.plot(u_plot, mesh=mesh, lw=5, c='k', label="$g$")
        plt.plot(g_mesh, g_plot, lw=5, c="k", label="$g$")
        plt.plot(
            [0] + intervals + [1],
            [0] + g_projected + [0],
            lw=5,
            c="green",
            alpha=0.6,
            ls="--",
            label="$\\hat{g}$",
            zorder=5,
        )

        if sols is not None:
            if sols.get(num_measurements, None) is None:
                raise AttributeError(
                    f"Solutions `sols` missing requested N={num_measurements}. `sols`={sols!r}"
                )
            else:
                prefix = f"{fname}/{example}_solutions_N{num_measurements}"
                plot_lam = np.array(sols[num_measurements])
                if example == "mud-alt":
                    qmap = "$Q_{%dD}^\\prime$" % lam.shape[1]
                    soltype = "MUD"
                elif example == "mud":
                    qmap = "$Q_{%dD}$" % lam.shape[1]
                    soltype = "MUD"
                elif example == "map":
                    qmap = "$Q_{1D}$"
                    soltype = "MAP"
                else:
                    raise ValueError("Unsupported example type.")
                plt.title(
                    f"{soltype} Estimates for {qmap}, $N={num_measurements}$",
                    fontsize=1.25 * fsize,
                )
        else:  # initial plot, first 100
            # prefix = f'pde_{lam.shape[1]}{dist}/initial'
            prefix = f"{fname}/{example}_initial_S{lam.shape[0]}"
            plot_lam = lam[0:100, :]
            plt.title("Samples from Initial Density", fontsize=1.25 * fsize)

        for _lam in plot_lam:
            plt.plot(
                [0] + intervals + [1],
                [0] + list(_lam) + [0],
                lw=1,
                c="purple",
                alpha=0.2,
            )

        plt.xlabel("$x_2$", fontsize=fsize)
        plt.ylabel("$g(x, \\lambda)$", fontsize=fsize)

        # label min(g)
        # plt.axvline(2/7, alpha=0.4, ls=':')
        # plt.axhline(-lam_true, alpha=0.4, ls=':')
        plt.ylim(-4, 0)
        plt.xlim(0, 1)
        plt.legend()
        if save:
            _fname = f"{prefix}.{ftype}"
            plt.savefig(_fname, bbox_inches="tight")
            _logger.info(f"Saved {_fname}")
            plt.close("all")
def plot_decay_solution(
    solutions,
    model_generator,
    sigma,
    prefix,
    time_vector,
    lam_true,
    qoi_true,
    end_time=3,
    fsize=32,
    save=True,
):
    alpha_signal = 0.2
    alpha_points = 0.6
    #     num_meas_plot_list = [25, 50, 400]
    fdir = "/".join(prefix.split("/")[:-1])
    check_dir(fdir)
    print("Plotting decay solution.")
    for num_meas_plot in solutions:
        filename = f"{prefix}_{num_meas_plot}_reference_solution.png"
        plt.rcParams["figure.figsize"] = 25, 10
        _ = plt.figure()  # TODO: proper figure handling with `fig`

        plotting_mesh = np.linspace(0, end_time, 1000 * end_time)
        plot_model = model_generator(plotting_mesh, lam_true)
        true_response = plot_model()  # no args evaluates true param

        # true signal
        plt.plot(
            plotting_mesh,
            true_response,
            lw=5,
            c="k",
            alpha=1,
            label="True Signal, $\\xi \\sim N(0, \\sigma^2)$",
        )

        # observations
        np.random.seed(11)
        annotate_height = 0.82
        u = qoi_true + np.random.randn(len(qoi_true)) * sigma
        plot_num_measure = num_meas_plot
        plt.scatter(
            time_vector[:plot_num_measure],
            u[:plot_num_measure],
            color="k",
            marker=".",
            s=250,
            alpha=alpha_points,
            label=f"{num_meas_plot} Sample Measurements",
        )
        plt.annotate("$ \\downarrow$ Observations begin",
                     (0.95, annotate_height),
                     fontsize=fsize)
        #     plt.annotate("$\\downarrow$ Possible Signals", (0,annotate_height), fontsize=fsize)

        # sample signals
        num_sample_signals = 100
        alpha_signal_sample = 0.15
        alpha_signal_mudpts = 0.45
        _true_response = plot_model(
            np.random.rand())  # uniform(0,1) draws from parameter space
        plt.plot(
            plotting_mesh,
            _true_response,
            lw=2,
            c="k",
            alpha=alpha_signal_sample,
            label="Predictions from Initial Density",
        )
        for i in range(1, num_sample_signals):
            _true_response = plot_model(
                np.random.rand())  # uniform(0,1) draws from parameter space
            plt.plot(plotting_mesh,
                     _true_response,
                     lw=1,
                     c="k",
                     alpha=alpha_signal_sample)

        # error bars
        sigma_label = f"$\\pm3\\sigma \\qquad\\qquad \\sigma^2={sigma**2:1.3E}$"
        plt.plot(
            plotting_mesh[1000:],
            true_response[1000:] + 3 * sigma,
            ls="--",
            lw=3,
            c="xkcd:black",
            alpha=1,
        )
        plt.plot(
            plotting_mesh[1000:],
            true_response[1000:] - 3 * sigma,
            ls="--",
            lw=3,
            c="xkcd:black",
            alpha=1,
            label=sigma_label,
        )
        plt.plot(
            plotting_mesh[:1000],
            true_response[:1000] + 3 * sigma,
            ls="--",
            lw=3,
            c="xkcd:black",
            alpha=alpha_signal,
        )
        plt.plot(
            plotting_mesh[:1000],
            true_response[:1000] - 3 * sigma,
            ls="--",
            lw=3,
            c="xkcd:black",
            alpha=alpha_signal,
        )

        # solutions / samples
        mud_solutions = solutions[num_meas_plot]
        plt.plot(
            plotting_mesh,
            plot_model(mud_solutions[0][0]),
            lw=3,
            c="xkcd:bright red",
            alpha=alpha_signal_mudpts,
            label=f"{len(mud_solutions)} Estimates with $N={num_meas_plot:3d}$",
        )
        for _lam in mud_solutions[1:]:
            _true_response = plot_model(_lam[0])
            plt.plot(
                plotting_mesh,
                _true_response,
                lw=3,
                c="xkcd:bright red",
                alpha=alpha_signal_mudpts,
            )

        plt.ylim([0, 0.9])
        plt.xlim([0, end_time + 0.05])
        plt.ylabel("Response", fontsize=60)
        plt.xlabel("Time", fontsize=60)
        plt.xticks(fontsize=fsize)
        plt.yticks(fontsize=fsize)
        # legend ordering has a mind of its own, so we format it to our will
        # plt.legend(fontsize=fsize, loc='upper right')
        handles, labels = plt.gca().get_legend_handles_labels()
        order = [4, 0, 2, 1, 3]
        plt.legend(
            [handles[idx] for idx in order],
            [labels[idx] for idx in order],
            fontsize=fsize,
            loc="upper right",
        )
        plt.tight_layout()
        if save:
            plt.savefig(filename, bbox_inches="tight")
Exemple #8
0
def plot_without_fenics(fname,
                        num_sensors=None,
                        num_qoi=2,
                        mode="sca",
                        fsize=36,
                        example=None):
    plt.figure(figsize=(10, 10))
    mode = mode.lower()
    colors = [
        "xkcd:red", "xkcd:black", "xkcd:orange", "xkcd:blue", "xkcd:green"
    ]

    if "data" in fname:  # TODO turn into function.
        _logger.info(f"Loading {fname} from package")
        data = pkgutil.get_data(__package__, fname)
        data = BytesIO(data)
    else:
        _logger.info("Loading from disk")
        data = open(fname, "rb")
    ref = pickle.load(data)

    sensors = ref["sensors"]
    # qoi_ref = ref['data']
    coords, vals = ref["plot_u"]
    #     try:
    #         import fenics as fin
    #         from poisson import poissonModel
    #         pn = poissonModel()
    #         fin.plot(pn, vmin=-0.5, vmax=0)
    #     except:
    plt.tricontourf(coords[:, 0],
                    coords[:, 1],
                    vals,
                    levels=20,
                    vmin=-0.5,
                    vmax=0)

    # input_dim = ref['lam'].shape[1]

    plt.title("Response Surface", fontsize=1.25 * fsize)
    if num_sensors is not None:  # plot sensors
        intervals = np.linspace(0, 1, num_qoi + 2)[1:-1]
        if mode == "sca":
            qoi_indices = band_qoi(sensors, 1, axis=1)
            _intervals = (
                np.array(intervals[1:]) +
                (np.array(intervals[:-1]) - np.array(intervals[1:])) / 2)

        elif mode == "hor":
            qoi_indices = band_qoi(sensors, num_qoi, axis=1)
            # partitions equidistant between sensors
            _intervals = (
                np.array(intervals[1:]) +
                (np.array(intervals[:-1]) - np.array(intervals[1:])) / 2)

        elif mode == "ver":
            qoi_indices = band_qoi(sensors, num_qoi, axis=0)
            # partitions equidistant on x_1 = (0, 1)
            _intervals = np.linspace(0, 1, num_qoi + 1)[1:]
        else:
            raise ValueError(
                "Unsupported mode type. Select from ('sca', 'ver', 'hor'). ")
        for i in range(0, len(qoi_indices)):
            _q = qoi_indices[i][qoi_indices[i] < num_sensors]
            plt.scatter(sensors[_q, 0],
                        sensors[_q, 1],
                        s=100,
                        color=colors[i % 2])
            if i < num_qoi - 1:
                if mode == "hor":
                    plt.axhline(_intervals[i], lw=3, c="k")
                elif mode == "ver":
                    plt.axvline(_intervals[i], lw=3, c="k")

        plt.scatter([0] * num_qoi, intervals, s=500, marker="^", c="w")
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    #     plt.xticks([])
    #     plt.yticks([])
    plt.xlabel("$x_1$", fontsize=fsize)
    plt.ylabel("$x_2$", fontsize=fsize)

    if example:
        # if 'data' in fname:  # TODO: clean this up
        #     fdir= '/'.join(fname.split('/')[1:-1])
        # else:
        #     fdir= '/'.join(fname.split('/')[:-1])
        fdir = "figures/" + fname.replace(".pkl", "")
        # print(fdir)
        check_dir(fdir)
        fname = f"{fdir}/{example}_surface.png"
        plt.savefig(fname, bbox_inches="tight")
        _logger.info(f"Saved {fname}")
Exemple #9
0
def make_reproducible_without_fenics(
    example="mud",
    lam_true=-3,
    input_dim=2,
    sample_dist="u",
    sample_tol=0.95,
    num_samples=None,
    num_measure=100,
):
    """
    (Currently) requires XML data to be on disk, simulates sensors
    and saves everything required to one pickle file.
    """
    if sample_dist == "u":
        sample_tol = 1.0
    elif sample_dist == "n":
        if sample_tol < 0 or sample_tol >= 1:
            raise ValueError(
                "Sample tolerance must be in (0, 1) when using normal distributions."
            )
    else:
        raise ValueError("Unsupported argument for `sample_dist`.")

    if lam_true < -4 or lam_true > 0:
        raise ValueError("True value must be in (-4, 0).")
    prefix = str(round(np.floor(sample_tol * 1000)))
    _logger.info("Running make_reproducible without fenics")
    # Either load or generate the data.
    try:  # TODO: generalize this path here... take as argument
        model_list = pickle.load(
            open(f"{prefix}_{input_dim}{sample_dist}.pkl", "rb"))
        if num_samples is None or num_samples > len(model_list):
            num_samples = len(model_list)

    except FileNotFoundError as e:
        if num_samples is None:
            num_samples = 50
        _logger.error(f"make_reproducible: {e}")
        _logger.warning("Attempting data generation with system call.")
        # below has to match where we expected our git-controlled file to be... TODO: generalize to data/
        # curdir = os.getcwd().split('/')[-1]
        os.system(
            f"generate_poisson_data -v -s {num_samples} -i {input_dim} -d {sample_dist} -t {sample_tol}"
        )
        try:
            model_list = pickle.load(
                open(f"{prefix}_{input_dim}{sample_dist}.pkl", "rb"))
            if num_samples is None or num_samples > len(model_list):
                num_samples = len(model_list)
        except TypeError:
            raise ModuleNotFoundError(
                "Try `conda install -c conda forge fenics`")

    fdir = f"pde_{input_dim}D"
    check_dir(fdir)

    if (
            input_dim == 1 and "alt" in example
    ):  # alternative measurement locations for more sensitivity / precision
        sensors = generate_sensors_pde(num_measure, ymax=0.95, xmax=0.25)
        fname = f"{fdir}/ref_alt_{prefix}_{input_dim}{sample_dist}.pkl"
    else:
        sensors = generate_sensors_pde(num_measure, ymax=0.95, xmax=0.95)
        fname = f"{fdir}/ref_{prefix}_{input_dim}{sample_dist}.pkl"

    lam, qoi = load_poisson_from_fenics_run(sensors,
                                            model_list[0:num_samples],
                                            nx=36,
                                            ny=36)
    qoi_ref = poisson_sensor_model(sensors, gamma=lam_true, nx=36, ny=36)

    pn = poissonModel(gamma=lam_true)
    c = pn.function_space().mesh().coordinates()
    v = [pn(c[i, 0], c[i, 1]) for i in range(len(c))]

    g = gamma_boundary_condition(lam_true)
    g_mesh = np.linspace(0, 1, 1000)
    g_plot = [g(0, y) for y in g_mesh]
    ref = {
        "sensors": sensors,
        "lam": lam,
        "qoi": qoi,
        "truth": lam_true,
        "data": qoi_ref,
        "plot_u": (c, v),
        "plot_g": (g_mesh, g_plot),
    }

    with open(fname, "wb") as f:
        pickle.dump(ref, f)
    _logger.info(fname + " saved: " + str(Path(fname).stat().st_size // 1000) +
                 "KB")

    return fname
def plot_experiment_equipment(
    tolerances,
    res,
    prefix,
    fsize=32,
    linewidth=5,
    title="Variance of MUD Error",
    save=True,
):
    print("Plotting experiments involving equipment differences...")
    plt.figure(figsize=(10, 10))
    for _res in res:
        _example, _in, _rm, _re, _fname = _res
        (
            regression_err_mean,
            slope_err_mean,
            regression_err_vars,
            slope_err_vars,
            sd_means,
            sd_vars,
            num_sensors,
        ) = _re
        plt.plot(
            tolerances,
            regression_err_mean,
            label=f"{_example.upper()} slope: {slope_err_mean:1.4f}",
            lw=linewidth,
        )
        plt.scatter(tolerances, sd_means, marker="x", lw=20)

    plt.yscale("log")
    plt.xscale("log")
    plt.Axes.set_aspect(plt.gca(), 1)
    # plt.ylim(2E-3, 2E-2)
    # plt.ylabel("Absolute Error", fontsize=fsize)
    plt.xlabel("Tolerance", fontsize=fsize)
    plt.legend()
    plt.title(f"Mean of MUD Error for N={num_sensors}", fontsize=1.25 * fsize)
    if save:
        fdir = "".join(prefix.split("/")[::-1])
        check_dir(f"figures/{_fname}/{fdir}")
        _logger.info("Saving equipment experiments: mean convergence.")
        plt.savefig(
            f"figures/{_fname}/{prefix}_convergence_mud_std_mean.png",
            bbox_inches="tight",
        )
    else:
        plt.show()

    plt.figure(figsize=(10, 10))
    for _res in res:
        _example, _in, _rm, _re, _fname = _res
        (
            regression_err_mean,
            slope_err_mean,
            regression_err_vars,
            slope_err_vars,
            sd_means,
            sd_vars,
            num_sensors,
        ) = _re
        plt.plot(
            tolerances,
            regression_err_vars,
            label=f"{_example.upper()} slope: {slope_err_vars:1.4f}",
            lw=linewidth,
        )
        plt.scatter(tolerances, sd_vars, marker="x", lw=20)
    plt.xscale("log")
    plt.yscale("log")
    # plt.ylim(2E-5, 2E-4)
    plt.Axes.set_aspect(plt.gca(), 1)
    # plt.ylabel("Absolute Error", fontsize=fsize)
    plt.xlabel("Tolerance", fontsize=fsize)
    plt.legend()
    plt.title(title, fontsize=1.25 * fsize)
    if save:
        _logger.info("Saving equipment experiments: variance convergence.")
        plt.savefig(
            f"figures/{_fname}/{prefix}_convergence_mud_std_var.png",
            bbox_inches="tight",
        )
    else:
        plt.show()
def plot_experiment_measurements(
    res,
    prefix,
    fsize=32,
    linewidth=5,
    xlabel="Number of Measurements",
    save=True,
    legend=True,
):
    print("Plotting experiments involving increasing # of measurements.")
    plt.figure(figsize=(10, 10))
    for _res in res:
        _example, _in, _rm, _re, _fname = _res
        solutions = _in[-1]
        measurements = list(solutions.keys())
        regression_mean, slope_mean, regression_vars, slope_vars, means, variances = _rm
        plt.plot(
            measurements[:len(regression_mean)],
            regression_mean,
            label=f"{_example.upper()} slope: {slope_mean:1.4f}",
            lw=linewidth,
        )
        plt.scatter(measurements[:len(means)], means, marker="x", lw=20)
    plt.xscale("log")
    plt.yscale("log")
    plt.Axes.set_aspect(plt.gca(), 1)
    # plt.ylim(0.9 * min(means), 1.3 * max(means))
    # plt.ylim(2E-3, 2E-1)
    plt.xlabel(xlabel, fontsize=fsize)
    if legend:
        plt.legend(fontsize=fsize * 0.8)
    # plt.ylabel('Absolute Error in MUD', fontsize=fsize)
    title = "$\\mathrm{\\mathbb{E}}(|\\lambda^* - \\lambda^\\dagger|)$"  # noqa E501
    plt.title(title, fontsize=1.25 * fsize)
    if save:
        fdir = "/".join(prefix.split("/")[:-1])
        check_dir(f"figures/{_fname}/{fdir}")
        _logger.info("Saving measurement experiments: mean convergence.")
        plt.savefig(f"figures/{_fname}/{prefix}_convergence_obs_mean.png",
                    bbox_inches="tight")
    else:
        plt.show()

    plt.figure(figsize=(10, 10))
    for _res in res:
        _example, _in, _rm, _re, _fname = _res
        regression_mean, slope_mean, regression_vars, slope_vars, means, variances = _rm
        plt.plot(
            measurements[:len(regression_vars)],
            regression_vars,
            label=f"{_example.upper()} slope: {slope_vars:1.4f}",
            lw=linewidth,
        )
        plt.scatter(measurements[:len(variances)],
                    variances,
                    marker="x",
                    lw=20)
    plt.xscale("log")
    plt.yscale("log")
    plt.Axes.set_aspect(plt.gca(), 1)
    #     if not len(np.unique(variances)) == 1:
    #         plt.ylim(0.9 * min(variances), 1.3 * max(variances))
    # plt.ylim(5E-6, 5E-4)
    plt.xlabel(xlabel, fontsize=fsize)
    if legend:
        plt.legend(fontsize=fsize * 0.8)
    # plt.ylabel('Absolute Error in MUD', fontsize=fsize)
    plt.title("$\\mathrm{Var}(|\\lambda^* - \\lambda^\\dagger|)$",
              fontsize=1.25 * fsize)
    if save:
        _logger.info("Saving measurement experiments: variance convergence.")
        plt.savefig(f"figures/{_fname}/{prefix}_convergence_obs_var.png",
                    bbox_inches="tight")
    else:
        plt.show()
def main(args):
    """
    Main entrypoint for example-generation
    """
    args = parse_args(args)
    setup_logging(args.loglevel)
    np.random.seed(args.seed)
    #     example       = args.example
    #     num_trials   = args.num_trials
    #     fsize        = args.fsize
    #     linewidth    = args.linewidth
    #     seed         = args.seed
    #     inputdim     = args.input_dim
    #     save         = args.save
    #     alt          = args.alt
    #     bayes        = args.bayes
    #     prefix       = args.prefix
    #     dist         = args.dist
    fdir = "figures/comparison"
    check_dir(fdir)
    presentation = False
    save = True

    if not presentation:
        plt.rcParams["mathtext.fontset"] = "stix"
        plt.rcParams["font.family"] = "STIXGeneral"

    # number of samples from initial and observed mean (mu) and st. dev (sigma)
    N, mu, sigma = int(1e3), 0.25, 0.1
    lam = np.random.uniform(low=-1, high=1, size=N)

    # Evaluate the QoI map on this initial sample set to form a predicted data
    qvals_predict = QoI(lam, 5)  # Evaluate lam^5 samples

    # Estimate the push-forward density for the QoI
    pi_predict = kde(qvals_predict)

    # Compute more observations for use in BIP
    tick_fsize = 28
    legend_fsize = 24
    for num_data in [1, 5, 10, 20]:
        np.random.seed(
            123456
        )  # Just for reproducibility, you can comment out if you want.
        data = norm.rvs(loc=mu, scale=sigma**2, size=num_data)

        # We will estimate the observed distribution using a parametric estimate to keep
        # the assumptions involved as similar as possible between the BIP and the SIP
        # So, we will assume the sigma is known but that the mean mu is unknown and estimated
        # from data to fit a Gaussian distribution
        mu_est = np.mean(data)

        r_approx = np.divide(norm.pdf(qvals_predict, loc=mu_est, scale=sigma),
                             pi_predict(qvals_predict))

        # Use r to compute weighted KDE approximating the updated density
        update_kde = kde(lam, weights=r_approx)

        # Construct estimated push-forward of this updated density
        pf_update_kde = kde(qvals_predict, weights=r_approx)

        likelihood_vals = np.zeros(N)
        for i in range(N):
            likelihood_vals[i] = data_likelihood(qvals_predict[i], data,
                                                 num_data, sigma)

        # compute normalizing constants
        C_nonlinear = np.mean(likelihood_vals)
        data_like_normalized = likelihood_vals / C_nonlinear

        posterior_kde = kde(lam, weights=data_like_normalized)

        # Construct push-forward of statistical Bayesian posterior
        pf_posterior_kde = kde(qvals_predict, weights=data_like_normalized)

        # Plot the initial, updated, and posterior densities
        fig, ax = plt.subplots(figsize=(10, 10))
        lam_plot = np.linspace(-1, 1, num=1000)
        ax.plot(
            lam_plot,
            uniform.pdf(lam_plot, loc=-1, scale=2),
            "b--",
            linewidth=4,
            label="Initial/Prior",
        )
        ax.plot(lam_plot,
                update_kde(lam_plot),
                "k-.",
                linewidth=4,
                label="Update")
        ax.plot(lam_plot,
                posterior_kde(lam_plot),
                "g:",
                linewidth=4,
                label="Posterior")
        ax.set_xlim([-1, 1])
        if num_data > 1:
            plt.annotate(f"$N={num_data}$", (-0.75, 5),
                         fontsize=legend_fsize * 1.5)
            ax.set_ylim([0, 28])  # fix axis height for comparisons
        #         else:
        #             ax.set_ylim([0, 5])
        ax.tick_params(axis="x", labelsize=tick_fsize)
        ax.tick_params(axis="y", labelsize=tick_fsize)
        ax.set_xlabel("$\\Lambda$", fontsize=1.25 * tick_fsize)
        ax.legend(fontsize=legend_fsize, loc="upper left")
        if save:
            fig.savefig(f"{fdir}/bip-vs-sip-{num_data}.png",
                        bbox_inches="tight")
            plt.close(fig)
        # plt.show()

        # Plot the push-forward of the initial, observed density,
        # and push-forward of pullback and stats posterior
        fig, ax = plt.subplots(figsize=(10, 10))
        qplot = np.linspace(-1, 1, num=1000)
        ax.plot(
            qplot,
            norm.pdf(qplot, loc=mu, scale=sigma),
            "r-",
            linewidth=6,
            label="$N(0.25, 0.1^2)$",
        )
        ax.plot(qplot,
                pi_predict(qplot),
                "b-.",
                linewidth=4,
                label="PF of Initial")
        ax.plot(qplot,
                pf_update_kde(qplot),
                "k--",
                linewidth=4,
                label="PF of Update")
        ax.plot(qplot,
                pf_posterior_kde(qplot),
                "g:",
                linewidth=4,
                label="PF of Posterior")

        ax.set_xlim([-1, 1])
        if num_data > 1:
            plt.annotate(f"$N={num_data}$", (-0.75, 5),
                         fontsize=legend_fsize * 1.5)
            ax.set_ylim([0, 28])  # fix axis height for comparisons
        #         else:
        #             ax.set_ylim([0, 5])
        ax.tick_params(axis="x", labelsize=tick_fsize)
        ax.tick_params(axis="y", labelsize=tick_fsize)
        ax.set_xlabel("$\\mathcal{D}$", fontsize=1.25 * tick_fsize)
        ax.legend(fontsize=legend_fsize, loc="upper left")
        if save:
            fig.savefig(f"{fdir}/bip-vs-sip-pf-{num_data}.png",
                        bbox_inches="tight")
            plt.close(fig)