def test_tomographic_kernel(): from jax import random from jaxns.gaussian_process.kernels import RBF import pylab as plt n = 300 a1 = jnp.array([[-1, 0., 0.]]) k1 = jnp.stack([ 4. * jnp.pi / 180. * random.uniform(random.PRNGKey(0), shape=(n, ), minval=-1, maxval=1), 4. * jnp.pi / 180. * random.uniform(random.PRNGKey(1), shape=(n, ), minval=-1, maxval=1), jnp.ones(n) ], axis=1) k1 /= jnp.linalg.norm(k1, axis=-1, keepdims=True) n = 1 a2 = jnp.array([[1., 0., 0.]]) k2 = jnp.stack([jnp.zeros(n), jnp.zeros(n), jnp.ones(n)], axis=1) k2 /= jnp.linalg.norm(k2, axis=-1, keepdims=True) x0 = jnp.zeros(3) K = tomographic_kernel(a1, a2, k1, k2, x0, RBF(), height=10., width=2., l=1., sigma=1., S=25) sc = plt.scatter(k1[:, 0], k1[:, 1], c=K[:, 0]) plt.colorbar(sc) plt.show()
def rbf_dtec(nant, ndir, ntime, height, width, sigma, l, uncert, v): """ In frozen flow the screen moves with velocity v. fed(x,t) = fed(x-v*t,0) so that the tomographic kernel transforms as, K(x1,k1,t1,x2,k2,t2) = K(x1-v * t1,k1,0,x2-v * t2,k2,0) """ import pylab as plt a = jnp.concatenate([ 10. * random.uniform(random.PRNGKey(0), shape=(nant, 2)), jnp.zeros((nant, 1)) ], axis=1) k = jnp.concatenate([ 4. * jnp.pi / 180. * random.uniform( random.PRNGKey(0), shape=(ndir, 2), minval=-1, maxval=1), jnp.ones((ndir, 1)) ], axis=1) k = k / jnp.linalg.norm(k, axis=1, keepdims=True) t = jnp.arange(ntime)[:, None] * 30. #seconds X = make_coord_array(a, k, t) a = X[:, 0:3] k = X[:, 3:6] t = X[:, 6:7] x0 = a[0, :] kernel = TomographicKernel(x0, RBF(), S_marg=100, S_gamma=100) K = kernel(X[:, :6] - jnp.concatenate([v, jnp.zeros(3)]) * t, X[:, :6] - jnp.concatenate([v, jnp.zeros(3)]) * t, height, width, l, sigma) plt.imshow(K) plt.colorbar() plt.show() plt.plot(jnp.sqrt(jnp.diag(K))) plt.show() L = msqrt(K) #jnp.linalg.cholesky(K + jnp.eye(K.shape_dict[0])*1e-3) tec = L @ random.normal(random.PRNGKey(2), shape=(L.shape[0], )) tec = tec.reshape((nant, ndir, ntime)) dtec = tec - tec[0, :, :] dtec = dtec.reshape((-1, )) plt.plot(dtec) plt.show() return X, dtec, dtec + uncert * random.normal(random.PRNGKey(3), shape=dtec.shape)
def rbf_dtec(nant, ndir, height, width, sigma, l, uncert=1.): import pylab as plt a = jnp.concatenate([ 10. * random.uniform(random.PRNGKey(0), shape=(nant, 2)), jnp.zeros((nant, 1)) ], axis=1) k = jnp.concatenate([ 4. * jnp.pi / 180. * random.uniform( random.PRNGKey(0), shape=(ndir, 2), minval=-1, maxval=1), jnp.ones((ndir, 1)) ], axis=1) k = k / jnp.linalg.norm(k, axis=1, keepdims=True) X = make_coord_array(a, k) a = X[:, 0:3] k = X[:, 3:6] x0 = a[0, :] kernel = TomographicKernel(x0, RBF(), S_marg=100, S_gamma=100) K = kernel(X, X, height, width, l, sigma) plt.imshow(K) plt.colorbar() plt.show() plt.plot(jnp.sqrt(jnp.diag(K))) plt.show() L = msqrt(K) #jnp.linalg.cholesky(K + jnp.eye(K.shape_dict[0])*1e-3) tec = L @ random.normal(random.PRNGKey(2), shape=(L.shape[0], )) tec = tec.reshape((nant, ndir)) dtec = tec - tec[0, :] dtec = jnp.reshape(dtec, (-1, )) TEC_CONV = -8.4479745e6 # mTECU/Hz freqs = jnp.linspace(121e6, 168e6, 24) tec_conv = TEC_CONV / freqs phase = dtec[:, None] * tec_conv Y = jnp.concatenate([jnp.cos(phase), jnp.sin(phase)], axis=-1) plt.plot(dtec) plt.show() return X, dtec, Y, Y + uncert * random.normal(random.PRNGKey(3), shape=Y.shape), tec_conv
def test_tomographic_kernel(): dp = make_example_datapack(500, 24, 1, clobber=True) with dp: select = dict(pol=slice(0, 1, 1), ant=slice(0, None, 1)) dp.current_solset = 'sol000' dp.select(**select) tec_mean, axes = dp.tec tec_mean = tec_mean[0, ...] patch_names, directions = dp.get_directions(axes['dir']) antenna_labels, antennas = dp.get_antennas(axes['ant']) timestamps, times = dp.get_times(axes['time']) antennas = ac.ITRS(*antennas.cartesian.xyz, obstime=times[0]) ref_ant = antennas[0] frame = ENU(obstime=times[0], location=ref_ant.earth_location) antennas = antennas.transform_to(frame) ref_ant = antennas[0] directions = directions.transform_to(frame) x = antennas.cartesian.xyz.to(au.km).value.T k = directions.cartesian.xyz.value.T X = make_coord_array(x[50:51, :], k) x0 = ref_ant.cartesian.xyz.to(au.km).value print(k.shape) kernel = TomographicKernel(x0, x0, RBF(), S_marg=25) K = jit(lambda X: kernel( X, X, bottom=200., width=50., fed_kernel_params=dict(l=7., sigma=1.)))( jnp.asarray(X)) # K /= jnp.outer(jnp.sqrt(jnp.diag(K)), jnp.sqrt(jnp.diag(K))) plt.imshow(K) plt.colorbar() plt.show() L = jnp.linalg.cholesky(K + 1e-6 * jnp.eye(K.shape[0])) print(L) dtec = L @ random.normal(random.PRNGKey(24532), shape=(K.shape[0], )) print(jnp.std(dtec)) ax = plot_vornoi_map(k[:, 0:2], dtec) ax.set_xlabel(r"$k_{\rm east}$") ax.set_ylabel(r"$k_{\rm north}$") ax.set_xlim(-0.1, 0.1) ax.set_ylim(-0.1, 0.1) plt.show()
# plt.errorbar(n, results.logZ, yerr=results.logZerr) plt.title("Kernel: {}".format(kernel.__class__.__name__)) plt.ylabel('log Z') plt.show() fstd = jnp.sqrt(results.marginalised['predict_fvar'] - results.marginalised['predict_f'] ** 2) plt.scatter(jnp.arange(Y.size),Y_obs, marker='+', label='data') plt.scatter(jnp.arange(Y.size),Y, marker="o", label='underlying') plt.scatter(jnp.arange(Y.size), results.marginalised['predict_f'], marker=".", label='underlying') plt.errorbar(jnp.arange(Y.size), results.marginalised['predict_f'], yerr=fstd, label='marginalised') plt.title("Kernel: {}".format(kernel.__class__.__name__)) plt.legend() plt.show() # plot_samples_development(results,save_name='./ray_integral_solution.mp4') plot_diagnostics(results) plot_cornerplot(results) return results.logZ, results.logZerr if __name__ == '__main__': logZ_rbf, logZerr_rbf = main(RBF()) # logZ_m12, logZerr_m12 = main(M12()) # plt.errorbar(['rbf', 'm12'], # [logZ_rbf, logZ_m12], # [logZerr_rbf, logZerr_m12]) # plt.ylabel("log Z") # plt.show()
def main(kernel): print(("Working on Kernel: {}".format(kernel.__class__.__name__))) def log_normal(x, mean, cov): L = jnp.linalg.cholesky(cov) # U, S, Vh = jnp.linalg.svd(cov) log_det = jnp.sum(jnp.log(jnp.diag(L))) # jnp.sum(jnp.log(S))# dx = x - mean dx = solve_triangular(L, dx, lower=True) # U S Vh V 1/S Uh # pinv = (Vh.T.conj() * jnp.where(S!=0., jnp.reciprocal(S), 0.)) @ U.T.conj() maha = dx @ dx # dx @ pinv @ dx#solve_triangular(L, dx, lower=True) log_likelihood = -0.5 * x.size * jnp.log(2. * jnp.pi) \ - log_det \ - 0.5 * maha # print(log_likelihood) return log_likelihood N = 100 X = jnp.linspace(-2., 2., N)[:, None] true_sigma, true_l, true_uncert = 1., 0.2, 0.2 data_mu = jnp.zeros((N, )) prior_cov = RBF()(X, X, true_l, true_sigma) + 1e-13 * jnp.eye(N) # print(jnp.linalg.cholesky(prior_cov), jnp.linalg.eigvals(prior_cov)) # return Y = jnp.linalg.cholesky(prior_cov) @ random.normal(random.PRNGKey(0), shape=(N, )) + data_mu Y_obs = Y + true_uncert * random.normal(random.PRNGKey(1), shape=(N, )) Y_obs = jnp.where((jnp.arange(N) > 50) & (jnp.arange(N) < 60), random.normal(random.PRNGKey(1), shape=(N, )), Y_obs) # plt.scatter(X[:, 0], Y_obs, label='data') # plt.plot(X[:, 0], Y, label='underlying') # plt.legend() # plt.show() def log_likelihood(K, uncert, **kwargs): """ P(Y|sigma, half_width) = N[Y, mu, K] Args: sigma: l: Returns: """ data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) mu = jnp.zeros_like(Y_obs) return log_normal(Y_obs, mu, K + data_cov) def predict_f(K, uncert, **kwargs): data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) mu = jnp.zeros_like(Y_obs) return mu + K @ jnp.linalg.solve(K + data_cov, Y_obs) def predict_fvar(K, uncert, **kwargs): data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) mu = jnp.zeros_like(Y_obs) return jnp.diag(K - K @ jnp.linalg.solve(K + data_cov, K)) l = UniformPrior('l', 0., 2.) uncert = UniformPrior('uncert', 0., 2.) sigma = UniformPrior('sigma', 0., 2.) cov = GaussianProcessKernelPrior('K', kernel, X, l, sigma) prior_chain = PriorChain().push(uncert).push(cov) # print(prior_chain) ns = NestedSampler(log_likelihood, prior_chain, sampler_name='multi_ellipsoid', predict_f=predict_f, predict_fvar=predict_fvar) def run_with_n(n): @jit def run(key): return ns(key=key, num_live_points=n, max_samples=1e5, collect_samples=True, termination_frac=0.01, stoachastic_uncertainty=False, sampler_kwargs=dict(depth=3)) t0 = default_timer() # with disable_jit(): results = run(random.PRNGKey(6)) print(results.efficiency) print( "Time to execute (including compile): {}".format(default_timer() - t0)) t0 = default_timer() results = run(random.PRNGKey(6)) print(results.efficiency) print("Time to execute (not including compile): {}".format( (default_timer() - t0))) return results for n in [100]: results = run_with_n(n) plt.scatter(n, results.logZ) plt.errorbar(n, results.logZ, yerr=results.logZerr) plt.title("Kernel: {}".format(kernel.__class__.__name__)) plt.ylabel('log Z') plt.show() plt.scatter(X[:, 0], Y_obs, label='data') plt.plot(X[:, 0], Y, label='underlying') plt.plot(X[:, 0], results.marginalised['predict_f'], label='marginalised') plt.plot(X[:, 0], results.marginalised['predict_f'] + jnp.sqrt(results.marginalised['predict_fvar']), ls='dotted', c='black') plt.plot(X[:, 0], results.marginalised['predict_f'] - jnp.sqrt(results.marginalised['predict_fvar']), ls='dotted', c='black') plt.title("Kernel: {}".format(kernel.__class__.__name__)) plt.legend() plt.show() plot_diagnostics(results) plot_cornerplot(results) return results.logZ, results.logZerr
def main(): def log_normal(x, mean, uncert): dx = x - mean dx = dx / uncert return -0.5 * x.size * jnp.log( 2. * jnp.pi) - x.size * jnp.log(uncert) - 0.5 * dx @ dx N = 100 X = jnp.linspace(-2., 2., N)[:, None] true_alpha, true_sigma, true_l, true_uncert = 1., 1., 0.2, 0.25 data_mu = jnp.zeros((N, )) prior_cov = RBF()(X, X, true_l, true_sigma) Y = jnp.linalg.cholesky(prior_cov) @ random.normal(random.PRNGKey(0), shape=(N, )) + data_mu Y_obs = Y + true_uncert * random.normal(random.PRNGKey(1), shape=(N, )) def predict_f(sigma, K, uncert, **kwargs): data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) mu = jnp.zeros_like(Y_obs) return mu + K @ jnp.linalg.solve(K + data_cov, Y_obs) def predict_fvar(sigma, K, uncert, **kwargs): data_cov = jnp.square(uncert) * jnp.eye(X.shape[0]) mu = jnp.zeros_like(Y_obs) return jnp.diag(K - K @ jnp.linalg.solve(K + data_cov, K)) ### # define the prior chain # Here we assume each image is represented by pixels. # Alternatively, you could choose regions arranged non-uniformly over the image. image_shape = (128, 128) npix = image_shape[0] * image_shape[1] I150 = jnp.ones(image_shape) alpha_cw_gp_sigma = HalfLaplacePrior('alpha_cw_gp_sigma', 1.) alpha_mw_gp_sigma = HalfLaplacePrior('alpha_mw_gp_sigma', 1.) l_cw = UniformPrior('l_cw', 0., 0.5) #degrees l_mw = UniformPrior('l_mw', 0.5, 2.) #degrees K_cw = GaussianProcessKernelPrior('K_cw', RBF(), X, l_cw, alpha_cw_gp_sigma) K_mw = GaussianProcessKernelPrior('K_mw', RBF(), X, l_mw, alpha_mw_gp_sigma) alpha_cw = MVNPrior('alpha_cw', -1.5, K_cw) alpha_mw = MVNPrior('alpha_mw', -2.5, K_mw) S_cw_150 = UniformPrior('S150_cw', 0., I150) S_mw_150 = UniformPrior('S150_mw', 0., I150) uncert = HalfLaplacePrior('uncert', 1.) def log_likelihood(uncert, alpha_cw, alpha_mw, S_cw_150, S_mw_150): log_prob = 0 for img, freq in zip(images, freqs): # <- need to define these I_total = S_mw_150 * (freq / 150e6)**(alpha_mw) + S_cw_150 * ( freq / 150e6)**(alpha_cw) log_prob += log_normal(img, I_total, uncert) return log_prob prior_chain = PriorChain()\ .push(alpha_cw).push(S_cw_150)\ .push(alpha_mw).push(S_mw_150)\ .push(uncert) print(prior_chain) ns = NestedSampler(log_likelihood, prior_chain, sampler_name='ellipsoid', predict_f=predict_f, predict_fvar=predict_fvar) def run_with_n(n): @jit def run(): return ns(key=random.PRNGKey(0), num_live_points=n, max_samples=1e3, collect_samples=True, termination_frac=0.01, stoachastic_uncertainty=True) results = run() return results
def train_neural_network(datapack: DataPack, batch_size, learning_rate, num_batches): with datapack: select = dict(pol=slice(0, 1, 1), ant=None, time=slice(0, 1, 1)) datapack.current_solset = 'sol000' datapack.select(**select) axes = datapack.axes_tec patch_names, directions = datapack.get_directions(axes['dir']) antenna_labels, antennas = datapack.get_antennas(axes['ant']) timestamps, times = datapack.get_times(axes['time']) antennas = ac.ITRS(*antennas.cartesian.xyz, obstime=times[0]) ref_ant = antennas[0] frame = ENU(obstime=times[0], location=ref_ant.earth_location) antennas = antennas.transform_to(frame) ref_ant = antennas[0] directions = directions.transform_to(frame) x = antennas.cartesian.xyz.to(au.km).value.T k = directions.cartesian.xyz.value.T t = times.mjd t -= t[len(t) // 2] t *= 86400. n_screen = 250 kstar = random.uniform(random.PRNGKey(29428942), (n_screen, 3), minval=jnp.min(k, axis=0), maxval=jnp.max(k, axis=0)) kstar /= jnp.linalg.norm(kstar, axis=-1, keepdims=True) X = jnp.asarray( make_coord_array(x, jnp.concatenate([k, kstar], axis=0), t[:, None])) x0 = jnp.asarray(antennas.cartesian.xyz.to(au.km).value.T[0, :]) ref_ant = x0 kernel = TomographicKernel(x0, ref_ant, RBF(), S_marg=100) neural_kernel = NeuralTomographicKernel(x0, ref_ant) def loss(params, key): keys = random.split(key, 5) indices = random.permutation(keys[0], jnp.arange(X.shape[0]))[:batch_size] X_batch = X[indices, :] wind_velocity = random.uniform(keys[1], shape=(3, ), minval=jnp.asarray([-200., -200., 0.]), maxval=jnp.asarray([200., 200., 0. ])) / 1000. bottom = random.uniform(keys[2], minval=50., maxval=500.) width = random.uniform(keys[3], minval=40., maxval=300.) l = random.uniform(keys[4], minval=1., maxval=30.) sigma = 1. K = kernel(X_batch, X_batch, bottom, width, l, sigma, wind_velocity=wind_velocity) neural_kernel.set_params(params) neural_K = neural_kernel(X_batch, X_batch, bottom, width, l, sigma, wind_velocity=wind_velocity) return jnp.mean((K - neural_K)**2) / width**2 init_params = neural_kernel.init_params(random.PRNGKey(42)) def train_one_batch(params, key): l, g = value_and_grad(lambda params: loss(params, key))(params) params = tree_multimap(lambda p, g: p - learning_rate * g, params, g) return params, l final_params, losses = jit(lambda key: scan( train_one_batch, init_params, random.split(key, num_batches)))( random.PRNGKey(42)) plt.plot(losses) plt.yscale('log') plt.show()