def main():
    Y_obs, amp, tec, freqs = generate_data()
    TEC_CONV = -8.4479745e6  # mTECU/Hz

    def log_normal(x, mean, scale):
        dx = (x - mean) / scale
        return -0.5 * x.size * jnp.log(2. * jnp.pi) - x.size*jnp.log(scale) \
               - 0.5 * dx @ dx

    def log_laplace(x, mean, scale):
        dx = jnp.abs(x - mean) / scale
        return -x.size * jnp.log(2. * scale) - jnp.sum(dx)

    def log_likelihood(tec, const, uncert, **kwargs):
        phase = tec * (TEC_CONV / freqs) + const
        Y = jnp.concatenate([amp * jnp.cos(phase), amp * jnp.sin(phase)],
                            axis=-1)
        log_prob = log_laplace(Y, Y_obs, uncert[0])
        return log_prob

    prior_chain = PriorChain() \
        .push(UniformPrior('tec', -100., 100.)) \
        .push(UniformPrior('const', -jnp.pi, jnp.pi)) \
        .push(HalfLaplacePrior('uncert', 0.25))

    print("Probabilistic model:\n{}".format(prior_chain))

    ns = NestedSampler(
        log_likelihood,
        prior_chain,
        sampler_name='slice',
        tec_mean=lambda tec, **kw:
        tec,  #I would like to this function over the posterior
        const_mean=lambda const, **kw:
        const  #I would like to this function over the posterior
    )

    run = jit(lambda key: ns(key=key,
                             num_live_points=1000,
                             max_samples=1e5,
                             collect_samples=True,
                             termination_frac=0.01,
                             stoachastic_uncertainty=False,
                             sampler_kwargs=dict(depth=4, num_slices=1)))

    t0 = default_timer()
    results = run(random.PRNGKey(2364))
    print(results.efficiency)
    print("Time compile", default_timer() - t0)

    t0 = default_timer()
    results = run(random.PRNGKey(1324))
    print(results.efficiency)
    print("Time no compile", default_timer() - t0)

    ###
    print(results.marginalised['tec_mean'])
    print(results.marginalised['const_mean'])
    plot_diagnostics(results)
    plot_cornerplot(results)
Ejemplo n.º 2
0
 def diagnostics(self):
     """
     Plot diagnostics of the result. This is a wrapper of :func:`jaxns.plotting.plot_diagnostics`
     and :func:`jaxns.plotting.plot_cornerplot`.
     """
     if self._results is None:
         raise RuntimeError(
             "NestedSampler.run(...) method should be called first to obtain results."
         )
     plot_diagnostics(self._results)
     plot_cornerplot(self._results)
Ejemplo n.º 3
0
def main():

    ndims = 4
    sigma = 0.1

    def log_likelihood(theta, **kwargs):
        r2 = jnp.sum(theta**2)
        logL = -0.5 * jnp.log(2. * jnp.pi * sigma**2) * ndims
        logL += -0.5 * r2 / sigma**2
        return logL

    prior_transform = PriorChain().push(
        UniformPrior('theta', -jnp.ones(ndims), jnp.ones(ndims)))
    ns = NestedSampler(log_likelihood, prior_transform, sampler_name='slice')

    def run_with_n(n):
        @jit
        def run(key):
            return ns(key=key,
                      num_live_points=n,
                      max_samples=1e5,
                      collect_samples=True,
                      termination_frac=0.01,
                      stoachastic_uncertainty=False,
                      sampler_kwargs=dict(depth=3))

        t0 = default_timer()
        results = run(random.PRNGKey(0))
        print(results.efficiency)
        print("Time to run including compile:", default_timer() - t0)
        print("Time efficiency normalised:",
              results.efficiency * (default_timer() - t0))
        t0 = default_timer()
        results = run(random.PRNGKey(1))
        print(results.efficiency)
        print("Time to run no compile:", default_timer() - t0)
        print("Time efficiency normalised:",
              results.efficiency * (default_timer() - t0))
        return results

    for n in [1000]:
        results = run_with_n(n)
        plt.scatter(n, results.logZ)
        plt.errorbar(n, results.logZ, yerr=results.logZerr)

    plt.show()

    # plot_samples_development(results, save_name='./example.mp4')
    plot_diagnostics(results)
    plot_cornerplot(results)
