def _compute_with_sigma(sigma):
            def _compute(dtec, dtec_uncert):
                return log_normal_with_outliers(dtec, 0., sigma**2 * K,
                                                dtec_uncert)

            # M
            return chunked_pmap(_compute, dtec, dtec_uncert, chunksize=1)
def precompute_log_prob_components_without_wind(kernel,
                                                X,
                                                dtec,
                                                dtec_uncert,
                                                bottom_array,
                                                width_array,
                                                lengthscale_array,
                                                sigma_array,
                                                chunksize=2):
    """
    Precompute the log_prob for each parameter.

    Args:
        kernel:
        X:
        dtec:
        dtec_uncert:
        *arrays:

    Returns:

    """

    arrays = jnp.meshgrid(bottom_array,
                          width_array,
                          lengthscale_array,
                          indexing='ij')
    arrays = [a.ravel() for a in arrays]

    def compute_log_prob_components(bottom, width, lengthscale):
        # N, N
        K = kernel(X, X, bottom, width, lengthscale, 1., wind_velocity=None)

        def _compute_with_sigma(sigma):
            def _compute(dtec, dtec_uncert):
                return log_normal_with_outliers(dtec, 0., sigma**2 * K,
                                                dtec_uncert)

            # M
            return chunked_pmap(_compute, dtec, dtec_uncert, chunksize=1)

        # Ns,M
        return chunked_pmap(_compute_with_sigma, sigma_array, chunksize=1)

    Nb = bottom_array.shape[0]
    Nw = width_array.shape[0]
    Nl = lengthscale_array.shape[0]
    Ns = sigma_array.shape[0]

    # Nb*Nw*Nl,Ns,M
    log_prob = chunked_pmap(compute_log_prob_components,
                            *arrays,
                            chunksize=chunksize)
    # M, Nb,Nw,Nl,Ns
    log_prob = log_prob.reshape((Nb * Nw * Nl * Ns, dtec.shape[0])).transpose(
        (1, 0)).reshape((dtec.shape[0], Nb, Nw, Nl, Ns))
    return log_prob
Ejemplo n.º 3
0
 def smooth(y, weights):
     y = axes_move(y, ['dat', 'b'], ['da', 'tb'], size_dict=size_dict)
     weights = axes_move(weights, ['dat', 'b'], ['da', 'tb'],
                         size_dict=size_dict)
     y = chunked_pmap(
         lambda y, weights: poly_smooth(times, y, deg=5, weights=weights),
         y, weights)
     y = axes_move(y, ['da', 'tb'], ['dat', 'b'], size_dict=size_dict)
     return y
    def compute_log_prob_components(lengthscale):
        # N, N
        K = kernel(X, X, lengthscale, 1.)

        def _compute_with_sigma(sigma):
            def _compute(dtec, dtec_uncert):
                #each [Nd]
                return log_normal_with_outliers(dtec, 0., sigma**2 * K,
                                                dtec_uncert)

            return chunked_pmap(_compute, dtec, dtec_uncert, chunksize=1)  #M

        # Ns,M
        return chunked_pmap(_compute_with_sigma, sigma_array, chunksize=1)
    def compute_log_prob_components(bottom, width, lengthscale):
        # N, N
        K = kernel(X, X, bottom, width, lengthscale, 1., wind_velocity=None)

        def _compute_with_sigma(sigma):
            def _compute(dtec, dtec_uncert):
                return log_normal_with_outliers(dtec, 0., sigma**2 * K,
                                                dtec_uncert)

            # M
            return chunked_pmap(_compute, dtec, dtec_uncert, chunksize=1)

        # Ns,M
        return chunked_pmap(_compute_with_sigma, sigma_array, chunksize=1)
