Пример #1
0
def square_norm_smooth(kappa, g1, g2, reg=1.):
    kE, kB = kappa
    gamma1, gamma2 = ks93inv(kE, kB)
    sq = np.sum((g1 - gamma1) * (g1 - gamma1)) + np.sum(
        (g2 - gamma2) * (g2 - gamma2))
    p = np.sum(kE * kE) + np.sum(kB * kB)
    return sq + reg * p
Пример #2
0
def square_norm_sparse(kappa, g1, g2, reg=.01):
    kE, kB = kappa
    gamma1, gamma2 = ks93inv(kE, kB)
    sq = np.sum((g1 - gamma1) * (g1 - gamma1)) + np.sum(
        (g2 - gamma2) * (g2 - gamma2))
    p = np.sum(np.abs(kE)) + np.sum(np.abs(kB))
    return sq + reg * p
Пример #3
0
    def log_likelihood(x, sigma, meas_shear, mask, sigma_mask):
        """ Likelihood function at the level of the measured shear
    """
        if b_mode:
            x = x.reshape((360, 360, 2))
            ke = x[..., 0]
            kb = x[..., 1]
        else:
            ke = x.reshape((360, 360))
            kb = jnp.zeros(ke.shape)

        model_shear = jnp.stack(ks93inv(ke, kb), axis=-1)

        return -jnp.sum((model_shear - meas_shear)**2 /
                        ((sigma_gamma)**2 + sigma**2 + sigma_mask)) / 2.
Пример #4
0
    def body_fun(i, val):
        s_q, s_u = val

        t_Q = q_operationA + jnp.multiply(q_operationB, s_q)
        t_U = u_operationA + jnp.multiply(u_operationB, s_u)
        # in E, B representation
        t_E, t_B = ks93(t_Q, t_U)
        s_E = (scov_ft_E / (scov_ft_E + tcov_ft)) * jnp.fft.fft2(t_E)
        s_B = (scov_ft_B / (scov_ft_B + tcov_ft)) * jnp.fft.fft2(t_B)
        s_E = jnp.fft.ifft2(s_E)
        s_B = jnp.fft.ifft2(s_B)
        # in Q, U representation
        s_q, s_u = ks93inv(s_E, s_B)

        return (s_q, s_u)
Пример #5
0
def spin_wiener_filter(data_q,
                       data_u,
                       ncov_diag_Q,
                       ncov_diag_U,
                       input_ps_map_E,
                       input_ps_map_B,
                       iterations=10):
    """
    Wiener filter Elsner-Wandelt messenger field adapted for spin-2 fields (CMB polarization or galaxy weak lensing(
    Parameters
    ----------
    data_q : Q square image data (e.g. gamma1)
    data_u : U square image data (e.g. gamma2)
    ncov_diag_Q : Q noise variance per pixel (assumed uncorrelated)
    ncov_diag_U : U noise variance per pixel (assumed uncorrelated)
    input_ps_map_E : 1D power P(k) for E-mode signal power spectrum evaluated 2D components k1,k2 as a square image
    input_ps_map_B : 1D power P(k) for B-mode signal power spectrum evaluated 2D components k1,k2 as a square image
    iterations : number of iterations

    Returns
    -------
    s_q,s_u : Wiener filtered q and u signals
    """
    tcov_diag = jnp.min(jnp.array([ncov_diag_Q, ncov_diag_U]))
    scov_ft_E = jnp.fft.fftshift(input_ps_map_E)
    scov_ft_B = jnp.fft.fftshift(input_ps_map_B)
    s_q = jnp.zeros(data_q.shape)
    s_u = jnp.zeros(data_q.shape)

    for i in jnp.arange(iterations):
        # in Q, U representation
        t_Q = (tcov_diag / ncov_diag_Q) * data_q + (
            (ncov_diag_Q - tcov_diag) / ncov_diag_Q) * s_q
        t_U = (tcov_diag / ncov_diag_U) * data_u + (
            (ncov_diag_U - tcov_diag) / ncov_diag_U) * s_u
        # in E, B representation
        t_E, t_B = ks93(t_Q, t_U)
        s_E = (scov_ft_E / (scov_ft_E + tcov_diag)) * jnp.fft.fft2(t_E)
        s_B = (scov_ft_B / (scov_ft_B + tcov_diag)) * jnp.fft.fft2(t_B)
        s_E = jnp.fft.ifft2(s_E)
        s_B = jnp.fft.ifft2(s_B)
        # in Q, U representation
        s_q, s_u = ks93inv(s_E, s_B)

    return s_q, s_u
