def precompute_attributes(self, model_arr, flags_arr, inv_var_chan): """Precomputes various stats before starting a solution""" MasterMachine.precompute_attributes(self, model_arr, flags_arr, inv_var_chan) for term in self.jones_terms: if term.solvable: term.precompute_attributes(model_arr, flags_arr, inv_var_chan)
def precompute_attributes(self, model_arr, flags_arr, inv_var_chan): """Precomputes various stats before starting a solution""" unflagged = MasterMachine.precompute_attributes( self, model_arr, flags_arr, inv_var_chan) # Pre-flag gain solution intervals that are completely flagged in the input data # (i.e. MISSING|PRIOR). This has shape (n_timint, n_freint, n_ant). missing_intervals = self.interval_and( (flags_arr & (FL.MISSING | FL.PRIOR) != 0).all(axis=-1)) self.missing_gain_fraction = missing_intervals.sum() / float( missing_intervals.size) # convert the intervals array to gain shape, and apply flags self.gflags[:, self._interval_to_gainres(missing_intervals)] = FL.MISSING # number of data points per time/frequency/antenna numeq_tfa = unflagged.sum(axis=-1) # compute error estimates per direction, antenna, and interval with np.errstate(invalid='ignore', divide='ignore'): sigmasq = 1 / inv_var_chan # squared noise per channel. Could be infinite if no data # collapse direction axis, if not directional if not self.dd_term: model_arr = model_arr.sum(axis=0, keepdims=True) # mean |model|^2 per direction+TFA modelsq = (model_arr*np.conj(model_arr)).real.sum(axis=(1,-1,-2,-3)) / \ (self.n_mod*self.n_cor*self.n_cor*numeq_tfa) modelsq[:, numeq_tfa == 0] = 0 # inverse SNR^2 per direction+TFA inv_snr2 = sigmasq[np.newaxis, np.newaxis, :, np.newaxis] / modelsq inv_snr2[:, numeq_tfa == 0] = 0 # take the mean SNR^-2 over each interval # numeq_tfa becomes number of points per interval, antenna numeq_tfa = self.interval_sum(numeq_tfa) inv_snr2_int = self.interval_sum(inv_snr2, 1) / numeq_tfa[np.newaxis, ...] inv_snr2_int[:, numeq_tfa == 0] = 0 # convert that into a gain error per direction,interval,antenna self.prior_gain_error = np.sqrt( inv_snr2_int / (self.eqs_per_interval - self.num_unknowns)[np.newaxis, :, :, np.newaxis]) self.prior_gain_error[:, ~self.valid_intervals, :] = 0 # reset to 0 for fixed directions if self.dd_term: self.prior_gain_error[self.fix_directions, ...] = 0 # flag gains on max error bad_gain_intervals = self.prior_gain_error > self.max_gain_error # dir,time,freq,ant if bad_gain_intervals.any(): # (n_dir,) array showing how many were flagged per direction self._n_flagged_on_max_error = bad_gain_intervals.sum(axis=(1, 2, 3)) # raised corresponding gain flags self.gflags[self._interval_to_gainres(bad_gain_intervals, 1)] |= FL.LOWSNR self.prior_gain_error[bad_gain_intervals] = 0 # flag intervals where all directions are bad, and propagate that out into flags bad_intervals = bad_gain_intervals.all(axis=0) if bad_intervals.any(): bad_slots = self.unpack_intervals(bad_intervals) flags_arr[bad_slots, ...] |= FL.LOWSNR unflagged[bad_slots, ...] = False self._update_equation_counts(unflagged) else: self._n_flagged_on_max_error = None self._n_flagged_on_max_posterior_error = None self.flagged = self.gflags != 0 self.n_flagged = self.flagged.sum() return unflagged
def precompute_attributes(self, data_arr, model_arr, flags_arr, inv_var_chan): """Precomputes various attributes of the machine before starting a solution""" unflagged = MasterMachine.precompute_attributes(self, data_arr, model_arr, flags_arr, inv_var_chan) ## NB: not sure why I used to apply MISSING|PRIOR here. Surely other input flags must be honoured ## (SKIPSOL, NULLDATA, etc.)? ### Pre-flag gain solution intervals that are completely flagged in the input data ### (i.e. MISSING|PRIOR). This has shape (n_timint, n_freint, n_ant). missing_intervals = self.interval_and((flags_arr!=0).all(axis=-1)) self.missing_gain_fraction = missing_intervals.sum() / float(missing_intervals.size) # convert the intervals array to gain shape, and apply flags self.gflags[:, self._interval_to_gainres(missing_intervals)] = FL.MISSING # number of data points per time/frequency/antenna numeq_tfa = unflagged.sum(axis=-1) # compute error estimates per direction, antenna, and interval if inv_var_chan is not None: with np.errstate(invalid='ignore', divide='ignore'): # collapse direction axis, if not directional if not self.dd_term: model_arr = model_arr.sum(axis=0, keepdims=True) # mean |model|^2 per direction+TFA modelsq = (model_arr*np.conj(model_arr)).real.sum(axis=(1,-1,-2,-3)) / \ (self.n_mod*self.n_cor*self.n_cor*numeq_tfa) modelsq[:, numeq_tfa==0] = 0 sigmasq = 1.0/inv_var_chan # squared noise per channel. Could be infinite if no data # take the sigma (in quadrature) over each interval # divided by quadrature unflagged contributing interferometers per interval # this yields var<g> # (numeq_tfa becomes number of unflagged points per interval, antenna) numeq_tfa = self.interval_sum(numeq_tfa) sigmasq[np.logical_or(np.isnan(sigmasq), np.isinf(sigmasq))] = 0.0 modelsq[np.logical_or(np.isnan(modelsq), np.isinf(modelsq))] = 0.0 NSR_int = self.interval_sum(np.ones_like(modelsq) * (sigmasq)[None, None, :, None], 1) / \ (self.interval_sum(modelsq, 1) * numeq_tfa) # convert that into a gain error per direction,interval,antenna self.prior_gain_error = np.sqrt(NSR_int) if self.dd_term: self.prior_gain_error[self.fix_directions, ...] = 0 pge_flag_invalid = np.logical_or(np.isnan(self.prior_gain_error), np.isinf(self.prior_gain_error)) invalid_models = np.logical_or(self.interval_sum(modelsq, 1) == 0, np.logical_or(np.isnan(self.interval_sum(modelsq, 1)), np.isinf(self.interval_sum(modelsq, 1)))) if np.any(np.all(numeq_tfa == 0, axis=-1)) and log.verbosity() > 1: self.raise_userwarning( logging.CRITICAL, "One or more directions (or its frequency intervals) are already fully flagged.", 90, raise_once="prior_fully_flagged_dirs", verbosity=2, color="red") if np.any(np.all(invalid_models, axis=-1)) and log.verbosity() > 1: self.raise_userwarning( logging.CRITICAL, "One or more directions (or its frequency intervals) have invalid or 0 models.", 90, raise_once="invalid_models", verbosity=2, color="red") self.prior_gain_error[:, ~self.valid_intervals, :] = 0 # reset to 0 for fixed directions if self.dd_term: self.prior_gain_error[self.fix_directions, ...] = 0 # flag gains on max error self._n_flagged_on_max_error = None bad_gain_intervals = pge_flag_invalid if self.max_gain_error: low_snr = self.prior_gain_error > self.max_gain_error if low_snr.all(axis=0).all(): msg = "'{0:s}' {1:s} All directions flagged, either due to low SNR. "\ "You need to check your tagged directions and your max-prior-error and/or solution intervals. "\ "New flags will be raised for this chunk of data".format( self.jones_label, self.chunk_label) self.raise_userwarning(logging.CRITICAL, msg, 70, verbosity=log.verbosity(), color="red") else: if low_snr.all(axis=-1).all(axis=-1).all(axis=-1).any(): #all antennas fully flagged of some direction dir_snr = {} for d in range(self.prior_gain_error.shape[0]): percflagged = np.sum(low_snr[d]) * 100.0 / low_snr[d].size if percflagged > self.low_snr_warn and d not in self.fix_directions: dir_snr[d] = percflagged if len(dir_snr) > 0: if log.verbosity() > 2: msg = "Low SNR in one or more directions of gain '{0:s}' chunk '{1:s}':".format( self.jones_label, self.chunk_label) +\ "\n{0:s}\n".format("\n".join(["\t direction {0:s}: {1:.3f}% gains affected".format( str(d), dir_snr[d]) for d in sorted(dir_snr)])) +\ "Check your settings for gain solution intervals and max-prior-error. " else: msg = "'{0:s}' {1:s} Low SNR in directions {2:s}. Increase solution intervals or raise max-prior-error!".format( self.jones_label, self.chunk_label, ", ".join(map(str, sorted(dir_snr)))) self.raise_userwarning(logging.CRITICAL, msg, 50, verbosity=log.verbosity(), color="red") if low_snr.all(axis=0).all(axis=0).all(axis=-1).any(): msg = "'{0:s}' {1:s} All time of one or more frequency intervals flagged due to low SNR. "\ "You need to check your max-prior-error and/or solution intervals. "\ "New flags will be raised for this chunk of data".format( self.jones_label, self.chunk_label) self.raise_userwarning(logging.WARNING, msg, 70, verbosity=log.verbosity()) if low_snr.all(axis=0).all(axis=1).all(axis=-1).any(): msg = "'{0:s}' {1:s} All channels of one or more time intervals flagged due to low SNR. "\ "You need to check your max-prior-error and/or solution intervals. "\ "New flags will be raised for this chunk of data".format( self.jones_label, self.chunk_label) self.raise_userwarning(logging.WARNING, msg, 70, verbosity=log.verbosity()) stationflags = np.argwhere(low_snr.all(axis=0).all(axis=0).all(axis=0)).flatten() if stationflags.size > 0: msg = "'{0:s}' {1:s} Stations {2:s} ({3:d}/{4:d}) fully flagged due to low SNR. "\ "These stations may be faulty or your SNR requirements (max-prior-error) are not met. "\ "New flags will be raised for this chunk of data".format( self.jones_label, self.chunk_label, ", ".join(map(str, stationflags)), np.sum(low_snr.all(axis=0).all(axis=0).all(axis=0)), low_snr.shape[3]) self.raise_userwarning(logging.WARNING, msg, 70, verbosity=log.verbosity()) bad_gain_intervals = np.logical_or(bad_gain_intervals, low_snr) # dir,time,freq,ant if bad_gain_intervals.any(): # (n_dir,) array showing how many were flagged per direction self._n_flagged_on_max_error = bad_gain_intervals.sum(axis=(1,2,3)) # raised corresponding gain flags self.gflags[self._interval_to_gainres(bad_gain_intervals,1)] |= FL.LOWSNR self.prior_gain_error[bad_gain_intervals] = 0 # flag intervals where all directions are bad, and propagate that out into flags bad_intervals = bad_gain_intervals.all(axis=0) if bad_intervals.any(): bad_slots = self.unpack_intervals(bad_intervals) flags_arr[bad_slots,...] |= FL.LOWSNR unflagged[bad_slots,...] = False self.update_equation_counts(unflagged) self._n_flagged_on_max_posterior_error = None self.flagged = self.gflags != 0 self.n_flagged = self.flagged.sum() return unflagged