Example #1
0
def test_forced_identifiability_prior():
    from jax import random
    prior = PriorChain().push(ForcedIdentifiabilityPrior('x', 10, 0., 10.))
    for i in range(10):
        out = prior(random.uniform(random.PRNGKey(i), shape=(prior.U_ndims, )))
        assert jnp.all(jnp.sort(out['x'], axis=0) == out['x'])
        assert jnp.all((out['x'] >= 0.) & (out['x'] <= 10.))
    prior = PriorChain().push(
        ForcedIdentifiabilityPrior('x', 10, jnp.array([0., 0.]), 10.))
    for i in range(10):
        out = prior(random.uniform(random.PRNGKey(i), shape=(prior.U_ndims, )))
        assert out['x'].shape == (10, 2)
        assert jnp.all(jnp.sort(out['x'], axis=0) == out['x'])
        assert jnp.all((out['x'] >= 0.) & (out['x'] <= 10.))
Example #2
0
def build_prior(X, kernel, tec_to_dtec, x0):
    K = GaussianProcessKernelPrior('K',
                                   TomographicKernel(x0,
                                                     kernel,
                                                     S_marg=100,
                                                     S_gamma=10),
                                   X,
                                   UniformPrior('height', 100., 300.),
                                   UniformPrior('width', 50., 150.),
                                   UniformPrior('l', 7., 20.),
                                   UniformPrior('sigma', 0.3, 2.),
                                   tracked=False)
    tec = MVNPrior('tec',
                   jnp.zeros((X.shape[0], )),
                   K,
                   ill_cond=False,
                   tracked=False)
    dtec = DeterministicTransformPrior('dtec',
                                       tec_to_dtec,
                                       tec.to_shape,
                                       tec,
                                       tracked=False)
    prior_chain = PriorChain() \
        .push(dtec) \
        .push(UniformPrior('uncert', 0., 5.))
    return prior_chain
def main():
    Y_obs, amp, tec, freqs = generate_data()
    TEC_CONV = -8.4479745e6  # mTECU/Hz

    def log_normal(x, mean, scale):
        dx = (x - mean) / scale
        return -0.5 * x.size * jnp.log(2. * jnp.pi) - x.size*jnp.log(scale) \
               - 0.5 * dx @ dx

    def log_laplace(x, mean, scale):
        dx = jnp.abs(x - mean) / scale
        return -x.size * jnp.log(2. * scale) - jnp.sum(dx)

    def log_likelihood(tec, const, uncert, **kwargs):
        phase = tec * (TEC_CONV / freqs) + const
        Y = jnp.concatenate([amp * jnp.cos(phase), amp * jnp.sin(phase)],
                            axis=-1)
        log_prob = log_laplace(Y, Y_obs, uncert[0])
        return log_prob

    prior_chain = PriorChain() \
        .push(UniformPrior('tec', -100., 100.)) \
        .push(UniformPrior('const', -jnp.pi, jnp.pi)) \
        .push(HalfLaplacePrior('uncert', 0.25))

    print("Probabilistic model:\n{}".format(prior_chain))

    ns = NestedSampler(
        log_likelihood,
        prior_chain,
        sampler_name='slice',
        tec_mean=lambda tec, **kw:
        tec,  #I would like to this function over the posterior
        const_mean=lambda const, **kw:
        const  #I would like to this function over the posterior
    )

    run = jit(lambda key: ns(key=key,
                             num_live_points=1000,
                             max_samples=1e5,
                             collect_samples=True,
                             termination_frac=0.01,
                             stoachastic_uncertainty=False,
                             sampler_kwargs=dict(depth=4, num_slices=1)))

    t0 = default_timer()
    results = run(random.PRNGKey(2364))
    print(results.efficiency)
    print("Time compile", default_timer() - t0)

    t0 = default_timer()
    results = run(random.PRNGKey(1324))
    print(results.efficiency)
    print("Time no compile", default_timer() - t0)

    ###
    print(results.marginalised['tec_mean'])
    print(results.marginalised['const_mean'])
    plot_diagnostics(results)
    plot_cornerplot(results)