Пример #6
0
  def preprocess_batch(rng_key, batch):
    """ Creates a noisy KS map as an input to the model
    """
    key1, key2 = jax.random.split(rng_key, 2)
    # Preprocess the batch for deep mass, i.e. apply KS, add noise, mask, and
    # do inverse Kaiser-Squires
    input_map = batch['x'][...,0]
    g1, g2 = ks93inv(input_map, jnp.zeros_like(input_map))

    if FLAGS.gaussian_noise:    
      # Add Gaussian noise and mask
      g1 = mask * (g1 + std1*jax.random.normal(key1, g1.shape))
      g2 = mask * (g2 + std2*jax.random.normal(key2, g2.shape))
    else:  
      # COSMOS noise realisations 
      random_e1, random_e2 = random_rotations(cat_e1, cat_e2, g1.shape[0], rng_key)
      noise_e1 = jax.vmap(b2d)(v=random_e1)
      noise_e2 = jax.vmap(b2d)(v=random_e2)

      g1 = mask * (g1 + noise_e1)
      g2 = mask * (g2 + noise_e2)

    ks_map = jnp.stack(ks93(g1, g2), axis=-1)
    return ks_map, input_map
Пример #7
0
def square_norm(kappa, g1, g2):
    kE, kB = kappa
    gamma1, gamma2 = ks93inv(kE, kB)
    return np.sum((g1 - gamma1) * (g1 - gamma1)) + np.sum(
        (g2 - gamma2) * (g2 - gamma2))
Пример #8
0
 def least_squares(g1, g2, kE, kB):
     gamma1, gamma2 = ks93inv(kE, kB)
     return np.linalg.norm(
         np.vstack([g1, g2]) - np.vstack([gamma1, gamma2]))
Пример #9
0
if __name__ == '__main__':

    key = random.PRNGKey(0)

    # (g1, g2) should in practice be measurements from a real galaxy survey
    g1, g2 = 0.1 * random.normal(key, (2, 32, 32)) + 0.1 * np.ones((2, 32, 32))
    kE, kB = ks93(g1, g2)

    def least_squares(g1, g2, kE, kB):
        gamma1, gamma2 = ks93inv(kE, kB)
        return np.linalg.norm(
            np.vstack([g1, g2]) - np.vstack([gamma1, gamma2]))

    # Computing shear from convergence
    gamma1, gamma2 = ks93inv(kE, kB)
    print('gamma1.shape', gamma1.shape)
    print('gamma error', least_squares(g1, g2, kE, kB))

    # Computing convergence form shear
    kappaE, kappaB = ks93(gamma1, gamma2)
    print('kappaE.mean()', kappaE.mean(), 'should be 0')
    print('kappa error', np.linalg.norm(kE - kappaE))

    print('')
    print('Recovering kappa with SGD')

    kEhat, kBhat = inverse_problem(g1,
                                   g2,
                                   obj=square_norm_smooth,
                                   kappa_shape=kE.shape)