Ejemplo n.º 4
0
def main():
    def log_likelihood(theta, **kwargs):
        return 5. * (2. + jnp.prod(jnp.cos(0.5 * theta)))

    prior_chain = PriorChain() \
        .push(UniformPrior('theta', low=jnp.zeros(2), high=jnp.pi * 10. * jnp.ones(2)))

    theta = vmap(
        lambda key: prior_chain(random.uniform(key, (prior_chain.U_ndims, ))))(
            random.split(random.PRNGKey(0), 10000))
    lik = vmap(lambda theta: log_likelihood(**theta))(theta)
    sc = plt.scatter(theta['theta'][:, 0], theta['theta'][:, 1], c=lik)
    plt.colorbar(sc)
    plt.show()

    ns = NestedSampler(log_likelihood, prior_chain, sampler_name='slice')

    def run_with_n(n):
        @jit
        def run(key):
            return ns(key=key,
                      num_live_points=n,
                      max_samples=1e5,
                      collect_samples=True,
                      termination_frac=0.01,
                      stoachastic_uncertainty=False,
                      sampler_kwargs=dict(depth=7))

        t0 = default_timer()
        # with disable_jit():
        results = run(random.PRNGKey(0))
        print("Efficiency", results.efficiency)
        print("Time to run (including compile)", default_timer() - t0)
        t0 = default_timer()
        results = run(random.PRNGKey(1))
        print(results.efficiency)
        print("Time to run (no compile)", default_timer() - t0)
        return results

    for n in [500]:
        results = run_with_n(n)
        plt.scatter(n, results.logZ)
        plt.errorbar(n, results.logZ, yerr=results.logZerr)
    plt.ylabel('log Z')
    plt.show()

    plot_diagnostics(results)
    plot_cornerplot(results)
    return results.logZ, results.logZerr
Ejemplo n.º 5
0
def test_unit_cube_mixture_prior():
    import jax.numpy as jnp
    from jax import random
    from jaxns.nested_sampling import NestedSampler
    from jaxns.plotting import plot_cornerplot, plot_diagnostics

    # prior_chain = PriorChain().push(MultiCubeMixturePrior('x', 2, 1, -5., 15.))
    prior_chain = PriorChain().push(GMMMarginalPrior('x', 2, -5., 15.))

    def loglikelihood(x, **kwargs):
        return jnp.log(
            0.5 * jnp.exp(-0.5 * jnp.sum(x)**2) / jnp.sqrt(2. * jnp.pi) +
            0.5 * jnp.exp(-0.5 * jnp.sum(x - 10.)**2) / jnp.sqrt(2. * jnp.pi))

    ns = NestedSampler(loglikelihood, prior_chain, sampler_name='ellipsoid')
    results = ns(random.PRNGKey(0),
                 100,
                 max_samples=1e5,
                 collect_samples=True,
                 termination_frac=0.05,
                 stoachastic_uncertainty=True)
    plot_diagnostics(results)
    plot_cornerplot(results)
