Exemplo n.º 1
0
def residual_vis(time_bin_indices, time_bin_counts, antenna1, antenna2, jones,
                 vis, flag, model):

    mode = check_type(jones, vis)
    subtract_model = subtract_model_factory(mode)

    @wraps(residual_vis)
    def _residual_vis_fn(time_bin_indices, time_bin_counts, antenna1, antenna2,
                         jones, vis, flag, model):
        # for dask arrays we need to adjust the chunks to
        # start counting from zero
        time_bin_indices -= time_bin_indices.min()
        n_tim = np.shape(time_bin_indices)[0]
        vis_shape = np.shape(vis)
        n_chan = vis_shape[1]
        residual = np.zeros(vis_shape, dtype=vis.dtype)
        for t in range(n_tim):
            for row in range(time_bin_indices[t],
                             time_bin_indices[t] + time_bin_counts[t]):
                p = int(antenna1[row])
                q = int(antenna2[row])
                gp = jones[t, p]
                gq = jones[t, q]
                for nu in range(n_chan):
                    if not np.any(flag[row, nu]):
                        subtract_model(gp[nu], vis[row, nu], gq[nu],
                                       model[row, nu], residual[row, nu])
        return residual

    return _residual_vis_fn
Exemplo n.º 2
0
def corrupt_vis(time_bin_indices, time_bin_counts, antenna1, antenna2, jones,
                model):

    mode = check_type(jones, model, vis_type='model')
    jones_mul = jones_mul_factory(mode)

    def _corrupt_vis_fn(time_bin_indices, time_bin_counts, antenna1, antenna2,
                        jones, model):
        # for dask arrays we need to adjust the chunks to
        # start counting from zero
        time_bin_indices -= time_bin_indices.min()
        n_tim = np.shape(time_bin_indices)[0]
        model_shape = np.shape(model)
        vis_shape = model.shape[:2] + model.shape[3:]
        vis = np.zeros(vis_shape, dtype=model.dtype)
        n_chan = model_shape[1]
        for t in range(n_tim):
            for row in range(time_bin_indices[t],
                             time_bin_indices[t] + time_bin_counts[t]):
                p = int(antenna1[row])
                q = int(antenna2[row])
                gp = jones[t, p]
                gq = jones[t, q]
                for nu in range(n_chan):
                    jones_mul(gp[nu], model[row, nu], gq[nu], vis[row, nu])
        return vis

    return _corrupt_vis_fn
Exemplo n.º 3
0
def compute_jhr(time_bin_indices, time_bin_counts, antenna1, antenna2, jones,
                residual, model, flag):

    mode = check_type(jones, residual)

    if mode != DIAG_DIAG:
        raise NotImplementedError("Only DIAG-DIAG case has been implemented")

    jones_shape = ('row', 'ant', 'chan', 'dir', 'corr')
    vis_shape = ('row', 'chan', 'corr')
    model_shape = ('row', 'chan', 'dir', 'corr')
    return blockwise(
        np_compute_jhr,
        jones_shape,
        time_bin_indices,
        ('row', ),
        time_bin_counts,
        ('row', ),
        antenna1,
        ('row', ),
        antenna2,
        ('row', ),
        jones,
        jones_shape,
        residual,
        vis_shape,
        model,
        model_shape,
        flag,
        vis_shape,
        adjust_chunks={"row": antenna1.chunks[0]},
        new_axes={"corr2": 2},  # why?
        dtype=model.dtype,
        align_arrays=False)
Exemplo n.º 4
0
def correct_vis(time_bin_indices, time_bin_counts, antenna1, antenna2, jones,
                vis, flag):

    mode = check_type(jones, vis)
    jones_inverse_mul = jones_inverse_mul_factory(mode)

    def _correct_vis_fn(time_bin_indices, time_bin_counts, antenna1, antenna2,
                        jones, vis, flag):
        # for dask arrays we need to adjust the chunks to
        # start counting from zero
        time_bin_indices -= time_bin_indices.min()
        jones_shape = np.shape(jones)
        n_tim = jones_shape[0]
        n_dir = jones_shape[3]
        if n_dir > 1:
            raise ValueError("Jones has n_dir > 1. Cannot correct "
                             "for direction dependent gains")
        n_chan = jones_shape[2]
        corrected_vis = np.zeros_like(vis, dtype=vis.dtype)
        for t in range(n_tim):
            for row in range(time_bin_indices[t],
                             time_bin_indices[t] + time_bin_counts[t]):
                p = int(antenna1[row])
                q = int(antenna2[row])
                gp = jones[t, p]
                gq = jones[t, q]
                for nu in range(n_chan):
                    if not np.any(flag[row, nu]):
                        jones_inverse_mul(gp[nu, 0], vis[row, nu], gq[nu, 0],
                                          corrected_vis[row, nu])
        return corrected_vis

    return _correct_vis_fn
