コード例 #1
0
def main():
    Y_obs, amp, tec, freqs = generate_data()
    TEC_CONV = -8.4479745e6  # mTECU/Hz

    def log_normal(x, mean, scale):
        dx = (x - mean) / scale
        return -0.5 * x.size * jnp.log(2. * jnp.pi) - x.size*jnp.log(scale) \
               - 0.5 * dx @ dx

    def log_laplace(x, mean, scale):
        dx = jnp.abs(x - mean) / scale
        return -x.size * jnp.log(2. * scale) - jnp.sum(dx)

    def log_likelihood(tec, const, uncert, **kwargs):
        phase = tec * (TEC_CONV / freqs) + const
        Y = jnp.concatenate([amp * jnp.cos(phase), amp * jnp.sin(phase)],
                            axis=-1)
        log_prob = log_laplace(Y, Y_obs, uncert[0])
        return log_prob

    prior_chain = PriorChain() \
        .push(UniformPrior('tec', -100., 100.)) \
        .push(UniformPrior('const', -jnp.pi, jnp.pi)) \
        .push(HalfLaplacePrior('uncert', 0.25))

    print("Probabilistic model:\n{}".format(prior_chain))

    ns = NestedSampler(
        log_likelihood,
        prior_chain,
        sampler_name='slice',
        tec_mean=lambda tec, **kw:
        tec,  #I would like to this function over the posterior
        const_mean=lambda const, **kw:
        const  #I would like to this function over the posterior
    )

    run = jit(lambda key: ns(key=key,
                             num_live_points=1000,
                             max_samples=1e5,
                             collect_samples=True,
                             termination_frac=0.01,
                             stoachastic_uncertainty=False,
                             sampler_kwargs=dict(depth=4, num_slices=1)))

    t0 = default_timer()
    results = run(random.PRNGKey(2364))
    print(results.efficiency)
    print("Time compile", default_timer() - t0)

    t0 = default_timer()
    results = run(random.PRNGKey(1324))
    print(results.efficiency)
    print("Time no compile", default_timer() - t0)

    ###
    print(results.marginalised['tec_mean'])
    print(results.marginalised['const_mean'])
    plot_diagnostics(results)
    plot_cornerplot(results)
コード例 #2
0
ファイル: toy_gaussian.py プロジェクト: fehiepsi/jaxns
def main():

    ndims = 4
    sigma = 0.1

    def log_likelihood(theta, **kwargs):
        r2 = jnp.sum(theta**2)
        logL = -0.5 * jnp.log(2. * jnp.pi * sigma**2) * ndims
        logL += -0.5 * r2 / sigma**2
        return logL

    prior_transform = PriorChain().push(
        UniformPrior('theta', -jnp.ones(ndims), jnp.ones(ndims)))
    ns = NestedSampler(log_likelihood, prior_transform, sampler_name='slice')

    def run_with_n(n):
        @jit
        def run(key):
            return ns(key=key,
                      num_live_points=n,
                      max_samples=1e5,
                      collect_samples=True,
                      termination_frac=0.01,
                      stoachastic_uncertainty=False,
                      sampler_kwargs=dict(depth=3))

        t0 = default_timer()
        results = run(random.PRNGKey(0))
        print(results.efficiency)
        print("Time to run including compile:", default_timer() - t0)
        print("Time efficiency normalised:",
              results.efficiency * (default_timer() - t0))
        t0 = default_timer()
        results = run(random.PRNGKey(1))
        print(results.efficiency)
        print("Time to run no compile:", default_timer() - t0)
        print("Time efficiency normalised:",
              results.efficiency * (default_timer() - t0))
        return results

    for n in [1000]:
        results = run_with_n(n)
        plt.scatter(n, results.logZ)
        plt.errorbar(n, results.logZ, yerr=results.logZerr)

    plt.show()

    # plot_samples_development(results, save_name='./example.mp4')
    plot_diagnostics(results)
    plot_cornerplot(results)
コード例 #3
0
ファイル: egg_box_problem.py プロジェクト: fehiepsi/jaxns
def main():
    def log_likelihood(theta, **kwargs):
        return 5. * (2. + jnp.prod(jnp.cos(0.5 * theta)))

    prior_chain = PriorChain() \
        .push(UniformPrior('theta', low=jnp.zeros(2), high=jnp.pi * 10. * jnp.ones(2)))

    theta = vmap(
        lambda key: prior_chain(random.uniform(key, (prior_chain.U_ndims, ))))(
            random.split(random.PRNGKey(0), 10000))
    lik = vmap(lambda theta: log_likelihood(**theta))(theta)
    sc = plt.scatter(theta['theta'][:, 0], theta['theta'][:, 1], c=lik)
    plt.colorbar(sc)
    plt.show()

    ns = NestedSampler(log_likelihood, prior_chain, sampler_name='slice')

    def run_with_n(n):
        @jit
        def run(key):
            return ns(key=key,
                      num_live_points=n,
                      max_samples=1e5,
                      collect_samples=True,
                      termination_frac=0.01,
                      stoachastic_uncertainty=False,
                      sampler_kwargs=dict(depth=7))

        t0 = default_timer()
        # with disable_jit():
        results = run(random.PRNGKey(0))
        print("Efficiency", results.efficiency)
        print("Time to run (including compile)", default_timer() - t0)
        t0 = default_timer()
        results = run(random.PRNGKey(1))
        print(results.efficiency)
        print("Time to run (no compile)", default_timer() - t0)
        return results

    for n in [500]:
        results = run_with_n(n)
        plt.scatter(n, results.logZ)
        plt.errorbar(n, results.logZ, yerr=results.logZerr)
    plt.ylabel('log Z')
    plt.show()

    plot_diagnostics(results)
    plot_cornerplot(results)
    return results.logZ, results.logZerr
