Exemple #1
0
def test_tomographic_kernel():
    from jax import random
    from jaxns.gaussian_process.kernels import RBF
    import pylab as plt
    n = 300
    a1 = jnp.array([[-1, 0., 0.]])
    k1 = jnp.stack([
        4. * jnp.pi / 180. *
        random.uniform(random.PRNGKey(0), shape=(n, ), minval=-1, maxval=1),
        4. * jnp.pi / 180. *
        random.uniform(random.PRNGKey(1), shape=(n, ), minval=-1, maxval=1),
        jnp.ones(n)
    ],
                   axis=1)
    k1 /= jnp.linalg.norm(k1, axis=-1, keepdims=True)
    n = 1
    a2 = jnp.array([[1., 0., 0.]])
    k2 = jnp.stack([jnp.zeros(n), jnp.zeros(n), jnp.ones(n)], axis=1)
    k2 /= jnp.linalg.norm(k2, axis=-1, keepdims=True)
    x0 = jnp.zeros(3)
    K = tomographic_kernel(a1,
                           a2,
                           k1,
                           k2,
                           x0,
                           RBF(),
                           height=10.,
                           width=2.,
                           l=1.,
                           sigma=1.,
                           S=25)
    sc = plt.scatter(k1[:, 0], k1[:, 1], c=K[:, 0])
    plt.colorbar(sc)
    plt.show()
Exemple #2
0
def rbf_dtec(nant, ndir, ntime, height, width, sigma, l, uncert, v):
    """
    In frozen flow the screen moves with velocity v.
    fed(x,t) = fed(x-v*t,0)
    so that the  tomographic kernel transforms as,
    K(x1,k1,t1,x2,k2,t2) = K(x1-v * t1,k1,0,x2-v * t2,k2,0)
    """
    import pylab as plt
    a = jnp.concatenate([
        10. * random.uniform(random.PRNGKey(0), shape=(nant, 2)),
        jnp.zeros((nant, 1))
    ],
                        axis=1)
    k = jnp.concatenate([
        4. * jnp.pi / 180. * random.uniform(
            random.PRNGKey(0), shape=(ndir, 2), minval=-1, maxval=1),
        jnp.ones((ndir, 1))
    ],
                        axis=1)
    k = k / jnp.linalg.norm(k, axis=1, keepdims=True)
    t = jnp.arange(ntime)[:, None] * 30.  #seconds
    X = make_coord_array(a, k, t)
    a = X[:, 0:3]
    k = X[:, 3:6]
    t = X[:, 6:7]
    x0 = a[0, :]
    kernel = TomographicKernel(x0, RBF(), S_marg=100, S_gamma=100)
    K = kernel(X[:, :6] - jnp.concatenate([v, jnp.zeros(3)]) * t,
               X[:, :6] - jnp.concatenate([v, jnp.zeros(3)]) * t, height,
               width, l, sigma)
    plt.imshow(K)
    plt.colorbar()
    plt.show()
    plt.plot(jnp.sqrt(jnp.diag(K)))
    plt.show()

    L = msqrt(K)  #jnp.linalg.cholesky(K + jnp.eye(K.shape_dict[0])*1e-3)

    tec = L @ random.normal(random.PRNGKey(2), shape=(L.shape[0], ))
    tec = tec.reshape((nant, ndir, ntime))
    dtec = tec - tec[0, :, :]
    dtec = dtec.reshape((-1, ))
    plt.plot(dtec)
    plt.show()
    return X, dtec, dtec + uncert * random.normal(random.PRNGKey(3),
                                                  shape=dtec.shape)