Exemplo n.º 5
0
def jhj_and_jhr(time_bin_indices, time_bin_counts, antenna1,
                antenna2, jones, residual, model, flag):

    mode = check_type(jones, residual)

    if mode:
        raise NotImplementedError("Only DIAG-DIAG case has been implemented")

    jacobian = jacobian_factory(mode)

    def _jhj_and_jhr_fn(time_bin_indices, time_bin_counts, antenna1,
                        antenna2, jones, residual, model, flag):
        jones_shape = np.shape(jones)
        tmp_out_array = np.zeros_like(jones[0, 0, 0, 0], dtype=jones.dtype)
        n_tim = jones_shape[0]
        n_ant = jones_shape[1]
        n_chan = jones_shape[2]
        n_dir = jones_shape[3]

        jhr = np.zeros(jones.shape, dtype=jones.dtype)
        jhj = np.zeros(jones.shape, dtype=jones.real.dtype)

        for t in range(n_tim):
            ind = np.arange(time_bin_indices[t],
                            time_bin_indices[t] + time_bin_counts[t])
            for ant in range(n_ant):
                # find where either antenna == ant
                # these will be mutually exclusive since no autocorrs
                for row in ind:
                    if antenna1[row] == ant or antenna2[row] == ant:
                        p = antenna1[row]
                        q = antenna2[row]
                        if ant == p:
                            sign = 1.0j
                        elif ant == q:
                            sign = -1.0j
                        else:
                            raise ValueError(
                                "Got impossible antenna number. This is a bug")
                        for nu in range(n_chan):
                            if not np.any(flag[row, nu]):
                                for s in range(n_dir):
                                    jacobian(
                                        jones[t, p, nu, s],
                                        model[row, nu, s],
                                        jones[t, q, nu, s],
                                        sign,
                                        tmp_out_array)
                                    jhj[t, ant, nu,
                                        s] += (np.conj(tmp_out_array) *
                                               tmp_out_array).real
                                    jhr[t, ant, nu,
                                        s] += (np.conj(tmp_out_array) *
                                               residual[row, nu])
        return jhj, jhr
    return _jhj_and_jhr_fn
Exemplo n.º 6
0
def compute_and_corrupt_vis(time_bin_indices, time_bin_counts, antenna1,
                            antenna2, jones, model, uvw, freq, lm):

    if jones.chunks[1][0] != jones.shape[1]:
        raise ValueError("Cannot chunk jones over antenna")
    if jones.chunks[3][0] != jones.shape[3]:
        raise ValueError("Cannot chunk jones over direction")
    if model.chunks[2][0] != model.shape[2]:
        raise ValueError("Cannot chunk model over direction")
    if uvw.chunks[1][0] != uvw.shape[1]:
        raise ValueError("Cannot chunk uvw over last axis")
    if lm.chunks[1][0] != lm.shape[1]:
        raise ValueError("Cannot chunks lm over direction")
    if lm.chunks[2][0] != lm.shape[2]:
        raise ValueError("Cannot chunks lm over last axis")

    mode = check_type(jones, model, vis_type='model')

    if mode == DIAG_DIAG:
        out_shape = ("row", "chan", "corr1")
        model_shape = ("row", "chan", "dir", "corr1")
        jones_shape = ("row", "ant", "chan", "dir", "corr1")
    elif mode == DIAG:
        out_shape = ("row", "chan", "corr1", "corr2")
        model_shape = ("row", "chan", "dir", "corr1", "corr2")
        jones_shape = ("row", "ant", "chan", "dir", "corr1")
    elif mode == FULL:
        out_shape = ("row", "chan", "corr1", "corr2")
        model_shape = ("row", "chan", "dir", "corr1", "corr2")
        jones_shape = ("row", "ant", "chan", "dir", "corr1", "corr2")
    else:
        raise ValueError("Unknown mode argument of %s" % mode)

    # the new_axes={"corr2": 2} is required because of a dask bug
    # see https://github.com/dask/dask/issues/5550
    return blockwise(_compute_and_corrupt_vis_wrapper,
                     out_shape,
                     time_bin_indices, ("row", ),
                     time_bin_counts, ("row", ),
                     antenna1, ("row", ),
                     antenna2, ("row", ),
                     jones,
                     jones_shape,
                     model,
                     model_shape,
                     uvw, ("row", "three"),
                     freq, ("chan", ),
                     lm, ("row", "dir", "two"),
                     adjust_chunks={"row": antenna1.chunks[0]},
                     new_axes={"corr2": 2},
                     dtype=model.dtype,
                     align_arrays=False)