コード例 #4
0
    def run_jaxns(num_live_points):
        try:
            from jaxns.nested_sampling import NestedSampler
            from jaxns.prior_transforms import PriorChain, UniformPrior
        except:
            raise ImportError("Install JaxNS!")
        from timeit import default_timer
        from jax import random, jit
        import jax.numpy as jnp

        def log_likelihood(theta, **kwargs):
            r2 = jnp.sum(theta ** 2)
            logL = -0.5 * jnp.log(2. * jnp.pi * sigma ** 2) * ndims
            logL += -0.5 * r2 / sigma ** 2
            return logL

        prior_transform = PriorChain().push(UniformPrior('theta', -jnp.ones(ndims), jnp.ones(ndims)))
        ns = NestedSampler(log_likelihood, prior_transform, sampler_name='slice')

        def run_with_n(n):
            @jit
            def run(key):
                return ns(key=key,
                          num_live_points=n,
                          max_samples=1e6,
                          collect_samples=False,
                          termination_frac=0.01,
                          stoachastic_uncertainty=False,
                          sampler_kwargs=dict(depth=3, num_slices=2))

            results = run(random.PRNGKey(0))
            results.logZ.block_until_ready()
            t0 = default_timer()
            results = run(random.PRNGKey(1))
            print("Efficiency and logZ", results.efficiency, results.logZ)
            run_time = (default_timer() - t0)
            return run_time

        return run_with_n(num_live_points)
コード例 #5
0
    def E_update(self, prior_mu, prior_Gamma, Y, Sigma, *control_params):
        # amp = control_params[0]
        key = control_params[1]

        prior_chain = PriorChain() \
            .push(MVNPrior('param', prior_mu, prior_Gamma))

        # .push(HalfLaplacePrior('uncert', jnp.sqrt(jnp.mean(jnp.diag(Sigma)))))

        def log_normal(x, mean, cov):

            dx = x - mean
            # L = jnp.linalg.cholesky(cov)
            # dx = solve_triangular(L, dx, lower=True)
            L = jnp.sqrt(jnp.diag(cov))
            dx = dx / L
            return -0.5 * x.size * jnp.log(2. * jnp.pi) - jnp.sum(jnp.log(jnp.diag(L))) \
                   - 0.5 * dx @ dx

        def log_likelihood(param, **kwargs):
            Y_model = self.forward_model(param, *control_params)
            # Sigma = uncert**2 * jnp.eye(Y.shape[-1])
            return log_normal(Y_model, Y, Sigma)

        ns = NestedSampler(log_likelihood,
                           prior_chain,
                           sampler_name='whitened_ellipsoid')
        results = ns(key,
                     self._phase_basis_size * 15,
                     max_samples=1e5,
                     collect_samples=False,
                     termination_frac=0.01,
                     stoachastic_uncertainty=True)

        post_mu = results.param_mean['param']
        post_Gamma = results.param_covariance['param']

        return post_mu, post_Gamma
コード例 #6
0
ファイル: tests.py プロジェクト: fehiepsi/jaxns
def test_unit_cube_mixture_prior():
    import jax.numpy as jnp
    from jax import random
    from jaxns.nested_sampling import NestedSampler
    from jaxns.plotting import plot_cornerplot, plot_diagnostics

    # prior_chain = PriorChain().push(MultiCubeMixturePrior('x', 2, 1, -5., 15.))
    prior_chain = PriorChain().push(GMMMarginalPrior('x', 2, -5., 15.))

    def loglikelihood(x, **kwargs):
        return jnp.log(
            0.5 * jnp.exp(-0.5 * jnp.sum(x)**2) / jnp.sqrt(2. * jnp.pi) +
            0.5 * jnp.exp(-0.5 * jnp.sum(x - 10.)**2) / jnp.sqrt(2. * jnp.pi))

    ns = NestedSampler(loglikelihood, prior_chain, sampler_name='ellipsoid')
    results = ns(random.PRNGKey(0),
                 100,
                 max_samples=1e5,
                 collect_samples=True,
                 termination_frac=0.05,
                 stoachastic_uncertainty=True)
    plot_diagnostics(results)
    plot_cornerplot(results)
コード例 #7
0
        def run(key):
            prior_transform = PriorChain().push(
                MVNDiagPrior('x', prior_mu, jnp.sqrt(jnp.diag(prior_cov))))

            # prior_transform = LaplacePrior(prior_mu, jnp.sqrt(jnp.diag(prior_cov)))
            # prior_transform = UniformPrior(-20.*jnp.ones(ndims), 20.*jnp.ones(ndims))
            def param_mean(x, **args):
                return x

            def param_covariance(x, **args):
                return jnp.outer(x, x)

            ns = NestedSampler(log_likelihood,
                               prior_transform,
                               sampler_name='slice',
                               x_mean=param_mean,
                               x_cov=param_covariance)
            return ns(key=key,
                      num_live_points=n,
                      max_samples=1e5,
                      collect_samples=True,
                      termination_frac=0.01,
                      stoachastic_uncertainty=False,
                      sampler_kwargs=dict(depth=3, num_slices=2))