Пример #10
0
def main(_):
    b_mode = False

    std1 = jnp.expand_dims(fits.getdata(FLAGS.std1).astype('float32'), -1)
    std2 = jnp.expand_dims(fits.getdata(FLAGS.std2).astype('float32'), -1)
    sigma_gamma = jnp.concatenate([std1, std2], axis=-1)

    #fits.writeto("./sigma_gamma.fits", onp.array(sigma_gamma), overwrite=False)
    def log_likelihood(x, sigma, meas_shear, mask, sigma_mask):
        """ Likelihood function at the level of the measured shear
    """
        if b_mode:
            x = x.reshape((360, 360, 2))
            ke = x[..., 0]
            kb = x[..., 1]
        else:
            ke = x.reshape((360, 360))
            kb = jnp.zeros(ke.shape)

        model_shear = jnp.stack(ks93inv(ke, kb), axis=-1)

        return -jnp.sum((model_shear - meas_shear)**2 /
                        ((sigma_gamma)**2 + sigma**2 + sigma_mask)) / 2.

    likelihood_score = jax.vmap(jax.grad(log_likelihood),
                                in_axes=[0, 0, None, None, None])

    map_size = fits.getdata(FLAGS.mask).astype('float32').shape[0]

    # Make the network
    #model = hk.transform_with_state(forward_fn)
    model = hk.without_apply_rng(hk.transform_with_state(forward_fn))

    rng_seq = hk.PRNGSequence(42)
    params, state = model.init(next(rng_seq),
                               jnp.zeros((1, map_size, map_size, 2)),
                               jnp.zeros((1, 1, 1, 1)),
                               is_training=True)

    # Load the weights of the neural network
    if not FLAGS.gaussian_only:
        with open(FLAGS.model_weights, 'rb') as file:
            params, state, sn_state = pickle.load(file)
        residual_prior_score = partial(model.apply,
                                       params,
                                       state,
                                       next(rng_seq),
                                       is_training=True)

    pixel_size = jnp.pi * FLAGS.resolution / 180. / 60.  #rad/pixel
    # Load prior power spectrum
    ps_data = onp.load(FLAGS.gaussian_path).astype('float32')
    ell = jnp.array(ps_data[0, :])
    # 4th channel for massivenu
    ps_halofit = jnp.array(ps_data[1, :] /
                           pixel_size**2)  # normalisation by pixel size
    # convert to pixel units of our simple power spectrum calculator
    kell = ell / 2 / jnp.pi * 360 * pixel_size / map_size
    # Interpolate the Power Spectrum in Fourier Space
    power_map = jnp.array(make_power_map(ps_halofit, map_size, kps=kell))

    # Load the noiseless convergence map
    if not FLAGS.COSMOS:
        print('i am here')
        convergence = fits.getdata(FLAGS.convergence).astype('float32')

        # Get the correspinding shear
        gamma1, gamma2 = ks93inv(convergence, onp.zeros_like(convergence))

        if not FLAGS.no_cluster:
            print('adding a cluster')
            # Compute NFW profile shear map
            g1_NFW, g2_NFW = gen_nfw_shear(x_cen=FLAGS.x_cluster,
                                           y_cen=FLAGS.y_cluster,
                                           resolution=FLAGS.resolution,
                                           nx=map_size,
                                           ny=map_size,
                                           z=FLAGS.z_halo,
                                           m=FLAGS.mass_halo,
                                           zs=FLAGS.zs)
            # Shear with added NFW cluster
            gamma1 += g1_NFW
            gamma2 += g2_NFW

            # Target convergence map with the added cluster
            #ke_cluster, kb_cluster = ks93(g1_cluster, g2_cluster)

        # Add noise the shear map
        if FLAGS.cosmos_noise_realisation:
            print('cosmos noise real')
            gamma1 += fits.getdata(FLAGS.cosmos_noise_e1).astype('float32')
            gamma2 += fits.getdata(FLAGS.cosmos_noise_e2).astype('float32')

        else:
            gamma1 += std1[..., 0] * jax.random.normal(
                jax.random.PRNGKey(42),
                gamma1.shape)  #onp.random.randn(map_size,map_size)
            gamma2 += std2[..., 0] * jax.random.normal(
                jax.random.PRNGKey(43),
                gamma2.shape)  #onp.random.randn(map_size,map_size)

        # Load the shear maps and corresponding mask
        gamma = onp.stack(
            [gamma1, gamma2],
            -1)  # Shear is expected in the format [map_size,map_size,2]

    else:

        # Load the shear maps and corresponding mask
        g1 = fits.getdata('../data/COSMOS/cosmos_full_e1_0.29arcmin360.fits'
                          ).astype('float32').reshape([map_size, map_size, 1])
        g2 = fits.getdata('../data/COSMOS/cosmos_full_e2_0.29arcmin360.fits'
                          ).astype('float32').reshape([map_size, map_size, 1])
        gamma = onp.concatenate([g1, g2], axis=-1)

    mask = jnp.expand_dims(fits.getdata(FLAGS.mask).astype('float32'),
                           -1)  # has shape [map_size,map_size,1]

    masked_true_shear = gamma * mask
    #fits.writeto("./input_shear.fits", onp.array(masked_true_shear), overwrite=False)

    sigma_mask = (1 - mask) * 1e10

    def score_fn(params, state, x, sigma, is_training=False):
        if b_mode:
            x = x.reshape((-1, 360, 360, 2))
            ke = x[..., 0]
            kb = x[..., 1]
        else:
            ke = x.reshape((-1, 360, 360))

        if FLAGS.gaussian_prior:
            # If requested, first compute the Gaussian prior
            gs = gaussian_prior_score(ke, sigma.reshape((-1, 1, 1)), power_map)
            gs = jnp.expand_dims(gs, axis=-1)
            #print((jnp.abs(sigma.reshape((-1,1,1,1)))**2).shape, (gs).shape)
            net_input = jnp.concatenate([
                ke.reshape((-1, 360, 360, 1)),
                jnp.abs(sigma.reshape((-1, 1, 1, 1)))**2 * gs
            ],
                                        axis=-1)
            res, state = model.apply(params,
                                     state,
                                     net_input,
                                     sigma.reshape((-1, 1, 1, 1)),
                                     is_training=is_training)
            if b_mode:
                gsb = gaussian_prior_score_b(kb, sigma.reshape((-1, 1, 1)))
                gsb = jnp.expand_dims(gsb, axis=-1)
            else:
                gsb = jnp.zeros_like(res)
        else:
            res, state = model.apply(params,
                                     state,
                                     ke.reshape((-1, 360, 360, 1)),
                                     sigma.reshape((-1, 1, 1, 1)),
                                     is_training=is_training)
            gs = jnp.zeros_like(res)
            gsb = jnp.zeros_like(res)
        return _, res, gs, gsb

    score_fn = partial(score_fn, params, state)

    def score_prior(x, sigma):
        if b_mode:
            _, res, gaussian_score, gsb = score_fn(x.reshape(-1, 360, 360, 2),
                                                   sigma.reshape(-1, 1, 1, 1))
        else:
            _, res, gaussian_score, gsb = score_fn(x.reshape(-1, 360, 360),
                                                   sigma.reshape(-1, 1, 1))
        ke = (res[..., 0:1] + gaussian_score).reshape(-1, 360 * 360)
        kb = gsb[..., 0].reshape(-1, 360 * 360)
        if b_mode:
            return jnp.stack([ke, kb], axis=-1)
        else:
            return ke

    def total_score_fn(x, sigma):
        if b_mode:
            sl = likelihood_score(x, sigma, masked_true_shear, mask,
                                  sigma_mask).reshape(-1, 360 * 360, 2)
        else:
            sl = likelihood_score(x, sigma, masked_true_shear, mask,
                                  sigma_mask).reshape(-1, 360 * 360)
        sp = score_prior(x, sigma)
        if b_mode:
            return (sl + sp).reshape(-1, 360 * 360 * 2)
        else:
            return (sl + sp).reshape(-1, 360 * 360)
        #return (sp).reshape(-1, 360*360,2)

    # Prepare the input with a high noise level map

    initial_temperature = FLAGS.initial_temperature
    delta_tmp = initial_temperature  #onp.sqrt(initial_temperature**2 - 0.148**2)
    initial_step_size = FLAGS.initial_step_size  #0.018
    min_steps_per_temp = FLAGS.min_steps_per_temp  #10
    init_image, _ = ks93(mask[..., 0] * masked_true_shear[..., 0],
                         mask[..., 0] * masked_true_shear[..., 1])
    init_image = jnp.expand_dims(init_image, axis=0)
    init_image = jnp.repeat(init_image, FLAGS.batch_size, axis=0)
    init_image += (delta_tmp * onp.random.randn(FLAGS.batch_size, 360, 360))

    def make_kernel_fn(target_log_prob_fn, target_score_fn, sigma):
        return ScoreHamiltonianMonteCarlo(
            target_log_prob_fn=target_log_prob_fn,
            target_score_fn=target_score_fn,
            step_size=initial_step_size *
            (jnp.max(sigma) / initial_temperature)**0.5,
            num_leapfrog_steps=3,
            num_delta_logp_steps=4)

    tmc = TemperedMC(
        target_score_fn=total_score_fn,  #score_prior,
        inverse_temperatures=initial_temperature *
        jnp.ones([FLAGS.batch_size]),
        make_kernel_fn=make_kernel_fn,
        gamma=0.98,
        min_temp=8e-3,
        min_steps_per_temp=min_steps_per_temp,
        num_delta_logp_steps=4)

    num_burnin_steps = int(0)

    samples, trace = tfp.mcmc.sample_chain(
        num_results=2,  #FLAGS.num_steps,
        current_state=init_image.reshape([FLAGS.batch_size, -1]),
        kernel=tmc,
        num_burnin_steps=num_burnin_steps,
        num_steps_between_results=6000,  #num_results//FLAGS.num_steps,
        trace_fn=lambda _, pkr:
        (pkr.pre_tempering_results.is_accepted, pkr.
         post_tempering_inverse_temperatures, pkr.tempering_log_accept_ratio),
        seed=jax.random.PRNGKey(int(time.time())))

    sol = samples[-1, ...].reshape(-1, 360, 360)

    from scipy import integrate

    @jax.jit
    def dynamics(t, x):
        if b_mode:
            x = x.reshape([-1, 360, 360, 2])
            return -0.5 * total_score_fn(
                x, sigma=jnp.ones(
                    (FLAGS.batch_size, 1, 1, 1)) * jnp.sqrt(t)).reshape([-1])
        else:
            x = x.reshape([-1, 360, 360])
            return -0.5 * total_score_fn(
                x, sigma=jnp.ones(
                    (FLAGS.batch_size, 1, 1)) * jnp.sqrt(t)).reshape([-1])

    init_ode = sol

    last_trace = jnp.mean(trace[1][-1])
    noise = last_trace
    start_and_end_times = jnp.logspace(jnp.log10(0.99 * noise**2), -5, num=50)

    solution = integrate.solve_ivp(dynamics, [noise**2, (1e-5)],
                                   init_ode.flatten(),
                                   t_eval=start_and_end_times)

    denoised = solution.y[:, -1].reshape([FLAGS.batch_size, 360, 360])

    fits.writeto("./results/" + FLAGS.output_folder + "/samples_hmc_" +
                 FLAGS.output_file + ".fits",
                 onp.array(sol),
                 overwrite=False)
    fits.writeto("./results/" + FLAGS.output_folder + "/samples_denoised_" +
                 FLAGS.output_file + ".fits",
                 onp.array(denoised),
                 overwrite=False)

    print('end of sampling')