Ejemplo n.º 6
0
def get_data(solution_file):
    logger.info("Getting DDS4 data.")
    with DataPack(solution_file, readonly=True) as h:
        select = dict(pol=slice(0, 1, 1),
                      ant=slice(50, 51),
                      dir=slice(0, None, 1),
                      time=slice(0, 100, 1))
        h.select(**select)
        phase, axes = h.phase
        phase = phase[0, ...]
        gain_outliers, _ = h.weights_phase
        gain_outliers = gain_outliers[0, ...] == 1
        amp, axes = h.amplitude
        amp = amp[0, ...]
        _, freqs = h.get_freqs(axes['freq'])
        freqs = freqs.to(au.Hz).value
        _, times = h.get_times(axes['time'])
        times = jnp.asarray(times.mjd, dtype=jnp.float64) * 86400.
        times = times - times[0]
        logger.info("Shape: {}".format(phase.shape))

        (Nd, Na, Nf, Nt) = phase.shape

        @jit
        def smooth(amp, outliers):
            '''
            Smooth amplitudes
            Args:
                amp: [Nt, Nf]
                outliers: [Nt, Nf]
            '''
            weights = jnp.where(outliers, 0., 1.)
            log_amp = jnp.log(amp)
            log_amp = vmap(lambda log_amp, weights: poly_smooth(
                times, log_amp, deg=3, weights=weights))(log_amp.T,
                                                         weights.T).T
            log_amp = vmap(lambda log_amp, weights: poly_smooth(
                freqs, log_amp, deg=3, weights=weights))(log_amp, weights)
            amp = jnp.exp(log_amp)
            return amp

        logger.info("Smoothing amplitudes")
        amp = chunked_pmap(smooth,
                           amp.reshape((Nd * Na, Nf, Nt)).transpose((0, 2, 1)),
                           gain_outliers.reshape((Nd * Na, Nf, Nt)).transpose(
                               (0, 2, 1)))  # Nd*Na,Nt,Nf
        amp = amp.transpose((0, 2, 1)).reshape((Nd, Na, Nf, Nt))
    return gain_outliers, phase, amp, times, freqs
def detect_tec_outliers(times, tec_mean, tec_std):
    """
    Detect outliers in dphase (in batch)
    Args:
        tec: [Nd, Na, Nt] tec uncert
        times: [Nt]
    Returns:
        outliers [Nd, Na,  Nt]
    """

    times, tec_mean, tec_std = jnp.asarray(times), jnp.asarray(
        tec_mean), jnp.asarray(tec_std)
    Nd, Na, Nt = tec_mean.shape
    tec_mean = tec_mean.reshape((Nd * Na, Nt))
    tec_std = tec_std.reshape((Nd * Na, Nt))
    res = chunked_pmap(lambda tec_mean, tec_std: single_detect_tec_outliers(
        times, tec_mean, tec_std),
                       tec_mean,
                       tec_std,
                       chunksize=None)
    res = tree_map(lambda x: x.reshape((Nd, Na, Nt)), res)
    return res
def detect_dphase_outliers(dphase):
    """
    Detect outliers in dphase (in batch)
    Args:
        dphase: [Nd, Na, Nf, Nt] tec uncert
        times: [Nt]
    Returns:
        outliers [Nd, Na, Nf, Nt]
    """

    Nd, Na, Nf, Nt = dphase.shape
    dphase = dphase.reshape((Nd * Na * Nf, Nt))
    outliers = jnp.abs(dphase) > 1.
    outliers = outliers | (jnp.abs(dphase) >
                           5. * jnp.sqrt(jnp.mean(dphase**2)))
    print(outliers.sum())
    outliers = chunked_pmap(lambda dphase, outliers: single_detect_outliers(
        dphase, window=15, init_outliers=outliers),
                            dphase,
                            outliers,
                            chunksize=None)
    outliers = outliers.reshape((Nd, Na, Nf, Nt))
    print(outliers.sum())
    return outliers