コード例 #8
0
def main():
    Sigma, T, Y_obs, amp, tec, freqs = generate_data()
    TEC_CONV = -8.4479745e6  # mTECU/Hz

    def log_mvnormal(x, mean, cov):
        L = jnp.linalg.cholesky(cov)
        dx = x - mean
        dx = solve_triangular(L, dx, lower=True)
        return -0.5 * x.size * jnp.log(2. * jnp.pi) - jnp.sum(jnp.log(jnp.diag(L))) \
               - 0.5 * dx @ dx

    def log_normal(x, mean, uncert):
        dx = (x - mean)/uncert
        return -0.5 * x.size * jnp.log(2. * jnp.pi) - x.size * jnp.log(uncert) \
               - 0.5 * dx @ dx

    def log_likelihood(tec, uncert, **kwargs):
        # tec = x[0]  # [:, 0]
        # uncert = x[1]  # [:, 1]
        # clock = x[2] * 1e-9
        # uncert = 0.25#x[2]
        phase = tec * (TEC_CONV / freqs)  # + clock *(jnp.pi*2)*freqs#+ clock
        Y = jnp.concatenate([jnp.cos(phase), jnp.sin(phase)], axis=-1)
        return jnp.sum(vmap(lambda Y, Y_obs: log_normal(Y, Y_obs, uncert))(Y, Y_obs))

    # prior_transform = MVNDiagPrior(prior_mu, jnp.sqrt(jnp.diag(prior_cov)))
    # prior_transform = LaplacePrior(prior_mu, jnp.sqrt(jnp.diag(prior_cov)))
    prior_chain = PriorChain() \
        .push(DiagGaussianWalkPrior('tec', T, LaplacePrior('tec0', 0., 100.), UniformPrior('omega', 1, 15))) \
        .push(UniformPrior('uncert', 0.01, 0.5))

    ns = NestedSampler(log_likelihood, prior_chain, sampler_name='slice',
                       tec_mean=lambda tec,**kwargs: tec)

    @jit
    def run(key):
        return ns(key=key,
                  num_live_points=500,
                  max_samples=1e5,
                  collect_samples=True,
                  termination_frac=0.01,
                  stoachastic_uncertainty=False,
                  sampler_kwargs=dict(depth=7))

    # with disable_jit():
    t0 = default_timer()
    results = run(random.PRNGKey(0))
    print("Time with compile efficiency normalised", results.efficiency * (default_timer() - t0))
    print("Time with compile", default_timer() - t0)
    t0 = default_timer()
    results = run(random.PRNGKey(1))
    print("Time no compile efficiency normalised", results.efficiency * (default_timer() - t0))
    print("Time no compile", default_timer() - t0)


    plt.plot(tec)
    plt.plot(results.marginalised['tec_mean'])
    plt.show()
    plt.plot(results.marginalised['tec_mean'][:,0]-tec)
    plt.show()
    ###

    plot_diagnostics(results)
コード例 #9
0
ファイル: main.py プロジェクト: fehiepsi/jaxns
def main():

    nant, ndir = 5, 20
    uncert = 0.1

    theta_true, gamma_true, y, y_obs = fake_vis(nant, ndir, uncert)

    def log_likelihood(delta, **kwargs):
        dy = delta - y_obs
        r2 = jnp.sum(jnp.real(dy)**2) / uncert**2
        r2 = r2 + jnp.sum(jnp.imag(dy)**2) / uncert**2
        logL = -0.5 * r2 - jnp.log(2. * jnp.pi * uncert**2) * dy.size
        return logL

    prior_transform = build_prior(nant, ndir)

    ### MAP with  BFGS
    def constrain(U):
        return 0.05 + sigmoid(U) * 0.9

    def loss(U):
        U = constrain(U)
        return -log_likelihood(**prior_transform(U))

    print(loss(jnp.zeros(prior_transform.U_ndims)))

    @jit
    def do_minimisation():
        results = minimize(loss,
                           jnp.zeros(prior_transform.U_ndims),
                           method='BFGS',
                           options=dict(gtol=1e-10, line_search_maxiter=200))
        print(results.message)
        return prior_transform(constrain(results.x)), constrain(
            results.x), results.status

    results = do_minimisation()
    print('Status', results[2])
    print(results)
    plt.scatter(jnp.arange(nant * ndir), results[0]['theta'], label='inferred')
    plt.scatter(jnp.arange(nant * ndir), theta_true, label='true')
    plt.legend()
    plt.show()

    plt.scatter(jnp.arange(ndir), results[0]['gamma'], label='inferred')
    plt.scatter(jnp.arange(ndir), gamma_true, label='true')
    plt.legend()
    plt.show()
    return

    ns = NestedSampler(log_likelihood,
                       prior_transform,
                       sampler_name='multi_ellipsoid')

    def run_with_n(n):
        @jit
        def run(key):
            return ns(key=key,
                      num_live_points=n,
                      max_samples=1e5,
                      collect_samples=False,
                      termination_frac=0.01,
                      stoachastic_uncertainty=False,
                      sampler_kwargs=dict(depth=3))

        t0 = default_timer()
        results = run(random.PRNGKey(0))
        print(results.efficiency)
        print("Time to run including compile:", default_timer() - t0)
        print("Time efficiency normalised:",
              results.efficiency * (default_timer() - t0))
        return results

    for n in [100]:
        results = run_with_n(n)
        plt.scatter(n, results.logZ)
        plt.errorbar(n, results.logZ, yerr=results.logZerr)

    plt.show()

    # plot_diagnostics(results)
    plt.errorbar(jnp.arange(nant * ndir),
                 results.param_mean['theta'],
                 yerr=jnp.sqrt(jnp.diag(results.param_covariance['theta'])),
                 label='inferred')
    plt.scatter(jnp.arange(nant * ndir), theta_true, label='true')
    plt.legend()
    plt.show()

    plt.errorbar(jnp.arange(ndir),
                 results.param_mean['gamma'],
                 yerr=jnp.sqrt(jnp.diag(results.param_covariance['gamma'])),
                 label='inferred')
    plt.scatter(jnp.arange(ndir), gamma_true, label='true')
    plt.legend()
    plt.show()