Example #4
0
def build_prior(X, kernel, tec_to_dtec, x0, tec_conv):
    K = GaussianProcessKernelPrior('K',
                                   TomographicKernel(x0,
                                                     kernel,
                                                     S_marg=100,
                                                     S_gamma=10),
                                   X,
                                   UniformPrior('height', 100., 300.),
                                   UniformPrior('width', 50., 150.),
                                   UniformPrior('l', 7., 20.),
                                   UniformPrior('sigma', 0.3, 2.),
                                   tracked=False)
    tec = MVNPrior('tec',
                   jnp.zeros((X.shape[0], )),
                   K,
                   ill_cond=True,
                   tracked=False)
    dtec = DeterministicTransformPrior('dtec',
                                       tec_to_dtec,
                                       tec.to_shape,
                                       tec,
                                       tracked=False)
    Y = DeterministicTransformPrior('Y',
                                    lambda dtec: jnp.concatenate([
                                        jnp.cos(dtec[:, None] * tec_conv),
                                        jnp.sin(dtec[:, None] * tec_conv)
                                    ],
                                                                 axis=-1),
                                    dtec.to_shape + (tec_conv.size * 2, ),
                                    dtec,
                                    tracked=False)
    prior_chain = PriorChain() \
        .push(Y) \
        .push(UniformPrior('uncert', 0.01, 1.))
    return prior_chain
    def run_block(key, dtec, dtec_uncert, log_prob):
        key1, key2 = random.split(key, 2)

        def log_likelihood(lengthscale, sigma, **kwargs):
            # K = kernel(X, X, lengthscale, sigma)
            # def _compute(dtec, dtec_uncert):
            #     #each [Nd]
            #     return log_normal_with_outliers(dtec, 0., K, jnp.maximum(1e-6, dtec_uncert))
            # return chunked_pmap(_compute, dtec, dtec_uncert, chunksize=1).sum()
            return lookup_func(log_prob, lengthscale, sigma)

        lengthscale = UniformPrior('lengthscale', jnp.min(lengthscale_array),
                                   jnp.max(lengthscale_array))
        sigma = UniformPrior('sigma', sigma_array.min(), sigma_array.max())
        prior_chain = PriorChain(lengthscale, sigma)

        ns = NestedSampler(loglikelihood=log_likelihood,
                           prior_chain=prior_chain,
                           sampler_kwargs=dict(num_slices=prior_chain.U_ndims *
                                               1),
                           num_live_points=prior_chain.U_ndims * 50)
        ns = jit(ns)
        results = ns(key1, termination_evidence_frac=0.1)

        def marg_func(lengthscale, sigma, **kwargs):
            def screen(dtec, dtec_uncert, **kw):
                K = kernel(X, X, lengthscale, sigma)
                Kstar = kernel(X, Xstar, lengthscale, sigma)
                L = jnp.linalg.cholesky(
                    K / (dtec_uncert[:, None] * dtec_uncert[None, :]) +
                    jnp.eye(dtec.shape[0]))
                # L = jnp.where(jnp.isnan(L), jnp.eye(L.shape[0])/sigma, L)
                dx = solve_triangular(L, dtec / dtec_uncert, lower=True)
                JT = solve_triangular(L,
                                      Kstar / dtec_uncert[:, None],
                                      lower=True)
                #var_ik = JT_ji JT_jk
                mean = JT.T @ dx
                var = jnp.sum(JT * JT, axis=0)
                return mean, var

            return vmap(screen)(dtec, dtec_uncert), lengthscale, jnp.log(
                sigma
            )  #[time_block_size,  Nd_screen], [time_block_size,  Nd_screen]

        #[time_block_size,  Nd_screen], [time_block_size,  Nd_screen], [time_block_size]
        (mean, var), mean_lengthscale, mean_logsigma = marginalise_static(
            key2, results.samples, results.log_p, 500, marg_func)
        uncert = jnp.sqrt(var)
        mean_sigma = jnp.exp(mean_logsigma)
        mean_lengthscale = jnp.ones(time_block_size) * mean_lengthscale
        mean_sigma = jnp.ones(time_block_size) * mean_sigma
        ESS = results.ESS * jnp.ones(time_block_size)
        logZ = results.logZ * jnp.ones(time_block_size)
        likelihood_evals = results.num_likelihood_evaluations * jnp.ones(
            time_block_size)
        return mean, uncert, mean_lengthscale, mean_sigma, ESS, logZ, likelihood_evals
Example #6
0
def build_layered_prior(X, kernel, x0, tec_to_dtec):

    layer_edges = jnp.linspace(80., 500., int((500. - 80.) / 50.) + 1)
    layer_kernels = []
    for i in range(len(layer_edges) - 1):
        height = 0.5 * (layer_edges[i] + layer_edges[i + 1])
        width = layer_edges[i + 1] - layer_edges[i]
        #Efficiency 0.39664684771546416
        # Time to run (including compile) 246.36920081824064
        # 0.39198953960498245
        # Time to run (no compile) 130.1565416753292
        # Efficiency normalised time 51.020025508433804
        K = GaussianProcessKernelPrior('K{}'.format(i),
                                       TomographicKernel(x0,
                                                         kernel,
                                                         S_marg=100,
                                                         S_gamma=20),
                                       X,
                                       DeltaPrior('height{}'.format(i),
                                                  height,
                                                  tracked=False),
                                       DeltaPrior('width{}'.format(i),
                                                  width,
                                                  tracked=False),
                                       UniformPrior('l{}'.format(i),
                                                    7.,
                                                    20.,
                                                    tracked=False),
                                       UniformPrior('sigma{}'.format(i),
                                                    0.3,
                                                    2.,
                                                    tracked=False),
                                       tracked=False)
        layer_kernels.append(K)
    logits = jnp.zeros(len(layer_kernels))
    select = CategoricalPrior('j', logits, tracked=True)
    K = DeterministicTransformPrior(
        'K',
        lambda j, *K: jnp.stack(K, axis=0)[j[0], :, :],
        layer_kernels[0].to_shape,
        select,
        *layer_kernels,
        tracked=False)
    tec = MVNPrior('tec',
                   jnp.zeros((X.shape[0], )),
                   K,
                   ill_cond=True,
                   tracked=False)
    dtec = DeterministicTransformPrior('dtec',
                                       tec_to_dtec,
                                       tec.to_shape,
                                       tec,
                                       tracked=False)
    prior_chain = PriorChain() \
        .push(dtec) \
        .push(UniformPrior('uncert', 2., 3.))
    return prior_chain
Example #7
0
def main():

    ndims = 4
    sigma = 0.1

    def log_likelihood(theta, **kwargs):
        r2 = jnp.sum(theta**2)
        logL = -0.5 * jnp.log(2. * jnp.pi * sigma**2) * ndims
        logL += -0.5 * r2 / sigma**2
        return logL

    prior_transform = PriorChain().push(
        UniformPrior('theta', -jnp.ones(ndims), jnp.ones(ndims)))
    ns = NestedSampler(log_likelihood, prior_transform, sampler_name='slice')

    def run_with_n(n):
        @jit
        def run(key):
            return ns(key=key,
                      num_live_points=n,
                      max_samples=1e5,
                      collect_samples=True,
                      termination_frac=0.01,
                      stoachastic_uncertainty=False,
                      sampler_kwargs=dict(depth=3))

        t0 = default_timer()
        results = run(random.PRNGKey(0))
        print(results.efficiency)
        print("Time to run including compile:", default_timer() - t0)
        print("Time efficiency normalised:",
              results.efficiency * (default_timer() - t0))
        t0 = default_timer()
        results = run(random.PRNGKey(1))
        print(results.efficiency)
        print("Time to run no compile:", default_timer() - t0)
        print("Time efficiency normalised:",
              results.efficiency * (default_timer() - t0))
        return results

    for n in [1000]:
        results = run_with_n(n)
        plt.scatter(n, results.logZ)
        plt.errorbar(n, results.logZ, yerr=results.logZerr)

    plt.show()

    # plot_samples_development(results, save_name='./example.mp4')
    plot_diagnostics(results)
    plot_cornerplot(results)
