def test_forced_identifiability_prior(): from jax import random prior = PriorChain().push(ForcedIdentifiabilityPrior('x', 10, 0., 10.)) for i in range(10): out = prior(random.uniform(random.PRNGKey(i), shape=(prior.U_ndims, ))) assert jnp.all(jnp.sort(out['x'], axis=0) == out['x']) assert jnp.all((out['x'] >= 0.) & (out['x'] <= 10.)) prior = PriorChain().push( ForcedIdentifiabilityPrior('x', 10, jnp.array([0., 0.]), 10.)) for i in range(10): out = prior(random.uniform(random.PRNGKey(i), shape=(prior.U_ndims, ))) assert out['x'].shape == (10, 2) assert jnp.all(jnp.sort(out['x'], axis=0) == out['x']) assert jnp.all((out['x'] >= 0.) & (out['x'] <= 10.))
def build_prior(X, kernel, tec_to_dtec, x0): K = GaussianProcessKernelPrior('K', TomographicKernel(x0, kernel, S_marg=100, S_gamma=10), X, UniformPrior('height', 100., 300.), UniformPrior('width', 50., 150.), UniformPrior('l', 7., 20.), UniformPrior('sigma', 0.3, 2.), tracked=False) tec = MVNPrior('tec', jnp.zeros((X.shape[0], )), K, ill_cond=False, tracked=False) dtec = DeterministicTransformPrior('dtec', tec_to_dtec, tec.to_shape, tec, tracked=False) prior_chain = PriorChain() \ .push(dtec) \ .push(UniformPrior('uncert', 0., 5.)) return prior_chain
def main(): Y_obs, amp, tec, freqs = generate_data() TEC_CONV = -8.4479745e6 # mTECU/Hz def log_normal(x, mean, scale): dx = (x - mean) / scale return -0.5 * x.size * jnp.log(2. * jnp.pi) - x.size*jnp.log(scale) \ - 0.5 * dx @ dx def log_laplace(x, mean, scale): dx = jnp.abs(x - mean) / scale return -x.size * jnp.log(2. * scale) - jnp.sum(dx) def log_likelihood(tec, const, uncert, **kwargs): phase = tec * (TEC_CONV / freqs) + const Y = jnp.concatenate([amp * jnp.cos(phase), amp * jnp.sin(phase)], axis=-1) log_prob = log_laplace(Y, Y_obs, uncert[0]) return log_prob prior_chain = PriorChain() \ .push(UniformPrior('tec', -100., 100.)) \ .push(UniformPrior('const', -jnp.pi, jnp.pi)) \ .push(HalfLaplacePrior('uncert', 0.25)) print("Probabilistic model:\n{}".format(prior_chain)) ns = NestedSampler( log_likelihood, prior_chain, sampler_name='slice', tec_mean=lambda tec, **kw: tec, #I would like to this function over the posterior const_mean=lambda const, **kw: const #I would like to this function over the posterior ) run = jit(lambda key: ns(key=key, num_live_points=1000, max_samples=1e5, collect_samples=True, termination_frac=0.01, stoachastic_uncertainty=False, sampler_kwargs=dict(depth=4, num_slices=1))) t0 = default_timer() results = run(random.PRNGKey(2364)) print(results.efficiency) print("Time compile", default_timer() - t0) t0 = default_timer() results = run(random.PRNGKey(1324)) print(results.efficiency) print("Time no compile", default_timer() - t0) ### print(results.marginalised['tec_mean']) print(results.marginalised['const_mean']) plot_diagnostics(results) plot_cornerplot(results)
def build_prior(X, kernel, tec_to_dtec, x0, tec_conv): K = GaussianProcessKernelPrior('K', TomographicKernel(x0, kernel, S_marg=100, S_gamma=10), X, UniformPrior('height', 100., 300.), UniformPrior('width', 50., 150.), UniformPrior('l', 7., 20.), UniformPrior('sigma', 0.3, 2.), tracked=False) tec = MVNPrior('tec', jnp.zeros((X.shape[0], )), K, ill_cond=True, tracked=False) dtec = DeterministicTransformPrior('dtec', tec_to_dtec, tec.to_shape, tec, tracked=False) Y = DeterministicTransformPrior('Y', lambda dtec: jnp.concatenate([ jnp.cos(dtec[:, None] * tec_conv), jnp.sin(dtec[:, None] * tec_conv) ], axis=-1), dtec.to_shape + (tec_conv.size * 2, ), dtec, tracked=False) prior_chain = PriorChain() \ .push(Y) \ .push(UniformPrior('uncert', 0.01, 1.)) return prior_chain
def run_block(key, dtec, dtec_uncert, log_prob): key1, key2 = random.split(key, 2) def log_likelihood(lengthscale, sigma, **kwargs): # K = kernel(X, X, lengthscale, sigma) # def _compute(dtec, dtec_uncert): # #each [Nd] # return log_normal_with_outliers(dtec, 0., K, jnp.maximum(1e-6, dtec_uncert)) # return chunked_pmap(_compute, dtec, dtec_uncert, chunksize=1).sum() return lookup_func(log_prob, lengthscale, sigma) lengthscale = UniformPrior('lengthscale', jnp.min(lengthscale_array), jnp.max(lengthscale_array)) sigma = UniformPrior('sigma', sigma_array.min(), sigma_array.max()) prior_chain = PriorChain(lengthscale, sigma) ns = NestedSampler(loglikelihood=log_likelihood, prior_chain=prior_chain, sampler_kwargs=dict(num_slices=prior_chain.U_ndims * 1), num_live_points=prior_chain.U_ndims * 50) ns = jit(ns) results = ns(key1, termination_evidence_frac=0.1) def marg_func(lengthscale, sigma, **kwargs): def screen(dtec, dtec_uncert, **kw): K = kernel(X, X, lengthscale, sigma) Kstar = kernel(X, Xstar, lengthscale, sigma) L = jnp.linalg.cholesky( K / (dtec_uncert[:, None] * dtec_uncert[None, :]) + jnp.eye(dtec.shape[0])) # L = jnp.where(jnp.isnan(L), jnp.eye(L.shape[0])/sigma, L) dx = solve_triangular(L, dtec / dtec_uncert, lower=True) JT = solve_triangular(L, Kstar / dtec_uncert[:, None], lower=True) #var_ik = JT_ji JT_jk mean = JT.T @ dx var = jnp.sum(JT * JT, axis=0) return mean, var return vmap(screen)(dtec, dtec_uncert), lengthscale, jnp.log( sigma ) #[time_block_size, Nd_screen], [time_block_size, Nd_screen] #[time_block_size, Nd_screen], [time_block_size, Nd_screen], [time_block_size] (mean, var), mean_lengthscale, mean_logsigma = marginalise_static( key2, results.samples, results.log_p, 500, marg_func) uncert = jnp.sqrt(var) mean_sigma = jnp.exp(mean_logsigma) mean_lengthscale = jnp.ones(time_block_size) * mean_lengthscale mean_sigma = jnp.ones(time_block_size) * mean_sigma ESS = results.ESS * jnp.ones(time_block_size) logZ = results.logZ * jnp.ones(time_block_size) likelihood_evals = results.num_likelihood_evaluations * jnp.ones( time_block_size) return mean, uncert, mean_lengthscale, mean_sigma, ESS, logZ, likelihood_evals
def build_layered_prior(X, kernel, x0, tec_to_dtec): layer_edges = jnp.linspace(80., 500., int((500. - 80.) / 50.) + 1) layer_kernels = [] for i in range(len(layer_edges) - 1): height = 0.5 * (layer_edges[i] + layer_edges[i + 1]) width = layer_edges[i + 1] - layer_edges[i] #Efficiency 0.39664684771546416 # Time to run (including compile) 246.36920081824064 # 0.39198953960498245 # Time to run (no compile) 130.1565416753292 # Efficiency normalised time 51.020025508433804 K = GaussianProcessKernelPrior('K{}'.format(i), TomographicKernel(x0, kernel, S_marg=100, S_gamma=20), X, DeltaPrior('height{}'.format(i), height, tracked=False), DeltaPrior('width{}'.format(i), width, tracked=False), UniformPrior('l{}'.format(i), 7., 20., tracked=False), UniformPrior('sigma{}'.format(i), 0.3, 2., tracked=False), tracked=False) layer_kernels.append(K) logits = jnp.zeros(len(layer_kernels)) select = CategoricalPrior('j', logits, tracked=True) K = DeterministicTransformPrior( 'K', lambda j, *K: jnp.stack(K, axis=0)[j[0], :, :], layer_kernels[0].to_shape, select, *layer_kernels, tracked=False) tec = MVNPrior('tec', jnp.zeros((X.shape[0], )), K, ill_cond=True, tracked=False) dtec = DeterministicTransformPrior('dtec', tec_to_dtec, tec.to_shape, tec, tracked=False) prior_chain = PriorChain() \ .push(dtec) \ .push(UniformPrior('uncert', 2., 3.)) return prior_chain
def main(): ndims = 4 sigma = 0.1 def log_likelihood(theta, **kwargs): r2 = jnp.sum(theta**2) logL = -0.5 * jnp.log(2. * jnp.pi * sigma**2) * ndims logL += -0.5 * r2 / sigma**2 return logL prior_transform = PriorChain().push( UniformPrior('theta', -jnp.ones(ndims), jnp.ones(ndims))) ns = NestedSampler(log_likelihood, prior_transform, sampler_name='slice') def run_with_n(n): @jit def run(key): return ns(key=key, num_live_points=n, max_samples=1e5, collect_samples=True, termination_frac=0.01, stoachastic_uncertainty=False, sampler_kwargs=dict(depth=3)) t0 = default_timer() results = run(random.PRNGKey(0)) print(results.efficiency) print("Time to run including compile:", default_timer() - t0) print("Time efficiency normalised:", results.efficiency * (default_timer() - t0)) t0 = default_timer() results = run(random.PRNGKey(1)) print(results.efficiency) print("Time to run no compile:", default_timer() - t0) print("Time efficiency normalised:", results.efficiency * (default_timer() - t0)) return results for n in [1000]: results = run_with_n(n) plt.scatter(n, results.logZ) plt.errorbar(n, results.logZ, yerr=results.logZerr) plt.show() # plot_samples_development(results, save_name='./example.mp4') plot_diagnostics(results) plot_cornerplot(results)
def main(): def log_likelihood(theta, **kwargs): return 5. * (2. + jnp.prod(jnp.cos(0.5 * theta))) prior_chain = PriorChain() \ .push(UniformPrior('theta', low=jnp.zeros(2), high=jnp.pi * 10. * jnp.ones(2))) theta = vmap( lambda key: prior_chain(random.uniform(key, (prior_chain.U_ndims, ))))( random.split(random.PRNGKey(0), 10000)) lik = vmap(lambda theta: log_likelihood(**theta))(theta) sc = plt.scatter(theta['theta'][:, 0], theta['theta'][:, 1], c=lik) plt.colorbar(sc) plt.show() ns = NestedSampler(log_likelihood, prior_chain, sampler_name='slice') def run_with_n(n): @jit def run(key): return ns(key=key, num_live_points=n, max_samples=1e5, collect_samples=True, termination_frac=0.01, stoachastic_uncertainty=False, sampler_kwargs=dict(depth=7)) t0 = default_timer() # with disable_jit(): results = run(random.PRNGKey(0)) print("Efficiency", results.efficiency) print("Time to run (including compile)", default_timer() - t0) t0 = default_timer() results = run(random.PRNGKey(1)) print(results.efficiency) print("Time to run (no compile)", default_timer() - t0) return results for n in [500]: results = run_with_n(n) plt.scatter(n, results.logZ) plt.errorbar(n, results.logZ, yerr=results.logZerr) plt.ylabel('log Z') plt.show() plot_diagnostics(results) plot_cornerplot(results) return results.logZ, results.logZerr
def run_jaxns(num_live_points): try: from jaxns.nested_sampling import NestedSampler from jaxns.prior_transforms import PriorChain, UniformPrior except: raise ImportError("Install JaxNS!") from timeit import default_timer from jax import random, jit import jax.numpy as jnp def log_likelihood(theta, **kwargs): r2 = jnp.sum(theta ** 2) logL = -0.5 * jnp.log(2. * jnp.pi * sigma ** 2) * ndims logL += -0.5 * r2 / sigma ** 2 return logL prior_transform = PriorChain().push(UniformPrior('theta', -jnp.ones(ndims), jnp.ones(ndims))) ns = NestedSampler(log_likelihood, prior_transform, sampler_name='slice') def run_with_n(n): @jit def run(key): return ns(key=key, num_live_points=n, max_samples=1e6, collect_samples=False, termination_frac=0.01, stoachastic_uncertainty=False, sampler_kwargs=dict(depth=3, num_slices=2)) results = run(random.PRNGKey(0)) results.logZ.block_until_ready() t0 = default_timer() results = run(random.PRNGKey(1)) print("Efficiency and logZ", results.efficiency, results.logZ) run_time = (default_timer() - t0) return run_time return run_with_n(num_live_points)
def main(): def log_likelihood(theta, **kwargs): return (2. + jnp.prod(jnp.cos(0.5 * theta)))**5 prior_chain = PriorChain() \ .push(UniformPrior('theta', low=jnp.zeros(2), high=jnp.pi * 10. * jnp.ones(2))) U = vmap(lambda key: random.uniform(key, (prior_chain.U_ndims, )))( random.split(random.PRNGKey(0), 700)) theta = vmap(lambda u: prior_chain(u))(U) lik = vmap(lambda theta: log_likelihood(**theta))(theta) select = lik > 150. print("Selecting", jnp.sum(select), "need", 18 * 3) log_VS = jnp.log(jnp.sum(select) / select.size) print("V(S)", jnp.exp(log_VS)) U = U[select, :] with disable_jit(): cluster_id, ellipsoid_parameters = \ jit(lambda key, points, log_VS: ellipsoid_clustering(random.PRNGKey(0), points, 7, log_VS) )(random.PRNGKey(0), U, log_VS) mu, radii, rotation = ellipsoid_parameters theta = jnp.linspace(0., jnp.pi * 2, 100) x = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0) for i, (mu, radii, rotation) in enumerate(zip(mu, radii, rotation)): y = mu[:, None] + rotation @ jnp.diag(radii) @ x plt.plot(y[0, :], y[1, :]) mask = cluster_id == i plt.scatter(U[mask, 0], U[mask, 1], c=jnp.atleast_2d(plt.cm.jet(i / len(ellipsoid_parameters)))) plt.show()
def E_update(self, prior_mu, prior_Gamma, Y, Sigma, *control_params): # amp = control_params[0] key = control_params[1] prior_chain = PriorChain() \ .push(MVNPrior('param', prior_mu, prior_Gamma)) # .push(HalfLaplacePrior('uncert', jnp.sqrt(jnp.mean(jnp.diag(Sigma))))) def log_normal(x, mean, cov): dx = x - mean # L = jnp.linalg.cholesky(cov) # dx = solve_triangular(L, dx, lower=True) L = jnp.sqrt(jnp.diag(cov)) dx = dx / L return -0.5 * x.size * jnp.log(2. * jnp.pi) - jnp.sum(jnp.log(jnp.diag(L))) \ - 0.5 * dx @ dx def log_likelihood(param, **kwargs): Y_model = self.forward_model(param, *control_params) # Sigma = uncert**2 * jnp.eye(Y.shape[-1]) return log_normal(Y_model, Y, Sigma) ns = NestedSampler(log_likelihood, prior_chain, sampler_name='whitened_ellipsoid') results = ns(key, self._phase_basis_size * 15, max_samples=1e5, collect_samples=False, termination_frac=0.01, stoachastic_uncertainty=True) post_mu = results.param_mean['param'] post_Gamma = results.param_covariance['param'] return post_mu, post_Gamma
def build_frozen_flow_prior(X, kernel, tec_to_dtec, x0): v_dir = DeterministicTransformPrior('v_dir', lambda n: n / jnp.linalg.norm(n), (3,), MVNDiagPrior('n', jnp.zeros(3), jnp.ones(3), tracked=False), tracked=False) v_mag = UniformPrior('v_mag', 0., 0.5, tracked=False) v = DeterministicTransformPrior('v', lambda v_dir, v_mag: v_mag * v_dir, (3,), v_dir, v_mag, tracked=True) X_frozen_flow = DeterministicTransformPrior('X', lambda v: X[:, 0:6] - jnp.concatenate([v, jnp.zeros(3)]) * X[:, 6:7], X[:, 0:6].shape, v, tracked=False) K = GaussianProcessKernelPrior('K', TomographicKernel(x0, kernel, S_marg=20, S_gamma=10), X_frozen_flow, UniformPrior('height', 100., 300.), UniformPrior('width', 50., 150.), UniformPrior('l', 0., 20.), UniformPrior('sigma', 0., 2.), tracked=False) tec = MVNPrior('tec', jnp.zeros((X.shape[0],)), K, ill_cond=True, tracked=False) dtec = DeterministicTransformPrior('dtec', tec_to_dtec, tec.to_shape, tec, tracked=False) prior_chain = PriorChain() \ .push(dtec) \ .push(UniformPrior('uncert', 0., 5.)) return prior_chain
def test_unit_cube_mixture_prior(): import jax.numpy as jnp from jax import random from jaxns.nested_sampling import NestedSampler from jaxns.plotting import plot_cornerplot, plot_diagnostics # prior_chain = PriorChain().push(MultiCubeMixturePrior('x', 2, 1, -5., 15.)) prior_chain = PriorChain().push(GMMMarginalPrior('x', 2, -5., 15.)) def loglikelihood(x, **kwargs): return jnp.log( 0.5 * jnp.exp(-0.5 * jnp.sum(x)**2) / jnp.sqrt(2. * jnp.pi) + 0.5 * jnp.exp(-0.5 * jnp.sum(x - 10.)**2) / jnp.sqrt(2. * jnp.pi)) ns = NestedSampler(loglikelihood, prior_chain, sampler_name='ellipsoid') results = ns(random.PRNGKey(0), 100, max_samples=1e5, collect_samples=True, termination_frac=0.05, stoachastic_uncertainty=True) plot_diagnostics(results) plot_cornerplot(results)
def build_prior(nant, ndir): theta = MVNDiagPrior('theta', jnp.zeros(nant * ndir), jnp.ones(nant * ndir), tracked=True) gamma = MVNDiagPrior('gamma', jnp.zeros(ndir), 0. * jnp.ones(ndir), tracked=True) def vis(theta, gamma, **kwargs): theta = theta.reshape((nant, ndir)) diff = 1j * (theta[:, None, :] - theta) delta = jnp.mean(jnp.exp(-gamma + diff), axis=-1) return delta delta = DeterministicTransformPrior('delta', vis, (nant, nant), theta, gamma, tracked=False) prior = PriorChain().push(delta) return prior
def run(key): prior_transform = PriorChain().push( MVNDiagPrior('x', prior_mu, jnp.sqrt(jnp.diag(prior_cov)))) # prior_transform = LaplacePrior(prior_mu, jnp.sqrt(jnp.diag(prior_cov))) # prior_transform = UniformPrior(-20.*jnp.ones(ndims), 20.*jnp.ones(ndims)) def param_mean(x, **args): return x def param_covariance(x, **args): return jnp.outer(x, x) ns = NestedSampler(log_likelihood, prior_transform, sampler_name='slice', x_mean=param_mean, x_cov=param_covariance) return ns(key=key, num_live_points=n, max_samples=1e5, collect_samples=True, termination_frac=0.01, stoachastic_uncertainty=False, sampler_kwargs=dict(depth=3, num_slices=2))
def run_block(block_idx): def log_likelihood(bottom, width, lengthscale, sigma, **kwargs): return jnp.sum( vmap(lambda log_prob: lookup_func(log_prob, bottom, width, lengthscale, sigma))( log_prob[block_idx])) bottom = UniformPrior('bottom', bottom_array.min(), bottom_array.max()) width = DeltaPrior('width', 50., tracked=False) lengthscale = UniformPrior('lengthscale', jnp.min(lengthscale_array), jnp.max(lengthscale_array)) sigma = UniformPrior('sigma', sigma_array.min(), sigma_array.max()) prior_chain = PriorChain(lengthscale, sigma, bottom, width) ns = NestedSampler(loglikelihood=log_likelihood, prior_chain=prior_chain, sampler_name='slice', sampler_kwargs=dict(num_slices=prior_chain.U_ndims * 5), num_live_points=prior_chain.U_ndims * 50) ns = jit(ns) results = ns(random.PRNGKey(42), termination_frac=0.001) return results
def test_half_laplace(): p = PriorChain().push(HalfLaplacePrior('x', 1.)) U = jnp.linspace(0., 1., 100)[:, None] assert ~jnp.any(jnp.isnan(vmap(p)(U)['x']))
def test_prior_chain(): from jax import random chain = PriorChain() mu = MVNDiagPrior('mu', jnp.array([0., 0.]), 1.) gamma = jnp.array([1.]) X = MVNDiagPrior('x', mu, gamma) chain.push(mu).push(X) print(chain) U = random.uniform(random.PRNGKey(0), shape=(chain.U_ndims, )) y = chain(U) print(y) chain = PriorChain() mu = MVNDiagPrior('mu', jnp.array([0., 0.]), 1.) gamma = jnp.array([1.]) X = LaplacePrior('x', mu, gamma) chain.push(mu).push(X) print(chain) U = random.uniform(random.PRNGKey(0), shape=(chain.U_ndims, )) y = chain(U) print(y) chain = PriorChain() x0 = MVNDiagPrior('x0', jnp.array([0., 0.]), 1.) gamma = 1. X = DiagGaussianWalkPrior('W', 2, x0, gamma) chain.push(mu).push(X) print(chain) U = random.uniform(random.PRNGKey(0), shape=(chain.U_ndims, )) y = chain(U) print(y)
def main(kernel): print(("Working on Kernel: {}".format(kernel.__class__.__name__))) def log_normal(x, mean, cov): L = jnp.linalg.cholesky(cov) # U, S, Vh = jnp.linalg.svd(cov) log_det = jnp.sum(jnp.log(jnp.diag(L))) # jnp.sum(jnp.log(S))# dx = x - mean dx = solve_triangular(L, dx, lower=True) # U S Vh V 1/S Uh # pinv = (Vh.T.conj() * jnp.where(S!=0., jnp.reciprocal(S), 0.)) @ U.T.conj() maha = dx @ dx # dx @ pinv @ dx#solve_triangular(L, dx, lower=True) log_likelihood = -0.5 * x.size * jnp.log(2. * jnp.pi) \ - log_det \ - 0.5 * maha # print(log_likelihood) return log_likelihood N = 100 X = jnp.linspace(-2., 2., N)[:, None] true_sigma, true_l, true_uncert = 1., 0.2, 0.2 data_mu = jnp.zeros((N, )) prior_cov = RBF()(X, X, true_l, true_sigma) + 1e-13 * jnp.eye(N) # print(jnp.linalg.cholesky(prior_cov), jnp.linalg.eigvals(prior_cov)) # return Y = jnp.linalg.cholesky(prior_cov) @ random.normal(random.PRNGKey(0), shape=(N, )) + data_mu Y_obs = Y + true_uncert * random.normal(random.PRNGKey(1), shape=(N, )) Y_obs = jnp.where((jnp.arange(N) > 50) & (jnp.arange(N) < 60), random.normal(random.PRNGKey(1), shape=(N, )), Y_obs) # plt.scatter(X[:, 0], Y_obs, label='data') # plt.plot(X[:, 0], Y, label='underlying') # plt.legend() # plt.show() def log_likelihood(K, uncert, **kwargs): """ P(Y|sigma, half_width) = N[Y, mu, K] Args: sigma: l: Returns: """ data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) mu = jnp.zeros_like(Y_obs) return log_normal(Y_obs, mu, K + data_cov) def predict_f(K, uncert, **kwargs): data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) mu = jnp.zeros_like(Y_obs) return mu + K @ jnp.linalg.solve(K + data_cov, Y_obs) def predict_fvar(K, uncert, **kwargs): data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) mu = jnp.zeros_like(Y_obs) return jnp.diag(K - K @ jnp.linalg.solve(K + data_cov, K)) l = UniformPrior('l', 0., 2.) uncert = UniformPrior('uncert', 0., 2.) sigma = UniformPrior('sigma', 0., 2.) cov = GaussianProcessKernelPrior('K', kernel, X, l, sigma) prior_chain = PriorChain().push(uncert).push(cov) # print(prior_chain) ns = NestedSampler(log_likelihood, prior_chain, sampler_name='multi_ellipsoid', predict_f=predict_f, predict_fvar=predict_fvar) def run_with_n(n): @jit def run(key): return ns(key=key, num_live_points=n, max_samples=1e5, collect_samples=True, termination_frac=0.01, stoachastic_uncertainty=False, sampler_kwargs=dict(depth=3)) t0 = default_timer() # with disable_jit(): results = run(random.PRNGKey(6)) print(results.efficiency) print( "Time to execute (including compile): {}".format(default_timer() - t0)) t0 = default_timer() results = run(random.PRNGKey(6)) print(results.efficiency) print("Time to execute (not including compile): {}".format( (default_timer() - t0))) return results for n in [100]: results = run_with_n(n) plt.scatter(n, results.logZ) plt.errorbar(n, results.logZ, yerr=results.logZerr) plt.title("Kernel: {}".format(kernel.__class__.__name__)) plt.ylabel('log Z') plt.show() plt.scatter(X[:, 0], Y_obs, label='data') plt.plot(X[:, 0], Y, label='underlying') plt.plot(X[:, 0], results.marginalised['predict_f'], label='marginalised') plt.plot(X[:, 0], results.marginalised['predict_f'] + jnp.sqrt(results.marginalised['predict_fvar']), ls='dotted', c='black') plt.plot(X[:, 0], results.marginalised['predict_f'] - jnp.sqrt(results.marginalised['predict_fvar']), ls='dotted', c='black') plt.title("Kernel: {}".format(kernel.__class__.__name__)) plt.legend() plt.show() plot_diagnostics(results) plot_cornerplot(results) return results.logZ, results.logZerr
def main(): def log_normal(x, mean, uncert): dx = x - mean dx = dx / uncert return -0.5 * x.size * jnp.log( 2. * jnp.pi) - x.size * jnp.log(uncert) - 0.5 * dx @ dx N = 100 X = jnp.linspace(-2., 2., N)[:, None] true_alpha, true_sigma, true_l, true_uncert = 1., 1., 0.2, 0.25 data_mu = jnp.zeros((N, )) prior_cov = RBF()(X, X, true_l, true_sigma) Y = jnp.linalg.cholesky(prior_cov) @ random.normal(random.PRNGKey(0), shape=(N, )) + data_mu Y_obs = Y + true_uncert * random.normal(random.PRNGKey(1), shape=(N, )) def predict_f(sigma, K, uncert, **kwargs): data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) mu = jnp.zeros_like(Y_obs) return mu + K @ jnp.linalg.solve(K + data_cov, Y_obs) def predict_fvar(sigma, K, uncert, **kwargs): data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) mu = jnp.zeros_like(Y_obs) return jnp.diag(K - K @ jnp.linalg.solve(K + data_cov, K)) ### # define the prior chain # Here we assume each image is represented by pixels. # Alternatively, you could choose regions arranged non-uniformly over the image. image_shape = (128, 128) npix = image_shape[0] * image_shape[1] I150 = jnp.ones(image_shape) alpha_cw_gp_sigma = HalfLaplacePrior('alpha_cw_gp_sigma', 1.) alpha_mw_gp_sigma = HalfLaplacePrior('alpha_mw_gp_sigma', 1.) l_cw = UniformPrior('l_cw', 0., 0.5) #degrees l_mw = UniformPrior('l_mw', 0.5, 2.) #degrees K_cw = GaussianProcessKernelPrior('K_cw', RBF(), X, l_cw, alpha_cw_gp_sigma) K_mw = GaussianProcessKernelPrior('K_mw', RBF(), X, l_mw, alpha_mw_gp_sigma) alpha_cw = MVNPrior('alpha_cw', -1.5, K_cw) alpha_mw = MVNPrior('alpha_mw', -2.5, K_mw) S_cw_150 = UniformPrior('S150_cw', 0., I150) S_mw_150 = UniformPrior('S150_mw', 0., I150) uncert = HalfLaplacePrior('uncert', 1.) def log_likelihood(uncert, alpha_cw, alpha_mw, S_cw_150, S_mw_150): log_prob = 0 for img, freq in zip(images, freqs): # <- need to define these I_total = S_mw_150 * (freq / 150e6)**(alpha_mw) + S_cw_150 * ( freq / 150e6)**(alpha_cw) log_prob += log_normal(img, I_total, uncert) return log_prob prior_chain = PriorChain()\ .push(alpha_cw).push(S_cw_150)\ .push(alpha_mw).push(S_mw_150)\ .push(uncert) print(prior_chain) ns = NestedSampler(log_likelihood, prior_chain, sampler_name='ellipsoid', predict_f=predict_f, predict_fvar=predict_fvar) def run_with_n(n): @jit def run(): return ns(key=random.PRNGKey(0), num_live_points=n, max_samples=1e3, collect_samples=True, termination_frac=0.01, stoachastic_uncertainty=True) results = run() return results
def main(): def log_normal(x, mean, cov): L = jnp.linalg.cholesky(cov) dx = x - mean dx = solve_triangular(L, dx, lower=True) return -0.5 * x.size * jnp.log(2. * jnp.pi) - jnp.sum(jnp.log(jnp.diag(L))) \ - 0.5 * dx @ dx N = 100 X = jnp.linspace(-2., 2., N)[:, None] true_alpha, true_sigma, true_l, true_uncert = 1., 1., 0.2, 0.25 data_mu = jnp.zeros((N, )) prior_cov = RationalQuadratic()(X, X, true_l, true_sigma, true_alpha) Y = jnp.linalg.cholesky(prior_cov) @ random.normal(random.PRNGKey(0), shape=(N, )) + data_mu Y_obs = Y + true_uncert * random.normal(random.PRNGKey(1), shape=(N, )) # Y_obs = jnp.where((jnp.arange(N) > 50) & (jnp.arange(N) < 60), # random.normal(random.PRNGKey(1), shape_dict=(N, )), # Y_obs) # plt.scatter(X[:, 0], Y_obs, label='data') # plt.plot(X[:, 0], Y, label='underlying') # plt.legend() # plt.show() def log_likelihood(K, uncert, **kwargs): """ P(Y|sigma, half_width) = N[Y, mu, K] Args: sigma: l: Returns: """ data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) mu = jnp.zeros_like(Y_obs) log_prob = log_normal(Y_obs, mu, K + data_cov) # print(log_prob) return log_prob def predict_f(K, uncert, **kwargs): data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) mu = jnp.zeros_like(Y_obs) return mu + K @ jnp.linalg.solve(K + data_cov, Y_obs) def predict_fvar(K, uncert, **kwargs): data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) mu = jnp.zeros_like(Y_obs) return jnp.diag(K - K @ jnp.linalg.solve(K + data_cov, K)) prior_chain = PriorChain() \ .push(GaussianProcessKernelPrior('K', RationalQuadratic(), X, UniformPrior('l', 0., 4.), UniformPrior('sigma', 0., 4.), UniformPrior('alpha', 0., 4.))) \ .push(UniformPrior('uncert', 0., 2.)) ns = NestedSampler(log_likelihood, prior_chain, sampler_name='multi_ellipsoid', predict_f=predict_f, predict_fvar=predict_fvar) def run_with_n(n): @jit def run(): return ns(key=random.PRNGKey(0), num_live_points=n, max_samples=1e4, collect_samples=True, termination_frac=0.01, stoachastic_uncertainty=False, sampler_kwargs=dict(depth=4)) results = run() return results for n in [200]: results = run_with_n(n) plt.scatter(n, results.logZ) plt.errorbar(n, results.logZ, yerr=results.logZerr) plt.title("Kernel: {}".format(RationalQuadratic.__name__)) plt.ylabel('log Z') plt.show() plt.scatter(X[:, 0], Y_obs, label='data') plt.plot(X[:, 0], Y, label='underlying') plt.plot(X[:, 0], results.marginalised['predict_f'], label='marginalised') plt.plot(X[:, 0], results.marginalised['predict_f'] + jnp.sqrt(results.marginalised['predict_fvar']), ls='dotted', c='black') plt.plot(X[:, 0], results.marginalised['predict_f'] - jnp.sqrt(results.marginalised['predict_fvar']), ls='dotted', c='black') plt.title("Kernel: {}".format(RationalQuadratic.__name__)) plt.legend() plt.show() plot_diagnostics(results) plot_cornerplot(results) return results.logZ, results.logZerr
def test_generic_kmeans(): from jaxns.prior_transforms import PriorChain, UniformPrior from jax import vmap, disable_jit, jit import pylab as plt data = 'shells' if data == 'eggbox': def log_likelihood(theta, **kwargs): return (2. + jnp.prod(jnp.cos(0.5 * theta))) ** 5 prior_chain = PriorChain() \ .push(UniformPrior('theta', low=jnp.zeros(2), high=jnp.pi * 10. * jnp.ones(2))) U = vmap(lambda key: random.uniform(key, (prior_chain.U_ndims,)))(random.split(random.PRNGKey(0), 1000)) theta = vmap(lambda u: prior_chain(u))(U) lik = vmap(lambda theta: log_likelihood(**theta))(theta) select = lik > 100. if data == 'shells': def log_likelihood(theta, **kwargs): def log_circ(theta, c, r, w): return -0.5*(jnp.linalg.norm(theta - c) - r)**2/w**2 - jnp.log(jnp.sqrt(2*jnp.pi*w**2)) w1=w2=jnp.array(0.1) r1=r2=jnp.array(2.) c1 = jnp.array([0., -4.]) c2 = jnp.array([0., 4.]) return jnp.logaddexp(log_circ(theta, c1,r1,w1) , log_circ(theta,c2,r2,w2)) prior_chain = PriorChain() \ .push(UniformPrior('theta', low=-12.*jnp.ones(2), high=12.*jnp.ones(2))) U = vmap(lambda key: random.uniform(key, (prior_chain.U_ndims,)))(random.split(random.PRNGKey(0), 40000)) theta = vmap(lambda u: prior_chain(u))(U) lik = vmap(lambda theta: log_likelihood(**theta))(theta) select = lik > 1. print("Selecting", jnp.sum(select)) log_VS = jnp.log(jnp.sum(select)/select.size) print("V(S)",jnp.exp(log_VS)) points = U[select, :] sc = plt.scatter(U[:,0], U[:,1],c=jnp.exp(lik)) plt.colorbar(sc) plt.show() mask = jnp.ones(points.shape[0], dtype=jnp.bool_) K = 18 with disable_jit(): # state = generic_kmeans(random.PRNGKey(0), points, mask, method='ellipsoid',K=K,meta=dict(log_VS=log_VS)) # state = generic_kmeans(random.PRNGKey(0), points, mask, method='mahalanobis',K=K) # state = generic_kmeans(random.PRNGKey(0), points, mask, method='euclidean',K=K) # cluster_id, log_cluster_VS = hierarchical_clustering(random.PRNGKey(0), points, 7, log_VS) cluster_id, ellipsoid_parameters = \ jit(lambda key, points, log_VS: ellipsoid_clustering(random.PRNGKey(0), points, 7, log_VS) )(random.PRNGKey(0), points, log_VS) # mu, radii, rotation = ellipsoid_parameters K = int(jnp.max(cluster_id)+1) mu, C = vmap(lambda k: bounding_ellipsoid(points, cluster_id == k))(jnp.arange(K)) radii, rotation = vmap(ellipsoid_params)(C) theta = jnp.linspace(0., jnp.pi * 2, 100) x = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0) for i, (mu, radii, rotation) in enumerate(zip(mu, radii, rotation)): y = mu[:, None] + rotation @ jnp.diag(radii) @ x plt.plot(y[0, :], y[1, :], c=plt.cm.jet(i / K)) mask = cluster_id == i plt.scatter(points[mask, 0], points[mask, 1], c=jnp.atleast_2d(plt.cm.jet(i / K))) plt.xlim(-1,2) plt.ylim(-1,2) plt.show()
def main(): Sigma, T, Y_obs, amp, tec, freqs = generate_data() TEC_CONV = -8.4479745e6 # mTECU/Hz def log_mvnormal(x, mean, cov): L = jnp.linalg.cholesky(cov) dx = x - mean dx = solve_triangular(L, dx, lower=True) return -0.5 * x.size * jnp.log(2. * jnp.pi) - jnp.sum(jnp.log(jnp.diag(L))) \ - 0.5 * dx @ dx def log_normal(x, mean, uncert): dx = (x - mean)/uncert return -0.5 * x.size * jnp.log(2. * jnp.pi) - x.size * jnp.log(uncert) \ - 0.5 * dx @ dx def log_likelihood(tec, uncert, **kwargs): # tec = x[0] # [:, 0] # uncert = x[1] # [:, 1] # clock = x[2] * 1e-9 # uncert = 0.25#x[2] phase = tec * (TEC_CONV / freqs) # + clock *(jnp.pi*2)*freqs#+ clock Y = jnp.concatenate([jnp.cos(phase), jnp.sin(phase)], axis=-1) return jnp.sum(vmap(lambda Y, Y_obs: log_normal(Y, Y_obs, uncert))(Y, Y_obs)) # prior_transform = MVNDiagPrior(prior_mu, jnp.sqrt(jnp.diag(prior_cov))) # prior_transform = LaplacePrior(prior_mu, jnp.sqrt(jnp.diag(prior_cov))) prior_chain = PriorChain() \ .push(DiagGaussianWalkPrior('tec', T, LaplacePrior('tec0', 0., 100.), UniformPrior('omega', 1, 15))) \ .push(UniformPrior('uncert', 0.01, 0.5)) ns = NestedSampler(log_likelihood, prior_chain, sampler_name='slice', tec_mean=lambda tec,**kwargs: tec) @jit def run(key): return ns(key=key, num_live_points=500, max_samples=1e5, collect_samples=True, termination_frac=0.01, stoachastic_uncertainty=False, sampler_kwargs=dict(depth=7)) # with disable_jit(): t0 = default_timer() results = run(random.PRNGKey(0)) print("Time with compile efficiency normalised", results.efficiency * (default_timer() - t0)) print("Time with compile", default_timer() - t0) t0 = default_timer() results = run(random.PRNGKey(1)) print("Time no compile efficiency normalised", results.efficiency * (default_timer() - t0)) print("Time no compile", default_timer() - t0) plt.plot(tec) plt.plot(results.marginalised['tec_mean']) plt.show() plt.plot(results.marginalised['tec_mean'][:,0]-tec) plt.show() ### plot_diagnostics(results)
def unconstrained_solve(freqs, key, phase_obs, phase_outliers): key1, key2, key3, key4 = random.split(key, 4) Nt, Nf = phase_obs.shape assert Nt == 2, "Observations should be consequentive pairs of 2" tec0_array = jnp.linspace(-300., 300., 30) dtec_array = jnp.linspace(30., 30., 30) const_array = jnp.linspace(-jnp.pi, jnp.pi, 10) uncert0_array = jnp.linspace(0., 1., 10) uncert1_array = jnp.linspace(0., 1., 10) def log_likelihood(tec0, dtec, const, uncert0, uncert1, **kwargs): tec = jnp.asarray([tec0, tec0 + dtec]) t = freqs - jnp.min(freqs) t /= t[-1] uncert = uncert0 + (uncert1 - uncert0) * t phase = tec[:, None] * (TEC_CONV / freqs) + const # 2,Nf logL = jnp.sum( jnp.where( phase_outliers, 0., log_normal(wrap(wrap(phase) - wrap(phase_obs)), 0., uncert))) return logL # X = make_coord_array(tec0_array[:, None], dtec_array[:, None], const_array[:, None], uncert0_array[:, None], uncert1_array[:, None], flat=True) # log_prob_array = vmap(lambda x: log_likelihood(x[0], x[1], x[2], x[3], x[4]))(X) # log_prob_array = log_prob_array.reshape((tec0_array.size, dtec_array.size, const_array.size, uncert0_array.size, uncert1_array.size)) # # lookup_func = build_lookup_index(tec0_array, dtec_array, const_array, uncert0_array, uncert1_array) # # def efficient_log_likelihood(tec0, dtec, const, uncert0, uncert1, **kwargs): # b = 0.5 # log_prob_uncert0 = - uncert0 / b - jnp.log(b) # log_prob_uncert1 = - uncert1 / b - jnp.log(b) # return lookup_func(log_prob_array, tec0, dtec, const, uncert0, uncert1) + log_prob_uncert0 + log_prob_uncert1 tec0 = UniformPrior('tec0', tec0_array.min(), tec0_array.max()) # 30mTECU/30seconds is the maximum change dtec = UniformPrior('dtec', dtec_array.min(), dtec_array.max()) const = UniformPrior('const', const_array.min(), const_array.max()) uncert0 = UniformPrior('uncert0', uncert0_array.min(), uncert0_array.max()) uncert1 = UniformPrior('uncert1', uncert1_array.min(), uncert1_array.max()) prior_chain = PriorChain(tec0, dtec, const, uncert0, uncert1) ns = NestedSampler(log_likelihood, prior_chain, sampler_name='slice', num_live_points=20 * prior_chain.U_ndims, sampler_kwargs=dict(num_slices=prior_chain.U_ndims * 4)) results = ns(key=key1, termination_evidence_frac=0.3) ESS = 900 # emperically estimated for this problem def marginalisation(tec0, dtec, const, uncert0, uncert1, **kwargs): tec = jnp.asarray([tec0, tec0 + dtec]) return tec, tec**2, jnp.cos(const), jnp.sin(const), 0.5 * (uncert0 + uncert1) tec_mean, tec2_mean, const_real, const_imag, uncert_mean = marginalise_static( key2, results.samples, results.log_p, ESS, marginalisation) tec_std = jnp.sqrt(tec2_mean - tec_mean**2) const_mean = jnp.arctan2(const_imag, const_real) def marginalisation(const, **kwargs): return wrap(wrap(const) - wrap(const_mean))**2 const_var = marginalise_static(key2, results.samples, results.log_p, ESS, marginalisation) const_std = jnp.sqrt(const_var) return tec_mean, tec_std, const_mean * jnp.ones(Nt), const_std * jnp.ones( Nt), uncert_mean * jnp.ones(Nt)