Ejemplo n.º 6
0
def main(kernel):
    def log_normal(x, mean, cov):
        L = jnp.linalg.cholesky(cov)
        dx = x - mean

        dx = solve_triangular(L, dx, lower=True)
        # maha = dx @ jnp.linalg.solve(cov, dx)
        maha = dx @ dx
        # logdet = jnp.log(jnp.linalg.det(cov))
        logdet = jnp.sum(jnp.log(jnp.diag(L)))
        log_prob = -0.5 * x.size * jnp.log(2. * jnp.pi) - logdet - 0.5 * maha
        return log_prob

    true_height, true_width, true_sigma, true_l, true_uncert = 200., 100., 1., 10., 2.5
    nant = 5
    ndir = 5
    X, Y, Y_obs = rbf_dtec(nant, ndir, true_height, true_width, true_sigma, true_l, true_uncert)
    a = X[:, 0:3]
    k = X[:, 3:6]
    x0 = a[0, :]

    def log_likelihood(dtec, uncert, **kwargs):
        """
        P(Y|sigma, half_width) = N[Y, mu, K]
        Args:
            sigma:
            l:

        Returns:

        """
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        return log_normal(Y_obs, dtec, data_cov)

    def predict_f(dtec, uncert, **kwargs):
        return dtec

    def predict_fvar(dtec, uncert, **kwargs):
        return dtec ** 2

    def tec_to_dtec(tec):
        tec = tec.reshape((nant, ndir))
        dtec = jnp.reshape(tec - tec[0, :], (-1,))
        return dtec

    prior_chain = build_prior(X, kernel, tec_to_dtec, x0)
    print(prior_chain)

    U_test = jnp.array([random.uniform(key, shape=(prior_chain.U_ndims,)) for key in random.split(random.PRNGKey(4325),1000)])
    log_lik = jnp.array([log_likelihood(**prior_chain(U)) for U in U_test])
    print(jnp.sum(jnp.isnan(log_lik)))
    print(U_test[jnp.isnan(log_lik)])
    ns = NestedSampler(log_likelihood, prior_chain, sampler_name='slice', predict_f=predict_f,
                       predict_fvar=predict_fvar)

    def run_with_n(n):
        @jit
        def run(key):
            return ns(key=key,
                      num_live_points=n,
                      max_samples=1e5,
                      collect_samples=True,
                      termination_frac=0.01,
                      stoachastic_uncertainty=False,
                      sampler_kwargs=dict(depth=7, num_slices=1))

        t0 = default_timer()
        results = run(random.PRNGKey(0))
        print("Efficiency", results.efficiency)
        print("Time to run (including compile)", default_timer() - t0)
        t0 = default_timer()
        results = run(random.PRNGKey(1))
        print("Efficiency",results.efficiency)
        print("Time to run (no compile)", default_timer() - t0)
        print("Time efficiency normalised", (default_timer() - t0)*results.efficiency)
        return results

    for n in [1000]:
        results = run_with_n(n)
        plt.scatter(n, results.logZ)
        plt.errorbar(n, results.logZ, yerr=results.logZerr)

    # #
    # K = GaussianProcessKernelPrior('K',
    #                                TomographicKernel(x0, kernel, S=20), X,
    #                                MVNPrior('height', results.param_mean['height'], results.param_covariance['height']),#UniformPrior('height', 100., 300.),
    #                                MVNPrior('width', results.param_mean['width'], results.param_covariance['width']),#UniformPrior('width', 50., 150.),
    #                                MVNPrior('l', results.param_mean['l'], results.param_covariance['l']),#UniformPrior('l', 7., 20.),
    #                                MVNPrior('sigma', results.param_mean['sigma'], results.param_covariance['sigma']),#UniformPrior('sigma', 0.3, 2.),
    #                                tracked=False)
    # tec = MVNPrior('tec', jnp.zeros((X.shape[0],)), K, ill_cond=True, tracked=False)
    # dtec = DeterministicTransformPrior('dtec', tec_to_dtec, tec.to_shape, tec, tracked=False)
    # prior_chain = PriorChain() \
    #     .push(dtec) \
    #     .push(UniformPrior('uncert', 0., 5.))
    #
    # ns = NestedSampler(log_likelihood, prior_chain, sampler_name='multi_ellipsoid', predict_f=predict_f,
    #                    predict_fvar=predict_fvar)
    #
    # def run_with_n(n):
    #     @jit
    #     def run(key):
    #         return ns(key=key,
    #                   num_live_points=n,
    #                   max_samples=1e5,
    #                   collect_samples=True,
    #                   termination_frac=0.01,
    #                   stoachastic_uncertainty=False,
    #                   sampler_kwargs=dict(depth=4))
    #
    #     t0 = default_timer()
    #     results = run(random.PRNGKey(0))
    #     print("Efficiency", results.efficiency)
    #     print("Time to run (including compile)", default_timer() - t0)
    #     t0 = default_timer()
    #     results = run(random.PRNGKey(1))
    #     print(results.efficiency)
    #     print("Time to run (no compile)", default_timer() - t0)
    #     return results
    #
    # for n in [100]:
    #     results = run_with_n(n)
    #     plt.scatter(n, results.logZ)
    #     plt.errorbar(n, results.logZ, yerr=results.logZerr)


    plt.title("Kernel: {}".format(kernel.__class__.__name__))
    plt.ylabel('log Z')
    plt.show()

    fstd = jnp.sqrt(results.marginalised['predict_fvar'] - results.marginalised['predict_f'] ** 2)
    plt.scatter(jnp.arange(Y.size),Y_obs, marker='+', label='data')
    plt.scatter(jnp.arange(Y.size),Y, marker="o", label='underlying')
    plt.scatter(jnp.arange(Y.size), results.marginalised['predict_f'], marker=".", label='underlying')
    plt.errorbar(jnp.arange(Y.size), results.marginalised['predict_f'], yerr=fstd, label='marginalised')
    plt.title("Kernel: {}".format(kernel.__class__.__name__))
    plt.legend()
    plt.show()

    # plot_samples_development(results,save_name='./ray_integral_solution.mp4')
    plot_diagnostics(results)
    plot_cornerplot(results)
    return results.logZ, results.logZerr