Example #8
0
def main():
    def log_likelihood(theta, **kwargs):
        return 5. * (2. + jnp.prod(jnp.cos(0.5 * theta)))

    prior_chain = PriorChain() \
        .push(UniformPrior('theta', low=jnp.zeros(2), high=jnp.pi * 10. * jnp.ones(2)))

    theta = vmap(
        lambda key: prior_chain(random.uniform(key, (prior_chain.U_ndims, ))))(
            random.split(random.PRNGKey(0), 10000))
    lik = vmap(lambda theta: log_likelihood(**theta))(theta)
    sc = plt.scatter(theta['theta'][:, 0], theta['theta'][:, 1], c=lik)
    plt.colorbar(sc)
    plt.show()

    ns = NestedSampler(log_likelihood, prior_chain, sampler_name='slice')

    def run_with_n(n):
        @jit
        def run(key):
            return ns(key=key,
                      num_live_points=n,
                      max_samples=1e5,
                      collect_samples=True,
                      termination_frac=0.01,
                      stoachastic_uncertainty=False,
                      sampler_kwargs=dict(depth=7))

        t0 = default_timer()
        # with disable_jit():
        results = run(random.PRNGKey(0))
        print("Efficiency", results.efficiency)
        print("Time to run (including compile)", default_timer() - t0)
        t0 = default_timer()
        results = run(random.PRNGKey(1))
        print(results.efficiency)
        print("Time to run (no compile)", default_timer() - t0)
        return results

    for n in [500]:
        results = run_with_n(n)
        plt.scatter(n, results.logZ)
        plt.errorbar(n, results.logZ, yerr=results.logZerr)
    plt.ylabel('log Z')
    plt.show()

    plot_diagnostics(results)
    plot_cornerplot(results)
    return results.logZ, results.logZerr
    def run_jaxns(num_live_points):
        try:
            from jaxns.nested_sampling import NestedSampler
            from jaxns.prior_transforms import PriorChain, UniformPrior
        except:
            raise ImportError("Install JaxNS!")
        from timeit import default_timer
        from jax import random, jit
        import jax.numpy as jnp

        def log_likelihood(theta, **kwargs):
            r2 = jnp.sum(theta ** 2)
            logL = -0.5 * jnp.log(2. * jnp.pi * sigma ** 2) * ndims
            logL += -0.5 * r2 / sigma ** 2
            return logL

        prior_transform = PriorChain().push(UniformPrior('theta', -jnp.ones(ndims), jnp.ones(ndims)))
        ns = NestedSampler(log_likelihood, prior_transform, sampler_name='slice')

        def run_with_n(n):
            @jit
            def run(key):
                return ns(key=key,
                          num_live_points=n,
                          max_samples=1e6,
                          collect_samples=False,
                          termination_frac=0.01,
                          stoachastic_uncertainty=False,
                          sampler_kwargs=dict(depth=3, num_slices=2))

            results = run(random.PRNGKey(0))
            results.logZ.block_until_ready()
            t0 = default_timer()
            results = run(random.PRNGKey(1))
            print("Efficiency and logZ", results.efficiency, results.logZ)
            run_time = (default_timer() - t0)
            return run_time

        return run_with_n(num_live_points)
Example #10
0
def main():
    def log_likelihood(theta, **kwargs):
        return (2. + jnp.prod(jnp.cos(0.5 * theta)))**5

    prior_chain = PriorChain() \
        .push(UniformPrior('theta', low=jnp.zeros(2), high=jnp.pi * 10. * jnp.ones(2)))

    U = vmap(lambda key: random.uniform(key, (prior_chain.U_ndims, )))(
        random.split(random.PRNGKey(0), 700))
    theta = vmap(lambda u: prior_chain(u))(U)
    lik = vmap(lambda theta: log_likelihood(**theta))(theta)

    select = lik > 150.
    print("Selecting", jnp.sum(select), "need", 18 * 3)
    log_VS = jnp.log(jnp.sum(select) / select.size)
    print("V(S)", jnp.exp(log_VS))

    U = U[select, :]

    with disable_jit():
        cluster_id, ellipsoid_parameters = \
            jit(lambda key, points, log_VS: ellipsoid_clustering(random.PRNGKey(0), points, 7, log_VS)
                )(random.PRNGKey(0), U, log_VS)
        mu, radii, rotation = ellipsoid_parameters

    theta = jnp.linspace(0., jnp.pi * 2, 100)
    x = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0)

    for i, (mu, radii, rotation) in enumerate(zip(mu, radii, rotation)):
        y = mu[:, None] + rotation @ jnp.diag(radii) @ x
        plt.plot(y[0, :], y[1, :])
        mask = cluster_id == i
        plt.scatter(U[mask, 0],
                    U[mask, 1],
                    c=jnp.atleast_2d(plt.cm.jet(i /
                                                len(ellipsoid_parameters))))

    plt.show()
    def E_update(self, prior_mu, prior_Gamma, Y, Sigma, *control_params):
        # amp = control_params[0]
        key = control_params[1]

        prior_chain = PriorChain() \
            .push(MVNPrior('param', prior_mu, prior_Gamma))

        # .push(HalfLaplacePrior('uncert', jnp.sqrt(jnp.mean(jnp.diag(Sigma)))))

        def log_normal(x, mean, cov):

            dx = x - mean
            # L = jnp.linalg.cholesky(cov)
            # dx = solve_triangular(L, dx, lower=True)
            L = jnp.sqrt(jnp.diag(cov))
            dx = dx / L
            return -0.5 * x.size * jnp.log(2. * jnp.pi) - jnp.sum(jnp.log(jnp.diag(L))) \
                   - 0.5 * dx @ dx

        def log_likelihood(param, **kwargs):
            Y_model = self.forward_model(param, *control_params)
            # Sigma = uncert**2 * jnp.eye(Y.shape[-1])
            return log_normal(Y_model, Y, Sigma)

        ns = NestedSampler(log_likelihood,
                           prior_chain,
                           sampler_name='whitened_ellipsoid')
        results = ns(key,
                     self._phase_basis_size * 15,
                     max_samples=1e5,
                     collect_samples=False,
                     termination_frac=0.01,
                     stoachastic_uncertainty=True)

        post_mu = results.param_mean['param']
        post_Gamma = results.param_covariance['param']

        return post_mu, post_Gamma
