def square_norm_smooth(kappa, g1, g2, reg=1.): kE, kB = kappa gamma1, gamma2 = ks93inv(kE, kB) sq = np.sum((g1 - gamma1) * (g1 - gamma1)) + np.sum( (g2 - gamma2) * (g2 - gamma2)) p = np.sum(kE * kE) + np.sum(kB * kB) return sq + reg * p
def square_norm_sparse(kappa, g1, g2, reg=.01): kE, kB = kappa gamma1, gamma2 = ks93inv(kE, kB) sq = np.sum((g1 - gamma1) * (g1 - gamma1)) + np.sum( (g2 - gamma2) * (g2 - gamma2)) p = np.sum(np.abs(kE)) + np.sum(np.abs(kB)) return sq + reg * p
def log_likelihood(x, sigma, meas_shear, mask, sigma_mask): """ Likelihood function at the level of the measured shear """ if b_mode: x = x.reshape((360, 360, 2)) ke = x[..., 0] kb = x[..., 1] else: ke = x.reshape((360, 360)) kb = jnp.zeros(ke.shape) model_shear = jnp.stack(ks93inv(ke, kb), axis=-1) return -jnp.sum((model_shear - meas_shear)**2 / ((sigma_gamma)**2 + sigma**2 + sigma_mask)) / 2.
def body_fun(i, val): s_q, s_u = val t_Q = q_operationA + jnp.multiply(q_operationB, s_q) t_U = u_operationA + jnp.multiply(u_operationB, s_u) # in E, B representation t_E, t_B = ks93(t_Q, t_U) s_E = (scov_ft_E / (scov_ft_E + tcov_ft)) * jnp.fft.fft2(t_E) s_B = (scov_ft_B / (scov_ft_B + tcov_ft)) * jnp.fft.fft2(t_B) s_E = jnp.fft.ifft2(s_E) s_B = jnp.fft.ifft2(s_B) # in Q, U representation s_q, s_u = ks93inv(s_E, s_B) return (s_q, s_u)
def spin_wiener_filter(data_q, data_u, ncov_diag_Q, ncov_diag_U, input_ps_map_E, input_ps_map_B, iterations=10): """ Wiener filter Elsner-Wandelt messenger field adapted for spin-2 fields (CMB polarization or galaxy weak lensing( Parameters ---------- data_q : Q square image data (e.g. gamma1) data_u : U square image data (e.g. gamma2) ncov_diag_Q : Q noise variance per pixel (assumed uncorrelated) ncov_diag_U : U noise variance per pixel (assumed uncorrelated) input_ps_map_E : 1D power P(k) for E-mode signal power spectrum evaluated 2D components k1,k2 as a square image input_ps_map_B : 1D power P(k) for B-mode signal power spectrum evaluated 2D components k1,k2 as a square image iterations : number of iterations Returns ------- s_q,s_u : Wiener filtered q and u signals """ tcov_diag = jnp.min(jnp.array([ncov_diag_Q, ncov_diag_U])) scov_ft_E = jnp.fft.fftshift(input_ps_map_E) scov_ft_B = jnp.fft.fftshift(input_ps_map_B) s_q = jnp.zeros(data_q.shape) s_u = jnp.zeros(data_q.shape) for i in jnp.arange(iterations): # in Q, U representation t_Q = (tcov_diag / ncov_diag_Q) * data_q + ( (ncov_diag_Q - tcov_diag) / ncov_diag_Q) * s_q t_U = (tcov_diag / ncov_diag_U) * data_u + ( (ncov_diag_U - tcov_diag) / ncov_diag_U) * s_u # in E, B representation t_E, t_B = ks93(t_Q, t_U) s_E = (scov_ft_E / (scov_ft_E + tcov_diag)) * jnp.fft.fft2(t_E) s_B = (scov_ft_B / (scov_ft_B + tcov_diag)) * jnp.fft.fft2(t_B) s_E = jnp.fft.ifft2(s_E) s_B = jnp.fft.ifft2(s_B) # in Q, U representation s_q, s_u = ks93inv(s_E, s_B) return s_q, s_u
def preprocess_batch(rng_key, batch): """ Creates a noisy KS map as an input to the model """ key1, key2 = jax.random.split(rng_key, 2) # Preprocess the batch for deep mass, i.e. apply KS, add noise, mask, and # do inverse Kaiser-Squires input_map = batch['x'][...,0] g1, g2 = ks93inv(input_map, jnp.zeros_like(input_map)) if FLAGS.gaussian_noise: # Add Gaussian noise and mask g1 = mask * (g1 + std1*jax.random.normal(key1, g1.shape)) g2 = mask * (g2 + std2*jax.random.normal(key2, g2.shape)) else: # COSMOS noise realisations random_e1, random_e2 = random_rotations(cat_e1, cat_e2, g1.shape[0], rng_key) noise_e1 = jax.vmap(b2d)(v=random_e1) noise_e2 = jax.vmap(b2d)(v=random_e2) g1 = mask * (g1 + noise_e1) g2 = mask * (g2 + noise_e2) ks_map = jnp.stack(ks93(g1, g2), axis=-1) return ks_map, input_map
def square_norm(kappa, g1, g2): kE, kB = kappa gamma1, gamma2 = ks93inv(kE, kB) return np.sum((g1 - gamma1) * (g1 - gamma1)) + np.sum( (g2 - gamma2) * (g2 - gamma2))
def least_squares(g1, g2, kE, kB): gamma1, gamma2 = ks93inv(kE, kB) return np.linalg.norm( np.vstack([g1, g2]) - np.vstack([gamma1, gamma2]))
if __name__ == '__main__': key = random.PRNGKey(0) # (g1, g2) should in practice be measurements from a real galaxy survey g1, g2 = 0.1 * random.normal(key, (2, 32, 32)) + 0.1 * np.ones((2, 32, 32)) kE, kB = ks93(g1, g2) def least_squares(g1, g2, kE, kB): gamma1, gamma2 = ks93inv(kE, kB) return np.linalg.norm( np.vstack([g1, g2]) - np.vstack([gamma1, gamma2])) # Computing shear from convergence gamma1, gamma2 = ks93inv(kE, kB) print('gamma1.shape', gamma1.shape) print('gamma error', least_squares(g1, g2, kE, kB)) # Computing convergence form shear kappaE, kappaB = ks93(gamma1, gamma2) print('kappaE.mean()', kappaE.mean(), 'should be 0') print('kappa error', np.linalg.norm(kE - kappaE)) print('') print('Recovering kappa with SGD') kEhat, kBhat = inverse_problem(g1, g2, obj=square_norm_smooth, kappa_shape=kE.shape)
def main(_): b_mode = False std1 = jnp.expand_dims(fits.getdata(FLAGS.std1).astype('float32'), -1) std2 = jnp.expand_dims(fits.getdata(FLAGS.std2).astype('float32'), -1) sigma_gamma = jnp.concatenate([std1, std2], axis=-1) #fits.writeto("./sigma_gamma.fits", onp.array(sigma_gamma), overwrite=False) def log_likelihood(x, sigma, meas_shear, mask, sigma_mask): """ Likelihood function at the level of the measured shear """ if b_mode: x = x.reshape((360, 360, 2)) ke = x[..., 0] kb = x[..., 1] else: ke = x.reshape((360, 360)) kb = jnp.zeros(ke.shape) model_shear = jnp.stack(ks93inv(ke, kb), axis=-1) return -jnp.sum((model_shear - meas_shear)**2 / ((sigma_gamma)**2 + sigma**2 + sigma_mask)) / 2. likelihood_score = jax.vmap(jax.grad(log_likelihood), in_axes=[0, 0, None, None, None]) map_size = fits.getdata(FLAGS.mask).astype('float32').shape[0] # Make the network #model = hk.transform_with_state(forward_fn) model = hk.without_apply_rng(hk.transform_with_state(forward_fn)) rng_seq = hk.PRNGSequence(42) params, state = model.init(next(rng_seq), jnp.zeros((1, map_size, map_size, 2)), jnp.zeros((1, 1, 1, 1)), is_training=True) # Load the weights of the neural network if not FLAGS.gaussian_only: with open(FLAGS.model_weights, 'rb') as file: params, state, sn_state = pickle.load(file) residual_prior_score = partial(model.apply, params, state, next(rng_seq), is_training=True) pixel_size = jnp.pi * FLAGS.resolution / 180. / 60. #rad/pixel # Load prior power spectrum ps_data = onp.load(FLAGS.gaussian_path).astype('float32') ell = jnp.array(ps_data[0, :]) # 4th channel for massivenu ps_halofit = jnp.array(ps_data[1, :] / pixel_size**2) # normalisation by pixel size # convert to pixel units of our simple power spectrum calculator kell = ell / 2 / jnp.pi * 360 * pixel_size / map_size # Interpolate the Power Spectrum in Fourier Space power_map = jnp.array(make_power_map(ps_halofit, map_size, kps=kell)) # Load the noiseless convergence map if not FLAGS.COSMOS: print('i am here') convergence = fits.getdata(FLAGS.convergence).astype('float32') # Get the correspinding shear gamma1, gamma2 = ks93inv(convergence, onp.zeros_like(convergence)) if not FLAGS.no_cluster: print('adding a cluster') # Compute NFW profile shear map g1_NFW, g2_NFW = gen_nfw_shear(x_cen=FLAGS.x_cluster, y_cen=FLAGS.y_cluster, resolution=FLAGS.resolution, nx=map_size, ny=map_size, z=FLAGS.z_halo, m=FLAGS.mass_halo, zs=FLAGS.zs) # Shear with added NFW cluster gamma1 += g1_NFW gamma2 += g2_NFW # Target convergence map with the added cluster #ke_cluster, kb_cluster = ks93(g1_cluster, g2_cluster) # Add noise the shear map if FLAGS.cosmos_noise_realisation: print('cosmos noise real') gamma1 += fits.getdata(FLAGS.cosmos_noise_e1).astype('float32') gamma2 += fits.getdata(FLAGS.cosmos_noise_e2).astype('float32') else: gamma1 += std1[..., 0] * jax.random.normal( jax.random.PRNGKey(42), gamma1.shape) #onp.random.randn(map_size,map_size) gamma2 += std2[..., 0] * jax.random.normal( jax.random.PRNGKey(43), gamma2.shape) #onp.random.randn(map_size,map_size) # Load the shear maps and corresponding mask gamma = onp.stack( [gamma1, gamma2], -1) # Shear is expected in the format [map_size,map_size,2] else: # Load the shear maps and corresponding mask g1 = fits.getdata('../data/COSMOS/cosmos_full_e1_0.29arcmin360.fits' ).astype('float32').reshape([map_size, map_size, 1]) g2 = fits.getdata('../data/COSMOS/cosmos_full_e2_0.29arcmin360.fits' ).astype('float32').reshape([map_size, map_size, 1]) gamma = onp.concatenate([g1, g2], axis=-1) mask = jnp.expand_dims(fits.getdata(FLAGS.mask).astype('float32'), -1) # has shape [map_size,map_size,1] masked_true_shear = gamma * mask #fits.writeto("./input_shear.fits", onp.array(masked_true_shear), overwrite=False) sigma_mask = (1 - mask) * 1e10 def score_fn(params, state, x, sigma, is_training=False): if b_mode: x = x.reshape((-1, 360, 360, 2)) ke = x[..., 0] kb = x[..., 1] else: ke = x.reshape((-1, 360, 360)) if FLAGS.gaussian_prior: # If requested, first compute the Gaussian prior gs = gaussian_prior_score(ke, sigma.reshape((-1, 1, 1)), power_map) gs = jnp.expand_dims(gs, axis=-1) #print((jnp.abs(sigma.reshape((-1,1,1,1)))**2).shape, (gs).shape) net_input = jnp.concatenate([ ke.reshape((-1, 360, 360, 1)), jnp.abs(sigma.reshape((-1, 1, 1, 1)))**2 * gs ], axis=-1) res, state = model.apply(params, state, net_input, sigma.reshape((-1, 1, 1, 1)), is_training=is_training) if b_mode: gsb = gaussian_prior_score_b(kb, sigma.reshape((-1, 1, 1))) gsb = jnp.expand_dims(gsb, axis=-1) else: gsb = jnp.zeros_like(res) else: res, state = model.apply(params, state, ke.reshape((-1, 360, 360, 1)), sigma.reshape((-1, 1, 1, 1)), is_training=is_training) gs = jnp.zeros_like(res) gsb = jnp.zeros_like(res) return _, res, gs, gsb score_fn = partial(score_fn, params, state) def score_prior(x, sigma): if b_mode: _, res, gaussian_score, gsb = score_fn(x.reshape(-1, 360, 360, 2), sigma.reshape(-1, 1, 1, 1)) else: _, res, gaussian_score, gsb = score_fn(x.reshape(-1, 360, 360), sigma.reshape(-1, 1, 1)) ke = (res[..., 0:1] + gaussian_score).reshape(-1, 360 * 360) kb = gsb[..., 0].reshape(-1, 360 * 360) if b_mode: return jnp.stack([ke, kb], axis=-1) else: return ke def total_score_fn(x, sigma): if b_mode: sl = likelihood_score(x, sigma, masked_true_shear, mask, sigma_mask).reshape(-1, 360 * 360, 2) else: sl = likelihood_score(x, sigma, masked_true_shear, mask, sigma_mask).reshape(-1, 360 * 360) sp = score_prior(x, sigma) if b_mode: return (sl + sp).reshape(-1, 360 * 360 * 2) else: return (sl + sp).reshape(-1, 360 * 360) #return (sp).reshape(-1, 360*360,2) # Prepare the input with a high noise level map initial_temperature = FLAGS.initial_temperature delta_tmp = initial_temperature #onp.sqrt(initial_temperature**2 - 0.148**2) initial_step_size = FLAGS.initial_step_size #0.018 min_steps_per_temp = FLAGS.min_steps_per_temp #10 init_image, _ = ks93(mask[..., 0] * masked_true_shear[..., 0], mask[..., 0] * masked_true_shear[..., 1]) init_image = jnp.expand_dims(init_image, axis=0) init_image = jnp.repeat(init_image, FLAGS.batch_size, axis=0) init_image += (delta_tmp * onp.random.randn(FLAGS.batch_size, 360, 360)) def make_kernel_fn(target_log_prob_fn, target_score_fn, sigma): return ScoreHamiltonianMonteCarlo( target_log_prob_fn=target_log_prob_fn, target_score_fn=target_score_fn, step_size=initial_step_size * (jnp.max(sigma) / initial_temperature)**0.5, num_leapfrog_steps=3, num_delta_logp_steps=4) tmc = TemperedMC( target_score_fn=total_score_fn, #score_prior, inverse_temperatures=initial_temperature * jnp.ones([FLAGS.batch_size]), make_kernel_fn=make_kernel_fn, gamma=0.98, min_temp=8e-3, min_steps_per_temp=min_steps_per_temp, num_delta_logp_steps=4) num_burnin_steps = int(0) samples, trace = tfp.mcmc.sample_chain( num_results=2, #FLAGS.num_steps, current_state=init_image.reshape([FLAGS.batch_size, -1]), kernel=tmc, num_burnin_steps=num_burnin_steps, num_steps_between_results=6000, #num_results//FLAGS.num_steps, trace_fn=lambda _, pkr: (pkr.pre_tempering_results.is_accepted, pkr. post_tempering_inverse_temperatures, pkr.tempering_log_accept_ratio), seed=jax.random.PRNGKey(int(time.time()))) sol = samples[-1, ...].reshape(-1, 360, 360) from scipy import integrate @jax.jit def dynamics(t, x): if b_mode: x = x.reshape([-1, 360, 360, 2]) return -0.5 * total_score_fn( x, sigma=jnp.ones( (FLAGS.batch_size, 1, 1, 1)) * jnp.sqrt(t)).reshape([-1]) else: x = x.reshape([-1, 360, 360]) return -0.5 * total_score_fn( x, sigma=jnp.ones( (FLAGS.batch_size, 1, 1)) * jnp.sqrt(t)).reshape([-1]) init_ode = sol last_trace = jnp.mean(trace[1][-1]) noise = last_trace start_and_end_times = jnp.logspace(jnp.log10(0.99 * noise**2), -5, num=50) solution = integrate.solve_ivp(dynamics, [noise**2, (1e-5)], init_ode.flatten(), t_eval=start_and_end_times) denoised = solution.y[:, -1].reshape([FLAGS.batch_size, 360, 360]) fits.writeto("./results/" + FLAGS.output_folder + "/samples_hmc_" + FLAGS.output_file + ".fits", onp.array(sol), overwrite=False) fits.writeto("./results/" + FLAGS.output_folder + "/samples_denoised_" + FLAGS.output_file + ".fits", onp.array(denoised), overwrite=False) print('end of sampling')
def spin_wiener_sampler(data_q, data_u, ncov_diag_Q, ncov_diag_U, input_ps_map_E, input_ps_map_B, iterations=10, initial_map=None, thinning=1, verbose=False): """ Wiener posterior sampler using Elsner-Wandelt messenger field adapted for spin-2 fields (CMB polarization or galaxy weak lensing( Parameters Parameters ---------- data_q : Q square image data (e.g. gamma1) data_u : U square image data (e.g. gamma2) ncov_diag_Q : Q noise variance per pixel (assumed uncorrelated) ncov_diag_U : U noise variance per pixel (assumed uncorrelated) input_ps_map_E : 1D power P(k) for E-mode signal power spectrum evaluated 2D components k1,k2 as a square image input_ps_map_B : 1D power P(k) for B-mode signal power spectrum evaluated 2D components k1,k2 as a square image iterations : number of iterations initial_map : starting image for the sampler thinning : thinning factor (iterations must be divisible by thinning factor) verbose : bool verbose Returns ------- samples_E, samples_B : samples from Wiener posterior """ size = (data_q).shape[0] tcov_diag = np.min(np.array([ncov_diag_Q, ncov_diag_U])) tcov_ft = tcov_diag # unnecessary really, but convention dependent scov_ft_E = np.fft.fftshift(input_ps_map_E) scov_ft_B = np.fft.fftshift(input_ps_map_B) sigma_t_squared_Q = tcov_diag - tcov_diag * tcov_diag / ncov_diag_Q sigma_t_squared_U = tcov_diag - tcov_diag * tcov_diag / ncov_diag_U sigma_s_squared_E = scov_ft_E * tcov_ft / (tcov_ft + scov_ft_E) sigma_s_squared_B = scov_ft_B * tcov_ft / (tcov_ft + scov_ft_B) print(sigma_s_squared_B.mean()) if initial_map is None: s = data_q + 1j * data_u else: s = np.copy(initial_map) assert (iterations % thinning == 0) samples_E = np.zeros(shape=(int(iterations / thinning), size, size), dtype=jnp.complex128) samples_B = np.zeros(shape=(int(iterations / thinning), size, size), dtype=jnp.complex128) for i in range(iterations): # in Q, U representation t_Q = (tcov_diag / ncov_diag_Q) * data_q + ( (ncov_diag_Q - tcov_diag) / ncov_diag_Q) * s[0] t_U = (tcov_diag / ncov_diag_U) * data_u + ( (ncov_diag_U - tcov_diag) / ncov_diag_U) * s[1] t_Q = np.random.normal(t_Q.real, np.sqrt(sigma_t_squared_Q.real)) t_U = np.random.normal(t_U.real, np.sqrt(sigma_t_squared_U.real)) # in E, B representation t = ks93(t_Q, t_U) s_E = (scov_ft_E / (scov_ft_E + tcov_ft)) * np.fft.fft2(t[0]) s_B = (scov_ft_B / (scov_ft_B + tcov_ft)) * np.fft.fft2(t[1]) s_E = np.random.normal(s_E.real * 0., np.sqrt(sigma_s_squared_E.real) * size) + s_E s_B = np.random.normal(s_B.real * 0., np.sqrt(sigma_s_squared_B.real) * size) + s_B s_E = (np.fft.ifft2(s_E)) s_E = (s_E.real + s_E.imag) s_B = (np.fft.ifft2(s_B)) s_B = (s_B.real + s_B.imag) s = ks93inv(s_E, s_B) if i % thinning == 0: samples_E[int(i / thinning)] = s_E samples_B[int(i / thinning)] = s_B if verbose == True: print(i) return samples_E, samples_B