Exemple #3
0
def rbf_dtec(nant, ndir, height, width, sigma, l, uncert=1.):
    import pylab as plt
    a = jnp.concatenate([
        10. * random.uniform(random.PRNGKey(0), shape=(nant, 2)),
        jnp.zeros((nant, 1))
    ],
                        axis=1)
    k = jnp.concatenate([
        4. * jnp.pi / 180. * random.uniform(
            random.PRNGKey(0), shape=(ndir, 2), minval=-1, maxval=1),
        jnp.ones((ndir, 1))
    ],
                        axis=1)
    k = k / jnp.linalg.norm(k, axis=1, keepdims=True)
    X = make_coord_array(a, k)
    a = X[:, 0:3]
    k = X[:, 3:6]
    x0 = a[0, :]
    kernel = TomographicKernel(x0, RBF(), S_marg=100, S_gamma=100)
    K = kernel(X, X, height, width, l, sigma)
    plt.imshow(K)
    plt.colorbar()
    plt.show()
    plt.plot(jnp.sqrt(jnp.diag(K)))
    plt.show()

    L = msqrt(K)  #jnp.linalg.cholesky(K + jnp.eye(K.shape_dict[0])*1e-3)

    tec = L @ random.normal(random.PRNGKey(2), shape=(L.shape[0], ))
    tec = tec.reshape((nant, ndir))
    dtec = tec - tec[0, :]
    dtec = jnp.reshape(dtec, (-1, ))
    TEC_CONV = -8.4479745e6  # mTECU/Hz
    freqs = jnp.linspace(121e6, 168e6, 24)
    tec_conv = TEC_CONV / freqs
    phase = dtec[:, None] * tec_conv
    Y = jnp.concatenate([jnp.cos(phase), jnp.sin(phase)], axis=-1)
    plt.plot(dtec)
    plt.show()
    return X, dtec, Y, Y + uncert * random.normal(random.PRNGKey(3),
                                                  shape=Y.shape), tec_conv
def test_tomographic_kernel():
    dp = make_example_datapack(500, 24, 1, clobber=True)
    with dp:
        select = dict(pol=slice(0, 1, 1), ant=slice(0, None, 1))
        dp.current_solset = 'sol000'
        dp.select(**select)
        tec_mean, axes = dp.tec
        tec_mean = tec_mean[0, ...]
        patch_names, directions = dp.get_directions(axes['dir'])
        antenna_labels, antennas = dp.get_antennas(axes['ant'])
        timestamps, times = dp.get_times(axes['time'])
    antennas = ac.ITRS(*antennas.cartesian.xyz, obstime=times[0])
    ref_ant = antennas[0]
    frame = ENU(obstime=times[0], location=ref_ant.earth_location)
    antennas = antennas.transform_to(frame)
    ref_ant = antennas[0]
    directions = directions.transform_to(frame)
    x = antennas.cartesian.xyz.to(au.km).value.T
    k = directions.cartesian.xyz.value.T
    X = make_coord_array(x[50:51, :], k)
    x0 = ref_ant.cartesian.xyz.to(au.km).value
    print(k.shape)

    kernel = TomographicKernel(x0, x0, RBF(), S_marg=25)
    K = jit(lambda X: kernel(
        X, X, bottom=200., width=50., fed_kernel_params=dict(l=7., sigma=1.)))(
            jnp.asarray(X))
    # K /= jnp.outer(jnp.sqrt(jnp.diag(K)), jnp.sqrt(jnp.diag(K)))
    plt.imshow(K)
    plt.colorbar()
    plt.show()
    L = jnp.linalg.cholesky(K + 1e-6 * jnp.eye(K.shape[0]))
    print(L)
    dtec = L @ random.normal(random.PRNGKey(24532), shape=(K.shape[0], ))
    print(jnp.std(dtec))
    ax = plot_vornoi_map(k[:, 0:2], dtec)
    ax.set_xlabel(r"$k_{\rm east}$")
    ax.set_ylabel(r"$k_{\rm north}$")
    ax.set_xlim(-0.1, 0.1)
    ax.set_ylim(-0.1, 0.1)
    plt.show()
Exemple #5
0
    #     plt.errorbar(n, results.logZ, yerr=results.logZerr)


    plt.title("Kernel: {}".format(kernel.__class__.__name__))
    plt.ylabel('log Z')
    plt.show()

    fstd = jnp.sqrt(results.marginalised['predict_fvar'] - results.marginalised['predict_f'] ** 2)
    plt.scatter(jnp.arange(Y.size),Y_obs, marker='+', label='data')
    plt.scatter(jnp.arange(Y.size),Y, marker="o", label='underlying')
    plt.scatter(jnp.arange(Y.size), results.marginalised['predict_f'], marker=".", label='underlying')
    plt.errorbar(jnp.arange(Y.size), results.marginalised['predict_f'], yerr=fstd, label='marginalised')
    plt.title("Kernel: {}".format(kernel.__class__.__name__))
    plt.legend()
    plt.show()

    # plot_samples_development(results,save_name='./ray_integral_solution.mp4')
    plot_diagnostics(results)
    plot_cornerplot(results)
    return results.logZ, results.logZerr