Example #12
0
def build_frozen_flow_prior(X, kernel, tec_to_dtec, x0):
    v_dir = DeterministicTransformPrior('v_dir', lambda n: n / jnp.linalg.norm(n), (3,),
                                        MVNDiagPrior('n', jnp.zeros(3), jnp.ones(3),
                                                     tracked=False), tracked=False)
    v_mag = UniformPrior('v_mag', 0., 0.5, tracked=False)
    v = DeterministicTransformPrior('v', lambda v_dir, v_mag: v_mag * v_dir,
                                    (3,), v_dir, v_mag, tracked=True)
    X_frozen_flow = DeterministicTransformPrior('X',
                                                lambda v: X[:, 0:6] - jnp.concatenate([v, jnp.zeros(3)]) * X[:, 6:7],
                                                X[:, 0:6].shape, v, tracked=False)
    K = GaussianProcessKernelPrior('K',
                                   TomographicKernel(x0, kernel, S_marg=20, S_gamma=10),
                                   X_frozen_flow,
                                   UniformPrior('height', 100., 300.),
                                   UniformPrior('width', 50., 150.),
                                   UniformPrior('l', 0., 20.),
                                   UniformPrior('sigma', 0., 2.), tracked=False)
    tec = MVNPrior('tec', jnp.zeros((X.shape[0],)), K, ill_cond=True, tracked=False)
    dtec = DeterministicTransformPrior('dtec', tec_to_dtec, tec.to_shape, tec, tracked=False)
    prior_chain = PriorChain() \
        .push(dtec) \
        .push(UniformPrior('uncert', 0., 5.))
    return prior_chain
Example #13
0
def test_unit_cube_mixture_prior():
    import jax.numpy as jnp
    from jax import random
    from jaxns.nested_sampling import NestedSampler
    from jaxns.plotting import plot_cornerplot, plot_diagnostics

    # prior_chain = PriorChain().push(MultiCubeMixturePrior('x', 2, 1, -5., 15.))
    prior_chain = PriorChain().push(GMMMarginalPrior('x', 2, -5., 15.))

    def loglikelihood(x, **kwargs):
        return jnp.log(
            0.5 * jnp.exp(-0.5 * jnp.sum(x)**2) / jnp.sqrt(2. * jnp.pi) +
            0.5 * jnp.exp(-0.5 * jnp.sum(x - 10.)**2) / jnp.sqrt(2. * jnp.pi))

    ns = NestedSampler(loglikelihood, prior_chain, sampler_name='ellipsoid')
    results = ns(random.PRNGKey(0),
                 100,
                 max_samples=1e5,
                 collect_samples=True,
                 termination_frac=0.05,
                 stoachastic_uncertainty=True)
    plot_diagnostics(results)
    plot_cornerplot(results)
Example #14
0
def build_prior(nant, ndir):
    theta = MVNDiagPrior('theta',
                         jnp.zeros(nant * ndir),
                         jnp.ones(nant * ndir),
                         tracked=True)
    gamma = MVNDiagPrior('gamma',
                         jnp.zeros(ndir),
                         0. * jnp.ones(ndir),
                         tracked=True)

    def vis(theta, gamma, **kwargs):
        theta = theta.reshape((nant, ndir))
        diff = 1j * (theta[:, None, :] - theta)
        delta = jnp.mean(jnp.exp(-gamma + diff), axis=-1)
        return delta

    delta = DeterministicTransformPrior('delta',
                                        vis, (nant, nant),
                                        theta,
                                        gamma,
                                        tracked=False)
    prior = PriorChain().push(delta)
    return prior
Example #15
0
        def run(key):
            prior_transform = PriorChain().push(
                MVNDiagPrior('x', prior_mu, jnp.sqrt(jnp.diag(prior_cov))))

            # prior_transform = LaplacePrior(prior_mu, jnp.sqrt(jnp.diag(prior_cov)))
            # prior_transform = UniformPrior(-20.*jnp.ones(ndims), 20.*jnp.ones(ndims))
            def param_mean(x, **args):
                return x

            def param_covariance(x, **args):
                return jnp.outer(x, x)

            ns = NestedSampler(log_likelihood,
                               prior_transform,
                               sampler_name='slice',
                               x_mean=param_mean,
                               x_cov=param_covariance)
            return ns(key=key,
                      num_live_points=n,
                      max_samples=1e5,
                      collect_samples=True,
                      termination_frac=0.01,
                      stoachastic_uncertainty=False,
                      sampler_kwargs=dict(depth=3, num_slices=2))
    def run_block(block_idx):
        def log_likelihood(bottom, width, lengthscale, sigma, **kwargs):
            return jnp.sum(
                vmap(lambda log_prob: lookup_func(log_prob, bottom, width,
                                                  lengthscale, sigma))(
                                                      log_prob[block_idx]))

        bottom = UniformPrior('bottom', bottom_array.min(), bottom_array.max())
        width = DeltaPrior('width', 50., tracked=False)
        lengthscale = UniformPrior('lengthscale', jnp.min(lengthscale_array),
                                   jnp.max(lengthscale_array))
        sigma = UniformPrior('sigma', sigma_array.min(), sigma_array.max())
        prior_chain = PriorChain(lengthscale, sigma, bottom, width)

        ns = NestedSampler(loglikelihood=log_likelihood,
                           prior_chain=prior_chain,
                           sampler_name='slice',
                           sampler_kwargs=dict(num_slices=prior_chain.U_ndims *
                                               5),
                           num_live_points=prior_chain.U_ndims * 50)
        ns = jit(ns)
        results = ns(random.PRNGKey(42), termination_frac=0.001)

        return results
Example #17
0
def test_half_laplace():
    p = PriorChain().push(HalfLaplacePrior('x', 1.))
    U = jnp.linspace(0., 1., 100)[:, None]
    assert ~jnp.any(jnp.isnan(vmap(p)(U)['x']))