Ejemplo n.º 7
0
def main(kernel):
    print(("Working on Kernel: {}".format(kernel.__class__.__name__)))

    def log_normal(x, mean, cov):
        L = jnp.linalg.cholesky(cov)
        # U, S, Vh = jnp.linalg.svd(cov)
        log_det = jnp.sum(jnp.log(jnp.diag(L)))  # jnp.sum(jnp.log(S))#
        dx = x - mean
        dx = solve_triangular(L, dx, lower=True)
        # U S Vh V 1/S Uh
        # pinv = (Vh.T.conj() * jnp.where(S!=0., jnp.reciprocal(S), 0.)) @ U.T.conj()
        maha = dx @ dx  # dx @ pinv @ dx#solve_triangular(L, dx, lower=True)
        log_likelihood = -0.5 * x.size * jnp.log(2. * jnp.pi) \
                         - log_det \
                         - 0.5 * maha
        # print(log_likelihood)
        return log_likelihood

    N = 100
    X = jnp.linspace(-2., 2., N)[:, None]
    true_sigma, true_l, true_uncert = 1., 0.2, 0.2
    data_mu = jnp.zeros((N, ))
    prior_cov = RBF()(X, X, true_l, true_sigma) + 1e-13 * jnp.eye(N)
    # print(jnp.linalg.cholesky(prior_cov), jnp.linalg.eigvals(prior_cov))
    # return
    Y = jnp.linalg.cholesky(prior_cov) @ random.normal(random.PRNGKey(0),
                                                       shape=(N, )) + data_mu
    Y_obs = Y + true_uncert * random.normal(random.PRNGKey(1), shape=(N, ))
    Y_obs = jnp.where((jnp.arange(N) > 50) & (jnp.arange(N) < 60),
                      random.normal(random.PRNGKey(1), shape=(N, )), Y_obs)

    # plt.scatter(X[:, 0], Y_obs, label='data')
    # plt.plot(X[:, 0], Y, label='underlying')
    # plt.legend()
    # plt.show()

    def log_likelihood(K, uncert, **kwargs):
        """
        P(Y|sigma, half_width) = N[Y, mu, K]
        Args:
            sigma:
            l:

        Returns:

        """
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        mu = jnp.zeros_like(Y_obs)
        return log_normal(Y_obs, mu, K + data_cov)

    def predict_f(K, uncert, **kwargs):
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        mu = jnp.zeros_like(Y_obs)
        return mu + K @ jnp.linalg.solve(K + data_cov, Y_obs)

    def predict_fvar(K, uncert, **kwargs):
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        mu = jnp.zeros_like(Y_obs)
        return jnp.diag(K - K @ jnp.linalg.solve(K + data_cov, K))

    l = UniformPrior('l', 0., 2.)
    uncert = UniformPrior('uncert', 0., 2.)
    sigma = UniformPrior('sigma', 0., 2.)
    cov = GaussianProcessKernelPrior('K', kernel, X, l, sigma)
    prior_chain = PriorChain().push(uncert).push(cov)
    # print(prior_chain)

    ns = NestedSampler(log_likelihood,
                       prior_chain,
                       sampler_name='multi_ellipsoid',
                       predict_f=predict_f,
                       predict_fvar=predict_fvar)

    def run_with_n(n):
        @jit
        def run(key):
            return ns(key=key,
                      num_live_points=n,
                      max_samples=1e5,
                      collect_samples=True,
                      termination_frac=0.01,
                      stoachastic_uncertainty=False,
                      sampler_kwargs=dict(depth=3))

        t0 = default_timer()
        # with disable_jit():
        results = run(random.PRNGKey(6))
        print(results.efficiency)
        print(
            "Time to execute (including compile): {}".format(default_timer() -
                                                             t0))
        t0 = default_timer()
        results = run(random.PRNGKey(6))
        print(results.efficiency)
        print("Time to execute (not including compile): {}".format(
            (default_timer() - t0)))
        return results

    for n in [100]:
        results = run_with_n(n)
        plt.scatter(n, results.logZ)
        plt.errorbar(n, results.logZ, yerr=results.logZerr)
    plt.title("Kernel: {}".format(kernel.__class__.__name__))
    plt.ylabel('log Z')
    plt.show()

    plt.scatter(X[:, 0], Y_obs, label='data')
    plt.plot(X[:, 0], Y, label='underlying')
    plt.plot(X[:, 0], results.marginalised['predict_f'], label='marginalised')
    plt.plot(X[:, 0],
             results.marginalised['predict_f'] +
             jnp.sqrt(results.marginalised['predict_fvar']),
             ls='dotted',
             c='black')
    plt.plot(X[:, 0],
             results.marginalised['predict_f'] -
             jnp.sqrt(results.marginalised['predict_fvar']),
             ls='dotted',
             c='black')
    plt.title("Kernel: {}".format(kernel.__class__.__name__))
    plt.legend()
    plt.show()
    plot_diagnostics(results)
    plot_cornerplot(results)
    return results.logZ, results.logZerr