コード例 #10
0
ファイル: main.py プロジェクト: fehiepsi/jaxns
def main(kernel):
    def log_normal(x, mean, cov):
        L = jnp.linalg.cholesky(cov)
        dx = x - mean

        dx = solve_triangular(L, dx, lower=True)
        # maha = dx @ jnp.linalg.solve(cov, dx)
        maha = dx @ dx
        # logdet = jnp.log(jnp.linalg.det(cov))
        logdet = jnp.sum(jnp.log(jnp.diag(L)))
        log_prob = -0.5 * x.size * jnp.log(2. * jnp.pi) - logdet - 0.5 * maha
        return log_prob

    true_height, true_width, true_sigma, true_l, true_uncert = 200., 100., 1., 10., 2.5
    nant = 5
    ndir = 5
    X, Y, Y_obs = rbf_dtec(nant, ndir, true_height, true_width, true_sigma, true_l, true_uncert)
    a = X[:, 0:3]
    k = X[:, 3:6]
    x0 = a[0, :]

    def log_likelihood(dtec, uncert, **kwargs):
        """
        P(Y|sigma, half_width) = N[Y, mu, K]
        Args:
            sigma:
            l:

        Returns:

        """
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        return log_normal(Y_obs, dtec, data_cov)

    def predict_f(dtec, uncert, **kwargs):
        return dtec

    def predict_fvar(dtec, uncert, **kwargs):
        return dtec ** 2

    def tec_to_dtec(tec):
        tec = tec.reshape((nant, ndir))
        dtec = jnp.reshape(tec - tec[0, :], (-1,))
        return dtec

    prior_chain = build_prior(X, kernel, tec_to_dtec, x0)
    print(prior_chain)

    U_test = jnp.array([random.uniform(key, shape=(prior_chain.U_ndims,)) for key in random.split(random.PRNGKey(4325),1000)])
    log_lik = jnp.array([log_likelihood(**prior_chain(U)) for U in U_test])
    print(jnp.sum(jnp.isnan(log_lik)))
    print(U_test[jnp.isnan(log_lik)])
    ns = NestedSampler(log_likelihood, prior_chain, sampler_name='slice', predict_f=predict_f,
                       predict_fvar=predict_fvar)

    def run_with_n(n):
        @jit
        def run(key):
            return ns(key=key,
                      num_live_points=n,
                      max_samples=1e5,
                      collect_samples=True,
                      termination_frac=0.01,
                      stoachastic_uncertainty=False,
                      sampler_kwargs=dict(depth=7, num_slices=1))

        t0 = default_timer()
        results = run(random.PRNGKey(0))
        print("Efficiency", results.efficiency)
        print("Time to run (including compile)", default_timer() - t0)
        t0 = default_timer()
        results = run(random.PRNGKey(1))
        print("Efficiency",results.efficiency)
        print("Time to run (no compile)", default_timer() - t0)
        print("Time efficiency normalised", (default_timer() - t0)*results.efficiency)
        return results

    for n in [1000]:
        results = run_with_n(n)
        plt.scatter(n, results.logZ)
        plt.errorbar(n, results.logZ, yerr=results.logZerr)

    # #
    # K = GaussianProcessKernelPrior('K',
    #                                TomographicKernel(x0, kernel, S=20), X,
    #                                MVNPrior('height', results.param_mean['height'], results.param_covariance['height']),#UniformPrior('height', 100., 300.),
    #                                MVNPrior('width', results.param_mean['width'], results.param_covariance['width']),#UniformPrior('width', 50., 150.),
    #                                MVNPrior('l', results.param_mean['l'], results.param_covariance['l']),#UniformPrior('l', 7., 20.),
    #                                MVNPrior('sigma', results.param_mean['sigma'], results.param_covariance['sigma']),#UniformPrior('sigma', 0.3, 2.),
    #                                tracked=False)
    # tec = MVNPrior('tec', jnp.zeros((X.shape[0],)), K, ill_cond=True, tracked=False)
    # dtec = DeterministicTransformPrior('dtec', tec_to_dtec, tec.to_shape, tec, tracked=False)
    # prior_chain = PriorChain() \
    #     .push(dtec) \
    #     .push(UniformPrior('uncert', 0., 5.))
    #
    # ns = NestedSampler(log_likelihood, prior_chain, sampler_name='multi_ellipsoid', predict_f=predict_f,
    #                    predict_fvar=predict_fvar)
    #
    # def run_with_n(n):
    #     @jit
    #     def run(key):
    #         return ns(key=key,
    #                   num_live_points=n,
    #                   max_samples=1e5,
    #                   collect_samples=True,
    #                   termination_frac=0.01,
    #                   stoachastic_uncertainty=False,
    #                   sampler_kwargs=dict(depth=4))
    #
    #     t0 = default_timer()
    #     results = run(random.PRNGKey(0))
    #     print("Efficiency", results.efficiency)
    #     print("Time to run (including compile)", default_timer() - t0)
    #     t0 = default_timer()
    #     results = run(random.PRNGKey(1))
    #     print(results.efficiency)
    #     print("Time to run (no compile)", default_timer() - t0)
    #     return results
    #
    # for n in [100]:
    #     results = run_with_n(n)
    #     plt.scatter(n, results.logZ)
    #     plt.errorbar(n, results.logZ, yerr=results.logZerr)


    plt.title("Kernel: {}".format(kernel.__class__.__name__))
    plt.ylabel('log Z')
    plt.show()

    fstd = jnp.sqrt(results.marginalised['predict_fvar'] - results.marginalised['predict_f'] ** 2)
    plt.scatter(jnp.arange(Y.size),Y_obs, marker='+', label='data')
    plt.scatter(jnp.arange(Y.size),Y, marker="o", label='underlying')
    plt.scatter(jnp.arange(Y.size), results.marginalised['predict_f'], marker=".", label='underlying')
    plt.errorbar(jnp.arange(Y.size), results.marginalised['predict_f'], yerr=fstd, label='marginalised')
    plt.title("Kernel: {}".format(kernel.__class__.__name__))
    plt.legend()
    plt.show()

    # plot_samples_development(results,save_name='./ray_integral_solution.mp4')
    plot_diagnostics(results)
    plot_cornerplot(results)
    return results.logZ, results.logZerr