Example #18
0
def test_prior_chain():
    from jax import random
    chain = PriorChain()
    mu = MVNDiagPrior('mu', jnp.array([0., 0.]), 1.)
    gamma = jnp.array([1.])
    X = MVNDiagPrior('x', mu, gamma)
    chain.push(mu).push(X)
    print(chain)
    U = random.uniform(random.PRNGKey(0), shape=(chain.U_ndims, ))
    y = chain(U)
    print(y)

    chain = PriorChain()
    mu = MVNDiagPrior('mu', jnp.array([0., 0.]), 1.)
    gamma = jnp.array([1.])
    X = LaplacePrior('x', mu, gamma)
    chain.push(mu).push(X)
    print(chain)
    U = random.uniform(random.PRNGKey(0), shape=(chain.U_ndims, ))
    y = chain(U)
    print(y)

    chain = PriorChain()
    x0 = MVNDiagPrior('x0', jnp.array([0., 0.]), 1.)
    gamma = 1.
    X = DiagGaussianWalkPrior('W', 2, x0, gamma)
    chain.push(mu).push(X)
    print(chain)
    U = random.uniform(random.PRNGKey(0), shape=(chain.U_ndims, ))
    y = chain(U)
    print(y)
Example #19
0
def main(kernel):
    print(("Working on Kernel: {}".format(kernel.__class__.__name__)))

    def log_normal(x, mean, cov):
        L = jnp.linalg.cholesky(cov)
        # U, S, Vh = jnp.linalg.svd(cov)
        log_det = jnp.sum(jnp.log(jnp.diag(L)))  # jnp.sum(jnp.log(S))#
        dx = x - mean
        dx = solve_triangular(L, dx, lower=True)
        # U S Vh V 1/S Uh
        # pinv = (Vh.T.conj() * jnp.where(S!=0., jnp.reciprocal(S), 0.)) @ U.T.conj()
        maha = dx @ dx  # dx @ pinv @ dx#solve_triangular(L, dx, lower=True)
        log_likelihood = -0.5 * x.size * jnp.log(2. * jnp.pi) \
                         - log_det \
                         - 0.5 * maha
        # print(log_likelihood)
        return log_likelihood

    N = 100
    X = jnp.linspace(-2., 2., N)[:, None]
    true_sigma, true_l, true_uncert = 1., 0.2, 0.2
    data_mu = jnp.zeros((N, ))
    prior_cov = RBF()(X, X, true_l, true_sigma) + 1e-13 * jnp.eye(N)
    # print(jnp.linalg.cholesky(prior_cov), jnp.linalg.eigvals(prior_cov))
    # return
    Y = jnp.linalg.cholesky(prior_cov) @ random.normal(random.PRNGKey(0),
                                                       shape=(N, )) + data_mu
    Y_obs = Y + true_uncert * random.normal(random.PRNGKey(1), shape=(N, ))
    Y_obs = jnp.where((jnp.arange(N) > 50) & (jnp.arange(N) < 60),
                      random.normal(random.PRNGKey(1), shape=(N, )), Y_obs)

    # plt.scatter(X[:, 0], Y_obs, label='data')
    # plt.plot(X[:, 0], Y, label='underlying')
    # plt.legend()
    # plt.show()

    def log_likelihood(K, uncert, **kwargs):
        """
        P(Y|sigma, half_width) = N[Y, mu, K]
        Args:
            sigma:
            l:

        Returns:

        """
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        mu = jnp.zeros_like(Y_obs)
        return log_normal(Y_obs, mu, K + data_cov)

    def predict_f(K, uncert, **kwargs):
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        mu = jnp.zeros_like(Y_obs)
        return mu + K @ jnp.linalg.solve(K + data_cov, Y_obs)

    def predict_fvar(K, uncert, **kwargs):
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        mu = jnp.zeros_like(Y_obs)
        return jnp.diag(K - K @ jnp.linalg.solve(K + data_cov, K))

    l = UniformPrior('l', 0., 2.)
    uncert = UniformPrior('uncert', 0., 2.)
    sigma = UniformPrior('sigma', 0., 2.)
    cov = GaussianProcessKernelPrior('K', kernel, X, l, sigma)
    prior_chain = PriorChain().push(uncert).push(cov)
    # print(prior_chain)

    ns = NestedSampler(log_likelihood,
                       prior_chain,
                       sampler_name='multi_ellipsoid',
                       predict_f=predict_f,
                       predict_fvar=predict_fvar)

    def run_with_n(n):
        @jit
        def run(key):
            return ns(key=key,
                      num_live_points=n,
                      max_samples=1e5,
                      collect_samples=True,
                      termination_frac=0.01,
                      stoachastic_uncertainty=False,
                      sampler_kwargs=dict(depth=3))

        t0 = default_timer()
        # with disable_jit():
        results = run(random.PRNGKey(6))
        print(results.efficiency)
        print(
            "Time to execute (including compile): {}".format(default_timer() -
                                                             t0))
        t0 = default_timer()
        results = run(random.PRNGKey(6))
        print(results.efficiency)
        print("Time to execute (not including compile): {}".format(
            (default_timer() - t0)))
        return results

    for n in [100]:
        results = run_with_n(n)
        plt.scatter(n, results.logZ)
        plt.errorbar(n, results.logZ, yerr=results.logZerr)
    plt.title("Kernel: {}".format(kernel.__class__.__name__))
    plt.ylabel('log Z')
    plt.show()

    plt.scatter(X[:, 0], Y_obs, label='data')
    plt.plot(X[:, 0], Y, label='underlying')
    plt.plot(X[:, 0], results.marginalised['predict_f'], label='marginalised')
    plt.plot(X[:, 0],
             results.marginalised['predict_f'] +
             jnp.sqrt(results.marginalised['predict_fvar']),
             ls='dotted',
             c='black')
    plt.plot(X[:, 0],
             results.marginalised['predict_f'] -
             jnp.sqrt(results.marginalised['predict_fvar']),
             ls='dotted',
             c='black')
    plt.title("Kernel: {}".format(kernel.__class__.__name__))
    plt.legend()
    plt.show()
    plot_diagnostics(results)
    plot_cornerplot(results)
    return results.logZ, results.logZerr