if __name__ == '__main__':
    logZ_rbf, logZerr_rbf = main(RBF())
    # logZ_m12, logZerr_m12 = main(M12())
    # plt.errorbar(['rbf', 'm12'],
    #              [logZ_rbf, logZ_m12],
    #              [logZerr_rbf, logZerr_m12])
    # plt.ylabel("log Z")
    # plt.show()
Exemple #6
0
def main(kernel):
    print(("Working on Kernel: {}".format(kernel.__class__.__name__)))

    def log_normal(x, mean, cov):
        L = jnp.linalg.cholesky(cov)
        # U, S, Vh = jnp.linalg.svd(cov)
        log_det = jnp.sum(jnp.log(jnp.diag(L)))  # jnp.sum(jnp.log(S))#
        dx = x - mean
        dx = solve_triangular(L, dx, lower=True)
        # U S Vh V 1/S Uh
        # pinv = (Vh.T.conj() * jnp.where(S!=0., jnp.reciprocal(S), 0.)) @ U.T.conj()
        maha = dx @ dx  # dx @ pinv @ dx#solve_triangular(L, dx, lower=True)
        log_likelihood = -0.5 * x.size * jnp.log(2. * jnp.pi) \
                         - log_det \
                         - 0.5 * maha
        # print(log_likelihood)
        return log_likelihood

    N = 100
    X = jnp.linspace(-2., 2., N)[:, None]
    true_sigma, true_l, true_uncert = 1., 0.2, 0.2
    data_mu = jnp.zeros((N, ))
    prior_cov = RBF()(X, X, true_l, true_sigma) + 1e-13 * jnp.eye(N)
    # print(jnp.linalg.cholesky(prior_cov), jnp.linalg.eigvals(prior_cov))
    # return
    Y = jnp.linalg.cholesky(prior_cov) @ random.normal(random.PRNGKey(0),
                                                       shape=(N, )) + data_mu
    Y_obs = Y + true_uncert * random.normal(random.PRNGKey(1), shape=(N, ))
    Y_obs = jnp.where((jnp.arange(N) > 50) & (jnp.arange(N) < 60),
                      random.normal(random.PRNGKey(1), shape=(N, )), Y_obs)

    # plt.scatter(X[:, 0], Y_obs, label='data')
    # plt.plot(X[:, 0], Y, label='underlying')
    # plt.legend()
    # plt.show()

    def log_likelihood(K, uncert, **kwargs):
        """
        P(Y|sigma, half_width) = N[Y, mu, K]
        Args:
            sigma:
            l:

        Returns:

        """
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        mu = jnp.zeros_like(Y_obs)
        return log_normal(Y_obs, mu, K + data_cov)

    def predict_f(K, uncert, **kwargs):
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        mu = jnp.zeros_like(Y_obs)
        return mu + K @ jnp.linalg.solve(K + data_cov, Y_obs)

    def predict_fvar(K, uncert, **kwargs):
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        mu = jnp.zeros_like(Y_obs)
        return jnp.diag(K - K @ jnp.linalg.solve(K + data_cov, K))

    l = UniformPrior('l', 0., 2.)
    uncert = UniformPrior('uncert', 0., 2.)
    sigma = UniformPrior('sigma', 0., 2.)
    cov = GaussianProcessKernelPrior('K', kernel, X, l, sigma)
    prior_chain = PriorChain().push(uncert).push(cov)
    # print(prior_chain)

    ns = NestedSampler(log_likelihood,
                       prior_chain,
                       sampler_name='multi_ellipsoid',
                       predict_f=predict_f,
                       predict_fvar=predict_fvar)

    def run_with_n(n):
        @jit
        def run(key):
            return ns(key=key,
                      num_live_points=n,
                      max_samples=1e5,
                      collect_samples=True,
                      termination_frac=0.01,
                      stoachastic_uncertainty=False,
                      sampler_kwargs=dict(depth=3))

        t0 = default_timer()
        # with disable_jit():
        results = run(random.PRNGKey(6))
        print(results.efficiency)
        print(
            "Time to execute (including compile): {}".format(default_timer() -
                                                             t0))
        t0 = default_timer()
        results = run(random.PRNGKey(6))
        print(results.efficiency)
        print("Time to execute (not including compile): {}".format(
            (default_timer() - t0)))
        return results

    for n in [100]:
        results = run_with_n(n)
        plt.scatter(n, results.logZ)
        plt.errorbar(n, results.logZ, yerr=results.logZerr)
    plt.title("Kernel: {}".format(kernel.__class__.__name__))
    plt.ylabel('log Z')
    plt.show()

    plt.scatter(X[:, 0], Y_obs, label='data')
    plt.plot(X[:, 0], Y, label='underlying')
    plt.plot(X[:, 0], results.marginalised['predict_f'], label='marginalised')
    plt.plot(X[:, 0],
             results.marginalised['predict_f'] +
             jnp.sqrt(results.marginalised['predict_fvar']),
             ls='dotted',
             c='black')
    plt.plot(X[:, 0],
             results.marginalised['predict_f'] -
             jnp.sqrt(results.marginalised['predict_fvar']),
             ls='dotted',
             c='black')
    plt.title("Kernel: {}".format(kernel.__class__.__name__))
    plt.legend()
    plt.show()
    plot_diagnostics(results)
    plot_cornerplot(results)
    return results.logZ, results.logZerr
