def _update_equation_counts(self, unflagged): """Sets up equation counters based on flagging information. Overrides base version to compute additional stuff""" MasterMachine._update_equation_counts(self, unflagged) self.eqs_per_interval = self.interval_sum(self.eqs_per_tf_slot) ndir = self.n_dir - len(self.fix_directions) if self.dd_term else 1 self.num_unknowns = self.dof_per_antenna * self.n_ant * ndir # The following determines the number of valid (unflagged) time/frequency slots and the number # of valid solution intervals. self.valid_intervals = self.eqs_per_interval > self.num_unknowns self.num_valid_intervals = self.valid_intervals.sum() self.n_valid_sols = self.num_valid_intervals * self.n_dir if self.num_valid_intervals: # Adjust chi-sq normalisation based on DoF count: MasterMachine computes chi-sq normalization # as 1/N_eq, we want to compute it as the reduced chi-square statistic, 1/(N_eq-N_dof) # This results in a per-interval correction factor with np.errstate(invalid='ignore', divide='ignore'): corrfact = self.eqs_per_interval.astype(float) / ( self.eqs_per_interval - self.num_unknowns) corrfact[~self.valid_intervals] = 0 self._chisq_tf_norm_factor *= self.unpack_intervals(corrfact) self._chisq_norm_factor *= corrfact.sum( ) / self.num_valid_intervals
def precompute_attributes(self, model_arr, flags_arr, inv_var_chan): """Precomputes various stats before starting a solution""" MasterMachine.precompute_attributes(self, model_arr, flags_arr, inv_var_chan) for term in self.jones_terms: if term.solvable: term.precompute_attributes(model_arr, flags_arr, inv_var_chan)
def next_iteration(self): """ Updates the iteration count on the relevant element of the Jones chain. It will also handle updating the active Jones term. Ultimately, this should handle any complicated convergence/term switching functionality. """ self.last_active_index = self.active_index major_step = False if self.active_term.has_converged or self.active_term.has_stalled: print("term {} {} ({} iters): {}".format( self.active_term.jones_label, "converged" if self.active_term.has_converged else "stalled", self.active_term.iters, self.active_term.final_convergence_status_string), file=log(1)) self._convergence_states.append( self.active_term.final_convergence_status_string) self._convergence_states_finalized = True self._next_chain_term() major_step = True self.active_term.next_iteration() return MasterMachine.next_iteration(self)[0], major_step
def __init__(self, label, data_arr, ndir, nmod, times, frequencies, chunk_label, jones_options): """ Initialises a chain of complex 2x2 gain machines. Args: label (str): Label identifying the Jones term. data_arr (np.ndarray): Shape (n_mod, n_tim, n_fre, n_ant, n_ant, n_cor, n_cor) array containing observed visibilities. ndir (int): Number of directions. nmod (nmod): Number of models. times (np.ndarray): Times for the data being processed. frequencies (np.ndarray): Frequencies for the data being processsed. jones_options (dict): Dictionary of options pertaining to the chain. """ from cubical.main import UserInputError # This instantiates the number of complex 2x2 elements in our chain. Each element is a # gain machine in its own right - the purpose of this machine is to manage these machines # and do the relevant fiddling between parameter updates. When combining DD terms with # DI terms, we need to be initialise the DI terms using only one direction - we do this with # slicing rather than summation as it is slightly faster. self.jones_terms = [] self.num_left_di_terms = 0 # how many DI terms are there at the left of the chain seen_dd_term = False for iterm, term_opts in enumerate(jones_options['chain']): jones_class = machine_types.get_machine_class(term_opts['type']) if jones_class is None: raise UserInputError("unknown Jones class '{}'".format( term_opts['type'])) if not issubclass(jones_class, Complex2x2Gains) and not issubclass( jones_class, ComplexW2x2Gains) and term_opts['solvable']: raise UserInputError( "only complex-2x2 or robust-2x2 terms can be made solvable in a Jones chain" ) term = jones_class(term_opts["label"], data_arr, ndir, nmod, times, frequencies, chunk_label, term_opts) self.jones_terms.append(term) if term.dd_term: seen_dd_term = True elif not seen_dd_term: self.num_left_di_terms = iterm MasterMachine.__init__(self, label, data_arr, ndir, nmod, times, frequencies, chunk_label, jones_options) self.chain = cubical.kernels.import_kernel("chain") # kernel used for compute_residuals and such self.kernel = Complex2x2Gains.get_full_kernel( jones_options, diag_gains=self.is_diagonal) self.n_dir, self.n_mod = ndir, nmod _, self.n_tim, self.n_fre, self.n_ant, self.n_ant, self.n_cor, self.n_cor = data_arr.shape self.n_terms = len(self.jones_terms) # make list of number of iterations per solvable term # If not specified, just use the maxiter setting of each term # note that this list is updated as we converge, so make a copy term_iters = jones_options['sol']['term-iters'] if not term_iters: self.term_iters = [ term.maxiter for term in self.jones_terms if term.solvable ] elif type(term_iters) is int: self.term_iters = [term_iters] elif isinstance(term_iters, (list, tuple)): self.term_iters = list(term_iters) else: raise UserInputError( "invalid term-iters={} setting".format(term_iters)) self.solvable = bool(self.term_iters) and any( [term.solvable for term in self.jones_terms]) # setup first solvable term in chain self.active_index = None # this list accumulates the per-term convergence status strings self._convergence_states = [] # True when the last active term has had its convergence status queried self._convergence_states_finalized = False self.cached_model_arr = self._r = self._m = None
def next_iteration(self): np.copyto(self.old_gains, self.gains) return MasterMachine.next_iteration(self)
def precompute_attributes(self, model_arr, flags_arr, inv_var_chan): """Precomputes various stats before starting a solution""" unflagged = MasterMachine.precompute_attributes( self, model_arr, flags_arr, inv_var_chan) # Pre-flag gain solution intervals that are completely flagged in the input data # (i.e. MISSING|PRIOR). This has shape (n_timint, n_freint, n_ant). missing_intervals = self.interval_and( (flags_arr & (FL.MISSING | FL.PRIOR) != 0).all(axis=-1)) self.missing_gain_fraction = missing_intervals.sum() / float( missing_intervals.size) # convert the intervals array to gain shape, and apply flags self.gflags[:, self._interval_to_gainres(missing_intervals)] = FL.MISSING # number of data points per time/frequency/antenna numeq_tfa = unflagged.sum(axis=-1) # compute error estimates per direction, antenna, and interval with np.errstate(invalid='ignore', divide='ignore'): sigmasq = 1 / inv_var_chan # squared noise per channel. Could be infinite if no data # collapse direction axis, if not directional if not self.dd_term: model_arr = model_arr.sum(axis=0, keepdims=True) # mean |model|^2 per direction+TFA modelsq = (model_arr*np.conj(model_arr)).real.sum(axis=(1,-1,-2,-3)) / \ (self.n_mod*self.n_cor*self.n_cor*numeq_tfa) modelsq[:, numeq_tfa == 0] = 0 # inverse SNR^2 per direction+TFA inv_snr2 = sigmasq[np.newaxis, np.newaxis, :, np.newaxis] / modelsq inv_snr2[:, numeq_tfa == 0] = 0 # take the mean SNR^-2 over each interval # numeq_tfa becomes number of points per interval, antenna numeq_tfa = self.interval_sum(numeq_tfa) inv_snr2_int = self.interval_sum(inv_snr2, 1) / numeq_tfa[np.newaxis, ...] inv_snr2_int[:, numeq_tfa == 0] = 0 # convert that into a gain error per direction,interval,antenna self.prior_gain_error = np.sqrt( inv_snr2_int / (self.eqs_per_interval - self.num_unknowns)[np.newaxis, :, :, np.newaxis]) self.prior_gain_error[:, ~self.valid_intervals, :] = 0 # reset to 0 for fixed directions if self.dd_term: self.prior_gain_error[self.fix_directions, ...] = 0 # flag gains on max error bad_gain_intervals = self.prior_gain_error > self.max_gain_error # dir,time,freq,ant if bad_gain_intervals.any(): # (n_dir,) array showing how many were flagged per direction self._n_flagged_on_max_error = bad_gain_intervals.sum(axis=(1, 2, 3)) # raised corresponding gain flags self.gflags[self._interval_to_gainres(bad_gain_intervals, 1)] |= FL.LOWSNR self.prior_gain_error[bad_gain_intervals] = 0 # flag intervals where all directions are bad, and propagate that out into flags bad_intervals = bad_gain_intervals.all(axis=0) if bad_intervals.any(): bad_slots = self.unpack_intervals(bad_intervals) flags_arr[bad_slots, ...] |= FL.LOWSNR unflagged[bad_slots, ...] = False self._update_equation_counts(unflagged) else: self._n_flagged_on_max_error = None self._n_flagged_on_max_posterior_error = None self.flagged = self.gflags != 0 self.n_flagged = self.flagged.sum() return unflagged
def __init__(self, label, data_arr, ndir, nmod, times, frequencies, chunk_label, options, cykernel): """ Initialises a gain machine which supports solution intervals. Args: label (str): Label identifying the Jones term. data_arr (np.ndarray): Shape (n_mod, n_tim, n_fre, n_ant, n_ant, n_cor, n_cor) array containing observed visibilities. ndir (int): Number of directions. nmod (nmod): Number of models. times (np.ndarray): Times for the data being processed. freqs (np.ndarray): Frequencies for the data being processsed. options (dict): Dictionary of options. """ MasterMachine.__init__(self, label, data_arr, ndir, nmod, times, frequencies, chunk_label, options) self.cykernel = cykernel self.t_int = options["time-int"] or self.n_tim self.f_int = options["freq-int"] or self.n_fre self.eps = 1e-6 # Initialise attributes used for computing values over intervals. # n_tim and n_fre are the time and frequency dimensions of the data arrays. # n_timint and n_freint are the time and frequency dimensions of the gains. self.t_bins = range(0, self.n_tim, self.t_int) self.f_bins = range(0, self.n_fre, self.f_int) self.n_timint = len(self.t_bins) self.n_freint = len(self.f_bins) self.n_tf_ints = self.n_timint * self.n_freint # number of valid solutions self.n_valid_sols = self.n_dir * self.n_tf_ints # split grids into intervals, and find the centre of gravity of each timebins = np.split(times, self.t_bins[1:]) freqbins = np.split(frequencies, self.f_bins[1:]) timegrid = np.array([float(x.mean()) for x in timebins]) freqgrid = np.array([float(x.mean()) for x in freqbins]) # interval_grid determines the per-interval grid poins self.interval_grid = dict(time=timegrid, freq=freqgrid) # data_grid determines the full resolution grid self.data_grid = dict(time=times, freq=frequencies) # compute index from each data point to interval number t_ind = np.arange(self.n_tim) // self.t_int f_ind = np.arange(self.n_fre) // self.f_int self.t_mapping, self.f_mapping = np.meshgrid(t_ind, f_ind, indexing="ij") # Initialise attributes used in convergence testing. n_cnvgd is the number # of solutions which have converged. self._has_stalled = False self.n_cnvgd = 0 self._frac_cnvgd = 0 self.iters = 0 self.min_quorum = options["conv-quorum"] self.update_type = options["update-type"] self.ref_ant = options["ref-ant"] self.fix_directions = options["fix-dirs"] or [] if type(self.fix_directions) is int: self.fix_directions = [self.fix_directions] # True if gains are loaded from a DB self._gains_loaded = False # Construct flag array and populate flagging attributes. self.max_gain_error = options["max-prior-error"] self.max_post_error = options["max-post-error"] self.clip_lower = options["clip-low"] self.clip_upper = options["clip-high"] self.clip_after = options["clip-after"] self.init_gains() self.old_gains = self.gains.copy() # Gain error estimates. Populated by subclasses, if available # Should be array of same shape as the gains self.prior_gain_error = None self.posterior_gain_error = None # buffers for arrays used in internal updates self._jh = self._jhr = self._jhj = self._gh = self._r = self._ginv = self._ghinv = None self._update = None # flag: have gains been updated self._gh_update = self._ghinv_update = True
def precompute_attributes(self, data_arr, model_arr, flags_arr, inv_var_chan): """Precomputes various attributes of the machine before starting a solution""" unflagged = MasterMachine.precompute_attributes(self, data_arr, model_arr, flags_arr, inv_var_chan) ## NB: not sure why I used to apply MISSING|PRIOR here. Surely other input flags must be honoured ## (SKIPSOL, NULLDATA, etc.)? ### Pre-flag gain solution intervals that are completely flagged in the input data ### (i.e. MISSING|PRIOR). This has shape (n_timint, n_freint, n_ant). missing_intervals = self.interval_and((flags_arr!=0).all(axis=-1)) self.missing_gain_fraction = missing_intervals.sum() / float(missing_intervals.size) # convert the intervals array to gain shape, and apply flags self.gflags[:, self._interval_to_gainres(missing_intervals)] = FL.MISSING # number of data points per time/frequency/antenna numeq_tfa = unflagged.sum(axis=-1) # compute error estimates per direction, antenna, and interval if inv_var_chan is not None: with np.errstate(invalid='ignore', divide='ignore'): # collapse direction axis, if not directional if not self.dd_term: model_arr = model_arr.sum(axis=0, keepdims=True) # mean |model|^2 per direction+TFA modelsq = (model_arr*np.conj(model_arr)).real.sum(axis=(1,-1,-2,-3)) / \ (self.n_mod*self.n_cor*self.n_cor*numeq_tfa) modelsq[:, numeq_tfa==0] = 0 sigmasq = 1.0/inv_var_chan # squared noise per channel. Could be infinite if no data # take the sigma (in quadrature) over each interval # divided by quadrature unflagged contributing interferometers per interval # this yields var<g> # (numeq_tfa becomes number of unflagged points per interval, antenna) numeq_tfa = self.interval_sum(numeq_tfa) sigmasq[np.logical_or(np.isnan(sigmasq), np.isinf(sigmasq))] = 0.0 modelsq[np.logical_or(np.isnan(modelsq), np.isinf(modelsq))] = 0.0 NSR_int = self.interval_sum(np.ones_like(modelsq) * (sigmasq)[None, None, :, None], 1) / \ (self.interval_sum(modelsq, 1) * numeq_tfa) # convert that into a gain error per direction,interval,antenna self.prior_gain_error = np.sqrt(NSR_int) if self.dd_term: self.prior_gain_error[self.fix_directions, ...] = 0 pge_flag_invalid = np.logical_or(np.isnan(self.prior_gain_error), np.isinf(self.prior_gain_error)) invalid_models = np.logical_or(self.interval_sum(modelsq, 1) == 0, np.logical_or(np.isnan(self.interval_sum(modelsq, 1)), np.isinf(self.interval_sum(modelsq, 1)))) if np.any(np.all(numeq_tfa == 0, axis=-1)) and log.verbosity() > 1: self.raise_userwarning( logging.CRITICAL, "One or more directions (or its frequency intervals) are already fully flagged.", 90, raise_once="prior_fully_flagged_dirs", verbosity=2, color="red") if np.any(np.all(invalid_models, axis=-1)) and log.verbosity() > 1: self.raise_userwarning( logging.CRITICAL, "One or more directions (or its frequency intervals) have invalid or 0 models.", 90, raise_once="invalid_models", verbosity=2, color="red") self.prior_gain_error[:, ~self.valid_intervals, :] = 0 # reset to 0 for fixed directions if self.dd_term: self.prior_gain_error[self.fix_directions, ...] = 0 # flag gains on max error self._n_flagged_on_max_error = None bad_gain_intervals = pge_flag_invalid if self.max_gain_error: low_snr = self.prior_gain_error > self.max_gain_error if low_snr.all(axis=0).all(): msg = "'{0:s}' {1:s} All directions flagged, either due to low SNR. "\ "You need to check your tagged directions and your max-prior-error and/or solution intervals. "\ "New flags will be raised for this chunk of data".format( self.jones_label, self.chunk_label) self.raise_userwarning(logging.CRITICAL, msg, 70, verbosity=log.verbosity(), color="red") else: if low_snr.all(axis=-1).all(axis=-1).all(axis=-1).any(): #all antennas fully flagged of some direction dir_snr = {} for d in range(self.prior_gain_error.shape[0]): percflagged = np.sum(low_snr[d]) * 100.0 / low_snr[d].size if percflagged > self.low_snr_warn and d not in self.fix_directions: dir_snr[d] = percflagged if len(dir_snr) > 0: if log.verbosity() > 2: msg = "Low SNR in one or more directions of gain '{0:s}' chunk '{1:s}':".format( self.jones_label, self.chunk_label) +\ "\n{0:s}\n".format("\n".join(["\t direction {0:s}: {1:.3f}% gains affected".format( str(d), dir_snr[d]) for d in sorted(dir_snr)])) +\ "Check your settings for gain solution intervals and max-prior-error. " else: msg = "'{0:s}' {1:s} Low SNR in directions {2:s}. Increase solution intervals or raise max-prior-error!".format( self.jones_label, self.chunk_label, ", ".join(map(str, sorted(dir_snr)))) self.raise_userwarning(logging.CRITICAL, msg, 50, verbosity=log.verbosity(), color="red") if low_snr.all(axis=0).all(axis=0).all(axis=-1).any(): msg = "'{0:s}' {1:s} All time of one or more frequency intervals flagged due to low SNR. "\ "You need to check your max-prior-error and/or solution intervals. "\ "New flags will be raised for this chunk of data".format( self.jones_label, self.chunk_label) self.raise_userwarning(logging.WARNING, msg, 70, verbosity=log.verbosity()) if low_snr.all(axis=0).all(axis=1).all(axis=-1).any(): msg = "'{0:s}' {1:s} All channels of one or more time intervals flagged due to low SNR. "\ "You need to check your max-prior-error and/or solution intervals. "\ "New flags will be raised for this chunk of data".format( self.jones_label, self.chunk_label) self.raise_userwarning(logging.WARNING, msg, 70, verbosity=log.verbosity()) stationflags = np.argwhere(low_snr.all(axis=0).all(axis=0).all(axis=0)).flatten() if stationflags.size > 0: msg = "'{0:s}' {1:s} Stations {2:s} ({3:d}/{4:d}) fully flagged due to low SNR. "\ "These stations may be faulty or your SNR requirements (max-prior-error) are not met. "\ "New flags will be raised for this chunk of data".format( self.jones_label, self.chunk_label, ", ".join(map(str, stationflags)), np.sum(low_snr.all(axis=0).all(axis=0).all(axis=0)), low_snr.shape[3]) self.raise_userwarning(logging.WARNING, msg, 70, verbosity=log.verbosity()) bad_gain_intervals = np.logical_or(bad_gain_intervals, low_snr) # dir,time,freq,ant if bad_gain_intervals.any(): # (n_dir,) array showing how many were flagged per direction self._n_flagged_on_max_error = bad_gain_intervals.sum(axis=(1,2,3)) # raised corresponding gain flags self.gflags[self._interval_to_gainres(bad_gain_intervals,1)] |= FL.LOWSNR self.prior_gain_error[bad_gain_intervals] = 0 # flag intervals where all directions are bad, and propagate that out into flags bad_intervals = bad_gain_intervals.all(axis=0) if bad_intervals.any(): bad_slots = self.unpack_intervals(bad_intervals) flags_arr[bad_slots,...] |= FL.LOWSNR unflagged[bad_slots,...] = False self.update_equation_counts(unflagged) self._n_flagged_on_max_posterior_error = None self.flagged = self.gflags != 0 self.n_flagged = self.flagged.sum() return unflagged
def __init__(self, label, data_arr, ndir, nmod, times, frequencies, chunk_label, options): """ Initialises a gain machine which supports solution intervals. Args: label (str): Label identifying the Jones term. data_arr (np.ndarray): Shape (n_mod, n_tim, n_fre, n_ant, n_ant, n_cor, n_cor) array containing observed visibilities. ndir (int): Number of directions. nmod (nmod): Number of models. times (np.ndarray): Times for the data being processed. freqs (np.ndarray): Frequencies for the data being processsed. options (dict): Dictionary of options. diag_gains (bool): If True, gains are diagonal-only. Else gains are full 2x2. """ MasterMachine.__init__(self, label, data_arr, ndir, nmod, times, frequencies, chunk_label, options) # select which kernels to use for computing full data self.kernel = self.get_full_kernel(options, self.is_diagonal) # kernel used in solver is diag-diag in diag mode, else uses full kernel version if options.get('diag-data') or options.get('diag-only'): self.kernel_solve = cubical.kernels.import_kernel('diagdiag_complex') else: self.kernel_solve = self.kernel log(2).print("{} kernels are {} {}".format(label, self.kernel, self.kernel_solve)) self.t_int = options["time-int"] or self.n_tim self.f_int = options["freq-int"] or self.n_fre self.eps = 1e-6 # Initialise attributes used for computing values over intervals. # n_tim and n_fre are the time and frequency dimensions of the data arrays. # n_timint and n_freint are the time and frequency dimensions of the gains. self.t_bins = list(range(0, self.n_tim, self.t_int)) self.f_bins = list(range(0, self.n_fre, self.f_int)) self.n_timint = len(self.t_bins) self.n_freint = len(self.f_bins) self.n_tf_ints = self.n_timint * self.n_freint # number of valid solutions self.n_valid_sols = self.n_dir * self.n_tf_ints # split grids into intervals, and find the centre of gravity of each timebins = np.split(times, self.t_bins[1:]) freqbins = np.split(frequencies, self.f_bins[1:]) timegrid = np.array([float(x.mean()) for x in timebins]) freqgrid = np.array([float(x.mean()) for x in freqbins]) # interval_grid determines the per-interval grid poins self.interval_grid = dict(time=timegrid, freq=freqgrid) # data_grid determines the full resolution grid self.data_grid = dict(time=times, freq=frequencies) # compute index from each data point to interval number t_ind = np.arange(self.n_tim)//self.t_int f_ind = np.arange(self.n_fre)//self.f_int self.t_mapping, self.f_mapping = np.meshgrid(t_ind, f_ind, indexing="ij") # Initialise attributes used in convergence testing. n_cnvgd is the number # of solutions which have converged. self._has_stalled = False self.n_cnvgd = 0 self._frac_cnvgd = 0 self.iters = 0 self.min_quorum = options["conv-quorum"] self.update_type = options["update-type"] self.ref_ant = options["ref-ant"] self.fix_directions = options["fix-dirs"] if options["fix-dirs"] is not None and \ options["fix-dirs"] != "" else [] if type(self.fix_directions) is int: self.fix_directions = [self.fix_directions] if type(self.fix_directions) is str and re.match(r"^\W*\d{1,}(\W*,\W*\d{1,})*\W*$", self.fix_directions): self.fix_directions = map(int, map(str.strip, ",".split(self.fix_directions))) if not (type(self.fix_directions) is list and all(map(lambda x: type(x) is int, self.fix_directions))): raise ArgumentError("Fix directions must be number or list of numbers") # True if gains are loaded from a DB self._gains_loaded = False # Construct flag array and populate flagging attributes. self.max_gain_error = options["max-prior-error"] self.max_post_error = options["max-post-error"] self.low_snr_warn = options["low-snr-warn"] self.high_gain_var_warn = options["high-gain-var-warn"] self.clip_lower = options["clip-low"] self.clip_upper = options["clip-high"] self.clip_after = options["clip-after"] self.init_gains() self.old_gains = self.gains.copy() # Gain error estimates. Populated by subclasses, if available # Should be array of same shape as the gains self.prior_gain_error = None self.posterior_gain_error = None # buffers for arrays used in internal updates self._jh = self._jhr = self._jhj = self._gh = self._r = self._ginv = self._ghinv = None self._update = None # flag: have gains been updated self._gh_update = self._ghinv_update = True