コード例 #11
0
ファイル: gp_marginalisation.py プロジェクト: fehiepsi/jaxns
def main(kernel):
    print(("Working on Kernel: {}".format(kernel.__class__.__name__)))

    def log_normal(x, mean, cov):
        L = jnp.linalg.cholesky(cov)
        # U, S, Vh = jnp.linalg.svd(cov)
        log_det = jnp.sum(jnp.log(jnp.diag(L)))  # jnp.sum(jnp.log(S))#
        dx = x - mean
        dx = solve_triangular(L, dx, lower=True)
        # U S Vh V 1/S Uh
        # pinv = (Vh.T.conj() * jnp.where(S!=0., jnp.reciprocal(S), 0.)) @ U.T.conj()
        maha = dx @ dx  # dx @ pinv @ dx#solve_triangular(L, dx, lower=True)
        log_likelihood = -0.5 * x.size * jnp.log(2. * jnp.pi) \
                         - log_det \
                         - 0.5 * maha
        # print(log_likelihood)
        return log_likelihood

    N = 100
    X = jnp.linspace(-2., 2., N)[:, None]
    true_sigma, true_l, true_uncert = 1., 0.2, 0.2
    data_mu = jnp.zeros((N, ))
    prior_cov = RBF()(X, X, true_l, true_sigma) + 1e-13 * jnp.eye(N)
    # print(jnp.linalg.cholesky(prior_cov), jnp.linalg.eigvals(prior_cov))
    # return
    Y = jnp.linalg.cholesky(prior_cov) @ random.normal(random.PRNGKey(0),
                                                       shape=(N, )) + data_mu
    Y_obs = Y + true_uncert * random.normal(random.PRNGKey(1), shape=(N, ))
    Y_obs = jnp.where((jnp.arange(N) > 50) & (jnp.arange(N) < 60),
                      random.normal(random.PRNGKey(1), shape=(N, )), Y_obs)

    # plt.scatter(X[:, 0], Y_obs, label='data')
    # plt.plot(X[:, 0], Y, label='underlying')
    # plt.legend()
    # plt.show()

    def log_likelihood(K, uncert, **kwargs):
        """
        P(Y|sigma, half_width) = N[Y, mu, K]
        Args:
            sigma:
            l:

        Returns:

        """
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        mu = jnp.zeros_like(Y_obs)
        return log_normal(Y_obs, mu, K + data_cov)

    def predict_f(K, uncert, **kwargs):
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        mu = jnp.zeros_like(Y_obs)
        return mu + K @ jnp.linalg.solve(K + data_cov, Y_obs)

    def predict_fvar(K, uncert, **kwargs):
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        mu = jnp.zeros_like(Y_obs)
        return jnp.diag(K - K @ jnp.linalg.solve(K + data_cov, K))

    l = UniformPrior('l', 0., 2.)
    uncert = UniformPrior('uncert', 0., 2.)
    sigma = UniformPrior('sigma', 0., 2.)
    cov = GaussianProcessKernelPrior('K', kernel, X, l, sigma)
    prior_chain = PriorChain().push(uncert).push(cov)
    # print(prior_chain)

    ns = NestedSampler(log_likelihood,
                       prior_chain,
                       sampler_name='multi_ellipsoid',
                       predict_f=predict_f,
                       predict_fvar=predict_fvar)

    def run_with_n(n):
        @jit
        def run(key):
            return ns(key=key,
                      num_live_points=n,
                      max_samples=1e5,
                      collect_samples=True,
                      termination_frac=0.01,
                      stoachastic_uncertainty=False,
                      sampler_kwargs=dict(depth=3))

        t0 = default_timer()
        # with disable_jit():
        results = run(random.PRNGKey(6))
        print(results.efficiency)
        print(
            "Time to execute (including compile): {}".format(default_timer() -
                                                             t0))
        t0 = default_timer()
        results = run(random.PRNGKey(6))
        print(results.efficiency)
        print("Time to execute (not including compile): {}".format(
            (default_timer() - t0)))
        return results

    for n in [100]:
        results = run_with_n(n)
        plt.scatter(n, results.logZ)
        plt.errorbar(n, results.logZ, yerr=results.logZerr)
    plt.title("Kernel: {}".format(kernel.__class__.__name__))
    plt.ylabel('log Z')
    plt.show()

    plt.scatter(X[:, 0], Y_obs, label='data')
    plt.plot(X[:, 0], Y, label='underlying')
    plt.plot(X[:, 0], results.marginalised['predict_f'], label='marginalised')
    plt.plot(X[:, 0],
             results.marginalised['predict_f'] +
             jnp.sqrt(results.marginalised['predict_fvar']),
             ls='dotted',
             c='black')
    plt.plot(X[:, 0],
             results.marginalised['predict_f'] -
             jnp.sqrt(results.marginalised['predict_fvar']),
             ls='dotted',
             c='black')
    plt.title("Kernel: {}".format(kernel.__class__.__name__))
    plt.legend()
    plt.show()
    plot_diagnostics(results)
    plot_cornerplot(results)
    return results.logZ, results.logZerr