Example #20
0
def main():
    def log_normal(x, mean, uncert):
        dx = x - mean
        dx = dx / uncert
        return -0.5 * x.size * jnp.log(
            2. * jnp.pi) - x.size * jnp.log(uncert) - 0.5 * dx @ dx

    N = 100
    X = jnp.linspace(-2., 2., N)[:, None]
    true_alpha, true_sigma, true_l, true_uncert = 1., 1., 0.2, 0.25
    data_mu = jnp.zeros((N, ))
    prior_cov = RBF()(X, X, true_l, true_sigma)
    Y = jnp.linalg.cholesky(prior_cov) @ random.normal(random.PRNGKey(0),
                                                       shape=(N, )) + data_mu
    Y_obs = Y + true_uncert * random.normal(random.PRNGKey(1), shape=(N, ))

    def predict_f(sigma, K, uncert, **kwargs):
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        mu = jnp.zeros_like(Y_obs)
        return mu + K @ jnp.linalg.solve(K + data_cov, Y_obs)

    def predict_fvar(sigma, K, uncert, **kwargs):
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        mu = jnp.zeros_like(Y_obs)
        return jnp.diag(K - K @ jnp.linalg.solve(K + data_cov, K))

    ###
    # define the prior chain
    # Here we assume each image is represented by pixels.
    # Alternatively, you could choose regions arranged non-uniformly over the image.

    image_shape = (128, 128)
    npix = image_shape[0] * image_shape[1]
    I150 = jnp.ones(image_shape)

    alpha_cw_gp_sigma = HalfLaplacePrior('alpha_cw_gp_sigma', 1.)
    alpha_mw_gp_sigma = HalfLaplacePrior('alpha_mw_gp_sigma', 1.)
    l_cw = UniformPrior('l_cw', 0., 0.5)  #degrees
    l_mw = UniformPrior('l_mw', 0.5, 2.)  #degrees
    K_cw = GaussianProcessKernelPrior('K_cw', RBF(), X, l_cw,
                                      alpha_cw_gp_sigma)
    K_mw = GaussianProcessKernelPrior('K_mw', RBF(), X, l_mw,
                                      alpha_mw_gp_sigma)
    alpha_cw = MVNPrior('alpha_cw', -1.5, K_cw)
    alpha_mw = MVNPrior('alpha_mw', -2.5, K_mw)
    S_cw_150 = UniformPrior('S150_cw', 0., I150)
    S_mw_150 = UniformPrior('S150_mw', 0., I150)
    uncert = HalfLaplacePrior('uncert', 1.)

    def log_likelihood(uncert, alpha_cw, alpha_mw, S_cw_150, S_mw_150):
        log_prob = 0
        for img, freq in zip(images, freqs):  # <- need to define these
            I_total = S_mw_150 * (freq / 150e6)**(alpha_mw) + S_cw_150 * (
                freq / 150e6)**(alpha_cw)
            log_prob += log_normal(img, I_total, uncert)
        return log_prob

    prior_chain = PriorChain()\
        .push(alpha_cw).push(S_cw_150)\
        .push(alpha_mw).push(S_mw_150)\
        .push(uncert)
    print(prior_chain)

    ns = NestedSampler(log_likelihood,
                       prior_chain,
                       sampler_name='ellipsoid',
                       predict_f=predict_f,
                       predict_fvar=predict_fvar)

    def run_with_n(n):
        @jit
        def run():
            return ns(key=random.PRNGKey(0),
                      num_live_points=n,
                      max_samples=1e3,
                      collect_samples=True,
                      termination_frac=0.01,
                      stoachastic_uncertainty=True)

        results = run()
        return results
Example #21
0
def main():
    def log_normal(x, mean, cov):
        L = jnp.linalg.cholesky(cov)
        dx = x - mean
        dx = solve_triangular(L, dx, lower=True)
        return -0.5 * x.size * jnp.log(2. * jnp.pi) - jnp.sum(jnp.log(jnp.diag(L))) \
               - 0.5 * dx @ dx

    N = 100
    X = jnp.linspace(-2., 2., N)[:, None]
    true_alpha, true_sigma, true_l, true_uncert = 1., 1., 0.2, 0.25
    data_mu = jnp.zeros((N, ))
    prior_cov = RationalQuadratic()(X, X, true_l, true_sigma, true_alpha)
    Y = jnp.linalg.cholesky(prior_cov) @ random.normal(random.PRNGKey(0),
                                                       shape=(N, )) + data_mu
    Y_obs = Y + true_uncert * random.normal(random.PRNGKey(1), shape=(N, ))

    # Y_obs = jnp.where((jnp.arange(N) > 50) & (jnp.arange(N) < 60),
    #                   random.normal(random.PRNGKey(1), shape_dict=(N, )),
    #                   Y_obs)

    # plt.scatter(X[:, 0], Y_obs, label='data')
    # plt.plot(X[:, 0], Y, label='underlying')
    # plt.legend()
    # plt.show()

    def log_likelihood(K, uncert, **kwargs):
        """
        P(Y|sigma, half_width) = N[Y, mu, K]
        Args:
            sigma:
            l:

        Returns:

        """
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        mu = jnp.zeros_like(Y_obs)
        log_prob = log_normal(Y_obs, mu, K + data_cov)
        # print(log_prob)
        return log_prob

    def predict_f(K, uncert, **kwargs):
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        mu = jnp.zeros_like(Y_obs)
        return mu + K @ jnp.linalg.solve(K + data_cov, Y_obs)

    def predict_fvar(K, uncert, **kwargs):
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        mu = jnp.zeros_like(Y_obs)
        return jnp.diag(K - K @ jnp.linalg.solve(K + data_cov, K))

    prior_chain = PriorChain() \
        .push(GaussianProcessKernelPrior('K', RationalQuadratic(), X,
                                         UniformPrior('l', 0., 4.),
                                         UniformPrior('sigma', 0., 4.),
                                         UniformPrior('alpha', 0., 4.))) \
        .push(UniformPrior('uncert', 0., 2.))

    ns = NestedSampler(log_likelihood,
                       prior_chain,
                       sampler_name='multi_ellipsoid',
                       predict_f=predict_f,
                       predict_fvar=predict_fvar)

    def run_with_n(n):
        @jit
        def run():
            return ns(key=random.PRNGKey(0),
                      num_live_points=n,
                      max_samples=1e4,
                      collect_samples=True,
                      termination_frac=0.01,
                      stoachastic_uncertainty=False,
                      sampler_kwargs=dict(depth=4))

        results = run()
        return results

    for n in [200]:
        results = run_with_n(n)
        plt.scatter(n, results.logZ)
        plt.errorbar(n, results.logZ, yerr=results.logZerr)
    plt.title("Kernel: {}".format(RationalQuadratic.__name__))
    plt.ylabel('log Z')
    plt.show()

    plt.scatter(X[:, 0], Y_obs, label='data')
    plt.plot(X[:, 0], Y, label='underlying')
    plt.plot(X[:, 0], results.marginalised['predict_f'], label='marginalised')
    plt.plot(X[:, 0],
             results.marginalised['predict_f'] +
             jnp.sqrt(results.marginalised['predict_fvar']),
             ls='dotted',
             c='black')
    plt.plot(X[:, 0],
             results.marginalised['predict_f'] -
             jnp.sqrt(results.marginalised['predict_fvar']),
             ls='dotted',
             c='black')
    plt.title("Kernel: {}".format(RationalQuadratic.__name__))
    plt.legend()
    plt.show()

    plot_diagnostics(results)
    plot_cornerplot(results)
    return results.logZ, results.logZerr