Пример #11
0
def spin_wiener_sampler(data_q,
                        data_u,
                        ncov_diag_Q,
                        ncov_diag_U,
                        input_ps_map_E,
                        input_ps_map_B,
                        iterations=10,
                        initial_map=None,
                        thinning=1,
                        verbose=False):
    """
    Wiener posterior sampler using Elsner-Wandelt messenger field adapted for spin-2 fields (CMB polarization or galaxy weak lensing(
    Parameters
    Parameters
    ----------
    data_q : Q square image data (e.g. gamma1)
    data_u : U square image data (e.g. gamma2)
    ncov_diag_Q : Q noise variance per pixel (assumed uncorrelated)
    ncov_diag_U : U noise variance per pixel (assumed uncorrelated)
    input_ps_map_E : 1D power P(k) for E-mode signal power spectrum evaluated 2D components k1,k2 as a square image
    input_ps_map_B : 1D power P(k) for B-mode signal power spectrum evaluated 2D components k1,k2 as a square image
    iterations : number of iterations
    initial_map : starting image for the sampler
    thinning : thinning factor (iterations must be divisible by thinning factor)
    verbose : bool verbose

    Returns
    -------
    samples_E, samples_B : samples from Wiener posterior
    """
    size = (data_q).shape[0]
    tcov_diag = np.min(np.array([ncov_diag_Q, ncov_diag_U]))
    tcov_ft = tcov_diag  # unnecessary really, but convention dependent
    scov_ft_E = np.fft.fftshift(input_ps_map_E)
    scov_ft_B = np.fft.fftshift(input_ps_map_B)
    sigma_t_squared_Q = tcov_diag - tcov_diag * tcov_diag / ncov_diag_Q
    sigma_t_squared_U = tcov_diag - tcov_diag * tcov_diag / ncov_diag_U
    sigma_s_squared_E = scov_ft_E * tcov_ft / (tcov_ft + scov_ft_E)
    sigma_s_squared_B = scov_ft_B * tcov_ft / (tcov_ft + scov_ft_B)

    print(sigma_s_squared_B.mean())

    if initial_map is None:
        s = data_q + 1j * data_u
    else:
        s = np.copy(initial_map)

    assert (iterations % thinning == 0)

    samples_E = np.zeros(shape=(int(iterations / thinning), size, size),
                         dtype=jnp.complex128)
    samples_B = np.zeros(shape=(int(iterations / thinning), size, size),
                         dtype=jnp.complex128)

    for i in range(iterations):
        # in Q, U representation
        t_Q = (tcov_diag / ncov_diag_Q) * data_q + (
            (ncov_diag_Q - tcov_diag) / ncov_diag_Q) * s[0]
        t_U = (tcov_diag / ncov_diag_U) * data_u + (
            (ncov_diag_U - tcov_diag) / ncov_diag_U) * s[1]
        t_Q = np.random.normal(t_Q.real, np.sqrt(sigma_t_squared_Q.real))
        t_U = np.random.normal(t_U.real, np.sqrt(sigma_t_squared_U.real))
        # in E, B representation
        t = ks93(t_Q, t_U)
        s_E = (scov_ft_E / (scov_ft_E + tcov_ft)) * np.fft.fft2(t[0])
        s_B = (scov_ft_B / (scov_ft_B + tcov_ft)) * np.fft.fft2(t[1])

        s_E = np.random.normal(s_E.real * 0.,
                               np.sqrt(sigma_s_squared_E.real) * size) + s_E
        s_B = np.random.normal(s_B.real * 0.,
                               np.sqrt(sigma_s_squared_B.real) * size) + s_B
        s_E = (np.fft.ifft2(s_E))
        s_E = (s_E.real + s_E.imag)
        s_B = (np.fft.ifft2(s_B))
        s_B = (s_B.real + s_B.imag)
        s = ks93inv(s_E, s_B)
        if i % thinning == 0:
            samples_E[int(i / thinning)] = s_E
            samples_B[int(i / thinning)] = s_B
            if verbose == True:
                print(i)
    return samples_E, samples_B