Example #1
0
 def precompute_attributes(self, model_arr, flags_arr, inv_var_chan):
     """Precomputes various stats before starting a solution"""
     MasterMachine.precompute_attributes(self, model_arr, flags_arr,
                                         inv_var_chan)
     for term in self.jones_terms:
         if term.solvable:
             term.precompute_attributes(model_arr, flags_arr, inv_var_chan)
Example #2
0
    def precompute_attributes(self, model_arr, flags_arr, inv_var_chan):
        """Precomputes various stats before starting a solution"""
        unflagged = MasterMachine.precompute_attributes(
            self, model_arr, flags_arr, inv_var_chan)

        # Pre-flag gain solution intervals that are completely flagged in the input data
        # (i.e. MISSING|PRIOR). This has shape (n_timint, n_freint, n_ant).

        missing_intervals = self.interval_and(
            (flags_arr & (FL.MISSING | FL.PRIOR) != 0).all(axis=-1))

        self.missing_gain_fraction = missing_intervals.sum() / float(
            missing_intervals.size)

        # convert the intervals array to gain shape, and apply flags
        self.gflags[:,
                    self._interval_to_gainres(missing_intervals)] = FL.MISSING

        # number of data points per time/frequency/antenna
        numeq_tfa = unflagged.sum(axis=-1)

        # compute error estimates per direction, antenna, and interval
        with np.errstate(invalid='ignore', divide='ignore'):
            sigmasq = 1 / inv_var_chan  # squared noise per channel. Could be infinite if no data
            # collapse direction axis, if not directional
            if not self.dd_term:
                model_arr = model_arr.sum(axis=0, keepdims=True)
            # mean |model|^2 per direction+TFA
            modelsq = (model_arr*np.conj(model_arr)).real.sum(axis=(1,-1,-2,-3)) / \
                      (self.n_mod*self.n_cor*self.n_cor*numeq_tfa)
            modelsq[:, numeq_tfa == 0] = 0
            # inverse SNR^2 per direction+TFA
            inv_snr2 = sigmasq[np.newaxis, np.newaxis, :, np.newaxis] / modelsq
            inv_snr2[:, numeq_tfa == 0] = 0
            # take the mean SNR^-2 over each interval
            # numeq_tfa becomes number of points per interval, antenna
            numeq_tfa = self.interval_sum(numeq_tfa)
            inv_snr2_int = self.interval_sum(inv_snr2,
                                             1) / numeq_tfa[np.newaxis, ...]
            inv_snr2_int[:, numeq_tfa == 0] = 0
            # convert that into a gain error per direction,interval,antenna
            self.prior_gain_error = np.sqrt(
                inv_snr2_int /
                (self.eqs_per_interval - self.num_unknowns)[np.newaxis, :, :,
                                                            np.newaxis])

        self.prior_gain_error[:, ~self.valid_intervals, :] = 0
        # reset to 0 for fixed directions
        if self.dd_term:
            self.prior_gain_error[self.fix_directions, ...] = 0

        # flag gains on max error
        bad_gain_intervals = self.prior_gain_error > self.max_gain_error  # dir,time,freq,ant
        if bad_gain_intervals.any():
            # (n_dir,) array showing how many were flagged per direction
            self._n_flagged_on_max_error = bad_gain_intervals.sum(axis=(1, 2,
                                                                        3))
            # raised corresponding gain flags
            self.gflags[self._interval_to_gainres(bad_gain_intervals,
                                                  1)] |= FL.LOWSNR
            self.prior_gain_error[bad_gain_intervals] = 0
            # flag intervals where all directions are bad, and propagate that out into flags
            bad_intervals = bad_gain_intervals.all(axis=0)
            if bad_intervals.any():
                bad_slots = self.unpack_intervals(bad_intervals)
                flags_arr[bad_slots, ...] |= FL.LOWSNR
                unflagged[bad_slots, ...] = False
                self._update_equation_counts(unflagged)
        else:
            self._n_flagged_on_max_error = None

        self._n_flagged_on_max_posterior_error = None
        self.flagged = self.gflags != 0
        self.n_flagged = self.flagged.sum()

        return unflagged
    def precompute_attributes(self, data_arr, model_arr, flags_arr, inv_var_chan):
        """Precomputes various attributes of the machine before starting a solution"""
        unflagged = MasterMachine.precompute_attributes(self, data_arr, model_arr, flags_arr, inv_var_chan)

        ## NB: not sure why I used to apply MISSING|PRIOR here. Surely other input flags must be honoured
        ## (SKIPSOL, NULLDATA, etc.)?
        ### Pre-flag gain solution intervals that are completely flagged in the input data
        ### (i.e. MISSING|PRIOR). This has shape (n_timint, n_freint, n_ant).

        missing_intervals = self.interval_and((flags_arr!=0).all(axis=-1))

        self.missing_gain_fraction = missing_intervals.sum() / float(missing_intervals.size)

        # convert the intervals array to gain shape, and apply flags
        self.gflags[:, self._interval_to_gainres(missing_intervals)] = FL.MISSING

        # number of data points per time/frequency/antenna
        numeq_tfa = unflagged.sum(axis=-1)

        # compute error estimates per direction, antenna, and interval
        if inv_var_chan is not None:
            with np.errstate(invalid='ignore', divide='ignore'):
                # collapse direction axis, if not directional
                if not self.dd_term:
                    model_arr = model_arr.sum(axis=0, keepdims=True)
                # mean |model|^2 per direction+TFA
                modelsq = (model_arr*np.conj(model_arr)).real.sum(axis=(1,-1,-2,-3)) / \
                          (self.n_mod*self.n_cor*self.n_cor*numeq_tfa)
                modelsq[:, numeq_tfa==0] = 0

                sigmasq = 1.0/inv_var_chan                        # squared noise per channel. Could be infinite if no data
                # take the sigma (in quadrature) over each interval
                # divided by quadrature unflagged contributing interferometers per interval
                # this yields var<g> 
                # (numeq_tfa becomes number of unflagged points per interval, antenna)
                numeq_tfa = self.interval_sum(numeq_tfa)
                sigmasq[np.logical_or(np.isnan(sigmasq), np.isinf(sigmasq))] = 0.0
                modelsq[np.logical_or(np.isnan(modelsq), np.isinf(modelsq))] = 0.0
                NSR_int = self.interval_sum(np.ones_like(modelsq) * (sigmasq)[None, None, :, None], 1) / \
                              (self.interval_sum(modelsq, 1) * numeq_tfa)
                # convert that into a gain error per direction,interval,antenna
                self.prior_gain_error = np.sqrt(NSR_int)
                if self.dd_term:
                    self.prior_gain_error[self.fix_directions, ...] = 0

                pge_flag_invalid = np.logical_or(np.isnan(self.prior_gain_error),
                                                 np.isinf(self.prior_gain_error))

                invalid_models = np.logical_or(self.interval_sum(modelsq, 1) == 0,
                                               np.logical_or(np.isnan(self.interval_sum(modelsq, 1)),
                                                             np.isinf(self.interval_sum(modelsq, 1))))
                if np.any(np.all(numeq_tfa == 0, axis=-1)) and log.verbosity() > 1:
                    self.raise_userwarning(
                        logging.CRITICAL,
                        "One or more directions (or its frequency intervals) are already fully flagged.",
                        90, raise_once="prior_fully_flagged_dirs", verbosity=2, color="red")

                if np.any(np.all(invalid_models, axis=-1)) and log.verbosity() > 1:
                    self.raise_userwarning(
                        logging.CRITICAL,
                        "One or more directions (or its frequency intervals) have invalid or 0 models.",
                        90, raise_once="invalid_models", verbosity=2, color="red")

            self.prior_gain_error[:, ~self.valid_intervals, :] = 0
            # reset to 0 for fixed directions
            if self.dd_term:
                self.prior_gain_error[self.fix_directions, ...] = 0

            # flag gains on max error
            self._n_flagged_on_max_error = None
            bad_gain_intervals = pge_flag_invalid
            if self.max_gain_error:
                low_snr = self.prior_gain_error > self.max_gain_error
                if low_snr.all(axis=0).all():
                    msg = "'{0:s}' {1:s} All directions flagged, either due to low SNR. "\
                          "You need to check your tagged directions and your max-prior-error and/or solution intervals. "\
                          "New flags will be raised for this chunk of data".format(
                                self.jones_label, self.chunk_label)
                    self.raise_userwarning(logging.CRITICAL, msg, 70, verbosity=log.verbosity(), color="red")

                else:
                    if low_snr.all(axis=-1).all(axis=-1).all(axis=-1).any(): #all antennas fully flagged of some direction
                        dir_snr = {}
                        for d in range(self.prior_gain_error.shape[0]):
                            percflagged = np.sum(low_snr[d]) * 100.0 / low_snr[d].size
                            if percflagged > self.low_snr_warn and d not in self.fix_directions: dir_snr[d] = percflagged

                        if len(dir_snr) > 0:
                            if log.verbosity() > 2:
                                msg = "Low SNR in one or more directions of gain '{0:s}' chunk '{1:s}':".format(
                                        self.jones_label, self.chunk_label) +\
                                      "\n{0:s}\n".format("\n".join(["\t direction {0:s}: {1:.3f}% gains affected".format(
                                                            str(d), dir_snr[d]) for d in sorted(dir_snr)])) +\
                                      "Check your settings for gain solution intervals and max-prior-error. "
                            else:
                                msg = "'{0:s}' {1:s} Low SNR in directions {2:s}. Increase solution intervals or raise max-prior-error!".format(
                                    self.jones_label, self.chunk_label, ", ".join(map(str, sorted(dir_snr))))
                            self.raise_userwarning(logging.CRITICAL, msg, 50, verbosity=log.verbosity(), color="red")

                    if low_snr.all(axis=0).all(axis=0).all(axis=-1).any():
                        msg = "'{0:s}' {1:s} All time of one or more frequency intervals flagged due to low SNR. "\
                              "You need to check your max-prior-error and/or solution intervals. "\
                              "New flags will be raised for this chunk of data".format(
                                    self.jones_label, self.chunk_label)
                        self.raise_userwarning(logging.WARNING, msg, 70, verbosity=log.verbosity())

                    if low_snr.all(axis=0).all(axis=1).all(axis=-1).any():
                        msg = "'{0:s}' {1:s} All channels of one or more time intervals flagged due to low SNR. "\
                              "You need to check your max-prior-error and/or solution intervals. "\
                              "New flags will be raised for this chunk of data".format(
                                    self.jones_label, self.chunk_label)
                        self.raise_userwarning(logging.WARNING, msg, 70, verbosity=log.verbosity())
                    stationflags = np.argwhere(low_snr.all(axis=0).all(axis=0).all(axis=0)).flatten()
                    if stationflags.size > 0:
                        msg = "'{0:s}' {1:s} Stations {2:s} ({3:d}/{4:d}) fully flagged due to low SNR. "\
                              "These stations may be faulty or your SNR requirements (max-prior-error) are not met. "\
                              "New flags will be raised for this chunk of data".format(
                                    self.jones_label, self.chunk_label, ", ".join(map(str, stationflags)),
                                    np.sum(low_snr.all(axis=0).all(axis=0).all(axis=0)), low_snr.shape[3])
                        self.raise_userwarning(logging.WARNING, msg, 70, verbosity=log.verbosity())


                bad_gain_intervals = np.logical_or(bad_gain_intervals,
                                                   low_snr)    # dir,time,freq,ant

            if bad_gain_intervals.any():
                # (n_dir,) array showing how many were flagged per direction
                self._n_flagged_on_max_error = bad_gain_intervals.sum(axis=(1,2,3))
                # raised corresponding gain flags
                self.gflags[self._interval_to_gainres(bad_gain_intervals,1)] |= FL.LOWSNR
                self.prior_gain_error[bad_gain_intervals] = 0
                # flag intervals where all directions are bad, and propagate that out into flags
                bad_intervals = bad_gain_intervals.all(axis=0)
                if bad_intervals.any():
                    bad_slots = self.unpack_intervals(bad_intervals)
                    flags_arr[bad_slots,...] |= FL.LOWSNR
                    unflagged[bad_slots,...] = False
                    self.update_equation_counts(unflagged)

        self._n_flagged_on_max_posterior_error = None
        self.flagged = self.gflags != 0
        self.n_flagged = self.flagged.sum()

        return unflagged