Example #22
0
def test_generic_kmeans():
    from jaxns.prior_transforms import PriorChain, UniformPrior
    from jax import vmap, disable_jit, jit
    import pylab as plt

    data = 'shells'
    if data == 'eggbox':
        def log_likelihood(theta, **kwargs):
            return (2. + jnp.prod(jnp.cos(0.5 * theta))) ** 5

        prior_chain = PriorChain() \
            .push(UniformPrior('theta', low=jnp.zeros(2), high=jnp.pi * 10. * jnp.ones(2)))

        U = vmap(lambda key: random.uniform(key, (prior_chain.U_ndims,)))(random.split(random.PRNGKey(0), 1000))
        theta = vmap(lambda u: prior_chain(u))(U)
        lik = vmap(lambda theta: log_likelihood(**theta))(theta)
        select = lik > 100.

    if data == 'shells':

        def log_likelihood(theta, **kwargs):
            def log_circ(theta, c, r, w):
                return -0.5*(jnp.linalg.norm(theta - c) - r)**2/w**2 - jnp.log(jnp.sqrt(2*jnp.pi*w**2))
            w1=w2=jnp.array(0.1)
            r1=r2=jnp.array(2.)
            c1 = jnp.array([0., -4.])
            c2 = jnp.array([0., 4.])
            return jnp.logaddexp(log_circ(theta, c1,r1,w1) , log_circ(theta,c2,r2,w2))


        prior_chain = PriorChain() \
            .push(UniformPrior('theta', low=-12.*jnp.ones(2), high=12.*jnp.ones(2)))

        U = vmap(lambda key: random.uniform(key, (prior_chain.U_ndims,)))(random.split(random.PRNGKey(0), 40000))
        theta = vmap(lambda u: prior_chain(u))(U)
        lik = vmap(lambda theta: log_likelihood(**theta))(theta)
        select = lik > 1.

    print("Selecting", jnp.sum(select))
    log_VS = jnp.log(jnp.sum(select)/select.size)
    print("V(S)",jnp.exp(log_VS))

    points = U[select, :]
    sc = plt.scatter(U[:,0], U[:,1],c=jnp.exp(lik))
    plt.colorbar(sc)
    plt.show()
    mask = jnp.ones(points.shape[0], dtype=jnp.bool_)
    K = 18
    with disable_jit():
        # state = generic_kmeans(random.PRNGKey(0), points, mask, method='ellipsoid',K=K,meta=dict(log_VS=log_VS))
        # state = generic_kmeans(random.PRNGKey(0), points, mask, method='mahalanobis',K=K)
        # state = generic_kmeans(random.PRNGKey(0), points, mask, method='euclidean',K=K)
        # cluster_id, log_cluster_VS = hierarchical_clustering(random.PRNGKey(0), points, 7, log_VS)
        cluster_id, ellipsoid_parameters = \
            jit(lambda key, points, log_VS: ellipsoid_clustering(random.PRNGKey(0), points, 7, log_VS)
                )(random.PRNGKey(0), points, log_VS)
        # mu, radii, rotation = ellipsoid_parameters
        K = int(jnp.max(cluster_id)+1)

    mu, C = vmap(lambda k: bounding_ellipsoid(points, cluster_id == k))(jnp.arange(K))
    radii, rotation = vmap(ellipsoid_params)(C)

    theta = jnp.linspace(0., jnp.pi * 2, 100)
    x = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0)

    for i, (mu, radii, rotation) in enumerate(zip(mu, radii, rotation)):
        y = mu[:, None] + rotation @ jnp.diag(radii) @ x
        plt.plot(y[0, :], y[1, :], c=plt.cm.jet(i / K))
        mask = cluster_id == i
        plt.scatter(points[mask, 0], points[mask, 1], c=jnp.atleast_2d(plt.cm.jet(i / K)))
    plt.xlim(-1,2)
    plt.ylim(-1,2)
    plt.show()