コード例 #12
0
def unconstrained_solve(freqs, key, phase_obs, phase_outliers):
    key1, key2, key3, key4 = random.split(key, 4)
    Nt, Nf = phase_obs.shape
    assert Nt == 2, "Observations should be consequentive pairs of 2"

    tec0_array = jnp.linspace(-300., 300., 30)
    dtec_array = jnp.linspace(30., 30., 30)
    const_array = jnp.linspace(-jnp.pi, jnp.pi, 10)
    uncert0_array = jnp.linspace(0., 1., 10)
    uncert1_array = jnp.linspace(0., 1., 10)

    def log_likelihood(tec0, dtec, const, uncert0, uncert1, **kwargs):
        tec = jnp.asarray([tec0, tec0 + dtec])
        t = freqs - jnp.min(freqs)
        t /= t[-1]
        uncert = uncert0 + (uncert1 - uncert0) * t
        phase = tec[:, None] * (TEC_CONV / freqs) + const  # 2,Nf
        logL = jnp.sum(
            jnp.where(
                phase_outliers, 0.,
                log_normal(wrap(wrap(phase) - wrap(phase_obs)), 0., uncert)))
        return logL

    # X = make_coord_array(tec0_array[:, None], dtec_array[:, None], const_array[:, None], uncert0_array[:, None], uncert1_array[:, None], flat=True)
    # log_prob_array = vmap(lambda x: log_likelihood(x[0], x[1], x[2], x[3], x[4]))(X)
    # log_prob_array = log_prob_array.reshape((tec0_array.size, dtec_array.size, const_array.size, uncert0_array.size, uncert1_array.size))
    #
    # lookup_func = build_lookup_index(tec0_array, dtec_array, const_array, uncert0_array, uncert1_array)
    #
    # def efficient_log_likelihood(tec0, dtec, const, uncert0, uncert1, **kwargs):
    #     b = 0.5
    #     log_prob_uncert0 = - uncert0 / b - jnp.log(b)
    #     log_prob_uncert1 = - uncert1 / b - jnp.log(b)
    #     return lookup_func(log_prob_array, tec0, dtec, const, uncert0, uncert1) + log_prob_uncert0 + log_prob_uncert1

    tec0 = UniformPrior('tec0', tec0_array.min(), tec0_array.max())
    # 30mTECU/30seconds is the maximum change
    dtec = UniformPrior('dtec', dtec_array.min(), dtec_array.max())
    const = UniformPrior('const', const_array.min(), const_array.max())
    uncert0 = UniformPrior('uncert0', uncert0_array.min(), uncert0_array.max())
    uncert1 = UniformPrior('uncert1', uncert1_array.min(), uncert1_array.max())
    prior_chain = PriorChain(tec0, dtec, const, uncert0, uncert1)

    ns = NestedSampler(log_likelihood,
                       prior_chain,
                       sampler_name='slice',
                       num_live_points=20 * prior_chain.U_ndims,
                       sampler_kwargs=dict(num_slices=prior_chain.U_ndims * 4))

    results = ns(key=key1, termination_evidence_frac=0.3)

    ESS = 900  # emperically estimated for this problem

    def marginalisation(tec0, dtec, const, uncert0, uncert1, **kwargs):
        tec = jnp.asarray([tec0, tec0 + dtec])
        return tec, tec**2, jnp.cos(const), jnp.sin(const), 0.5 * (uncert0 +
                                                                   uncert1)

    tec_mean, tec2_mean, const_real, const_imag, uncert_mean = marginalise_static(
        key2, results.samples, results.log_p, ESS, marginalisation)

    tec_std = jnp.sqrt(tec2_mean - tec_mean**2)
    const_mean = jnp.arctan2(const_imag, const_real)

    def marginalisation(const, **kwargs):
        return wrap(wrap(const) - wrap(const_mean))**2

    const_var = marginalise_static(key2, results.samples, results.log_p, ESS,
                                   marginalisation)
    const_std = jnp.sqrt(const_var)

    return tec_mean, tec_std, const_mean * jnp.ones(Nt), const_std * jnp.ones(
        Nt), uncert_mean * jnp.ones(Nt)