Ejemplo n.º 8
0
def main(kernel):
    def log_normal(x, mean, cov):
        L = jnp.linalg.cholesky(cov)
        dx = x - mean

        dx = solve_triangular(L, dx, lower=True)
        # maha = dx @ jnp.linalg.solve(cov, dx)
        maha = dx @ dx
        # logdet = jnp.log(jnp.linalg.det(cov))
        logdet = jnp.sum(jnp.log(jnp.diag(L)))
        log_prob = -0.5 * x.size * jnp.log(2. * jnp.pi) - logdet - 0.5 * maha
        return log_prob

    true_height, true_width, true_sigma, true_l, true_uncert, true_v = 200., 100., 1., 10., 2.5, jnp.array(
        [0.3, 0., 0.])
    nant = 2
    ndir = 1
    ntime = 20
    X, Y, Y_obs = rbf_dtec(nant, ndir, ntime, true_height, true_width,
                           true_sigma, true_l, true_uncert, true_v)
    a = X[:, 0:3]
    k = X[:, 3:6]
    t = X[:, 6:7]
    x0 = a[0, :]

    def log_likelihood(dtec, uncert, **kwargs):
        """
        P(Y|sigma, half_width) = N[Y, mu, K]
        Args:
            sigma:
            l:

        Returns:

        """
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        return log_normal(Y_obs, dtec, data_cov)

    def predict_f(dtec, **kwargs):
        return dtec

    def predict_fvar(dtec, **kwargs):
        return dtec**2

    def tec_to_dtec(tec):
        tec = tec.reshape((nant, ndir, ntime))
        dtec = jnp.reshape(tec - tec[0, :, :], (-1, ))
        return dtec

    prior_chain = build_frozen_flow_prior(X, kernel, tec_to_dtec, x0)

    ns = NestedSampler(log_likelihood,
                       prior_chain,
                       sampler_name='slice',
                       predict_f=predict_f,
                       predict_fvar=predict_fvar)

    def run_with_n(n):
        @jit
        def run(key):
            return ns(key=key,
                      num_live_points=n,
                      max_samples=1e5,
                      collect_samples=True,
                      termination_frac=0.01,
                      stoachastic_uncertainty=False,
                      sampler_kwargs=dict(depth=5, num_slices=1))

        t0 = default_timer()
        results = run(random.PRNGKey(0))
        print("Efficiency", results.efficiency)
        print("Time to run (including compile)", default_timer() - t0)
        # t0 = default_timer()
        # results = run(random.PRNGKey(1))
        # print(results.efficiency)
        # print("Time to run (no compile)", default_timer() - t0)
        return results

    for n in [100]:
        results = run_with_n(n)
        plt.scatter(n, results.logZ)
        plt.errorbar(n, results.logZ, yerr=results.logZerr)
    plt.title("Kernel: {}".format(kernel.__class__.__name__))
    plt.ylabel('log Z')
    plt.show()

    fstd = jnp.sqrt(results.marginalised['predict_fvar'] -
                    results.marginalised['predict_f']**2)
    plt.scatter(jnp.arange(Y.size), Y_obs, marker='+', label='data')
    plt.scatter(jnp.arange(Y.size), Y, marker="o", label='underlying')
    plt.scatter(jnp.arange(Y.size),
                results.marginalised['predict_f'],
                marker=".",
                label='underlying')
    plt.errorbar(jnp.arange(Y.size),
                 results.marginalised['predict_f'],
                 yerr=fstd,
                 label='marginalised')
    plt.title("Kernel: {}".format(kernel.__class__.__name__))
    plt.legend()
    plt.show()

    plot_diagnostics(results)
    plot_cornerplot(results)
    return results.logZ, results.logZerr