Example #23
0
def main():
    Sigma, T, Y_obs, amp, tec, freqs = generate_data()
    TEC_CONV = -8.4479745e6  # mTECU/Hz

    def log_mvnormal(x, mean, cov):
        L = jnp.linalg.cholesky(cov)
        dx = x - mean
        dx = solve_triangular(L, dx, lower=True)
        return -0.5 * x.size * jnp.log(2. * jnp.pi) - jnp.sum(jnp.log(jnp.diag(L))) \
               - 0.5 * dx @ dx

    def log_normal(x, mean, uncert):
        dx = (x - mean)/uncert
        return -0.5 * x.size * jnp.log(2. * jnp.pi) - x.size * jnp.log(uncert) \
               - 0.5 * dx @ dx

    def log_likelihood(tec, uncert, **kwargs):
        # tec = x[0]  # [:, 0]
        # uncert = x[1]  # [:, 1]
        # clock = x[2] * 1e-9
        # uncert = 0.25#x[2]
        phase = tec * (TEC_CONV / freqs)  # + clock *(jnp.pi*2)*freqs#+ clock
        Y = jnp.concatenate([jnp.cos(phase), jnp.sin(phase)], axis=-1)
        return jnp.sum(vmap(lambda Y, Y_obs: log_normal(Y, Y_obs, uncert))(Y, Y_obs))

    # prior_transform = MVNDiagPrior(prior_mu, jnp.sqrt(jnp.diag(prior_cov)))
    # prior_transform = LaplacePrior(prior_mu, jnp.sqrt(jnp.diag(prior_cov)))
    prior_chain = PriorChain() \
        .push(DiagGaussianWalkPrior('tec', T, LaplacePrior('tec0', 0., 100.), UniformPrior('omega', 1, 15))) \
        .push(UniformPrior('uncert', 0.01, 0.5))

    ns = NestedSampler(log_likelihood, prior_chain, sampler_name='slice',
                       tec_mean=lambda tec,**kwargs: tec)

    @jit
    def run(key):
        return ns(key=key,
                  num_live_points=500,
                  max_samples=1e5,
                  collect_samples=True,
                  termination_frac=0.01,
                  stoachastic_uncertainty=False,
                  sampler_kwargs=dict(depth=7))

    # with disable_jit():
    t0 = default_timer()
    results = run(random.PRNGKey(0))
    print("Time with compile efficiency normalised", results.efficiency * (default_timer() - t0))
    print("Time with compile", default_timer() - t0)
    t0 = default_timer()
    results = run(random.PRNGKey(1))
    print("Time no compile efficiency normalised", results.efficiency * (default_timer() - t0))
    print("Time no compile", default_timer() - t0)


    plt.plot(tec)
    plt.plot(results.marginalised['tec_mean'])
    plt.show()
    plt.plot(results.marginalised['tec_mean'][:,0]-tec)
    plt.show()
    ###

    plot_diagnostics(results)
Example #24
0
def unconstrained_solve(freqs, key, phase_obs, phase_outliers):
    key1, key2, key3, key4 = random.split(key, 4)
    Nt, Nf = phase_obs.shape
    assert Nt == 2, "Observations should be consequentive pairs of 2"

    tec0_array = jnp.linspace(-300., 300., 30)
    dtec_array = jnp.linspace(30., 30., 30)
    const_array = jnp.linspace(-jnp.pi, jnp.pi, 10)
    uncert0_array = jnp.linspace(0., 1., 10)
    uncert1_array = jnp.linspace(0., 1., 10)

    def log_likelihood(tec0, dtec, const, uncert0, uncert1, **kwargs):
        tec = jnp.asarray([tec0, tec0 + dtec])
        t = freqs - jnp.min(freqs)
        t /= t[-1]
        uncert = uncert0 + (uncert1 - uncert0) * t
        phase = tec[:, None] * (TEC_CONV / freqs) + const  # 2,Nf
        logL = jnp.sum(
            jnp.where(
                phase_outliers, 0.,
                log_normal(wrap(wrap(phase) - wrap(phase_obs)), 0., uncert)))
        return logL

    # X = make_coord_array(tec0_array[:, None], dtec_array[:, None], const_array[:, None], uncert0_array[:, None], uncert1_array[:, None], flat=True)
    # log_prob_array = vmap(lambda x: log_likelihood(x[0], x[1], x[2], x[3], x[4]))(X)
    # log_prob_array = log_prob_array.reshape((tec0_array.size, dtec_array.size, const_array.size, uncert0_array.size, uncert1_array.size))
    #
    # lookup_func = build_lookup_index(tec0_array, dtec_array, const_array, uncert0_array, uncert1_array)
    #
    # def efficient_log_likelihood(tec0, dtec, const, uncert0, uncert1, **kwargs):
    #     b = 0.5
    #     log_prob_uncert0 = - uncert0 / b - jnp.log(b)
    #     log_prob_uncert1 = - uncert1 / b - jnp.log(b)
    #     return lookup_func(log_prob_array, tec0, dtec, const, uncert0, uncert1) + log_prob_uncert0 + log_prob_uncert1

    tec0 = UniformPrior('tec0', tec0_array.min(), tec0_array.max())
    # 30mTECU/30seconds is the maximum change
    dtec = UniformPrior('dtec', dtec_array.min(), dtec_array.max())
    const = UniformPrior('const', const_array.min(), const_array.max())
    uncert0 = UniformPrior('uncert0', uncert0_array.min(), uncert0_array.max())
    uncert1 = UniformPrior('uncert1', uncert1_array.min(), uncert1_array.max())
    prior_chain = PriorChain(tec0, dtec, const, uncert0, uncert1)

    ns = NestedSampler(log_likelihood,
                       prior_chain,
                       sampler_name='slice',
                       num_live_points=20 * prior_chain.U_ndims,
                       sampler_kwargs=dict(num_slices=prior_chain.U_ndims * 4))

    results = ns(key=key1, termination_evidence_frac=0.3)

    ESS = 900  # emperically estimated for this problem

    def marginalisation(tec0, dtec, const, uncert0, uncert1, **kwargs):
        tec = jnp.asarray([tec0, tec0 + dtec])
        return tec, tec**2, jnp.cos(const), jnp.sin(const), 0.5 * (uncert0 +
                                                                   uncert1)

    tec_mean, tec2_mean, const_real, const_imag, uncert_mean = marginalise_static(
        key2, results.samples, results.log_p, ESS, marginalisation)

    tec_std = jnp.sqrt(tec2_mean - tec_mean**2)
    const_mean = jnp.arctan2(const_imag, const_real)

    def marginalisation(const, **kwargs):
        return wrap(wrap(const) - wrap(const_mean))**2

    const_var = marginalise_static(key2, results.samples, results.log_p, ESS,
                                   marginalisation)
    const_std = jnp.sqrt(const_var)

    return tec_mean, tec_std, const_mean * jnp.ones(Nt), const_std * jnp.ones(
        Nt), uncert_mean * jnp.ones(Nt)