Exemple #7
0
def main():
    def log_normal(x, mean, uncert):
        dx = x - mean
        dx = dx / uncert
        return -0.5 * x.size * jnp.log(
            2. * jnp.pi) - x.size * jnp.log(uncert) - 0.5 * dx @ dx

    N = 100
    X = jnp.linspace(-2., 2., N)[:, None]
    true_alpha, true_sigma, true_l, true_uncert = 1., 1., 0.2, 0.25
    data_mu = jnp.zeros((N, ))
    prior_cov = RBF()(X, X, true_l, true_sigma)
    Y = jnp.linalg.cholesky(prior_cov) @ random.normal(random.PRNGKey(0),
                                                       shape=(N, )) + data_mu
    Y_obs = Y + true_uncert * random.normal(random.PRNGKey(1), shape=(N, ))

    def predict_f(sigma, K, uncert, **kwargs):
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        mu = jnp.zeros_like(Y_obs)
        return mu + K @ jnp.linalg.solve(K + data_cov, Y_obs)

    def predict_fvar(sigma, K, uncert, **kwargs):
        data_cov = jnp.square(uncert) * jnp.eye(X.shape[0])
        mu = jnp.zeros_like(Y_obs)
        return jnp.diag(K - K @ jnp.linalg.solve(K + data_cov, K))

    ###
    # define the prior chain
    # Here we assume each image is represented by pixels.
    # Alternatively, you could choose regions arranged non-uniformly over the image.

    image_shape = (128, 128)
    npix = image_shape[0] * image_shape[1]
    I150 = jnp.ones(image_shape)

    alpha_cw_gp_sigma = HalfLaplacePrior('alpha_cw_gp_sigma', 1.)
    alpha_mw_gp_sigma = HalfLaplacePrior('alpha_mw_gp_sigma', 1.)
    l_cw = UniformPrior('l_cw', 0., 0.5)  #degrees
    l_mw = UniformPrior('l_mw', 0.5, 2.)  #degrees
    K_cw = GaussianProcessKernelPrior('K_cw', RBF(), X, l_cw,
                                      alpha_cw_gp_sigma)
    K_mw = GaussianProcessKernelPrior('K_mw', RBF(), X, l_mw,
                                      alpha_mw_gp_sigma)
    alpha_cw = MVNPrior('alpha_cw', -1.5, K_cw)
    alpha_mw = MVNPrior('alpha_mw', -2.5, K_mw)
    S_cw_150 = UniformPrior('S150_cw', 0., I150)
    S_mw_150 = UniformPrior('S150_mw', 0., I150)
    uncert = HalfLaplacePrior('uncert', 1.)

    def log_likelihood(uncert, alpha_cw, alpha_mw, S_cw_150, S_mw_150):
        log_prob = 0
        for img, freq in zip(images, freqs):  # <- need to define these
            I_total = S_mw_150 * (freq / 150e6)**(alpha_mw) + S_cw_150 * (
                freq / 150e6)**(alpha_cw)
            log_prob += log_normal(img, I_total, uncert)
        return log_prob

    prior_chain = PriorChain()\
        .push(alpha_cw).push(S_cw_150)\
        .push(alpha_mw).push(S_mw_150)\
        .push(uncert)
    print(prior_chain)

    ns = NestedSampler(log_likelihood,
                       prior_chain,
                       sampler_name='ellipsoid',
                       predict_f=predict_f,
                       predict_fvar=predict_fvar)

    def run_with_n(n):
        @jit
        def run():
            return ns(key=random.PRNGKey(0),
                      num_live_points=n,
                      max_samples=1e3,
                      collect_samples=True,
                      termination_frac=0.01,
                      stoachastic_uncertainty=True)

        results = run()
        return results
