def weight_spectrum_normaliser_factory(present): if present: def impl(wt_spec_out, wt_spec_in, row, chan, corr): wt_spec_out[row, chan, corr] = wt_spec_in[row, chan, corr] else: def impl(wt_spec_out, wt_spec_in, row, chan, corr): pass return njit(nogil=True, cache=True)(impl)
def is_chan_flagged_factory(present): if present: def impl(flag, r, f, c): return flag[r, f, c] else: def impl(flag, r, f, c): return False return njit(nogil=True, cache=True)(impl)
def matching_flag_factory(present): if present: def impl(flag_row, ri, out_flag_row, ro): return flag_row[ri] == out_flag_row[ro] else: def impl(flag_row, ri, out_flag_row, ro): return True return njit(nogil=True, cache=True)(impl)
def promote_base_factory(is_base_list): if is_base_list: def impl(base, npol): return base + [base[-1]] * (npol - len(base)) else: def impl(base, npol): return [base] * npol return njit(nogil=True, cache=True)(impl)
def add_pol_dim_factory(have_pol_dim): if have_pol_dim: def impl(array): return array else: def impl(array): return array.reshape(array.shape + (1,)) return njit(nogil=True, cache=True)(impl)
def set_flagged_factory(present): if present: def impl(flag, r, f, c): flag[r, f, c] = 1 else: def impl(flag, r, f, c): pass return njit(nogil=True, cache=True)(impl)
def jones_inverse_mul_factory(mode): if mode == DIAG_DIAG: def jones_inverse_mul(a1j, blj, a2j, out): out[...] = blj / (a1j * np.conj(a2j)) elif mode == DIAG: def jones_inverse_mul(a1j, blj, a2j, out): out[0, 0] = blj[0, 0] / (a1j[0] * np.conj(a2j[0])) out[0, 1] = blj[0, 1] / (a1j[0] * np.conj(a2j[1])) out[1, 0] = blj[1, 0] / (a1j[1] * np.conj(a2j[0])) out[1, 1] = blj[1, 1] / (a1j[1] * np.conj(a2j[1])) elif mode == FULL: def jones_inverse_mul(a1j, blj, a2j, out): # get determinant deta1j = a1j[0, 0] * a1j[1, 1] - a1j[0, 1] * a1j[1, 0] # compute inverse a00 = a1j[1, 1] / deta1j a01 = -a1j[1, 0] / deta1j a10 = -a1j[0, 1] / deta1j a11 = a1j[0, 0] / deta1j # get determinant deta2j = a2j[0, 0] * a2j[1, 1] - a2j[0, 1] * a2j[1, 0] # get conjugate transpose inverse b00 = np.conj(a2j[1, 1] / deta2j) b01 = np.conj(-a2j[1, 0] / deta2j) b10 = np.conj(-a2j[0, 1] / deta2j) b11 = np.conj(a2j[0, 0] / deta2j) # precompute resuable terms t1 = a00 * blj[0, 0] t2 = a01 * blj[1, 0] t3 = a00 * blj[0, 1] t4 = a01 * blj[1, 1] # overwrite with result out[0, 0] = t1*b00 +\ t2*b00 +\ t3*b10 +\ t4*b10 out[0, 1] = t1*b01 +\ t2*b01 +\ t3*b11 +\ t4*b11 t1 = a10 * blj[0, 0] t2 = a11 * blj[1, 0] t3 = a10 * blj[0, 1] t4 = a11 * blj[1, 1] out[1, 0] = t1*b00 +\ t2*b00 +\ t3*b10 +\ t4*b10 out[1, 1] = t1*b01 +\ t2*b01 +\ t3*b11 +\ t4*b11 return njit(nogil=True)(jones_inverse_mul)
def subtract_model_factory(mode): if mode == DIAG_DIAG: def subtract_model(a1j, blj, a2j, model, out): n_dir = np.shape(model)[0] out[...] = blj for s in range(n_dir): out -= a1j[s] * model[s] * np.conj(a2j[s]) elif mode == DIAG: def subtract_model(a1j, blj, a2j, model, out): n_dir = np.shape(model)[0] out[...] = blj for s in range(n_dir): out[0, 0] -= a1j[s, 0] * model[s, 0, 0] * np.conj(a2j[s, 0]) out[0, 1] -= a1j[s, 0] * model[s, 0, 1] * np.conj(a2j[s, 1]) out[1, 0] -= a1j[s, 1] * model[s, 1, 0] * np.conj(a2j[s, 0]) out[1, 1] -= a1j[s, 1] * model[s, 1, 1] * np.conj(a2j[s, 1]) elif mode == FULL: def subtract_model(a1j, blj, a2j, model, out): n_dir = np.shape(model)[0] for s in range(n_dir): # precompute resuable terms t1 = a1j[s, 0, 0] * model[s, 0, 0] t2 = a1j[s, 0, 1] * model[s, 1, 0] t3 = a1j[s, 0, 0] * model[s, 0, 1] t4 = a1j[s, 0, 1] * model[s, 1, 1] tmp = np.conj(a2j[s].T) # overwrite with result out[0, 0] = blj[0, 0] -\ t1*tmp[0, 0] +\ t2*tmp[0, 0] +\ t3*tmp[1, 0] +\ t4*tmp[1, 0] out[0, 1] = blj[0, 1] -\ t1*tmp[0, 1] +\ t2*tmp[0, 1] +\ t3*tmp[1, 1] +\ t4*tmp[1, 1] t1 = a1j[s, 1, 0] * model[s, 0, 0] t2 = a1j[s, 1, 1] * model[s, 1, 0] t3 = a1j[s, 1, 0] * model[s, 0, 1] t4 = a1j[s, 1, 1] * model[s, 1, 1] out[1, 0] = blj[1, 0] -\ t1*tmp[0, 0] +\ t2*tmp[0, 0] +\ t3*tmp[1, 0] +\ t4*tmp[1, 0] out[1, 1] = blj[1, 1] -\ t1*tmp[0, 1] +\ t2*tmp[0, 1] +\ t3*tmp[1, 1] +\ t4*tmp[1, 1] return njit(nogil=True)(subtract_model)
def weight_sum_output_factory(present): """ Returns function producing vis weight sum if vis present """ if present: def impl(shape, array): return np.zeros(shape, dtype=array.real.dtype) else: def impl(shape, array): pass return njit(nogil=True, cache=True)(impl)
def chan_output_factory(present): """ Returns function producing outputs if the array is present """ if present: def impl(shape, array): return np.zeros(shape, dtype=array.dtype) else: def impl(shape, array): pass return njit(nogil=True, cache=True)(impl)
def normaliser_factory(present): """ Returns function for normalising data in a bin """ if present: def impl(data, row, bin_size): data[row] /= bin_size else: def impl(data, row, bin_size): pass return njit(nogil=True, cache=True)(impl)
def chan_add_factory(present): """ Returns function for adding data to a bin """ if present: def impl(output, input, orow, ochan, irow, ichan, corr): output[orow, ochan, corr] += input[irow, ichan, corr] else: def impl(output, input, orow, ochan, irow, ichan, corr): pass return njit(nogil=True, cache=True)(impl)
def chan_normaliser_factory(present): """ Returns function normalising channel data in a bin """ if present: def impl(data_out, data_in, row, chan, corr, bin_size): data_out[row, chan, corr] = data_in[row, chan, corr] / bin_size else: def impl(data_out, data_in, row, chan, corr, bin_size): pass return njit(nogil=True, cache=True)(impl)
def add_coh_factory(have_bvis): if have_bvis: def add_coh(base_vis, out): out += base_vis else: # noop def add_coh(base_vis, out): pass return njit(nogil=True)(add_coh)
def output_factory(have_flag_row): if have_flag_row: def impl(rows, flag_row): return np.zeros(rows, dtype=flag_row.dtype) else: def impl(rows, flag_row): return None return njit(nogil=True, cache=True)(impl)
def output_factory(present): """ Returns function creating an output if present """ if present: def impl(rows, array): return np.zeros((rows,) + array.shape[1:], array.dtype) else: def impl(rows, array): return None return njit(nogil=True, cache=True)(impl)
def sum_coherencies_factory(have_ddes, have_coh, jones_type): """ Factory function generating a function that sums coherencies """ jones_mul = jones_mul_factory(have_ddes, have_coh, jones_type, True) if have_ddes and have_coh: def sum_coh_fn(time, ant1, ant2, a1j, blj, a2j, tmin, out): for s in range(a1j.shape[0]): for r in range(time.shape[0]): ti = time[r] - tmin a1 = ant1[r] a2 = ant2[r] for f in range(a1j.shape[3]): jones_mul(a1j[s, ti, a1, f], blj[s, r, f], a2j[s, ti, a2, f], out[r, f]) elif have_ddes and not have_coh: def sum_coh_fn(time, ant1, ant2, a1j, blj, a2j, tmin, out): for s in range(a1j.shape[0]): for r in range(time.shape[0]): ti = time[r] - tmin a1 = ant1[r] a2 = ant2[r] for f in range(a1j.shape[3]): jones_mul(a1j[s, ti, a1, f], a2j[s, ti, a2, f], out[r, f]) elif not have_ddes and have_coh: if jones_type == JONES_2X2: def sum_coh_fn(time, ant1, ant2, a1j, blj, a2j, tmin, out): for s in range(blj.shape[0]): for r in range(blj.shape[1]): for f in range(blj.shape[2]): for c1 in range(blj.shape[3]): for c2 in range(blj.shape[4]): out[r, f, c1, c2] += blj[s, r, f, c1, c2] else: def sum_coh_fn(time, ant1, ant2, a1j, blj, a2j, tmin, out): for s in range(blj.shape[0]): for r in range(blj.shape[1]): for f in range(blj.shape[2]): for c in range(blj.shape[3]): out[r, f, c] += blj[s, r, f, c] else: # noop def sum_coh_fn(time, ant1, ant2, a1j, blj, a2j, tmin, out): pass return njit(nogil=True, inline='always')(sum_coh_fn)
def jacobian_factory(mode): if mode == DIAG_DIAG: def jacobian(a1j, blj, a2j, sign, out): out[...] = sign * a1j * blj * np.conj(a2j) elif mode == DIAG: def jacobian(a1j, blj, a2j, sign, out): out[...] = 0 elif mode == FULL: def jacobian(a1j, blj, a2j, sign, out): out[...] = 0 return njit(nogil=True)(jacobian)
def is_flagged_factory(have_flag_row): if have_flag_row: def impl(flag_row, r): return flag_row[r] != 0 else: def impl(flag_row, r): return False return njit(nogil=True, cache=True)(impl)
def vis_normaliser_factory(present): if present: def impl(vis_out, vis_in, row, chan, corr, weight_sum): wsum = weight_sum[row, chan, corr] if wsum != 0.0: vis_out[row, chan, corr] = vis_in[row, chan, corr] / wsum else: def impl(vis_out, vis_in, row, chan, corr, weight_sum): pass return njit(nogil=True, cache=True)(impl)
def jacobian_factory(mode): if mode == DIAG_DIAG: def jacobian(a1j, blj, a2j, sign, out): out[...] = sign * a1j * blj * a2j.conjugate() # for c in range(out.shape[-1]): # out[c] = sign * a1j[c] * blj[c] * a2j[c].conjugate() elif mode == DIAG: def jacobian(a1j, blj, a2j, sign, out): out[...] = 0 elif mode == FULL: def jacobian(a1j, blj, a2j, sign, out): out[...] = 0 return njit(nogil=True, inline='always')(jacobian)
def pol_getter_factory(npoldims): if npoldims == 0: def impl(pol_shape): return 1 else: def impl(pol_shape): npols = 1 for c in pol_shape: npols *= c return npols return njit(nogil=True, cache=True)(impl)
def jones_mul_factory(mode): if mode == DIAG_DIAG: def jones_mul(a1j, model, a2j, out): n_dir = np.shape(model)[0] for s in range(n_dir): out += a1j[s] * model[s] * np.conj(a2j[s]) elif mode == DIAG: def jones_mul(a1j, model, a2j, out): n_dir = np.shape(model)[0] for s in range(n_dir): out[0, 0] += a1j[s, 0] * model[s, 0, 0] * np.conj(a2j[s, 0]) out[0, 1] += a1j[s, 0] * model[s, 0, 1] * np.conj(a2j[s, 1]) out[1, 0] += a1j[s, 1] * model[s, 1, 0] * np.conj(a2j[s, 0]) out[1, 1] += a1j[s, 1] * model[s, 1, 1] * np.conj(a2j[s, 1]) elif mode == FULL: def jones_mul(a1j, model, a2j, out): n_dir = np.shape(model)[0] for s in range(n_dir): # precompute resuable terms t1 = a1j[s, 0, 0] * model[s, 0, 0] t2 = a1j[s, 0, 1] * model[s, 1, 0] t3 = a1j[s, 0, 0] * model[s, 0, 1] t4 = a1j[s, 0, 1] * model[s, 1, 1] tmp = np.conj(a2j[s].T) # overwrite with result out[0, 0] += t1*tmp[0, 0] +\ t2*tmp[0, 0] +\ t3*tmp[1, 0] +\ t4*tmp[1, 0] out[0, 1] += t1*tmp[0, 1] +\ t2*tmp[0, 1] +\ t3*tmp[1, 1] +\ t4*tmp[1, 1] t1 = a1j[s, 1, 0] * model[s, 0, 0] t2 = a1j[s, 1, 1] * model[s, 1, 0] t3 = a1j[s, 1, 0] * model[s, 0, 1] t4 = a1j[s, 1, 1] * model[s, 1, 1] out[1, 0] += t1*tmp[0, 0] +\ t2*tmp[0, 0] +\ t3*tmp[1, 0] +\ t4*tmp[1, 0] out[1, 1] += t1*tmp[0, 1] +\ t2*tmp[0, 1] +\ t3*tmp[1, 1] +\ t4*tmp[1, 1] return njit(nogil=True, inline='always')(jones_mul)
def comp_add_factory(present): """ Returns function for adding data with components to a bin. Rows are assumed to be in the first dimension and components are assumed to be in the second """ if present: def impl(output, orow, input, irow): for c in range(output.shape[1]): output[orow, c] += input[irow, c] else: def impl(input, irow, output, orow): pass return njit(nogil=True, cache=True)(impl)
def sigma_spectrum_normaliser_factory(present): if present: def impl(sigma_out, sigma_in, row, chan, corr, weight_sum): wsum = weight_sum[row, chan, corr] if wsum == 0.0: return # sqrt(sigma**2 * weight**2 / (weight(sum**2))) res = np.sqrt(sigma_in[row, chan, corr] / (wsum**2)) sigma_out[row, chan, corr] = res else: def impl(sigma_out, sigma_in, row, chan, corr, weight_sum): pass return njit(nogil=True, cache=True)(impl)
def set_flag_row_factory(have_flag_row): if have_flag_row: def impl(flag_row, in_row, out_flag_row, out_row, flagged): if flag_row[in_row] == 0 and flagged: raise RowMapperError("Unflagged input row contributing " "to flagged output row. " "This should never happen!") out_flag_row[out_row] = (1 if flagged else 0) else: def impl(flag_row, in_row, out_flag_row, out_row, flagged): pass return njit(nogil=True, cache=True)(impl)
def sigma_normaliser_factory(present): """ Returns function for normalising sigma in a bin """ if present: def impl(sigma, row, weight_sum): for c in range(sigma.shape[1]): wt = weight_sum[row, c] if wt == 0.0: continue sigma[row, c] = np.sqrt(sigma[row, c] / (wt**2)) else: def impl(sigma, row, weight_sum): pass return njit(nogil=True, cache=True)(impl)
def output_factory(have_ddes, have_coh, have_dies, have_base_vis, out_dtype): """ Factory function generating a function that creates function output """ if have_ddes: def output(time_index, dde1_jones, source_coh, dde2_jones, die1_jones, base_vis, die2_jones): row = time_index.shape[0] chan = dde1_jones.shape[3] corrs = dde1_jones.shape[4:] return np.zeros((row, chan) + corrs, dtype=out_dtype) elif have_coh: def output(time_index, dde1_jones, source_coh, dde2_jones, die1_jones, base_vis, die2_jones): row = time_index.shape[0] chan = source_coh.shape[2] corrs = source_coh.shape[3:] return np.zeros((row, chan) + corrs, dtype=out_dtype) elif have_dies: def output(time_index, dde1_jones, source_coh, dde2_jones, die1_jones, base_vis, die2_jones): row = time_index.shape[0] chan = die1_jones.shape[2] corrs = die1_jones.shape[3:] return np.zeros((row, chan) + corrs, dtype=out_dtype) elif have_base_vis: def output(time_index, dde1_jones, source_coh, dde2_jones, die1_jones, base_vis, die2_jones): row = time_index.shape[0] chan = base_vis.shape[1] corrs = base_vis.shape[2:] return np.zeros((row, chan) + corrs, dtype=out_dtype) else: raise ValueError("Insufficient inputs were supplied " "for determining the output shape") # TODO(sjperkins) # perhaps inline='always' on resolution of # https://github.com/numba/numba/issues/4691 return njit(nogil=True, inline='never')(output)
def apply_dies_factory(have_dies, have_bvis, jones_type): """ Factory function returning a function that applies Direction Independent Effects """ # We always "have visibilities", (the output array) jones_mul = jones_mul_factory(have_dies, True, jones_type, False) if have_dies and have_bvis: def apply_dies(time, ant1, ant2, die1_jones, die2_jones, tmin, out): # Iterate over rows for r in range(time.shape[0]): ti = time[r] - tmin a1 = ant1[r] a2 = ant2[r] # Iterate over channels for c in range(out.shape[1]): jones_mul(die1_jones[ti, a1, c], out[r, c], die2_jones[ti, a2, c], out[r, c]) elif have_dies and not have_bvis: def apply_dies(time, ant1, ant2, die1_jones, die2_jones, tmin, out): # Iterate over rows for r in range(time.shape[0]): ti = time[r] - tmin a1 = ant1[r] a2 = ant2[r] # Iterate over channels for c in range(out.shape[1]): jones_mul(die1_jones[ti, a1, c], out[r, c], die2_jones[ti, a2, c], out[r, c]) else: # noop def apply_dies(time, ant1, ant2, die1_jones, die2_jones, tmin, out): pass return njit(nogil=True, inline='always')(apply_dies)
def sigma_spectrum_add_factory(have_vis, have_weight, have_weight_spectrum): """ Returns function adding weighted sigma to a bin """ if not have_vis: def impl(out_sigma, out_weight_sum, in_sigma, weight, weight_spectrum, orow, ochan, irow, ichan, corr): pass elif have_weight_spectrum: # Always prefer more accurate weight spectrum if we have it def impl(out_sigma, out_weight_sum, in_sigma, weight, weight_spectrum, orow, ochan, irow, ichan, corr): # sum(sigma**2 * weight**2) wt = weight_spectrum[irow, ichan, corr] is_ = in_sigma[irow, ichan, corr]**2 * wt**2 out_sigma[orow, ochan, corr] += is_ out_weight_sum[orow, ochan, corr] += wt elif have_weight: # Otherwise fall back to row weights def impl(out_sigma, out_weight_sum, in_sigma, weight, weight_spectrum, orow, ochan, irow, ichan, corr): # sum(sigma**2 * weight**2) wt = weight[irow] is_ = in_sigma[irow, ichan, corr]**2 * wt**2 out_sigma[orow, ochan, corr] += is_ out_weight_sum[orow, ochan, corr] += wt else: # Natural weights def impl(out_sigma, out_weight_sum, in_sigma, weight, weight_spectrum, orow, ochan, irow, ichan, corr): # sum(sigma**2 * weight**2) out_sigma[orow, ochan, corr] += in_sigma[irow, ichan, corr]**2 out_weight_sum[orow, ochan, corr] += 1.0 return njit(nogil=True, cache=True)(impl)