コード例 #13
0
def main(kernel):
    def log_normal(x, mean, cov):
        L = jnp.linalg.cholesky(cov)
        dx = x - mean

        dx = solve_triangular(L, dx, lower=True)
        # maha = dx @ jnp.linalg.solve(cov, dx)
        maha = dx @ dx
        # logdet = jnp.log(jnp.linalg.det(cov))
        logdet = jnp.sum(jnp.log(jnp.diag(L)))
        log_prob = -0.5 * x.size * jnp.log(2. * jnp.pi) - logdet - 0.5 * maha
        return log_prob

    true_height, true_width, true_sigma, true_l, true_uncert, true_v = 200., 100., 1., 10., 2.5, jnp.array(
        [0.3, 0., 0.])
    nant = 2
    ndir = 1
    ntime = 20
    X, Y, Y_obs = rbf_dtec(nant, ndir, ntime, true_height, true_width,
                           true_sigma, true_l, true_uncert, true_v)
    a = X[:, 0:3]
    k = X[:, 3:6]
    t = X[:, 6:7]
    x0 = a[0, :]

    def log_likelihood(dtec, uncert, **kwargs):
        """
        P(Y|sigma, half_width) = N[Y, mu, K]
        Args:
            sigma:
            l:

        Returns:

        """
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        return log_normal(Y_obs, dtec, data_cov)

    def predict_f(dtec, **kwargs):
        return dtec

    def predict_fvar(dtec, **kwargs):
        return dtec**2

    def tec_to_dtec(tec):
        tec = tec.reshape((nant, ndir, ntime))
        dtec = jnp.reshape(tec - tec[0, :, :], (-1, ))
        return dtec

    prior_chain = build_frozen_flow_prior(X, kernel, tec_to_dtec, x0)

    ns = NestedSampler(log_likelihood,
                       prior_chain,
                       sampler_name='slice',
                       predict_f=predict_f,
                       predict_fvar=predict_fvar)

    def run_with_n(n):
        @jit
        def run(key):
            return ns(key=key,
                      num_live_points=n,
                      max_samples=1e5,
                      collect_samples=True,
                      termination_frac=0.01,
                      stoachastic_uncertainty=False,
                      sampler_kwargs=dict(depth=5, num_slices=1))

        t0 = default_timer()
        results = run(random.PRNGKey(0))
        print("Efficiency", results.efficiency)
        print("Time to run (including compile)", default_timer() - t0)
        # t0 = default_timer()
        # results = run(random.PRNGKey(1))
        # print(results.efficiency)
        # print("Time to run (no compile)", default_timer() - t0)
        return results

    for n in [100]:
        results = run_with_n(n)
        plt.scatter(n, results.logZ)
        plt.errorbar(n, results.logZ, yerr=results.logZerr)
    plt.title("Kernel: {}".format(kernel.__class__.__name__))
    plt.ylabel('log Z')
    plt.show()

    fstd = jnp.sqrt(results.marginalised['predict_fvar'] -
                    results.marginalised['predict_f']**2)
    plt.scatter(jnp.arange(Y.size), Y_obs, marker='+', label='data')
    plt.scatter(jnp.arange(Y.size), Y, marker="o", label='underlying')
    plt.scatter(jnp.arange(Y.size),
                results.marginalised['predict_f'],
                marker=".",
                label='underlying')
    plt.errorbar(jnp.arange(Y.size),
                 results.marginalised['predict_f'],
                 yerr=fstd,
                 label='marginalised')
    plt.title("Kernel: {}".format(kernel.__class__.__name__))
    plt.legend()
    plt.show()

    plot_diagnostics(results)
    plot_cornerplot(results)
    return results.logZ, results.logZerr
コード例 #14
0
def main():
    def log_normal(x, mean, uncert):
        dx = x - mean
        dx = dx / uncert
        return -0.5 * x.size * jnp.log(
            2. * jnp.pi) - x.size * jnp.log(uncert) - 0.5 * dx @ dx

    N = 100
    X = jnp.linspace(-2., 2., N)[:, None]
    true_alpha, true_sigma, true_l, true_uncert = 1., 1., 0.2, 0.25
    data_mu = jnp.zeros((N, ))
    prior_cov = RBF()(X, X, true_l, true_sigma)
    Y = jnp.linalg.cholesky(prior_cov) @ random.normal(random.PRNGKey(0),
                                                       shape=(N, )) + data_mu
    Y_obs = Y + true_uncert * random.normal(random.PRNGKey(1), shape=(N, ))

    def predict_f(sigma, K, uncert, **kwargs):
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        mu = jnp.zeros_like(Y_obs)
        return mu + K @ jnp.linalg.solve(K + data_cov, Y_obs)

    def predict_fvar(sigma, K, uncert, **kwargs):
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        mu = jnp.zeros_like(Y_obs)
        return jnp.diag(K - K @ jnp.linalg.solve(K + data_cov, K))

    ###
    # define the prior chain
    # Here we assume each image is represented by pixels.
    # Alternatively, you could choose regions arranged non-uniformly over the image.

    image_shape = (128, 128)
    npix = image_shape[0] * image_shape[1]
    I150 = jnp.ones(image_shape)

    alpha_cw_gp_sigma = HalfLaplacePrior('alpha_cw_gp_sigma', 1.)
    alpha_mw_gp_sigma = HalfLaplacePrior('alpha_mw_gp_sigma', 1.)
    l_cw = UniformPrior('l_cw', 0., 0.5)  #degrees
    l_mw = UniformPrior('l_mw', 0.5, 2.)  #degrees
    K_cw = GaussianProcessKernelPrior('K_cw', RBF(), X, l_cw,
                                      alpha_cw_gp_sigma)
    K_mw = GaussianProcessKernelPrior('K_mw', RBF(), X, l_mw,
                                      alpha_mw_gp_sigma)
    alpha_cw = MVNPrior('alpha_cw', -1.5, K_cw)
    alpha_mw = MVNPrior('alpha_mw', -2.5, K_mw)
    S_cw_150 = UniformPrior('S150_cw', 0., I150)
    S_mw_150 = UniformPrior('S150_mw', 0., I150)
    uncert = HalfLaplacePrior('uncert', 1.)

    def log_likelihood(uncert, alpha_cw, alpha_mw, S_cw_150, S_mw_150):
        log_prob = 0
        for img, freq in zip(images, freqs):  # <- need to define these
            I_total = S_mw_150 * (freq / 150e6)**(alpha_mw) + S_cw_150 * (
                freq / 150e6)**(alpha_cw)
            log_prob += log_normal(img, I_total, uncert)
        return log_prob

    prior_chain = PriorChain()\
        .push(alpha_cw).push(S_cw_150)\
        .push(alpha_mw).push(S_mw_150)\
        .push(uncert)
    print(prior_chain)

    ns = NestedSampler(log_likelihood,
                       prior_chain,
                       sampler_name='ellipsoid',
                       predict_f=predict_f,
                       predict_fvar=predict_fvar)

    def run_with_n(n):
        @jit
        def run():
            return ns(key=random.PRNGKey(0),
                      num_live_points=n,
                      max_samples=1e3,
                      collect_samples=True,
                      termination_frac=0.01,
                      stoachastic_uncertainty=True)

        results = run()
        return results