Exemplo n.º 7
0
def residual_vis(time_bin_indices, time_bin_counts, antenna1, antenna2, jones,
                 vis, flag, model):

    if jones.chunks[1][0] != jones.shape[1]:
        raise ValueError("Cannot chunk jones over antenna")
    if jones.chunks[3][0] != jones.shape[3]:
        raise ValueError("Cannot chunk jones over direction")
    if model.chunks[2][0] != model.shape[2]:
        raise ValueError("Cannot chunk model over direction")

    mode = check_type(jones, vis)

    if mode == DIAG_DIAG:
        out_shape = ("row", "chan", "corr1")
        model_shape = ("row", "chan", "dir", "corr1")
        jones_shape = ("row", "ant", "chan", "dir", "corr1")
    elif mode == DIAG:
        out_shape = ("row", "chan", "corr1", "corr2")
        model_shape = ("row", "chan", "dir", "corr1", "corr2")
        jones_shape = ("row", "ant", "chan", "dir", "corr1")
    elif mode == FULL:
        out_shape = ("row", "chan", "corr1", "corr2")
        model_shape = ("row", "chan", "dir", "corr1", "corr2")
        jones_shape = ("row", "ant", "chan", "dir", "corr1", "corr2")
    else:
        raise ValueError("Unknown mode argument of %s" % mode)

    # the new_axes={"corr2": 2} is required because of a dask bug
    # see https://github.com/dask/dask/issues/5550
    return blockwise(_residual_vis_wrapper,
                     out_shape,
                     time_bin_indices, ("row", ),
                     time_bin_counts, ("row", ),
                     antenna1, ("row", ),
                     antenna2, ("row", ),
                     jones,
                     jones_shape,
                     vis,
                     out_shape,
                     flag,
                     out_shape,
                     model,
                     model_shape,
                     adjust_chunks={"row": antenna1.chunks[0]},
                     new_axes={"corr2": 2},
                     dtype=vis.dtype,
                     align_arrays=False)
