def main(): Y_obs, amp, tec, freqs = generate_data() TEC_CONV = -8.4479745e6 # mTECU/Hz def log_normal(x, mean, scale): dx = (x - mean) / scale return -0.5 * x.size * jnp.log(2. * jnp.pi) - x.size*jnp.log(scale) \ - 0.5 * dx @ dx def log_laplace(x, mean, scale): dx = jnp.abs(x - mean) / scale return -x.size * jnp.log(2. * scale) - jnp.sum(dx) def log_likelihood(tec, const, uncert, **kwargs): phase = tec * (TEC_CONV / freqs) + const Y = jnp.concatenate([amp * jnp.cos(phase), amp * jnp.sin(phase)], axis=-1) log_prob = log_laplace(Y, Y_obs, uncert[0]) return log_prob prior_chain = PriorChain() \ .push(UniformPrior('tec', -100., 100.)) \ .push(UniformPrior('const', -jnp.pi, jnp.pi)) \ .push(HalfLaplacePrior('uncert', 0.25)) print("Probabilistic model:\n{}".format(prior_chain)) ns = NestedSampler( log_likelihood, prior_chain, sampler_name='slice', tec_mean=lambda tec, **kw: tec, #I would like to this function over the posterior const_mean=lambda const, **kw: const #I would like to this function over the posterior ) run = jit(lambda key: ns(key=key, num_live_points=1000, max_samples=1e5, collect_samples=True, termination_frac=0.01, stoachastic_uncertainty=False, sampler_kwargs=dict(depth=4, num_slices=1))) t0 = default_timer() results = run(random.PRNGKey(2364)) print(results.efficiency) print("Time compile", default_timer() - t0) t0 = default_timer() results = run(random.PRNGKey(1324)) print(results.efficiency) print("Time no compile", default_timer() - t0) ### print(results.marginalised['tec_mean']) print(results.marginalised['const_mean']) plot_diagnostics(results) plot_cornerplot(results)
def diagnostics(self): """ Plot diagnostics of the result. This is a wrapper of :func:`jaxns.plotting.plot_diagnostics` and :func:`jaxns.plotting.plot_cornerplot`. """ if self._results is None: raise RuntimeError( "NestedSampler.run(...) method should be called first to obtain results." ) plot_diagnostics(self._results) plot_cornerplot(self._results)
def main(): ndims = 4 sigma = 0.1 def log_likelihood(theta, **kwargs): r2 = jnp.sum(theta**2) logL = -0.5 * jnp.log(2. * jnp.pi * sigma**2) * ndims logL += -0.5 * r2 / sigma**2 return logL prior_transform = PriorChain().push( UniformPrior('theta', -jnp.ones(ndims), jnp.ones(ndims))) ns = NestedSampler(log_likelihood, prior_transform, sampler_name='slice') def run_with_n(n): @jit def run(key): return ns(key=key, num_live_points=n, max_samples=1e5, collect_samples=True, termination_frac=0.01, stoachastic_uncertainty=False, sampler_kwargs=dict(depth=3)) t0 = default_timer() results = run(random.PRNGKey(0)) print(results.efficiency) print("Time to run including compile:", default_timer() - t0) print("Time efficiency normalised:", results.efficiency * (default_timer() - t0)) t0 = default_timer() results = run(random.PRNGKey(1)) print(results.efficiency) print("Time to run no compile:", default_timer() - t0) print("Time efficiency normalised:", results.efficiency * (default_timer() - t0)) return results for n in [1000]: results = run_with_n(n) plt.scatter(n, results.logZ) plt.errorbar(n, results.logZ, yerr=results.logZerr) plt.show() # plot_samples_development(results, save_name='./example.mp4') plot_diagnostics(results) plot_cornerplot(results)
def main(): def log_likelihood(theta, **kwargs): return 5. * (2. + jnp.prod(jnp.cos(0.5 * theta))) prior_chain = PriorChain() \ .push(UniformPrior('theta', low=jnp.zeros(2), high=jnp.pi * 10. * jnp.ones(2))) theta = vmap( lambda key: prior_chain(random.uniform(key, (prior_chain.U_ndims, ))))( random.split(random.PRNGKey(0), 10000)) lik = vmap(lambda theta: log_likelihood(**theta))(theta) sc = plt.scatter(theta['theta'][:, 0], theta['theta'][:, 1], c=lik) plt.colorbar(sc) plt.show() ns = NestedSampler(log_likelihood, prior_chain, sampler_name='slice') def run_with_n(n): @jit def run(key): return ns(key=key, num_live_points=n, max_samples=1e5, collect_samples=True, termination_frac=0.01, stoachastic_uncertainty=False, sampler_kwargs=dict(depth=7)) t0 = default_timer() # with disable_jit(): results = run(random.PRNGKey(0)) print("Efficiency", results.efficiency) print("Time to run (including compile)", default_timer() - t0) t0 = default_timer() results = run(random.PRNGKey(1)) print(results.efficiency) print("Time to run (no compile)", default_timer() - t0) return results for n in [500]: results = run_with_n(n) plt.scatter(n, results.logZ) plt.errorbar(n, results.logZ, yerr=results.logZerr) plt.ylabel('log Z') plt.show() plot_diagnostics(results) plot_cornerplot(results) return results.logZ, results.logZerr
def test_unit_cube_mixture_prior(): import jax.numpy as jnp from jax import random from jaxns.nested_sampling import NestedSampler from jaxns.plotting import plot_cornerplot, plot_diagnostics # prior_chain = PriorChain().push(MultiCubeMixturePrior('x', 2, 1, -5., 15.)) prior_chain = PriorChain().push(GMMMarginalPrior('x', 2, -5., 15.)) def loglikelihood(x, **kwargs): return jnp.log( 0.5 * jnp.exp(-0.5 * jnp.sum(x)**2) / jnp.sqrt(2. * jnp.pi) + 0.5 * jnp.exp(-0.5 * jnp.sum(x - 10.)**2) / jnp.sqrt(2. * jnp.pi)) ns = NestedSampler(loglikelihood, prior_chain, sampler_name='ellipsoid') results = ns(random.PRNGKey(0), 100, max_samples=1e5, collect_samples=True, termination_frac=0.05, stoachastic_uncertainty=True) plot_diagnostics(results) plot_cornerplot(results)
def main(kernel): def log_normal(x, mean, cov): L = jnp.linalg.cholesky(cov) dx = x - mean dx = solve_triangular(L, dx, lower=True) # maha = dx @ jnp.linalg.solve(cov, dx) maha = dx @ dx # logdet = jnp.log(jnp.linalg.det(cov)) logdet = jnp.sum(jnp.log(jnp.diag(L))) log_prob = -0.5 * x.size * jnp.log(2. * jnp.pi) - logdet - 0.5 * maha return log_prob true_height, true_width, true_sigma, true_l, true_uncert = 200., 100., 1., 10., 2.5 nant = 5 ndir = 5 X, Y, Y_obs = rbf_dtec(nant, ndir, true_height, true_width, true_sigma, true_l, true_uncert) a = X[:, 0:3] k = X[:, 3:6] x0 = a[0, :] def log_likelihood(dtec, uncert, **kwargs): """ P(Y|sigma, half_width) = N[Y, mu, K] Args: sigma: l: Returns: """ data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) return log_normal(Y_obs, dtec, data_cov) def predict_f(dtec, uncert, **kwargs): return dtec def predict_fvar(dtec, uncert, **kwargs): return dtec ** 2 def tec_to_dtec(tec): tec = tec.reshape((nant, ndir)) dtec = jnp.reshape(tec - tec[0, :], (-1,)) return dtec prior_chain = build_prior(X, kernel, tec_to_dtec, x0) print(prior_chain) U_test = jnp.array([random.uniform(key, shape=(prior_chain.U_ndims,)) for key in random.split(random.PRNGKey(4325),1000)]) log_lik = jnp.array([log_likelihood(**prior_chain(U)) for U in U_test]) print(jnp.sum(jnp.isnan(log_lik))) print(U_test[jnp.isnan(log_lik)]) ns = NestedSampler(log_likelihood, prior_chain, sampler_name='slice', predict_f=predict_f, predict_fvar=predict_fvar) def run_with_n(n): @jit def run(key): return ns(key=key, num_live_points=n, max_samples=1e5, collect_samples=True, termination_frac=0.01, stoachastic_uncertainty=False, sampler_kwargs=dict(depth=7, num_slices=1)) t0 = default_timer() results = run(random.PRNGKey(0)) print("Efficiency", results.efficiency) print("Time to run (including compile)", default_timer() - t0) t0 = default_timer() results = run(random.PRNGKey(1)) print("Efficiency",results.efficiency) print("Time to run (no compile)", default_timer() - t0) print("Time efficiency normalised", (default_timer() - t0)*results.efficiency) return results for n in [1000]: results = run_with_n(n) plt.scatter(n, results.logZ) plt.errorbar(n, results.logZ, yerr=results.logZerr) # # # K = GaussianProcessKernelPrior('K', # TomographicKernel(x0, kernel, S=20), X, # MVNPrior('height', results.param_mean['height'], results.param_covariance['height']),#UniformPrior('height', 100., 300.), # MVNPrior('width', results.param_mean['width'], results.param_covariance['width']),#UniformPrior('width', 50., 150.), # MVNPrior('l', results.param_mean['l'], results.param_covariance['l']),#UniformPrior('l', 7., 20.), # MVNPrior('sigma', results.param_mean['sigma'], results.param_covariance['sigma']),#UniformPrior('sigma', 0.3, 2.), # tracked=False) # tec = MVNPrior('tec', jnp.zeros((X.shape[0],)), K, ill_cond=True, tracked=False) # dtec = DeterministicTransformPrior('dtec', tec_to_dtec, tec.to_shape, tec, tracked=False) # prior_chain = PriorChain() \ # .push(dtec) \ # .push(UniformPrior('uncert', 0., 5.)) # # ns = NestedSampler(log_likelihood, prior_chain, sampler_name='multi_ellipsoid', predict_f=predict_f, # predict_fvar=predict_fvar) # # def run_with_n(n): # @jit # def run(key): # return ns(key=key, # num_live_points=n, # max_samples=1e5, # collect_samples=True, # termination_frac=0.01, # stoachastic_uncertainty=False, # sampler_kwargs=dict(depth=4)) # # t0 = default_timer() # results = run(random.PRNGKey(0)) # print("Efficiency", results.efficiency) # print("Time to run (including compile)", default_timer() - t0) # t0 = default_timer() # results = run(random.PRNGKey(1)) # print(results.efficiency) # print("Time to run (no compile)", default_timer() - t0) # return results # # for n in [100]: # results = run_with_n(n) # plt.scatter(n, results.logZ) # plt.errorbar(n, results.logZ, yerr=results.logZerr) plt.title("Kernel: {}".format(kernel.__class__.__name__)) plt.ylabel('log Z') plt.show() fstd = jnp.sqrt(results.marginalised['predict_fvar'] - results.marginalised['predict_f'] ** 2) plt.scatter(jnp.arange(Y.size),Y_obs, marker='+', label='data') plt.scatter(jnp.arange(Y.size),Y, marker="o", label='underlying') plt.scatter(jnp.arange(Y.size), results.marginalised['predict_f'], marker=".", label='underlying') plt.errorbar(jnp.arange(Y.size), results.marginalised['predict_f'], yerr=fstd, label='marginalised') plt.title("Kernel: {}".format(kernel.__class__.__name__)) plt.legend() plt.show() # plot_samples_development(results,save_name='./ray_integral_solution.mp4') plot_diagnostics(results) plot_cornerplot(results) return results.logZ, results.logZerr
def main(kernel): print(("Working on Kernel: {}".format(kernel.__class__.__name__))) def log_normal(x, mean, cov): L = jnp.linalg.cholesky(cov) # U, S, Vh = jnp.linalg.svd(cov) log_det = jnp.sum(jnp.log(jnp.diag(L))) # jnp.sum(jnp.log(S))# dx = x - mean dx = solve_triangular(L, dx, lower=True) # U S Vh V 1/S Uh # pinv = (Vh.T.conj() * jnp.where(S!=0., jnp.reciprocal(S), 0.)) @ U.T.conj() maha = dx @ dx # dx @ pinv @ dx#solve_triangular(L, dx, lower=True) log_likelihood = -0.5 * x.size * jnp.log(2. * jnp.pi) \ - log_det \ - 0.5 * maha # print(log_likelihood) return log_likelihood N = 100 X = jnp.linspace(-2., 2., N)[:, None] true_sigma, true_l, true_uncert = 1., 0.2, 0.2 data_mu = jnp.zeros((N, )) prior_cov = RBF()(X, X, true_l, true_sigma) + 1e-13 * jnp.eye(N) # print(jnp.linalg.cholesky(prior_cov), jnp.linalg.eigvals(prior_cov)) # return Y = jnp.linalg.cholesky(prior_cov) @ random.normal(random.PRNGKey(0), shape=(N, )) + data_mu Y_obs = Y + true_uncert * random.normal(random.PRNGKey(1), shape=(N, )) Y_obs = jnp.where((jnp.arange(N) > 50) & (jnp.arange(N) < 60), random.normal(random.PRNGKey(1), shape=(N, )), Y_obs) # plt.scatter(X[:, 0], Y_obs, label='data') # plt.plot(X[:, 0], Y, label='underlying') # plt.legend() # plt.show() def log_likelihood(K, uncert, **kwargs): """ P(Y|sigma, half_width) = N[Y, mu, K] Args: sigma: l: Returns: """ data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) mu = jnp.zeros_like(Y_obs) return log_normal(Y_obs, mu, K + data_cov) def predict_f(K, uncert, **kwargs): data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) mu = jnp.zeros_like(Y_obs) return mu + K @ jnp.linalg.solve(K + data_cov, Y_obs) def predict_fvar(K, uncert, **kwargs): data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) mu = jnp.zeros_like(Y_obs) return jnp.diag(K - K @ jnp.linalg.solve(K + data_cov, K)) l = UniformPrior('l', 0., 2.) uncert = UniformPrior('uncert', 0., 2.) sigma = UniformPrior('sigma', 0., 2.) cov = GaussianProcessKernelPrior('K', kernel, X, l, sigma) prior_chain = PriorChain().push(uncert).push(cov) # print(prior_chain) ns = NestedSampler(log_likelihood, prior_chain, sampler_name='multi_ellipsoid', predict_f=predict_f, predict_fvar=predict_fvar) def run_with_n(n): @jit def run(key): return ns(key=key, num_live_points=n, max_samples=1e5, collect_samples=True, termination_frac=0.01, stoachastic_uncertainty=False, sampler_kwargs=dict(depth=3)) t0 = default_timer() # with disable_jit(): results = run(random.PRNGKey(6)) print(results.efficiency) print( "Time to execute (including compile): {}".format(default_timer() - t0)) t0 = default_timer() results = run(random.PRNGKey(6)) print(results.efficiency) print("Time to execute (not including compile): {}".format( (default_timer() - t0))) return results for n in [100]: results = run_with_n(n) plt.scatter(n, results.logZ) plt.errorbar(n, results.logZ, yerr=results.logZerr) plt.title("Kernel: {}".format(kernel.__class__.__name__)) plt.ylabel('log Z') plt.show() plt.scatter(X[:, 0], Y_obs, label='data') plt.plot(X[:, 0], Y, label='underlying') plt.plot(X[:, 0], results.marginalised['predict_f'], label='marginalised') plt.plot(X[:, 0], results.marginalised['predict_f'] + jnp.sqrt(results.marginalised['predict_fvar']), ls='dotted', c='black') plt.plot(X[:, 0], results.marginalised['predict_f'] - jnp.sqrt(results.marginalised['predict_fvar']), ls='dotted', c='black') plt.title("Kernel: {}".format(kernel.__class__.__name__)) plt.legend() plt.show() plot_diagnostics(results) plot_cornerplot(results) return results.logZ, results.logZerr
def main(kernel): def log_normal(x, mean, cov): L = jnp.linalg.cholesky(cov) dx = x - mean dx = solve_triangular(L, dx, lower=True) # maha = dx @ jnp.linalg.solve(cov, dx) maha = dx @ dx # logdet = jnp.log(jnp.linalg.det(cov)) logdet = jnp.sum(jnp.log(jnp.diag(L))) log_prob = -0.5 * x.size * jnp.log(2. * jnp.pi) - logdet - 0.5 * maha return log_prob true_height, true_width, true_sigma, true_l, true_uncert, true_v = 200., 100., 1., 10., 2.5, jnp.array( [0.3, 0., 0.]) nant = 2 ndir = 1 ntime = 20 X, Y, Y_obs = rbf_dtec(nant, ndir, ntime, true_height, true_width, true_sigma, true_l, true_uncert, true_v) a = X[:, 0:3] k = X[:, 3:6] t = X[:, 6:7] x0 = a[0, :] def log_likelihood(dtec, uncert, **kwargs): """ P(Y|sigma, half_width) = N[Y, mu, K] Args: sigma: l: Returns: """ data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) return log_normal(Y_obs, dtec, data_cov) def predict_f(dtec, **kwargs): return dtec def predict_fvar(dtec, **kwargs): return dtec**2 def tec_to_dtec(tec): tec = tec.reshape((nant, ndir, ntime)) dtec = jnp.reshape(tec - tec[0, :, :], (-1, )) return dtec prior_chain = build_frozen_flow_prior(X, kernel, tec_to_dtec, x0) ns = NestedSampler(log_likelihood, prior_chain, sampler_name='slice', predict_f=predict_f, predict_fvar=predict_fvar) def run_with_n(n): @jit def run(key): return ns(key=key, num_live_points=n, max_samples=1e5, collect_samples=True, termination_frac=0.01, stoachastic_uncertainty=False, sampler_kwargs=dict(depth=5, num_slices=1)) t0 = default_timer() results = run(random.PRNGKey(0)) print("Efficiency", results.efficiency) print("Time to run (including compile)", default_timer() - t0) # t0 = default_timer() # results = run(random.PRNGKey(1)) # print(results.efficiency) # print("Time to run (no compile)", default_timer() - t0) return results for n in [100]: results = run_with_n(n) plt.scatter(n, results.logZ) plt.errorbar(n, results.logZ, yerr=results.logZerr) plt.title("Kernel: {}".format(kernel.__class__.__name__)) plt.ylabel('log Z') plt.show() fstd = jnp.sqrt(results.marginalised['predict_fvar'] - results.marginalised['predict_f']**2) plt.scatter(jnp.arange(Y.size), Y_obs, marker='+', label='data') plt.scatter(jnp.arange(Y.size), Y, marker="o", label='underlying') plt.scatter(jnp.arange(Y.size), results.marginalised['predict_f'], marker=".", label='underlying') plt.errorbar(jnp.arange(Y.size), results.marginalised['predict_f'], yerr=fstd, label='marginalised') plt.title("Kernel: {}".format(kernel.__class__.__name__)) plt.legend() plt.show() plot_diagnostics(results) plot_cornerplot(results) return results.logZ, results.logZerr
def main(): def log_normal(x, mean, cov): L = jnp.linalg.cholesky(cov) dx = x - mean dx = solve_triangular(L, dx, lower=True) return -0.5 * x.size * jnp.log(2. * jnp.pi) - jnp.sum(jnp.log(jnp.diag(L))) \ - 0.5 * dx @ dx N = 100 X = jnp.linspace(-2., 2., N)[:, None] true_alpha, true_sigma, true_l, true_uncert = 1., 1., 0.2, 0.25 data_mu = jnp.zeros((N, )) prior_cov = RationalQuadratic()(X, X, true_l, true_sigma, true_alpha) Y = jnp.linalg.cholesky(prior_cov) @ random.normal(random.PRNGKey(0), shape=(N, )) + data_mu Y_obs = Y + true_uncert * random.normal(random.PRNGKey(1), shape=(N, )) # Y_obs = jnp.where((jnp.arange(N) > 50) & (jnp.arange(N) < 60), # random.normal(random.PRNGKey(1), shape_dict=(N, )), # Y_obs) # plt.scatter(X[:, 0], Y_obs, label='data') # plt.plot(X[:, 0], Y, label='underlying') # plt.legend() # plt.show() def log_likelihood(K, uncert, **kwargs): """ P(Y|sigma, half_width) = N[Y, mu, K] Args: sigma: l: Returns: """ data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) mu = jnp.zeros_like(Y_obs) log_prob = log_normal(Y_obs, mu, K + data_cov) # print(log_prob) return log_prob def predict_f(K, uncert, **kwargs): data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) mu = jnp.zeros_like(Y_obs) return mu + K @ jnp.linalg.solve(K + data_cov, Y_obs) def predict_fvar(K, uncert, **kwargs): data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) mu = jnp.zeros_like(Y_obs) return jnp.diag(K - K @ jnp.linalg.solve(K + data_cov, K)) prior_chain = PriorChain() \ .push(GaussianProcessKernelPrior('K', RationalQuadratic(), X, UniformPrior('l', 0., 4.), UniformPrior('sigma', 0., 4.), UniformPrior('alpha', 0., 4.))) \ .push(UniformPrior('uncert', 0., 2.)) ns = NestedSampler(log_likelihood, prior_chain, sampler_name='multi_ellipsoid', predict_f=predict_f, predict_fvar=predict_fvar) def run_with_n(n): @jit def run(): return ns(key=random.PRNGKey(0), num_live_points=n, max_samples=1e4, collect_samples=True, termination_frac=0.01, stoachastic_uncertainty=False, sampler_kwargs=dict(depth=4)) results = run() return results for n in [200]: results = run_with_n(n) plt.scatter(n, results.logZ) plt.errorbar(n, results.logZ, yerr=results.logZerr) plt.title("Kernel: {}".format(RationalQuadratic.__name__)) plt.ylabel('log Z') plt.show() plt.scatter(X[:, 0], Y_obs, label='data') plt.plot(X[:, 0], Y, label='underlying') plt.plot(X[:, 0], results.marginalised['predict_f'], label='marginalised') plt.plot(X[:, 0], results.marginalised['predict_f'] + jnp.sqrt(results.marginalised['predict_fvar']), ls='dotted', c='black') plt.plot(X[:, 0], results.marginalised['predict_f'] - jnp.sqrt(results.marginalised['predict_fvar']), ls='dotted', c='black') plt.title("Kernel: {}".format(RationalQuadratic.__name__)) plt.legend() plt.show() plot_diagnostics(results) plot_cornerplot(results) return results.logZ, results.logZerr