コード例 #15
0
def main():
    def log_normal(x, mean, cov):
        L = jnp.linalg.cholesky(cov)
        dx = x - mean
        dx = solve_triangular(L, dx, lower=True)
        return -0.5 * x.size * jnp.log(2. * jnp.pi) - jnp.sum(jnp.log(jnp.diag(L))) \
               - 0.5 * dx @ dx

    N = 100
    X = jnp.linspace(-2., 2., N)[:, None]
    true_alpha, true_sigma, true_l, true_uncert = 1., 1., 0.2, 0.25
    data_mu = jnp.zeros((N, ))
    prior_cov = RationalQuadratic()(X, X, true_l, true_sigma, true_alpha)
    Y = jnp.linalg.cholesky(prior_cov) @ random.normal(random.PRNGKey(0),
                                                       shape=(N, )) + data_mu
    Y_obs = Y + true_uncert * random.normal(random.PRNGKey(1), shape=(N, ))

    # Y_obs = jnp.where((jnp.arange(N) > 50) & (jnp.arange(N) < 60),
    #                   random.normal(random.PRNGKey(1), shape_dict=(N, )),
    #                   Y_obs)

    # plt.scatter(X[:, 0], Y_obs, label='data')
    # plt.plot(X[:, 0], Y, label='underlying')
    # plt.legend()
    # plt.show()

    def log_likelihood(K, uncert, **kwargs):
        """
        P(Y|sigma, half_width) = N[Y, mu, K]
        Args:
            sigma:
            l:

        Returns:

        """
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        mu = jnp.zeros_like(Y_obs)
        log_prob = log_normal(Y_obs, mu, K + data_cov)
        # print(log_prob)
        return log_prob

    def predict_f(K, uncert, **kwargs):
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        mu = jnp.zeros_like(Y_obs)
        return mu + K @ jnp.linalg.solve(K + data_cov, Y_obs)

    def predict_fvar(K, uncert, **kwargs):
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        mu = jnp.zeros_like(Y_obs)
        return jnp.diag(K - K @ jnp.linalg.solve(K + data_cov, K))

    prior_chain = PriorChain() \
        .push(GaussianProcessKernelPrior('K', RationalQuadratic(), X,
                                         UniformPrior('l', 0., 4.),
                                         UniformPrior('sigma', 0., 4.),
                                         UniformPrior('alpha', 0., 4.))) \
        .push(UniformPrior('uncert', 0., 2.))

    ns = NestedSampler(log_likelihood,
                       prior_chain,
                       sampler_name='multi_ellipsoid',
                       predict_f=predict_f,
                       predict_fvar=predict_fvar)

    def run_with_n(n):
        @jit
        def run():
            return ns(key=random.PRNGKey(0),
                      num_live_points=n,
                      max_samples=1e4,
                      collect_samples=True,
                      termination_frac=0.01,
                      stoachastic_uncertainty=False,
                      sampler_kwargs=dict(depth=4))

        results = run()
        return results

    for n in [200]:
        results = run_with_n(n)
        plt.scatter(n, results.logZ)
        plt.errorbar(n, results.logZ, yerr=results.logZerr)
    plt.title("Kernel: {}".format(RationalQuadratic.__name__))
    plt.ylabel('log Z')
    plt.show()

    plt.scatter(X[:, 0], Y_obs, label='data')
    plt.plot(X[:, 0], Y, label='underlying')
    plt.plot(X[:, 0], results.marginalised['predict_f'], label='marginalised')
    plt.plot(X[:, 0],
             results.marginalised['predict_f'] +
             jnp.sqrt(results.marginalised['predict_fvar']),
             ls='dotted',
             c='black')
    plt.plot(X[:, 0],
             results.marginalised['predict_f'] -
             jnp.sqrt(results.marginalised['predict_fvar']),
             ls='dotted',
             c='black')
    plt.title("Kernel: {}".format(RationalQuadratic.__name__))
    plt.legend()
    plt.show()

    plot_diagnostics(results)
    plot_cornerplot(results)
    return results.logZ, results.logZerr