Ejemplo n.º 1
0
    def flagging_stats_string(self):
        """Returns a string describing per-flagset statistics"""
        fstats = []

        for flag, mask in FL.categories().items():
            n_flag = ((self.gflags & mask) != 0).sum()
            if n_flag:
                fstats.append("{}:{}({:.2%})".format(flag, n_flag, n_flag/float(self.gflags.size)))

        return " ".join(fstats)
Ejemplo n.º 2
0
    def finalize(self, corr_vis):
        """
        Finalizes the output visibilities, running a pass of the flagger on them, if configured
        """

        # clear out MAD flags if madmax was in trial mode
        if self.stats.chunk.num_mad_flagged and self.madmax.trial_mode:
            self.vdm.flags_arr &= ~FL.MAD
            self.stats.chunk.num_mad_flagged = 0
        num_mad_flagged_prior = int(self.stats.chunk.num_mad_flagged)

        # apply final round of madmax on residuals, if asked to
        if GD['madmax']['residuals']:
            # recompute the residuals if required
            if self.outputs_full_corrected_residuals:
                resid_vis = corr_vis
                log(0).print(
                    "{}: doing final MadMax round on residuals".format(
                        self.label))
            else:
                log(0).print(
                    "{}: computing full residuals for final MadMax round".
                    format(self.label))
                resid_vis1 = self.vdm.corrupt_residual(
                    self.sol_opts["subtract-model"], slice(None))
                resid_vis = np.zeros_like(resid_vis1)
                self.gm.apply_inv_gains(resid_vis1,
                                        resid_vis,
                                        full2x2=True,
                                        direction=self.sol_opts["correct-dir"])
                del resid_vis1

            # clear the SKIPSOL flag to also flag data that's been omitted from the solutions
            self.vdm.flags_arr &= ~FL.SKIPSOL

            # run madmax on them
            self.madmax.set_mode(GD['madmax']['residuals'])
            thr1, thr2 = self.madmax.get_mad_thresholds()
            if thr1 or thr2:
                if self.madmax.beyond_thunderdome(
                        resid_vis, None, None, self.vdm.flags_arr, thr1, thr2,
                        "{} residual".format(self.label)):
                    self.stats.chunk.num_mad_flagged = (
                        (self.vdm.flags_arr & FL.MAD) != 0).sum()
            resid_vis = None  # release memory if new object was created

        # collect messages from various flagging sources, and print to log if any
        flagstatus = []

        if self.stats.chunk.num_sol_flagged:
            # also for up message with flagging stats
            fstats = []
            for flagname, mask in FL.categories().items():
                if mask != FL.MISSING:
                    n_flag, n_tot = self.gm.num_gain_flags(mask, final=True)
                    if n_flag:
                        fstats.append("{}:{}({:.2%})".format(
                            flagname, n_flag, n_flag / float(n_tot)))

            nfl, nsol = self.gm.num_gain_flags(final=True)
            flagstatus.append("gain flags {} ({:.2%} total)".format(
                " ".join(fstats), nfl / float(nsol)))

        if self.stats.chunk.num_mad_flagged:
            flagstatus.append(
                "MadMax took out {} visibilities ({} in final round)".format(
                    self.stats.chunk.num_mad_flagged,
                    self.stats.chunk.num_mad_flagged - num_mad_flagged_prior))

        if flagstatus:
            n_new_flags = (self.vdm.flags_arr & ~(FL.MISSING | FL.SKIPSOL) !=
                           0).sum() - self.stats.chunk.num_prior_flagged
            if n_new_flags < self.vdm.flags_arr.size * GD['flags']['warn-thr']:
                warning, color = "", "blue"
            else:
                warning, color = "", "red"
            log(0,
                color).print("{}{} has {} ({:.2%}) new data flags: {}".format(
                    warning, self.label, n_new_flags,
                    n_new_flags / float(self.vdm.flags_arr.size),
                    ", ".join(flagstatus)))