def train_neural_network(datapack: DataPack, batch_size, learning_rate,
                         num_batches):

    with datapack:
        select = dict(pol=slice(0, 1, 1), ant=None, time=slice(0, 1, 1))
        datapack.current_solset = 'sol000'
        datapack.select(**select)
        axes = datapack.axes_tec
        patch_names, directions = datapack.get_directions(axes['dir'])
        antenna_labels, antennas = datapack.get_antennas(axes['ant'])
        timestamps, times = datapack.get_times(axes['time'])

    antennas = ac.ITRS(*antennas.cartesian.xyz, obstime=times[0])
    ref_ant = antennas[0]
    frame = ENU(obstime=times[0], location=ref_ant.earth_location)
    antennas = antennas.transform_to(frame)
    ref_ant = antennas[0]
    directions = directions.transform_to(frame)
    x = antennas.cartesian.xyz.to(au.km).value.T
    k = directions.cartesian.xyz.value.T
    t = times.mjd
    t -= t[len(t) // 2]
    t *= 86400.
    n_screen = 250
    kstar = random.uniform(random.PRNGKey(29428942), (n_screen, 3),
                           minval=jnp.min(k, axis=0),
                           maxval=jnp.max(k, axis=0))
    kstar /= jnp.linalg.norm(kstar, axis=-1, keepdims=True)
    X = jnp.asarray(
        make_coord_array(x, jnp.concatenate([k, kstar], axis=0), t[:, None]))
    x0 = jnp.asarray(antennas.cartesian.xyz.to(au.km).value.T[0, :])
    ref_ant = x0

    kernel = TomographicKernel(x0, ref_ant, RBF(), S_marg=100)
    neural_kernel = NeuralTomographicKernel(x0, ref_ant)

    def loss(params, key):
        keys = random.split(key, 5)
        indices = random.permutation(keys[0],
                                     jnp.arange(X.shape[0]))[:batch_size]
        X_batch = X[indices, :]

        wind_velocity = random.uniform(keys[1],
                                       shape=(3, ),
                                       minval=jnp.asarray([-200., -200., 0.]),
                                       maxval=jnp.asarray([200., 200., 0.
                                                           ])) / 1000.
        bottom = random.uniform(keys[2], minval=50., maxval=500.)
        width = random.uniform(keys[3], minval=40., maxval=300.)
        l = random.uniform(keys[4], minval=1., maxval=30.)
        sigma = 1.
        K = kernel(X_batch,
                   X_batch,
                   bottom,
                   width,
                   l,
                   sigma,
                   wind_velocity=wind_velocity)
        neural_kernel.set_params(params)
        neural_K = neural_kernel(X_batch,
                                 X_batch,
                                 bottom,
                                 width,
                                 l,
                                 sigma,
                                 wind_velocity=wind_velocity)

        return jnp.mean((K - neural_K)**2) / width**2

    init_params = neural_kernel.init_params(random.PRNGKey(42))

    def train_one_batch(params, key):
        l, g = value_and_grad(lambda params: loss(params, key))(params)
        params = tree_multimap(lambda p, g: p - learning_rate * g, params, g)
        return params, l

    final_params, losses = jit(lambda key: scan(
        train_one_batch, init_params, random.split(key, num_batches)))(
            random.PRNGKey(42))

    plt.plot(losses)
    plt.yscale('log')
    plt.show()