def main(): Y_obs, amp, tec, freqs = generate_data() TEC_CONV = -8.4479745e6 # mTECU/Hz def log_normal(x, mean, scale): dx = (x - mean) / scale return -0.5 * x.size * jnp.log(2. * jnp.pi) - x.size*jnp.log(scale) \ - 0.5 * dx @ dx def log_laplace(x, mean, scale): dx = jnp.abs(x - mean) / scale return -x.size * jnp.log(2. * scale) - jnp.sum(dx) def log_likelihood(tec, const, uncert, **kwargs): phase = tec * (TEC_CONV / freqs) + const Y = jnp.concatenate([amp * jnp.cos(phase), amp * jnp.sin(phase)], axis=-1) log_prob = log_laplace(Y, Y_obs, uncert[0]) return log_prob prior_chain = PriorChain() \ .push(UniformPrior('tec', -100., 100.)) \ .push(UniformPrior('const', -jnp.pi, jnp.pi)) \ .push(HalfLaplacePrior('uncert', 0.25)) print("Probabilistic model:\n{}".format(prior_chain)) ns = NestedSampler( log_likelihood, prior_chain, sampler_name='slice', tec_mean=lambda tec, **kw: tec, #I would like to this function over the posterior const_mean=lambda const, **kw: const #I would like to this function over the posterior ) run = jit(lambda key: ns(key=key, num_live_points=1000, max_samples=1e5, collect_samples=True, termination_frac=0.01, stoachastic_uncertainty=False, sampler_kwargs=dict(depth=4, num_slices=1))) t0 = default_timer() results = run(random.PRNGKey(2364)) print(results.efficiency) print("Time compile", default_timer() - t0) t0 = default_timer() results = run(random.PRNGKey(1324)) print(results.efficiency) print("Time no compile", default_timer() - t0) ### print(results.marginalised['tec_mean']) print(results.marginalised['const_mean']) plot_diagnostics(results) plot_cornerplot(results)
def main(): ndims = 4 sigma = 0.1 def log_likelihood(theta, **kwargs): r2 = jnp.sum(theta**2) logL = -0.5 * jnp.log(2. * jnp.pi * sigma**2) * ndims logL += -0.5 * r2 / sigma**2 return logL prior_transform = PriorChain().push( UniformPrior('theta', -jnp.ones(ndims), jnp.ones(ndims))) ns = NestedSampler(log_likelihood, prior_transform, sampler_name='slice') def run_with_n(n): @jit def run(key): return ns(key=key, num_live_points=n, max_samples=1e5, collect_samples=True, termination_frac=0.01, stoachastic_uncertainty=False, sampler_kwargs=dict(depth=3)) t0 = default_timer() results = run(random.PRNGKey(0)) print(results.efficiency) print("Time to run including compile:", default_timer() - t0) print("Time efficiency normalised:", results.efficiency * (default_timer() - t0)) t0 = default_timer() results = run(random.PRNGKey(1)) print(results.efficiency) print("Time to run no compile:", default_timer() - t0) print("Time efficiency normalised:", results.efficiency * (default_timer() - t0)) return results for n in [1000]: results = run_with_n(n) plt.scatter(n, results.logZ) plt.errorbar(n, results.logZ, yerr=results.logZerr) plt.show() # plot_samples_development(results, save_name='./example.mp4') plot_diagnostics(results) plot_cornerplot(results)
def main(): def log_likelihood(theta, **kwargs): return 5. * (2. + jnp.prod(jnp.cos(0.5 * theta))) prior_chain = PriorChain() \ .push(UniformPrior('theta', low=jnp.zeros(2), high=jnp.pi * 10. * jnp.ones(2))) theta = vmap( lambda key: prior_chain(random.uniform(key, (prior_chain.U_ndims, ))))( random.split(random.PRNGKey(0), 10000)) lik = vmap(lambda theta: log_likelihood(**theta))(theta) sc = plt.scatter(theta['theta'][:, 0], theta['theta'][:, 1], c=lik) plt.colorbar(sc) plt.show() ns = NestedSampler(log_likelihood, prior_chain, sampler_name='slice') def run_with_n(n): @jit def run(key): return ns(key=key, num_live_points=n, max_samples=1e5, collect_samples=True, termination_frac=0.01, stoachastic_uncertainty=False, sampler_kwargs=dict(depth=7)) t0 = default_timer() # with disable_jit(): results = run(random.PRNGKey(0)) print("Efficiency", results.efficiency) print("Time to run (including compile)", default_timer() - t0) t0 = default_timer() results = run(random.PRNGKey(1)) print(results.efficiency) print("Time to run (no compile)", default_timer() - t0) return results for n in [500]: results = run_with_n(n) plt.scatter(n, results.logZ) plt.errorbar(n, results.logZ, yerr=results.logZerr) plt.ylabel('log Z') plt.show() plot_diagnostics(results) plot_cornerplot(results) return results.logZ, results.logZerr
def run_jaxns(num_live_points): try: from jaxns.nested_sampling import NestedSampler from jaxns.prior_transforms import PriorChain, UniformPrior except: raise ImportError("Install JaxNS!") from timeit import default_timer from jax import random, jit import jax.numpy as jnp def log_likelihood(theta, **kwargs): r2 = jnp.sum(theta ** 2) logL = -0.5 * jnp.log(2. * jnp.pi * sigma ** 2) * ndims logL += -0.5 * r2 / sigma ** 2 return logL prior_transform = PriorChain().push(UniformPrior('theta', -jnp.ones(ndims), jnp.ones(ndims))) ns = NestedSampler(log_likelihood, prior_transform, sampler_name='slice') def run_with_n(n): @jit def run(key): return ns(key=key, num_live_points=n, max_samples=1e6, collect_samples=False, termination_frac=0.01, stoachastic_uncertainty=False, sampler_kwargs=dict(depth=3, num_slices=2)) results = run(random.PRNGKey(0)) results.logZ.block_until_ready() t0 = default_timer() results = run(random.PRNGKey(1)) print("Efficiency and logZ", results.efficiency, results.logZ) run_time = (default_timer() - t0) return run_time return run_with_n(num_live_points)
def E_update(self, prior_mu, prior_Gamma, Y, Sigma, *control_params): # amp = control_params[0] key = control_params[1] prior_chain = PriorChain() \ .push(MVNPrior('param', prior_mu, prior_Gamma)) # .push(HalfLaplacePrior('uncert', jnp.sqrt(jnp.mean(jnp.diag(Sigma))))) def log_normal(x, mean, cov): dx = x - mean # L = jnp.linalg.cholesky(cov) # dx = solve_triangular(L, dx, lower=True) L = jnp.sqrt(jnp.diag(cov)) dx = dx / L return -0.5 * x.size * jnp.log(2. * jnp.pi) - jnp.sum(jnp.log(jnp.diag(L))) \ - 0.5 * dx @ dx def log_likelihood(param, **kwargs): Y_model = self.forward_model(param, *control_params) # Sigma = uncert**2 * jnp.eye(Y.shape[-1]) return log_normal(Y_model, Y, Sigma) ns = NestedSampler(log_likelihood, prior_chain, sampler_name='whitened_ellipsoid') results = ns(key, self._phase_basis_size * 15, max_samples=1e5, collect_samples=False, termination_frac=0.01, stoachastic_uncertainty=True) post_mu = results.param_mean['param'] post_Gamma = results.param_covariance['param'] return post_mu, post_Gamma
def test_unit_cube_mixture_prior(): import jax.numpy as jnp from jax import random from jaxns.nested_sampling import NestedSampler from jaxns.plotting import plot_cornerplot, plot_diagnostics # prior_chain = PriorChain().push(MultiCubeMixturePrior('x', 2, 1, -5., 15.)) prior_chain = PriorChain().push(GMMMarginalPrior('x', 2, -5., 15.)) def loglikelihood(x, **kwargs): return jnp.log( 0.5 * jnp.exp(-0.5 * jnp.sum(x)**2) / jnp.sqrt(2. * jnp.pi) + 0.5 * jnp.exp(-0.5 * jnp.sum(x - 10.)**2) / jnp.sqrt(2. * jnp.pi)) ns = NestedSampler(loglikelihood, prior_chain, sampler_name='ellipsoid') results = ns(random.PRNGKey(0), 100, max_samples=1e5, collect_samples=True, termination_frac=0.05, stoachastic_uncertainty=True) plot_diagnostics(results) plot_cornerplot(results)
def run(key): prior_transform = PriorChain().push( MVNDiagPrior('x', prior_mu, jnp.sqrt(jnp.diag(prior_cov)))) # prior_transform = LaplacePrior(prior_mu, jnp.sqrt(jnp.diag(prior_cov))) # prior_transform = UniformPrior(-20.*jnp.ones(ndims), 20.*jnp.ones(ndims)) def param_mean(x, **args): return x def param_covariance(x, **args): return jnp.outer(x, x) ns = NestedSampler(log_likelihood, prior_transform, sampler_name='slice', x_mean=param_mean, x_cov=param_covariance) return ns(key=key, num_live_points=n, max_samples=1e5, collect_samples=True, termination_frac=0.01, stoachastic_uncertainty=False, sampler_kwargs=dict(depth=3, num_slices=2))
def main(): Sigma, T, Y_obs, amp, tec, freqs = generate_data() TEC_CONV = -8.4479745e6 # mTECU/Hz def log_mvnormal(x, mean, cov): L = jnp.linalg.cholesky(cov) dx = x - mean dx = solve_triangular(L, dx, lower=True) return -0.5 * x.size * jnp.log(2. * jnp.pi) - jnp.sum(jnp.log(jnp.diag(L))) \ - 0.5 * dx @ dx def log_normal(x, mean, uncert): dx = (x - mean)/uncert return -0.5 * x.size * jnp.log(2. * jnp.pi) - x.size * jnp.log(uncert) \ - 0.5 * dx @ dx def log_likelihood(tec, uncert, **kwargs): # tec = x[0] # [:, 0] # uncert = x[1] # [:, 1] # clock = x[2] * 1e-9 # uncert = 0.25#x[2] phase = tec * (TEC_CONV / freqs) # + clock *(jnp.pi*2)*freqs#+ clock Y = jnp.concatenate([jnp.cos(phase), jnp.sin(phase)], axis=-1) return jnp.sum(vmap(lambda Y, Y_obs: log_normal(Y, Y_obs, uncert))(Y, Y_obs)) # prior_transform = MVNDiagPrior(prior_mu, jnp.sqrt(jnp.diag(prior_cov))) # prior_transform = LaplacePrior(prior_mu, jnp.sqrt(jnp.diag(prior_cov))) prior_chain = PriorChain() \ .push(DiagGaussianWalkPrior('tec', T, LaplacePrior('tec0', 0., 100.), UniformPrior('omega', 1, 15))) \ .push(UniformPrior('uncert', 0.01, 0.5)) ns = NestedSampler(log_likelihood, prior_chain, sampler_name='slice', tec_mean=lambda tec,**kwargs: tec) @jit def run(key): return ns(key=key, num_live_points=500, max_samples=1e5, collect_samples=True, termination_frac=0.01, stoachastic_uncertainty=False, sampler_kwargs=dict(depth=7)) # with disable_jit(): t0 = default_timer() results = run(random.PRNGKey(0)) print("Time with compile efficiency normalised", results.efficiency * (default_timer() - t0)) print("Time with compile", default_timer() - t0) t0 = default_timer() results = run(random.PRNGKey(1)) print("Time no compile efficiency normalised", results.efficiency * (default_timer() - t0)) print("Time no compile", default_timer() - t0) plt.plot(tec) plt.plot(results.marginalised['tec_mean']) plt.show() plt.plot(results.marginalised['tec_mean'][:,0]-tec) plt.show() ### plot_diagnostics(results)
def main(): nant, ndir = 5, 20 uncert = 0.1 theta_true, gamma_true, y, y_obs = fake_vis(nant, ndir, uncert) def log_likelihood(delta, **kwargs): dy = delta - y_obs r2 = jnp.sum(jnp.real(dy)**2) / uncert**2 r2 = r2 + jnp.sum(jnp.imag(dy)**2) / uncert**2 logL = -0.5 * r2 - jnp.log(2. * jnp.pi * uncert**2) * dy.size return logL prior_transform = build_prior(nant, ndir) ### MAP with BFGS def constrain(U): return 0.05 + sigmoid(U) * 0.9 def loss(U): U = constrain(U) return -log_likelihood(**prior_transform(U)) print(loss(jnp.zeros(prior_transform.U_ndims))) @jit def do_minimisation(): results = minimize(loss, jnp.zeros(prior_transform.U_ndims), method='BFGS', options=dict(gtol=1e-10, line_search_maxiter=200)) print(results.message) return prior_transform(constrain(results.x)), constrain( results.x), results.status results = do_minimisation() print('Status', results[2]) print(results) plt.scatter(jnp.arange(nant * ndir), results[0]['theta'], label='inferred') plt.scatter(jnp.arange(nant * ndir), theta_true, label='true') plt.legend() plt.show() plt.scatter(jnp.arange(ndir), results[0]['gamma'], label='inferred') plt.scatter(jnp.arange(ndir), gamma_true, label='true') plt.legend() plt.show() return ns = NestedSampler(log_likelihood, prior_transform, sampler_name='multi_ellipsoid') def run_with_n(n): @jit def run(key): return ns(key=key, num_live_points=n, max_samples=1e5, collect_samples=False, termination_frac=0.01, stoachastic_uncertainty=False, sampler_kwargs=dict(depth=3)) t0 = default_timer() results = run(random.PRNGKey(0)) print(results.efficiency) print("Time to run including compile:", default_timer() - t0) print("Time efficiency normalised:", results.efficiency * (default_timer() - t0)) return results for n in [100]: results = run_with_n(n) plt.scatter(n, results.logZ) plt.errorbar(n, results.logZ, yerr=results.logZerr) plt.show() # plot_diagnostics(results) plt.errorbar(jnp.arange(nant * ndir), results.param_mean['theta'], yerr=jnp.sqrt(jnp.diag(results.param_covariance['theta'])), label='inferred') plt.scatter(jnp.arange(nant * ndir), theta_true, label='true') plt.legend() plt.show() plt.errorbar(jnp.arange(ndir), results.param_mean['gamma'], yerr=jnp.sqrt(jnp.diag(results.param_covariance['gamma'])), label='inferred') plt.scatter(jnp.arange(ndir), gamma_true, label='true') plt.legend() plt.show()
def main(kernel): def log_normal(x, mean, cov): L = jnp.linalg.cholesky(cov) dx = x - mean dx = solve_triangular(L, dx, lower=True) # maha = dx @ jnp.linalg.solve(cov, dx) maha = dx @ dx # logdet = jnp.log(jnp.linalg.det(cov)) logdet = jnp.sum(jnp.log(jnp.diag(L))) log_prob = -0.5 * x.size * jnp.log(2. * jnp.pi) - logdet - 0.5 * maha return log_prob true_height, true_width, true_sigma, true_l, true_uncert = 200., 100., 1., 10., 2.5 nant = 5 ndir = 5 X, Y, Y_obs = rbf_dtec(nant, ndir, true_height, true_width, true_sigma, true_l, true_uncert) a = X[:, 0:3] k = X[:, 3:6] x0 = a[0, :] def log_likelihood(dtec, uncert, **kwargs): """ P(Y|sigma, half_width) = N[Y, mu, K] Args: sigma: l: Returns: """ data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) return log_normal(Y_obs, dtec, data_cov) def predict_f(dtec, uncert, **kwargs): return dtec def predict_fvar(dtec, uncert, **kwargs): return dtec ** 2 def tec_to_dtec(tec): tec = tec.reshape((nant, ndir)) dtec = jnp.reshape(tec - tec[0, :], (-1,)) return dtec prior_chain = build_prior(X, kernel, tec_to_dtec, x0) print(prior_chain) U_test = jnp.array([random.uniform(key, shape=(prior_chain.U_ndims,)) for key in random.split(random.PRNGKey(4325),1000)]) log_lik = jnp.array([log_likelihood(**prior_chain(U)) for U in U_test]) print(jnp.sum(jnp.isnan(log_lik))) print(U_test[jnp.isnan(log_lik)]) ns = NestedSampler(log_likelihood, prior_chain, sampler_name='slice', predict_f=predict_f, predict_fvar=predict_fvar) def run_with_n(n): @jit def run(key): return ns(key=key, num_live_points=n, max_samples=1e5, collect_samples=True, termination_frac=0.01, stoachastic_uncertainty=False, sampler_kwargs=dict(depth=7, num_slices=1)) t0 = default_timer() results = run(random.PRNGKey(0)) print("Efficiency", results.efficiency) print("Time to run (including compile)", default_timer() - t0) t0 = default_timer() results = run(random.PRNGKey(1)) print("Efficiency",results.efficiency) print("Time to run (no compile)", default_timer() - t0) print("Time efficiency normalised", (default_timer() - t0)*results.efficiency) return results for n in [1000]: results = run_with_n(n) plt.scatter(n, results.logZ) plt.errorbar(n, results.logZ, yerr=results.logZerr) # # # K = GaussianProcessKernelPrior('K', # TomographicKernel(x0, kernel, S=20), X, # MVNPrior('height', results.param_mean['height'], results.param_covariance['height']),#UniformPrior('height', 100., 300.), # MVNPrior('width', results.param_mean['width'], results.param_covariance['width']),#UniformPrior('width', 50., 150.), # MVNPrior('l', results.param_mean['l'], results.param_covariance['l']),#UniformPrior('l', 7., 20.), # MVNPrior('sigma', results.param_mean['sigma'], results.param_covariance['sigma']),#UniformPrior('sigma', 0.3, 2.), # tracked=False) # tec = MVNPrior('tec', jnp.zeros((X.shape[0],)), K, ill_cond=True, tracked=False) # dtec = DeterministicTransformPrior('dtec', tec_to_dtec, tec.to_shape, tec, tracked=False) # prior_chain = PriorChain() \ # .push(dtec) \ # .push(UniformPrior('uncert', 0., 5.)) # # ns = NestedSampler(log_likelihood, prior_chain, sampler_name='multi_ellipsoid', predict_f=predict_f, # predict_fvar=predict_fvar) # # def run_with_n(n): # @jit # def run(key): # return ns(key=key, # num_live_points=n, # max_samples=1e5, # collect_samples=True, # termination_frac=0.01, # stoachastic_uncertainty=False, # sampler_kwargs=dict(depth=4)) # # t0 = default_timer() # results = run(random.PRNGKey(0)) # print("Efficiency", results.efficiency) # print("Time to run (including compile)", default_timer() - t0) # t0 = default_timer() # results = run(random.PRNGKey(1)) # print(results.efficiency) # print("Time to run (no compile)", default_timer() - t0) # return results # # for n in [100]: # results = run_with_n(n) # plt.scatter(n, results.logZ) # plt.errorbar(n, results.logZ, yerr=results.logZerr) plt.title("Kernel: {}".format(kernel.__class__.__name__)) plt.ylabel('log Z') plt.show() fstd = jnp.sqrt(results.marginalised['predict_fvar'] - results.marginalised['predict_f'] ** 2) plt.scatter(jnp.arange(Y.size),Y_obs, marker='+', label='data') plt.scatter(jnp.arange(Y.size),Y, marker="o", label='underlying') plt.scatter(jnp.arange(Y.size), results.marginalised['predict_f'], marker=".", label='underlying') plt.errorbar(jnp.arange(Y.size), results.marginalised['predict_f'], yerr=fstd, label='marginalised') plt.title("Kernel: {}".format(kernel.__class__.__name__)) plt.legend() plt.show() # plot_samples_development(results,save_name='./ray_integral_solution.mp4') plot_diagnostics(results) plot_cornerplot(results) return results.logZ, results.logZerr
def main(kernel): print(("Working on Kernel: {}".format(kernel.__class__.__name__))) def log_normal(x, mean, cov): L = jnp.linalg.cholesky(cov) # U, S, Vh = jnp.linalg.svd(cov) log_det = jnp.sum(jnp.log(jnp.diag(L))) # jnp.sum(jnp.log(S))# dx = x - mean dx = solve_triangular(L, dx, lower=True) # U S Vh V 1/S Uh # pinv = (Vh.T.conj() * jnp.where(S!=0., jnp.reciprocal(S), 0.)) @ U.T.conj() maha = dx @ dx # dx @ pinv @ dx#solve_triangular(L, dx, lower=True) log_likelihood = -0.5 * x.size * jnp.log(2. * jnp.pi) \ - log_det \ - 0.5 * maha # print(log_likelihood) return log_likelihood N = 100 X = jnp.linspace(-2., 2., N)[:, None] true_sigma, true_l, true_uncert = 1., 0.2, 0.2 data_mu = jnp.zeros((N, )) prior_cov = RBF()(X, X, true_l, true_sigma) + 1e-13 * jnp.eye(N) # print(jnp.linalg.cholesky(prior_cov), jnp.linalg.eigvals(prior_cov)) # return Y = jnp.linalg.cholesky(prior_cov) @ random.normal(random.PRNGKey(0), shape=(N, )) + data_mu Y_obs = Y + true_uncert * random.normal(random.PRNGKey(1), shape=(N, )) Y_obs = jnp.where((jnp.arange(N) > 50) & (jnp.arange(N) < 60), random.normal(random.PRNGKey(1), shape=(N, )), Y_obs) # plt.scatter(X[:, 0], Y_obs, label='data') # plt.plot(X[:, 0], Y, label='underlying') # plt.legend() # plt.show() def log_likelihood(K, uncert, **kwargs): """ P(Y|sigma, half_width) = N[Y, mu, K] Args: sigma: l: Returns: """ data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) mu = jnp.zeros_like(Y_obs) return log_normal(Y_obs, mu, K + data_cov) def predict_f(K, uncert, **kwargs): data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) mu = jnp.zeros_like(Y_obs) return mu + K @ jnp.linalg.solve(K + data_cov, Y_obs) def predict_fvar(K, uncert, **kwargs): data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) mu = jnp.zeros_like(Y_obs) return jnp.diag(K - K @ jnp.linalg.solve(K + data_cov, K)) l = UniformPrior('l', 0., 2.) uncert = UniformPrior('uncert', 0., 2.) sigma = UniformPrior('sigma', 0., 2.) cov = GaussianProcessKernelPrior('K', kernel, X, l, sigma) prior_chain = PriorChain().push(uncert).push(cov) # print(prior_chain) ns = NestedSampler(log_likelihood, prior_chain, sampler_name='multi_ellipsoid', predict_f=predict_f, predict_fvar=predict_fvar) def run_with_n(n): @jit def run(key): return ns(key=key, num_live_points=n, max_samples=1e5, collect_samples=True, termination_frac=0.01, stoachastic_uncertainty=False, sampler_kwargs=dict(depth=3)) t0 = default_timer() # with disable_jit(): results = run(random.PRNGKey(6)) print(results.efficiency) print( "Time to execute (including compile): {}".format(default_timer() - t0)) t0 = default_timer() results = run(random.PRNGKey(6)) print(results.efficiency) print("Time to execute (not including compile): {}".format( (default_timer() - t0))) return results for n in [100]: results = run_with_n(n) plt.scatter(n, results.logZ) plt.errorbar(n, results.logZ, yerr=results.logZerr) plt.title("Kernel: {}".format(kernel.__class__.__name__)) plt.ylabel('log Z') plt.show() plt.scatter(X[:, 0], Y_obs, label='data') plt.plot(X[:, 0], Y, label='underlying') plt.plot(X[:, 0], results.marginalised['predict_f'], label='marginalised') plt.plot(X[:, 0], results.marginalised['predict_f'] + jnp.sqrt(results.marginalised['predict_fvar']), ls='dotted', c='black') plt.plot(X[:, 0], results.marginalised['predict_f'] - jnp.sqrt(results.marginalised['predict_fvar']), ls='dotted', c='black') plt.title("Kernel: {}".format(kernel.__class__.__name__)) plt.legend() plt.show() plot_diagnostics(results) plot_cornerplot(results) return results.logZ, results.logZerr
def unconstrained_solve(freqs, key, phase_obs, phase_outliers): key1, key2, key3, key4 = random.split(key, 4) Nt, Nf = phase_obs.shape assert Nt == 2, "Observations should be consequentive pairs of 2" tec0_array = jnp.linspace(-300., 300., 30) dtec_array = jnp.linspace(30., 30., 30) const_array = jnp.linspace(-jnp.pi, jnp.pi, 10) uncert0_array = jnp.linspace(0., 1., 10) uncert1_array = jnp.linspace(0., 1., 10) def log_likelihood(tec0, dtec, const, uncert0, uncert1, **kwargs): tec = jnp.asarray([tec0, tec0 + dtec]) t = freqs - jnp.min(freqs) t /= t[-1] uncert = uncert0 + (uncert1 - uncert0) * t phase = tec[:, None] * (TEC_CONV / freqs) + const # 2,Nf logL = jnp.sum( jnp.where( phase_outliers, 0., log_normal(wrap(wrap(phase) - wrap(phase_obs)), 0., uncert))) return logL # X = make_coord_array(tec0_array[:, None], dtec_array[:, None], const_array[:, None], uncert0_array[:, None], uncert1_array[:, None], flat=True) # log_prob_array = vmap(lambda x: log_likelihood(x[0], x[1], x[2], x[3], x[4]))(X) # log_prob_array = log_prob_array.reshape((tec0_array.size, dtec_array.size, const_array.size, uncert0_array.size, uncert1_array.size)) # # lookup_func = build_lookup_index(tec0_array, dtec_array, const_array, uncert0_array, uncert1_array) # # def efficient_log_likelihood(tec0, dtec, const, uncert0, uncert1, **kwargs): # b = 0.5 # log_prob_uncert0 = - uncert0 / b - jnp.log(b) # log_prob_uncert1 = - uncert1 / b - jnp.log(b) # return lookup_func(log_prob_array, tec0, dtec, const, uncert0, uncert1) + log_prob_uncert0 + log_prob_uncert1 tec0 = UniformPrior('tec0', tec0_array.min(), tec0_array.max()) # 30mTECU/30seconds is the maximum change dtec = UniformPrior('dtec', dtec_array.min(), dtec_array.max()) const = UniformPrior('const', const_array.min(), const_array.max()) uncert0 = UniformPrior('uncert0', uncert0_array.min(), uncert0_array.max()) uncert1 = UniformPrior('uncert1', uncert1_array.min(), uncert1_array.max()) prior_chain = PriorChain(tec0, dtec, const, uncert0, uncert1) ns = NestedSampler(log_likelihood, prior_chain, sampler_name='slice', num_live_points=20 * prior_chain.U_ndims, sampler_kwargs=dict(num_slices=prior_chain.U_ndims * 4)) results = ns(key=key1, termination_evidence_frac=0.3) ESS = 900 # emperically estimated for this problem def marginalisation(tec0, dtec, const, uncert0, uncert1, **kwargs): tec = jnp.asarray([tec0, tec0 + dtec]) return tec, tec**2, jnp.cos(const), jnp.sin(const), 0.5 * (uncert0 + uncert1) tec_mean, tec2_mean, const_real, const_imag, uncert_mean = marginalise_static( key2, results.samples, results.log_p, ESS, marginalisation) tec_std = jnp.sqrt(tec2_mean - tec_mean**2) const_mean = jnp.arctan2(const_imag, const_real) def marginalisation(const, **kwargs): return wrap(wrap(const) - wrap(const_mean))**2 const_var = marginalise_static(key2, results.samples, results.log_p, ESS, marginalisation) const_std = jnp.sqrt(const_var) return tec_mean, tec_std, const_mean * jnp.ones(Nt), const_std * jnp.ones( Nt), uncert_mean * jnp.ones(Nt)
def main(kernel): def log_normal(x, mean, cov): L = jnp.linalg.cholesky(cov) dx = x - mean dx = solve_triangular(L, dx, lower=True) # maha = dx @ jnp.linalg.solve(cov, dx) maha = dx @ dx # logdet = jnp.log(jnp.linalg.det(cov)) logdet = jnp.sum(jnp.log(jnp.diag(L))) log_prob = -0.5 * x.size * jnp.log(2. * jnp.pi) - logdet - 0.5 * maha return log_prob true_height, true_width, true_sigma, true_l, true_uncert, true_v = 200., 100., 1., 10., 2.5, jnp.array( [0.3, 0., 0.]) nant = 2 ndir = 1 ntime = 20 X, Y, Y_obs = rbf_dtec(nant, ndir, ntime, true_height, true_width, true_sigma, true_l, true_uncert, true_v) a = X[:, 0:3] k = X[:, 3:6] t = X[:, 6:7] x0 = a[0, :] def log_likelihood(dtec, uncert, **kwargs): """ P(Y|sigma, half_width) = N[Y, mu, K] Args: sigma: l: Returns: """ data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) return log_normal(Y_obs, dtec, data_cov) def predict_f(dtec, **kwargs): return dtec def predict_fvar(dtec, **kwargs): return dtec**2 def tec_to_dtec(tec): tec = tec.reshape((nant, ndir, ntime)) dtec = jnp.reshape(tec - tec[0, :, :], (-1, )) return dtec prior_chain = build_frozen_flow_prior(X, kernel, tec_to_dtec, x0) ns = NestedSampler(log_likelihood, prior_chain, sampler_name='slice', predict_f=predict_f, predict_fvar=predict_fvar) def run_with_n(n): @jit def run(key): return ns(key=key, num_live_points=n, max_samples=1e5, collect_samples=True, termination_frac=0.01, stoachastic_uncertainty=False, sampler_kwargs=dict(depth=5, num_slices=1)) t0 = default_timer() results = run(random.PRNGKey(0)) print("Efficiency", results.efficiency) print("Time to run (including compile)", default_timer() - t0) # t0 = default_timer() # results = run(random.PRNGKey(1)) # print(results.efficiency) # print("Time to run (no compile)", default_timer() - t0) return results for n in [100]: results = run_with_n(n) plt.scatter(n, results.logZ) plt.errorbar(n, results.logZ, yerr=results.logZerr) plt.title("Kernel: {}".format(kernel.__class__.__name__)) plt.ylabel('log Z') plt.show() fstd = jnp.sqrt(results.marginalised['predict_fvar'] - results.marginalised['predict_f']**2) plt.scatter(jnp.arange(Y.size), Y_obs, marker='+', label='data') plt.scatter(jnp.arange(Y.size), Y, marker="o", label='underlying') plt.scatter(jnp.arange(Y.size), results.marginalised['predict_f'], marker=".", label='underlying') plt.errorbar(jnp.arange(Y.size), results.marginalised['predict_f'], yerr=fstd, label='marginalised') plt.title("Kernel: {}".format(kernel.__class__.__name__)) plt.legend() plt.show() plot_diagnostics(results) plot_cornerplot(results) return results.logZ, results.logZerr
def main(): def log_normal(x, mean, uncert): dx = x - mean dx = dx / uncert return -0.5 * x.size * jnp.log( 2. * jnp.pi) - x.size * jnp.log(uncert) - 0.5 * dx @ dx N = 100 X = jnp.linspace(-2., 2., N)[:, None] true_alpha, true_sigma, true_l, true_uncert = 1., 1., 0.2, 0.25 data_mu = jnp.zeros((N, )) prior_cov = RBF()(X, X, true_l, true_sigma) Y = jnp.linalg.cholesky(prior_cov) @ random.normal(random.PRNGKey(0), shape=(N, )) + data_mu Y_obs = Y + true_uncert * random.normal(random.PRNGKey(1), shape=(N, )) def predict_f(sigma, K, uncert, **kwargs): data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) mu = jnp.zeros_like(Y_obs) return mu + K @ jnp.linalg.solve(K + data_cov, Y_obs) def predict_fvar(sigma, K, uncert, **kwargs): data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) mu = jnp.zeros_like(Y_obs) return jnp.diag(K - K @ jnp.linalg.solve(K + data_cov, K)) ### # define the prior chain # Here we assume each image is represented by pixels. # Alternatively, you could choose regions arranged non-uniformly over the image. image_shape = (128, 128) npix = image_shape[0] * image_shape[1] I150 = jnp.ones(image_shape) alpha_cw_gp_sigma = HalfLaplacePrior('alpha_cw_gp_sigma', 1.) alpha_mw_gp_sigma = HalfLaplacePrior('alpha_mw_gp_sigma', 1.) l_cw = UniformPrior('l_cw', 0., 0.5) #degrees l_mw = UniformPrior('l_mw', 0.5, 2.) #degrees K_cw = GaussianProcessKernelPrior('K_cw', RBF(), X, l_cw, alpha_cw_gp_sigma) K_mw = GaussianProcessKernelPrior('K_mw', RBF(), X, l_mw, alpha_mw_gp_sigma) alpha_cw = MVNPrior('alpha_cw', -1.5, K_cw) alpha_mw = MVNPrior('alpha_mw', -2.5, K_mw) S_cw_150 = UniformPrior('S150_cw', 0., I150) S_mw_150 = UniformPrior('S150_mw', 0., I150) uncert = HalfLaplacePrior('uncert', 1.) def log_likelihood(uncert, alpha_cw, alpha_mw, S_cw_150, S_mw_150): log_prob = 0 for img, freq in zip(images, freqs): # <- need to define these I_total = S_mw_150 * (freq / 150e6)**(alpha_mw) + S_cw_150 * ( freq / 150e6)**(alpha_cw) log_prob += log_normal(img, I_total, uncert) return log_prob prior_chain = PriorChain()\ .push(alpha_cw).push(S_cw_150)\ .push(alpha_mw).push(S_mw_150)\ .push(uncert) print(prior_chain) ns = NestedSampler(log_likelihood, prior_chain, sampler_name='ellipsoid', predict_f=predict_f, predict_fvar=predict_fvar) def run_with_n(n): @jit def run(): return ns(key=random.PRNGKey(0), num_live_points=n, max_samples=1e3, collect_samples=True, termination_frac=0.01, stoachastic_uncertainty=True) results = run() return results
def main(): def log_normal(x, mean, cov): L = jnp.linalg.cholesky(cov) dx = x - mean dx = solve_triangular(L, dx, lower=True) return -0.5 * x.size * jnp.log(2. * jnp.pi) - jnp.sum(jnp.log(jnp.diag(L))) \ - 0.5 * dx @ dx N = 100 X = jnp.linspace(-2., 2., N)[:, None] true_alpha, true_sigma, true_l, true_uncert = 1., 1., 0.2, 0.25 data_mu = jnp.zeros((N, )) prior_cov = RationalQuadratic()(X, X, true_l, true_sigma, true_alpha) Y = jnp.linalg.cholesky(prior_cov) @ random.normal(random.PRNGKey(0), shape=(N, )) + data_mu Y_obs = Y + true_uncert * random.normal(random.PRNGKey(1), shape=(N, )) # Y_obs = jnp.where((jnp.arange(N) > 50) & (jnp.arange(N) < 60), # random.normal(random.PRNGKey(1), shape_dict=(N, )), # Y_obs) # plt.scatter(X[:, 0], Y_obs, label='data') # plt.plot(X[:, 0], Y, label='underlying') # plt.legend() # plt.show() def log_likelihood(K, uncert, **kwargs): """ P(Y|sigma, half_width) = N[Y, mu, K] Args: sigma: l: Returns: """ data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) mu = jnp.zeros_like(Y_obs) log_prob = log_normal(Y_obs, mu, K + data_cov) # print(log_prob) return log_prob def predict_f(K, uncert, **kwargs): data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) mu = jnp.zeros_like(Y_obs) return mu + K @ jnp.linalg.solve(K + data_cov, Y_obs) def predict_fvar(K, uncert, **kwargs): data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) mu = jnp.zeros_like(Y_obs) return jnp.diag(K - K @ jnp.linalg.solve(K + data_cov, K)) prior_chain = PriorChain() \ .push(GaussianProcessKernelPrior('K', RationalQuadratic(), X, UniformPrior('l', 0., 4.), UniformPrior('sigma', 0., 4.), UniformPrior('alpha', 0., 4.))) \ .push(UniformPrior('uncert', 0., 2.)) ns = NestedSampler(log_likelihood, prior_chain, sampler_name='multi_ellipsoid', predict_f=predict_f, predict_fvar=predict_fvar) def run_with_n(n): @jit def run(): return ns(key=random.PRNGKey(0), num_live_points=n, max_samples=1e4, collect_samples=True, termination_frac=0.01, stoachastic_uncertainty=False, sampler_kwargs=dict(depth=4)) results = run() return results for n in [200]: results = run_with_n(n) plt.scatter(n, results.logZ) plt.errorbar(n, results.logZ, yerr=results.logZerr) plt.title("Kernel: {}".format(RationalQuadratic.__name__)) plt.ylabel('log Z') plt.show() plt.scatter(X[:, 0], Y_obs, label='data') plt.plot(X[:, 0], Y, label='underlying') plt.plot(X[:, 0], results.marginalised['predict_f'], label='marginalised') plt.plot(X[:, 0], results.marginalised['predict_f'] + jnp.sqrt(results.marginalised['predict_fvar']), ls='dotted', c='black') plt.plot(X[:, 0], results.marginalised['predict_f'] - jnp.sqrt(results.marginalised['predict_fvar']), ls='dotted', c='black') plt.title("Kernel: {}".format(RationalQuadratic.__name__)) plt.legend() plt.show() plot_diagnostics(results) plot_cornerplot(results) return results.logZ, results.logZerr