Ejemplo n.º 3
0
def _solve_gains(gm,
                 obser_arr,
                 model_arr,
                 flags_arr,
                 sol_opts,
                 label="",
                 compute_residuals=None):
    """
    Main body of the GN/LM method. Handles iterations and convergence tests.

    Args:
        gm (:obj:`~cubical.machines.abstract_machine.MasterMachine`): 
            The gain machine which will be used in the solver loop.
        obser_arr (np.ndarray): 
            Shape (n_mod, n_tim, n_fre, n_ant, n_ant, n_cor, n_cor) array containing observed 
            visibilities. 
        model_arr (np.ndarray): 
            Shape (n_dir, n_mod, n_tim, n_fre, n_ant, n_ant, n_cor, n_cor) array containing model 
            visibilities. 
        flags_arr (np.ndarray): 
            Shape (n_tim, n_fre, n_ant, n_ant) integer array containing flag data.
        sol_opts (dict): 
            Solver options (see [sol] section in DefaultParset.cfg).
        label (str, optional):             
            Label identifying the current chunk (e.g. "D0T1F2").
        compute_residuals (bool, optional): 
            If set, the final residuals will be computed and returned.

    Returns:
        2-element tuple
            
            - resid (np.ndarray)
                The final residuals (if compute_residuals is set), else None.
            - stats (:obj:`~cubical.statistics.SolverStats`)
                An object containing solver statistics.
    """
    min_delta_g = sol_opts["delta-g"]
    chi_tol = sol_opts["delta-chi"]
    chi_interval = sol_opts["chi-int"]
    stall_quorum = sol_opts["stall-quorum"]

    # Initialise stat object.

    stats = SolverStats(obser_arr)
    stats.chunk.label = label

    n_stall = 0
    frac_stall = 0
    n_original_flags = (flags_arr & ~(FL.PRIOR | FL.MISSING) != 0).sum()

    # initialize iteration counter

    num_iter = 0

    # Estimates the overall noise level and the inverse variance per channel and per antenna as
    # noise varies across the band. This is used to normalize chi^2.

    stats.chunk.init_noise, inv_var_antchan, inv_var_ant, inv_var_chan = \
                                                        stats.estimate_noise(obser_arr, flags_arr)

    # if we have directions in the model, but the gain machine is non-DD, collapse them
    if not gm.dd_term and model_arr.shape[0] > 1:
        model_arr = model_arr.sum(axis=0, keepdims=True)

    # This works out the conditioning of the solution, sets up various chi-sq normalization
    # factors etc, and does any other precomputation required by the current gain machine.

    gm.precompute_attributes(model_arr, flags_arr, inv_var_chan)

    def get_flagging_stats():
        """Returns a string describing per-flagset statistics"""
        fstats = []

        for flag, mask in FL.categories().iteritems():
            n_flag = ((flags_arr & mask) != 0).sum()
            if n_flag:
                fstats.append("{}:{}({:.2%})".format(
                    flag, n_flag, n_flag / float(flags_arr.size)))

        return " ".join(fstats)

    def update_stats(flags, statfields):
        """
        This function updates the solver stats object with a count of valid data points used for chi-sq 
        calculations
        """
        unflagged = (flags == 0)
        # Compute number of terms in each chi-square sum. Shape is (n_tim, n_fre, n_ant).

        nterms = 2 * gm.n_cor * gm.n_cor * np.sum(unflagged, axis=3)

        # Update stats object accordingly.

        for field in statfields:
            getattr(stats.chanant, field + 'n')[...] = np.sum(nterms, axis=0)
            getattr(stats.timeant, field + 'n')[...] = np.sum(nterms, axis=1)
            getattr(stats.timechan, field + 'n')[...] = np.sum(nterms, axis=2)

    update_stats(flags_arr, ('initchi2', 'chi2'))

    # In the event that there are no solutions with valid data, this will log some of the
    # flag information and break out of the function.

    if not gm.has_valid_solutions:
        stats.chunk.num_sol_flagged, _ = gm.num_gain_flags()

        print >> log, ModColor.Str("{} no solutions: {}; flags {}".format(
            label, gm.conditioning_status_string, get_flagging_stats()))
        return (obser_arr if compute_residuals else None), stats

    # Initialize a residual array.

    resid_shape = [
        gm.n_mod, gm.n_tim, gm.n_fre, gm.n_ant, gm.n_ant, gm.n_cor, gm.n_cor
    ]

    resid_arr = gm.cykernel.allocate_vis_array(resid_shape,
                                               obser_arr.dtype,
                                               zeros=True)
    gm.compute_residual(obser_arr, model_arr, resid_arr)
    resid_arr[:, flags_arr != 0] = 0

    # This flag is set to True when we have an up-to-date residual in resid_arr.

    have_residuals = True

    def compute_chisq(statfield=None):
        """
        Computes chi-squared statistic based on current residuals and noise estimates.
        Populates the stats object with it.
        """
        chisq, chisq_per_tf_slot, chisq_tot = gm.compute_chisq(
            resid_arr, inv_var_chan)

        if statfield:
            getattr(stats.chanant, statfield)[...] = np.sum(chisq, axis=0)
            getattr(stats.timeant, statfield)[...] = np.sum(chisq, axis=1)
            getattr(stats.timechan, statfield)[...] = np.sum(chisq, axis=2)

        return chisq_per_tf_slot, chisq_tot

    chi, mean_chi = compute_chisq(statfield='initchi2')
    stats.chunk.init_chi2 = mean_chi

    # The following provides conditioning information when verbose is set to > 0.
    if log.verbosity() > 0:

        print >> log, "{} chi^2_0 {:.4}; {}; noise {:.3}, flags: {}".format(
            label, mean_chi, gm.conditioning_status_string,
            float(stats.chunk.init_noise), get_flagging_stats())

    # Main loop of the NNLS method. Terminates after quorum is reached in either converged or
    # stalled solutions or when the maximum number of iterations is exceeded.

    while not (gm.has_converged) and not (gm.has_stalled):

        num_iter = gm.next_iteration()

        # This is currently an awkward necessity - if we have a chain of jones terms, we need to
        # make sure that the active term is correct and need to support some sort of decision making
        # for testing convergence. I think doing the iter increment here might be the best choice,
        # with an additional bit of functionality for Jones chains. I suspect I will still need to
        # change the while loop component to be compatible with the idea of partial convergence.
        # Perhaps this should all be done right at the top of the function? A better idea is to let
        # individual machines be aware of their own stalled/converged status, and make those
        # properties more complicated on the chain. This should allow for fairly easy substitution
        # between the various machines.

        gm.compute_update(model_arr, obser_arr)

        # flag solutions. This returns True if any flags have been propagated out to the data.
        if gm.flag_solutions(flags_arr, False):

            update_stats(flags_arr, ('chi2', ))

            # Re-zero the model and data at newly flagged points.
            # TODO: is this needed?
            # TODO: should we perhaps just zero the model per flagged direction, and only flag the data?
            # OMS: probably not: flag propagation is now handled inside the gain machine. If a flag is
            # propagated out to the data, then that slot is gone gone gone and should be zeroe'd everywhere.

            new_flags = flags_arr & ~(FL.MISSING | FL.PRIOR) != 0
            model_arr[:, :, new_flags, :, :] = 0
            obser_arr[:, new_flags, :, :] = 0

            # Break out of the solver loop if we find ourselves with no valid solution intervals.

            if not gm.has_valid_solutions:
                break

        # print>>log,"{} {} {}".format(de.gains[1,5,2,5], de.posterior_gain_error[1,5,2,5], de.posterior_gain_error[1].mean())
        #
        have_residuals = False

        # Compute values used in convergence tests. This check implicitly marks flagged gains as
        # converged.

        gm.check_convergence(min_delta_g)

        # Check residual behaviour after a number of iterations equal to chi_interval. This is
        # expensive, so we do it as infrequently as possible.

        if (num_iter % chi_interval) == 0:

            old_chi, old_mean_chi = chi, mean_chi

            gm.compute_residual(obser_arr, model_arr, resid_arr)
            resid_arr[:, flags_arr != 0] = 0

            chi, mean_chi = compute_chisq()

            have_residuals = True

            # Check for stalled solutions - solutions for which the residual is no longer improving.

            n_stall = float(np.sum(((old_chi - chi) < chi_tol * old_chi)))
            frac_stall = n_stall / chi.size

            gm.has_stalled = (frac_stall >= stall_quorum)

            if log.verbosity() > 1:

                delta_chi = (old_mean_chi - mean_chi) / old_mean_chi

                print >> log(2), (
                    "{} {} chi2 {:.4}, delta {:.4}, stall {:.2%}").format(
                        label, gm.current_convergence_status_string, mean_chi,
                        delta_chi, frac_stall)

    # num_valid_solutions will go to 0 if all solution intervals were flagged. If this is not the
    # case, generate residuals etc.

    if gm.has_valid_solutions:
        # Final round of flagging
        flagged = gm.flag_solutions(flags_arr, True)

    # check this again, because final round of flagging could have killed us
    if gm.has_valid_solutions:
        # Do we need to recompute the final residuals?
        if (sol_opts['last-rites']
                or compute_residuals) and (not have_residuals or flagged):
            gm.compute_residual(obser_arr, model_arr, resid_arr)
            resid_arr[:, flags_arr != 0] = 0
            if sol_opts['last-rites']:
                # Recompute chi-squared based on original noise statistics.
                chi, mean_chi = compute_chisq(statfield='chi2')

        # Re-estimate the noise using the final residuals, if last rites are needed.

        if sol_opts['last-rites']:
            stats.chunk.noise, inv_var_antchan, inv_var_ant, inv_var_chan = \
                                        stats.estimate_noise(resid_arr, flags_arr, residuals=True)
            chi1, mean_chi1 = compute_chisq(statfield='chi2')

        stats.chunk.chi2 = mean_chi

        message = "{} {}, stall {:.2%}, chi^2 {:.4} -> {:.4}".format(
            label, gm.final_convergence_status_string, frac_stall,
            float(stats.chunk.init_chi2), mean_chi)

        if sol_opts['last-rites']:

            message = "{} ({:.4}), noise {:.3} -> {:.3}".format(
                message, float(mean_chi1), float(stats.chunk.init_noise),
                float(stats.chunk.noise))

        print >> log, message

    # If everything has been flagged, no valid solutions are generated.

    else:

        print >> log(0, "red"), "{} {}: completely flagged".format(
            label, gm.final_convergence_status_string)

        stats.chunk.chi2 = 0
        resid_arr = obser_arr

    stats.chunk.iters = num_iter
    stats.chunk.num_converged = gm.num_converged_solutions
    stats.chunk.num_stalled = n_stall

    # copy out flags, if we raised any
    stats.chunk.num_sol_flagged, _ = gm.num_gain_flags()
    if stats.chunk.num_sol_flagged:
        # also for up message with flagging stats
        fstats = ""
        for flagname, mask in FL.categories().iteritems():
            if mask != FL.MISSING:
                n_flag, n_tot = gm.num_gain_flags(mask)
                if n_flag:
                    fstats += "{}:{}({:.2%}) ".format(flagname, n_flag,
                                                      n_flag / float(n_tot))
        n_new_flags = (flags_arr & ~(FL.PRIOR | FL.MISSING) !=
                       0).sum() - n_original_flags
        print >> log, ModColor.Str(
            "{} solver flags raised: {}-> {:.2%} data flags".format(
                label, fstats, n_new_flags / float(flags_arr.size)))

    return (resid_arr if compute_residuals else None), stats