Exemplo n.º 8
0
def compute_jhj_and_jhr(time_bin_indices, time_bin_counts, antenna1,
                        antenna2, jones, residual, model, flag):

    mode = check_type(jones, residual)
    if mode != DIAG_DIAG:
        raise NotImplementedError("Only DIAG-DIAG case has been implemented")

    jacobian = jacobian_factory(mode)

    def _jhj_and_jhr_fn(time_bin_indices, time_bin_counts, antenna1,
                        antenna2, jones, residual, model, flag):
        # for chunked dask arrays we need to adjust the chunks to
        # start counting from zero (see also map_blocks)
        time_bin_indices -= time_bin_indices.min()
        jones_shape = np.shape(jones)
        n_tim = jones_shape[0]
        n_chan = jones_shape[2]
        n_dir = jones_shape[3]

        # storage arrays
        jhr = np.zeros(jones.shape, dtype=jones.dtype)
        jhj = np.zeros(jones.shape, dtype=jones.real.dtype)
        # tmp array the shape of jones_corr
        jac = np.zeros_like(jones[0, 0, 0, 0], dtype=jones.dtype)
        for t in range(n_tim):
            for row in range(time_bin_indices[t],
                             time_bin_indices[t] + time_bin_counts[t]):
                p = antenna1[row]
                q = antenna2[row]
                for nu in range(n_chan):
                    if np.any(flag[row, nu]):
                        continue
                    gp = jones[t, p, nu]
                    gq = jones[t, q, nu]
                    for s in range(n_dir):
                        # for the derivative w.r.t. antenna p
                        jacobian(gp[s], model[row, nu, s], gq[s], 1.0j, jac)
                        jhj[t, p, nu, s] += (np.conj(jac) * jac).real
                        jhr[t, p, nu, s] += (np.conj(jac) * residual[row, nu])
                        # for the derivative w.r.t. antenna q
                        jacobian(gp[s], model[row, nu, s], gq[s], -1.0j, jac)
                        jhj[t, q, nu, s] += (np.conj(jac) * jac).real
                        jhr[t, q, nu, s] += (np.conj(jac) * residual[row, nu])
        return jhj, jhr
    return _jhj_and_jhr_fn
Exemplo n.º 9
0
def gauss_newton(time_bin_indices, time_bin_counts, antenna1,
                 antenna2, jones, vis, flag, model,
                 weight, tol=1e-4, maxiter=100):

    # whiten data
    sqrtweights = np.sqrt(weight)
    vis *= sqrtweights
    model *= sqrtweights[:, :, None]

    mode = check_type(jones, vis)

    # can avoid recomputing JHJ in DIAG_DIAG mode
    if mode == DIAG_DIAG:
        jhj = compute_jhj(time_bin_indices, time_bin_counts,
                          antenna1, antenna2, jones, model, flag)
    else:
        raise NotImplementedError("Only DIAG_DIAG mode implemented")

    eps = 1.0
    k = 0
    while eps > tol and k < maxiter:
        # keep track of old phases
        phases = np.angle(jones)

        # get residual TODO - we can avoid this in DIE case
        residual = residual_vis(time_bin_indices, time_bin_counts, antenna1,
                                antenna2, jones, vis, flag, model)

        jhr = compute_jhr(time_bin_indices, time_bin_counts,
                          antenna1, antenna2,
                          jones, residual, model, flag)

        # implement update
        phases_new = phases + 0.5 * (jhr/jhj).real
        jones = np.exp(1.0j * phases_new)

        # check convergence/iteration control
        eps = np.abs(phases_new - phases).max()
        k += 1

    return jones, jhj, jhr, k
Exemplo n.º 10
0
def compute_jhj(time_bin_indices, time_bin_counts, antenna1,
                antenna2, jones, model, flag):

    mode = check_type(jones, model, vis_type='model')

    jacobian = jacobian_factory(mode)

    def _compute_jhj_fn(time_bin_indices, time_bin_counts, antenna1,
                        antenna2, jones, model, flag):
        # for dask arrays we need to adjust the chunks to
        # start counting from zero
        time_bin_indices -= time_bin_indices.min()
        jones_shape = np.shape(jones)
        n_tim = jones_shape[0]
        n_chan = jones_shape[2]
        n_dir = jones_shape[3]

        jhj = np.zeros(jones.shape, dtype=jones.real.dtype)
        # tmp array the shape of jones_corr
        jac = np.zeros_like(jones[0, 0, 0, 0], dtype=jones.dtype)
        for t in range(n_tim):
            for row in range(time_bin_indices[t],
                             time_bin_indices[t] + time_bin_counts[t]):
                p = antenna1[row]
                q = antenna2[row]
                for nu in range(n_chan):
                    if np.any(flag[row, nu]):
                        continue
                    gp = jones[t, p, nu]
                    gq = jones[t, q, nu]
                    for s in range(n_dir):
                        jacobian(gp[s], model[row, nu, s], gq[s], 1.0j, jac)
                        jhj[t, p, nu, s] += (jac.conjugate() * jac).real
                        jacobian(gp[s], model[row, nu, s], gq[s], -1.0j, jac)
                        jhj[t, q, nu, s] += (jac.conjugate() * jac).real
        return jhj
    return _compute_jhj_fn