def residual_vis(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag, model): mode = check_type(jones, vis) subtract_model = subtract_model_factory(mode) @wraps(residual_vis) def _residual_vis_fn(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag, model): # for dask arrays we need to adjust the chunks to # start counting from zero time_bin_indices -= time_bin_indices.min() n_tim = np.shape(time_bin_indices)[0] vis_shape = np.shape(vis) n_chan = vis_shape[1] residual = np.zeros(vis_shape, dtype=vis.dtype) for t in range(n_tim): for row in range(time_bin_indices[t], time_bin_indices[t] + time_bin_counts[t]): p = int(antenna1[row]) q = int(antenna2[row]) gp = jones[t, p] gq = jones[t, q] for nu in range(n_chan): if not np.any(flag[row, nu]): subtract_model(gp[nu], vis[row, nu], gq[nu], model[row, nu], residual[row, nu]) return residual return _residual_vis_fn
def corrupt_vis(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model): mode = check_type(jones, model, vis_type='model') jones_mul = jones_mul_factory(mode) def _corrupt_vis_fn(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model): # for dask arrays we need to adjust the chunks to # start counting from zero time_bin_indices -= time_bin_indices.min() n_tim = np.shape(time_bin_indices)[0] model_shape = np.shape(model) vis_shape = model.shape[:2] + model.shape[3:] vis = np.zeros(vis_shape, dtype=model.dtype) n_chan = model_shape[1] for t in range(n_tim): for row in range(time_bin_indices[t], time_bin_indices[t] + time_bin_counts[t]): p = int(antenna1[row]) q = int(antenna2[row]) gp = jones[t, p] gq = jones[t, q] for nu in range(n_chan): jones_mul(gp[nu], model[row, nu], gq[nu], vis[row, nu]) return vis return _corrupt_vis_fn
def compute_jhr(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, residual, model, flag): mode = check_type(jones, residual) if mode != DIAG_DIAG: raise NotImplementedError("Only DIAG-DIAG case has been implemented") jones_shape = ('row', 'ant', 'chan', 'dir', 'corr') vis_shape = ('row', 'chan', 'corr') model_shape = ('row', 'chan', 'dir', 'corr') return blockwise( np_compute_jhr, jones_shape, time_bin_indices, ('row', ), time_bin_counts, ('row', ), antenna1, ('row', ), antenna2, ('row', ), jones, jones_shape, residual, vis_shape, model, model_shape, flag, vis_shape, adjust_chunks={"row": antenna1.chunks[0]}, new_axes={"corr2": 2}, # why? dtype=model.dtype, align_arrays=False)
def correct_vis(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag): mode = check_type(jones, vis) jones_inverse_mul = jones_inverse_mul_factory(mode) def _correct_vis_fn(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag): # for dask arrays we need to adjust the chunks to # start counting from zero time_bin_indices -= time_bin_indices.min() jones_shape = np.shape(jones) n_tim = jones_shape[0] n_dir = jones_shape[3] if n_dir > 1: raise ValueError("Jones has n_dir > 1. Cannot correct " "for direction dependent gains") n_chan = jones_shape[2] corrected_vis = np.zeros_like(vis, dtype=vis.dtype) for t in range(n_tim): for row in range(time_bin_indices[t], time_bin_indices[t] + time_bin_counts[t]): p = int(antenna1[row]) q = int(antenna2[row]) gp = jones[t, p] gq = jones[t, q] for nu in range(n_chan): if not np.any(flag[row, nu]): jones_inverse_mul(gp[nu, 0], vis[row, nu], gq[nu, 0], corrected_vis[row, nu]) return corrected_vis return _correct_vis_fn
def jhj_and_jhr(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, residual, model, flag): mode = check_type(jones, residual) if mode: raise NotImplementedError("Only DIAG-DIAG case has been implemented") jacobian = jacobian_factory(mode) def _jhj_and_jhr_fn(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, residual, model, flag): jones_shape = np.shape(jones) tmp_out_array = np.zeros_like(jones[0, 0, 0, 0], dtype=jones.dtype) n_tim = jones_shape[0] n_ant = jones_shape[1] n_chan = jones_shape[2] n_dir = jones_shape[3] jhr = np.zeros(jones.shape, dtype=jones.dtype) jhj = np.zeros(jones.shape, dtype=jones.real.dtype) for t in range(n_tim): ind = np.arange(time_bin_indices[t], time_bin_indices[t] + time_bin_counts[t]) for ant in range(n_ant): # find where either antenna == ant # these will be mutually exclusive since no autocorrs for row in ind: if antenna1[row] == ant or antenna2[row] == ant: p = antenna1[row] q = antenna2[row] if ant == p: sign = 1.0j elif ant == q: sign = -1.0j else: raise ValueError( "Got impossible antenna number. This is a bug") for nu in range(n_chan): if not np.any(flag[row, nu]): for s in range(n_dir): jacobian( jones[t, p, nu, s], model[row, nu, s], jones[t, q, nu, s], sign, tmp_out_array) jhj[t, ant, nu, s] += (np.conj(tmp_out_array) * tmp_out_array).real jhr[t, ant, nu, s] += (np.conj(tmp_out_array) * residual[row, nu]) return jhj, jhr return _jhj_and_jhr_fn
def compute_and_corrupt_vis(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model, uvw, freq, lm): if jones.chunks[1][0] != jones.shape[1]: raise ValueError("Cannot chunk jones over antenna") if jones.chunks[3][0] != jones.shape[3]: raise ValueError("Cannot chunk jones over direction") if model.chunks[2][0] != model.shape[2]: raise ValueError("Cannot chunk model over direction") if uvw.chunks[1][0] != uvw.shape[1]: raise ValueError("Cannot chunk uvw over last axis") if lm.chunks[1][0] != lm.shape[1]: raise ValueError("Cannot chunks lm over direction") if lm.chunks[2][0] != lm.shape[2]: raise ValueError("Cannot chunks lm over last axis") mode = check_type(jones, model, vis_type='model') if mode == DIAG_DIAG: out_shape = ("row", "chan", "corr1") model_shape = ("row", "chan", "dir", "corr1") jones_shape = ("row", "ant", "chan", "dir", "corr1") elif mode == DIAG: out_shape = ("row", "chan", "corr1", "corr2") model_shape = ("row", "chan", "dir", "corr1", "corr2") jones_shape = ("row", "ant", "chan", "dir", "corr1") elif mode == FULL: out_shape = ("row", "chan", "corr1", "corr2") model_shape = ("row", "chan", "dir", "corr1", "corr2") jones_shape = ("row", "ant", "chan", "dir", "corr1", "corr2") else: raise ValueError("Unknown mode argument of %s" % mode) # the new_axes={"corr2": 2} is required because of a dask bug # see https://github.com/dask/dask/issues/5550 return blockwise(_compute_and_corrupt_vis_wrapper, out_shape, time_bin_indices, ("row", ), time_bin_counts, ("row", ), antenna1, ("row", ), antenna2, ("row", ), jones, jones_shape, model, model_shape, uvw, ("row", "three"), freq, ("chan", ), lm, ("row", "dir", "two"), adjust_chunks={"row": antenna1.chunks[0]}, new_axes={"corr2": 2}, dtype=model.dtype, align_arrays=False)
def residual_vis(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag, model): if jones.chunks[1][0] != jones.shape[1]: raise ValueError("Cannot chunk jones over antenna") if jones.chunks[3][0] != jones.shape[3]: raise ValueError("Cannot chunk jones over direction") if model.chunks[2][0] != model.shape[2]: raise ValueError("Cannot chunk model over direction") mode = check_type(jones, vis) if mode == DIAG_DIAG: out_shape = ("row", "chan", "corr1") model_shape = ("row", "chan", "dir", "corr1") jones_shape = ("row", "ant", "chan", "dir", "corr1") elif mode == DIAG: out_shape = ("row", "chan", "corr1", "corr2") model_shape = ("row", "chan", "dir", "corr1", "corr2") jones_shape = ("row", "ant", "chan", "dir", "corr1") elif mode == FULL: out_shape = ("row", "chan", "corr1", "corr2") model_shape = ("row", "chan", "dir", "corr1", "corr2") jones_shape = ("row", "ant", "chan", "dir", "corr1", "corr2") else: raise ValueError("Unknown mode argument of %s" % mode) # the new_axes={"corr2": 2} is required because of a dask bug # see https://github.com/dask/dask/issues/5550 return blockwise(_residual_vis_wrapper, out_shape, time_bin_indices, ("row", ), time_bin_counts, ("row", ), antenna1, ("row", ), antenna2, ("row", ), jones, jones_shape, vis, out_shape, flag, out_shape, model, model_shape, adjust_chunks={"row": antenna1.chunks[0]}, new_axes={"corr2": 2}, dtype=vis.dtype, align_arrays=False)
def compute_jhj_and_jhr(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, residual, model, flag): mode = check_type(jones, residual) if mode != DIAG_DIAG: raise NotImplementedError("Only DIAG-DIAG case has been implemented") jacobian = jacobian_factory(mode) def _jhj_and_jhr_fn(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, residual, model, flag): # for chunked dask arrays we need to adjust the chunks to # start counting from zero (see also map_blocks) time_bin_indices -= time_bin_indices.min() jones_shape = np.shape(jones) n_tim = jones_shape[0] n_chan = jones_shape[2] n_dir = jones_shape[3] # storage arrays jhr = np.zeros(jones.shape, dtype=jones.dtype) jhj = np.zeros(jones.shape, dtype=jones.real.dtype) # tmp array the shape of jones_corr jac = np.zeros_like(jones[0, 0, 0, 0], dtype=jones.dtype) for t in range(n_tim): for row in range(time_bin_indices[t], time_bin_indices[t] + time_bin_counts[t]): p = antenna1[row] q = antenna2[row] for nu in range(n_chan): if np.any(flag[row, nu]): continue gp = jones[t, p, nu] gq = jones[t, q, nu] for s in range(n_dir): # for the derivative w.r.t. antenna p jacobian(gp[s], model[row, nu, s], gq[s], 1.0j, jac) jhj[t, p, nu, s] += (np.conj(jac) * jac).real jhr[t, p, nu, s] += (np.conj(jac) * residual[row, nu]) # for the derivative w.r.t. antenna q jacobian(gp[s], model[row, nu, s], gq[s], -1.0j, jac) jhj[t, q, nu, s] += (np.conj(jac) * jac).real jhr[t, q, nu, s] += (np.conj(jac) * residual[row, nu]) return jhj, jhr return _jhj_and_jhr_fn
def gauss_newton(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag, model, weight, tol=1e-4, maxiter=100): # whiten data sqrtweights = np.sqrt(weight) vis *= sqrtweights model *= sqrtweights[:, :, None] mode = check_type(jones, vis) # can avoid recomputing JHJ in DIAG_DIAG mode if mode == DIAG_DIAG: jhj = compute_jhj(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model, flag) else: raise NotImplementedError("Only DIAG_DIAG mode implemented") eps = 1.0 k = 0 while eps > tol and k < maxiter: # keep track of old phases phases = np.angle(jones) # get residual TODO - we can avoid this in DIE case residual = residual_vis(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, vis, flag, model) jhr = compute_jhr(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, residual, model, flag) # implement update phases_new = phases + 0.5 * (jhr/jhj).real jones = np.exp(1.0j * phases_new) # check convergence/iteration control eps = np.abs(phases_new - phases).max() k += 1 return jones, jhj, jhr, k
def compute_jhj(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model, flag): mode = check_type(jones, model, vis_type='model') jacobian = jacobian_factory(mode) def _compute_jhj_fn(time_bin_indices, time_bin_counts, antenna1, antenna2, jones, model, flag): # for dask arrays we need to adjust the chunks to # start counting from zero time_bin_indices -= time_bin_indices.min() jones_shape = np.shape(jones) n_tim = jones_shape[0] n_chan = jones_shape[2] n_dir = jones_shape[3] jhj = np.zeros(jones.shape, dtype=jones.real.dtype) # tmp array the shape of jones_corr jac = np.zeros_like(jones[0, 0, 0, 0], dtype=jones.dtype) for t in range(n_tim): for row in range(time_bin_indices[t], time_bin_indices[t] + time_bin_counts[t]): p = antenna1[row] q = antenna2[row] for nu in range(n_chan): if np.any(flag[row, nu]): continue gp = jones[t, p, nu] gq = jones[t, q, nu] for s in range(n_dir): jacobian(gp[s], model[row, nu, s], gq[s], 1.0j, jac) jhj[t, p, nu, s] += (jac.conjugate() * jac).real jacobian(gp[s], model[row, nu, s], gq[s], -1.0j, jac) jhj[t, q, nu, s] += (jac.conjugate() * jac).real return jhj return _compute_jhj_fn