def solve_with_vanilla_kernel(key, dtec, dtec_uncert, X, Xstar, fed_kernel,
                              time_block_size, chunksize):
    """
    Precompute look-up tables for all blocks.

    Args:
        key: PRNG key
        dtec: [Nd, Na, Nt] TECU
        dtec_uncert: [Nd, Na, Nt] TECU
        X: [Nd,2] coordinates in deg
        Xstar: [Nd_screen, 2] screen coordinates
        fed_kernel: StationaryKernel
        time_block_size: int
        chunksize: int number of parallel devices to use.

    """
    field_of_view = 4.  #deg
    min_separation_arcmin = 4.  #drcmin
    min_separation_deg = min_separation_arcmin / 60.
    lengthscale_array = jnp.linspace(min_separation_deg, field_of_view, 120)
    sigma_array = jnp.linspace(0., 150., 150)
    kernel = fed_kernel
    lookup_func = build_lookup_index(lengthscale_array, sigma_array)

    dtec_uncert = jnp.maximum(dtec_uncert, 1e-6)

    Nd, Na, Nt = dtec.shape
    remainder = Nt % time_block_size
    extra = time_block_size - remainder
    dtec = jnp.concatenate([dtec, dtec[:, :, Nt - extra:]], axis=-1)
    dtec_uncert = jnp.concatenate(
        [dtec_uncert, dtec_uncert[:, :, Nt - extra:]], axis=-1)
    Nt = dtec.shape[-1]
    size_dict = dict(a=Na, d=Nd, b=time_block_size)
    dtec = axes_move(dtec, ['d', 'a', 'tb'], ['atb', 'd'], size_dict=size_dict)
    dtec_uncert = axes_move(dtec_uncert, ['d', 'a', 'tb'], ['atb', 'd'],
                            size_dict=size_dict)

    def compute_log_prob_components(lengthscale):
        # N, N
        K = kernel(X, X, lengthscale, 1.)

        def _compute_with_sigma(sigma):
            def _compute(dtec, dtec_uncert):
                #each [Nd]
                return log_normal_with_outliers(dtec, 0., sigma**2 * K,
                                                dtec_uncert)

            return chunked_pmap(_compute, dtec, dtec_uncert, chunksize=1)  #M

        # Ns,M
        return chunked_pmap(_compute_with_sigma, sigma_array, chunksize=1)

    # Nl,Ns,M
    log_prob = chunked_pmap(compute_log_prob_components,
                            lengthscale_array,
                            chunksize=chunksize)
    # Na * (Nt//time_block_size),block_size,Nl,Ns
    log_prob = axes_move(log_prob, ['l', 's', 'atb'], ['at', 'b', 'l', 's'],
                         size_dict=size_dict)
    # Na * (Nt//time_block_size),Nl,Ns
    log_prob = jnp.sum(log_prob, axis=1)  #independent datasets summed up.

    def run_block(key, dtec, dtec_uncert, log_prob):
        key1, key2 = random.split(key, 2)

        def log_likelihood(lengthscale, sigma, **kwargs):
            # K = kernel(X, X, lengthscale, sigma)
            # def _compute(dtec, dtec_uncert):
            #     #each [Nd]
            #     return log_normal_with_outliers(dtec, 0., K, jnp.maximum(1e-6, dtec_uncert))
            # return chunked_pmap(_compute, dtec, dtec_uncert, chunksize=1).sum()
            return lookup_func(log_prob, lengthscale, sigma)

        lengthscale = UniformPrior('lengthscale', jnp.min(lengthscale_array),
                                   jnp.max(lengthscale_array))
        sigma = UniformPrior('sigma', sigma_array.min(), sigma_array.max())
        prior_chain = PriorChain(lengthscale, sigma)

        ns = NestedSampler(loglikelihood=log_likelihood,
                           prior_chain=prior_chain,
                           sampler_kwargs=dict(num_slices=prior_chain.U_ndims *
                                               1),
                           num_live_points=prior_chain.U_ndims * 50)
        ns = jit(ns)
        results = ns(key1, termination_evidence_frac=0.1)

        def marg_func(lengthscale, sigma, **kwargs):
            def screen(dtec, dtec_uncert, **kw):
                K = kernel(X, X, lengthscale, sigma)
                Kstar = kernel(X, Xstar, lengthscale, sigma)
                L = jnp.linalg.cholesky(
                    K / (dtec_uncert[:, None] * dtec_uncert[None, :]) +
                    jnp.eye(dtec.shape[0]))
                # L = jnp.where(jnp.isnan(L), jnp.eye(L.shape[0])/sigma, L)
                dx = solve_triangular(L, dtec / dtec_uncert, lower=True)
                JT = solve_triangular(L,
                                      Kstar / dtec_uncert[:, None],
                                      lower=True)
                #var_ik = JT_ji JT_jk
                mean = JT.T @ dx
                var = jnp.sum(JT * JT, axis=0)
                return mean, var

            return vmap(screen)(dtec, dtec_uncert), lengthscale, jnp.log(
                sigma
            )  #[time_block_size,  Nd_screen], [time_block_size,  Nd_screen]

        #[time_block_size,  Nd_screen], [time_block_size,  Nd_screen], [time_block_size]
        (mean, var), mean_lengthscale, mean_logsigma = marginalise_static(
            key2, results.samples, results.log_p, 500, marg_func)
        uncert = jnp.sqrt(var)
        mean_sigma = jnp.exp(mean_logsigma)
        mean_lengthscale = jnp.ones(time_block_size) * mean_lengthscale
        mean_sigma = jnp.ones(time_block_size) * mean_sigma
        ESS = results.ESS * jnp.ones(time_block_size)
        logZ = results.logZ * jnp.ones(time_block_size)
        likelihood_evals = results.num_likelihood_evaluations * jnp.ones(
            time_block_size)
        return mean, uncert, mean_lengthscale, mean_sigma, ESS, logZ, likelihood_evals

    T = Na * (Nt // time_block_size)
    keys = random.split(key, T)
    # [T, time_block_size, Nd_screen], [T, time_block_size, Nd_screen], [T, time_block_size], [T, time_block_size]
    dtec = axes_move(dtec, ['atb', 'd'], ['at', 'b', 'd'], size_dict=size_dict)
    dtec_uncert = axes_move(dtec_uncert, ['atb', 'd'], ['at', 'b', 'd'],
                            size_dict=size_dict)
    mean, uncert, mean_lengthscale, mean_sigma, ESS, logZ, likelihood_evals = chunked_pmap(
        run_block, keys, dtec, dtec_uncert, log_prob, chunksize=chunksize)
    mean = axes_move(mean, ['at', 'b', 'n'], ['n', 'a', 'tb'],
                     size_dict=size_dict)
    uncert = axes_move(uncert, ['at', 'b', 'n'], ['n', 'a', 'tb'],
                       size_dict=size_dict)
    mean_lengthscale = axes_move(mean_lengthscale, ['at', 'b'], ['a', 'tb'],
                                 size_dict=size_dict)
    mean_sigma = axes_move(mean_sigma, ['at', 'b'], ['a', 'tb'],
                           size_dict=size_dict)
    ESS = axes_move(ESS, ['at', 'b'], ['a', 'tb'], size_dict=size_dict)
    logZ = axes_move(logZ, ['at', 'b'], ['a', 'tb'], size_dict=size_dict)
    likelihood_evals = axes_move(likelihood_evals, ['at', 'b'], ['a', 'tb'],
                                 size_dict=size_dict)
    return mean[..., Nt - extra:], uncert[..., Nt - extra:], mean_lengthscale[
        ..., Nt - extra:], mean_sigma[..., Nt - extra:], ESS[
            ..., Nt - extra:], logZ, likelihood_evals[..., Nt - extra:]
def solve_with_tomographic_kernel(dtec, dtec_uncert, X, x0, fed_kernel,
                                  time_block_size):
    """
    Precompute look-up tables for all blocks.
    Assumes that each antenna is independent and doesn't take into account time.

    Args:
        dtec: [Nd, Na, Nt]
        dtec_uncert: [Nd, Na, Nt]
        X: [Nd,6]
        x0: [3]
        fed_kernel: StationaryKernel
        time_block_size: int
    """
    scale = jnp.std(dtec) / 35.
    dtec /= scale
    dtec_uncert /= scale
    bottom_array = jnp.linspace(200., 400., 5)
    width_array = jnp.linspace(50., 50., 1)
    lengthscale_array = jnp.linspace(0.1, 7.5, 7)
    sigma_array = jnp.linspace(0.1, 2., 11)
    kernel = TomographicKernel(x0,
                               x0,
                               fed_kernel,
                               S_marg=25,
                               compute_tec=False)

    lookup_func = build_lookup_index(bottom_array, width_array,
                                     lengthscale_array, sigma_array)
    Nd, Na, Nt = dtec.shape
    remainder = Nt % time_block_size
    dtec = jnp.concatenate([dtec, dtec[:, :, -remainder:]], axis=-1)
    dtec_uncert = jnp.concatenate(
        [dtec_uncert, dtec_uncert[:, :, -remainder:]], axis=-1)
    Nt = dtec.shape[-1]
    dtec = dtec.transpose((2, 1, 0)).reshape((Nt * Na, Nd))
    dtec_uncert = dtec_uncert.transpose((2, 1, 0)).reshape((Nt * Na, Nd))
    # Nt*Na, ...
    log_prob = precompute_log_prob_components_without_wind(kernel,
                                                           X,
                                                           dtec,
                                                           dtec_uncert,
                                                           bottom_array,
                                                           width_array,
                                                           lengthscale_array,
                                                           sigma_array,
                                                           chunksize=4)

    log_prob = jnp.reshape(log_prob, (Nt // remainder, remainder, Na) +
                           log_prob.shape[1:])

    def run_block(block_idx):
        def log_likelihood(bottom, width, lengthscale, sigma, **kwargs):
            return jnp.sum(
                vmap(lambda log_prob: lookup_func(log_prob, bottom, width,
                                                  lengthscale, sigma))(
                                                      log_prob[block_idx]))

        bottom = UniformPrior('bottom', bottom_array.min(), bottom_array.max())
        width = DeltaPrior('width', 50., tracked=False)
        lengthscale = UniformPrior('lengthscale', jnp.min(lengthscale_array),
                                   jnp.max(lengthscale_array))
        sigma = UniformPrior('sigma', sigma_array.min(), sigma_array.max())
        prior_chain = PriorChain(lengthscale, sigma, bottom, width)

        ns = NestedSampler(loglikelihood=log_likelihood,
                           prior_chain=prior_chain,
                           sampler_name='slice',
                           sampler_kwargs=dict(num_slices=prior_chain.U_ndims *
                                               5),
                           num_live_points=prior_chain.U_ndims * 50)
        ns = jit(ns)
        results = ns(random.PRNGKey(42), termination_frac=0.001)

        return results
        # results.efficiency.block_until_ready()
        # t0 = default_timer()
        # results = ns(random.PRNGKey(42), termination_frac=0.001)
        # summary(results)
        # print(default_timer() - t0)

        # def screen(bottom, lengthscale, east_wind_speed, north_wind_speed, sigma, **kw):
        #     wind_velocity = jnp.asarray([east_wind_speed, north_wind_speed, 0.])
        #     K = kernel(X, X, bottom, 50., lengthscale, sigma, wind_velocity=wind_velocity)
        #     Kstar = kernel(X, Xstar, bottom, 50., lengthscale, sigma)
        #     L = jnp.linalg.cholesky(K + jnp.diag(jnp.maximum(1e-6, dtec_uncert) ** 2))
        #     dx = solve_triangular(L, dtec, lower=True)
        #     return solve_triangular(L, Kstar, lower=True).T @ dx

        # summary(results)
        # plot_diagnostics(results)
        # plot_cornerplot(results)

        # screen_mean = marginalise_static(random.PRNGKey(4325325), results.samples, results.log_p, int(results.ESS), screen)

        # print(screen_mean)
        # plot_vornoi_map(Xstar[:, 3:5], screen_mean)
        # plt.show()
        # plot_vornoi_map(X[:, 3:5], dtec)
        # plt.show()

        # return screen_mean

    results = chunked_pmap(run_block, jnp.arange(Nt // time_block_size))
Ejemplo n.º 11
0
def solve_and_smooth(gain_outliers, phase_obs, times, freqs):
    logger.info("Performing solve for tec and const from phases.")
    Nd, Na, Nf, Nt = phase_obs.shape

    logger.info("Number of nan: {}".format(jnp.sum(jnp.isnan(phase_obs))))
    logger.info("Number of inf: {}".format(jnp.sum(jnp.isinf(phase_obs))))

    # blocksize chosen to maximise Fisher information, which is 2 for tec+const, and 3 for tec+const+clock
    blocksize = 2

    remainder = Nt % blocksize
    if remainder != 0:
        if remainder < Nt:
            raise ValueError(
                f"Block size {blocksize} too big for number of timesteps {Nt}."
            )
        (gain_outliers, phase_obs) = tree_map(
            lambda x: jnp.concatenate(
                [x, jnp.repeat(x[..., -1:], remainder, axis=-1)], axis=-1),
            (gain_outliers, phase_obs))
        Nt = Nt + remainder
        times = jnp.concatenate([
            times, times[-1] +
            jnp.arange(1, 1 + remainder) * jnp.mean(jnp.diff(times))
        ])

    size_dict = dict(d=Nd, a=Na, f=Nf, b=blocksize)

    # [Nd*Na*(Nt//blocksize), blocksize, Nf]
    gain_outliers = axes_move(gain_outliers, ['d', 'a', 'f', 'tb'],
                              ['dat', 'b', 'f'],
                              size_dict=size_dict)
    phase_obs = axes_move(phase_obs, ['d', 'a', 'f', 'tb'], ['dat', 'b', 'f'],
                          size_dict=size_dict)

    T = Nd * Na * (Nt // blocksize)  # Nd * Na * (Nt // blocksize)
    keys = random.split(random.PRNGKey(int(1000 * default_timer())), T)

    # [Nd*Na*(Nt//blocksize), blocksize], [# Nd*Na*(Nt//blocksize), blocksize]
    tec_mean, tec_std, const_mean, const_std, uncert_mean = chunked_pmap(
        lambda *args: unconstrained_solve(freqs, *args), keys, phase_obs,
        gain_outliers)  # Nd*Na*(Nt//blocksize), blocksize

    const_weights = 1. / const_std**2

    def smooth(y, weights):
        y = axes_move(y, ['dat', 'b'], ['da', 'tb'], size_dict=size_dict)
        weights = axes_move(weights, ['dat', 'b'], ['da', 'tb'],
                            size_dict=size_dict)
        y = chunked_pmap(
            lambda y, weights: poly_smooth(times, y, deg=5, weights=weights),
            y, weights)
        y = axes_move(y, ['da', 'tb'], ['dat', 'b'], size_dict=size_dict)
        return y

    logger.info("Smoothing and outlier rejection of const (a weak prior).")
    # Nd,Na,Nt/blocksize, blocksize
    const_real_mean = smooth(jnp.cos(const_mean),
                             const_weights)  # Nd*Na*(Nt//blocksize), blocksize
    const_imag_mean = smooth(jnp.sin(const_mean),
                             const_weights)  # Nd*Na*(Nt//blocksize), blocksize
    const_mean_smoothed = jnp.arctan2(
        const_imag_mean, const_real_mean)  # Nd*Na*(Nt//blocksize), blocksize

    # empirically determined uncertainty point where sigma(tec - tec_true) > 6 mTECU
    which_reprocess = jnp.any(uncert_mean > 0.,
                              axis=1)  # Nd*Na*(Nt//blocksize)
    replace_map = jnp.where(which_reprocess)

    logger.info("Performing refined tec-only solve, with fixed const.")
    keys = random.split(random.PRNGKey(int(1000 * default_timer())),
                        jnp.sum(which_reprocess))
    # [Nd*Na*(Nt//blocksize), blocksize]
    (tec_mean_constrained, tec_std_constrained, const_mean_constrained, const_std_constrained) = \
        chunked_pmap(lambda *args: constrained_solve(freqs, *args),
                     keys,
                     phase_obs[which_reprocess],
                     gain_outliers[which_reprocess],
                     const_mean_smoothed[which_reprocess],
                     const_std[which_reprocess]
                     )
    tec_mean = tec_mean.at[replace_map].set(tec_mean_constrained)
    tec_std = tec_std.at[replace_map].set(tec_std_constrained)
    const_std = const_std.at[replace_map].set(const_std_constrained)
    const_mean = const_mean.at[replace_map].set(const_mean_constrained)

    (tec_mean, tec_std, const_mean,
     const_std) = tree_map(lambda x: x.reshape((Nd, Na, Nt)),
                           (tec_mean, tec_std, const_mean, const_std))

    # Nd, Na, Nt
    logger.info("Performing outlier detection on tec values.")
    tec_est, tec_outliers = detect_tec_outliers(times, tec_mean, tec_std)
    tec_std = jnp.where(tec_outliers, jnp.inf, tec_std)

    # remove remainder at the end
    if remainder != 0:
        (tec_mean, tec_std, const_mean,
         const_std) = tree_map(lambda x: x[..., :Nt - remainder],
                               (tec_mean, tec_std, const_mean, const_std))

    # compute phase mean with outlier-suppressed tec.
    phase_mean = tec_mean[..., None, :] * (
        TEC_CONV / freqs[:, None]) + const_mean[..., None, :]
    phase_uncert = jnp.sqrt((tec_std[..., None, :] *
                             (TEC_CONV / freqs[:, None]))**2 +
                            (const_std[..., None, :])**2)

    return phase_mean, phase_uncert, tec_mean, tec_std, tec_outliers, const_mean, const_std
Ejemplo n.º 12
0
    def run(self, output_h5parm, ncpu, avg_direction_spacing,
            field_of_view_diameter, duration, time_resolution, start_time,
            array_name, phase_tracking):

        Nd = get_num_directions(
            avg_direction_spacing,
            field_of_view_diameter,
        )
        Nf = 2  # 8000
        Nt = int(duration / time_resolution) + 1
        min_freq = 700.
        max_freq = 2000.
        dp = create_empty_datapack(
            Nd,
            Nf,
            Nt,
            pols=None,
            field_of_view_diameter=field_of_view_diameter,
            start_time=start_time,
            time_resolution=time_resolution,
            min_freq=min_freq,
            max_freq=max_freq,
            array_file=ARRAYS[array_name],
            phase_tracking=(phase_tracking.ra.deg, phase_tracking.dec.deg),
            save_name=output_h5parm,
            clobber=True)

        with dp:
            dp.current_solset = 'sol000'
            dp.select(pol=slice(0, 1, 1))
            axes = dp.axes_tec
            patch_names, directions = dp.get_directions(axes['dir'])
            antenna_labels, antennas = dp.get_antennas(axes['ant'])
            timestamps, times = dp.get_times(axes['time'])
            ref_ant = antennas[0]
            ref_time = times[0]

        Na = len(antennas)
        Nd = len(directions)
        Nt = len(times)

        logger.info(f"Number of directions: {Nd}")
        logger.info(f"Number of antennas: {Na}")
        logger.info(f"Number of times: {Nt}")
        logger.info(f"Reference Ant: {ref_ant}")
        logger.info(f"Reference Time: {ref_time.isot}")

        # Plot Antenna Layout in East North Up frame
        ref_frame = ENU(obstime=ref_time, location=ref_ant.earth_location)

        _antennas = ac.ITRS(*antennas.cartesian.xyz,
                            obstime=ref_time).transform_to(ref_frame)
        # plt.scatter(_antennas.east, _antennas.north, marker='+')
        # plt.xlabel(f"East (m)")
        # plt.ylabel(f"North (m)")
        # plt.show()

        x0 = ac.ITRS(
            *antennas[0].cartesian.xyz,
            obstime=ref_time).transform_to(ref_frame).cartesian.xyz.to(
                au.km).value
        earth_centre_x = ac.ITRS(
            x=0 * au.m, y=0 * au.m, z=0. * au.m,
            obstime=ref_time).transform_to(ref_frame).cartesian.xyz.to(
                au.km).value
        self._kernel = TomographicKernel(x0,
                                         earth_centre_x,
                                         M32(),
                                         S_marg=20,
                                         compute_tec=False)

        k = directions.transform_to(ref_frame).cartesian.xyz.value.T

        t = times.mjd * 86400.
        t -= t[0]

        X1 = GeodesicTuple(x=[], k=[], t=[], ref_x=[])

        logger.info("Computing coordinates in frame ...")

        for i, time in tqdm(enumerate(times)):
            x = ac.ITRS(*antennas.cartesian.xyz,
                        obstime=time).transform_to(ref_frame).cartesian.xyz.to(
                            au.km).value.T
            ref_ant_x = ac.ITRS(
                *ref_ant.cartesian.xyz,
                obstime=time).transform_to(ref_frame).cartesian.xyz.to(
                    au.km).value

            X = make_coord_array(x,
                                 k,
                                 t[i:i + 1, None],
                                 ref_ant_x[None, :],
                                 flat=True)

            X1.x.append(X[:, 0:3])
            X1.k.append(X[:, 3:6])
            X1.t.append(X[:, 6:7])
            X1.ref_x.append(X[:, 7:8])

        X1 = X1._replace(
            x=jnp.concatenate(X1.x, axis=0),
            k=jnp.concatenate(X1.k, axis=0),
            t=jnp.concatenate(X1.t, axis=0),
            ref_x=jnp.concatenate(X1.ref_x, axis=0),
        )

        logger.info(f"Total number of coordinates: {X1.x.shape[0]}")

        def compute_covariance_row(X1: GeodesicTuple, X2: GeodesicTuple):
            K = self._kernel(X1,
                             X2,
                             self._bottom,
                             self._width,
                             self._fed_sigma,
                             self._fed_kernel_params,
                             wind_velocity=self._wind_vector)  # 1, N
            return K[0, :]

        covariance_row = lambda X: compute_covariance_row(
            tree_map(lambda x: x.reshape((1, -1)), X), X1)

        mean = jit(lambda X1: self._kernel.mean_function(X1,
                                                         self._bottom,
                                                         self._width,
                                                         self._fed_mu,
                                                         wind_velocity=self.
                                                         _wind_vector))(X1)

        cov = chunked_pmap(covariance_row,
                           X1,
                           batch_size=X1.x.shape[0],
                           chunksize=ncpu)

        plt.imshow(cov)
        plt.show()

        Z = random.normal(random.PRNGKey(42), (cov.shape[0], 1),
                          dtype=cov.dtype)

        t0 = default_timer()
        jitter = 1e-6
        logger.info(f"Computing Cholesky with jitter: {jitter}")
        L = jnp.linalg.cholesky(cov + jitter * jnp.eye(cov.shape[0]))
        if np.any(np.isnan(L)):
            logger.info("Numerically instable. Using SVD.")
            L = msqrt(cov)

        logger.info(f"Cholesky took {default_timer() - t0} seconds.")

        dtec = (L @ Z + mean[:, None])[:, 0].reshape((Na, Nd, Nt)).transpose(
            (1, 0, 2))

        logger.info(f"Saving result to {output_h5parm}")
        with dp:
            dp.current_solset = 'sol000'
            dp.select(pol=slice(0, 1, 1))
            dp.tec = np.asarray(dtec[None])