Ejemplo n.º 9
0
def main():
    def log_normal(x, mean, cov):
        L = jnp.linalg.cholesky(cov)
        dx = x - mean
        dx = solve_triangular(L, dx, lower=True)
        return -0.5 * x.size * jnp.log(2. * jnp.pi) - jnp.sum(jnp.log(jnp.diag(L))) \
               - 0.5 * dx @ dx

    N = 100
    X = jnp.linspace(-2., 2., N)[:, None]
    true_alpha, true_sigma, true_l, true_uncert = 1., 1., 0.2, 0.25
    data_mu = jnp.zeros((N, ))
    prior_cov = RationalQuadratic()(X, X, true_l, true_sigma, true_alpha)
    Y = jnp.linalg.cholesky(prior_cov) @ random.normal(random.PRNGKey(0),
                                                       shape=(N, )) + data_mu
    Y_obs = Y + true_uncert * random.normal(random.PRNGKey(1), shape=(N, ))

    # Y_obs = jnp.where((jnp.arange(N) > 50) & (jnp.arange(N) < 60),
    #                   random.normal(random.PRNGKey(1), shape_dict=(N, )),
    #                   Y_obs)

    # plt.scatter(X[:, 0], Y_obs, label='data')
    # plt.plot(X[:, 0], Y, label='underlying')
    # plt.legend()
    # plt.show()

    def log_likelihood(K, uncert, **kwargs):
        """
        P(Y|sigma, half_width) = N[Y, mu, K]
        Args:
            sigma:
            l:

        Returns:

        """
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        mu = jnp.zeros_like(Y_obs)
        log_prob = log_normal(Y_obs, mu, K + data_cov)
        # print(log_prob)
        return log_prob

    def predict_f(K, uncert, **kwargs):
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        mu = jnp.zeros_like(Y_obs)
        return mu + K @ jnp.linalg.solve(K + data_cov, Y_obs)

    def predict_fvar(K, uncert, **kwargs):
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        mu = jnp.zeros_like(Y_obs)
        return jnp.diag(K - K @ jnp.linalg.solve(K + data_cov, K))

    prior_chain = PriorChain() \
        .push(GaussianProcessKernelPrior('K', RationalQuadratic(), X,
                                         UniformPrior('l', 0., 4.),
                                         UniformPrior('sigma', 0., 4.),
                                         UniformPrior('alpha', 0., 4.))) \
        .push(UniformPrior('uncert', 0., 2.))

    ns = NestedSampler(log_likelihood,
                       prior_chain,
                       sampler_name='multi_ellipsoid',
                       predict_f=predict_f,
                       predict_fvar=predict_fvar)

    def run_with_n(n):
        @jit
        def run():
            return ns(key=random.PRNGKey(0),
                      num_live_points=n,
                      max_samples=1e4,
                      collect_samples=True,
                      termination_frac=0.01,
                      stoachastic_uncertainty=False,
                      sampler_kwargs=dict(depth=4))

        results = run()
        return results

    for n in [200]:
        results = run_with_n(n)
        plt.scatter(n, results.logZ)
        plt.errorbar(n, results.logZ, yerr=results.logZerr)
    plt.title("Kernel: {}".format(RationalQuadratic.__name__))
    plt.ylabel('log Z')
    plt.show()

    plt.scatter(X[:, 0], Y_obs, label='data')
    plt.plot(X[:, 0], Y, label='underlying')
    plt.plot(X[:, 0], results.marginalised['predict_f'], label='marginalised')
    plt.plot(X[:, 0],
             results.marginalised['predict_f'] +
             jnp.sqrt(results.marginalised['predict_fvar']),
             ls='dotted',
             c='black')
    plt.plot(X[:, 0],
             results.marginalised['predict_f'] -
             jnp.sqrt(results.marginalised['predict_fvar']),
             ls='dotted',
             c='black')
    plt.title("Kernel: {}".format(RationalQuadratic.__name__))
    plt.legend()
    plt.show()

    plot_diagnostics(results)
    plot_cornerplot(results)
    return results.logZ, results.logZerr