コード例 #1
0
 def make_filename(self, filename, jones_label=None):
     """
     Helper method: expands full filename a from templated filename. This uses the standard 
     str.format() function, passing in self.global_options, as well as JONES=jones_label, as keyword 
     arguments. This allows for filename templates that include strings from the global options
     dictionary, e.g. "{data[ms]}-ddid{sel[ddid]}".
     
     Args:
         filename (str): 
             the templated filename
         jones_label (str, optional):
             Jones matrix label, overrides self.jones_label if specified.
         
     Returns:
         str:
             Expanded filename
         
     """
     if not filename:
         return None
     try:
         # substitute recursively, but up to a limit
         for i in xrange(10):
             fname = filename.format(JONES=jones_label or self.jones_label, **self.global_options)
             if fname == filename:
                 break
             filename = fname
         return filename
     except Exception, exc:
         print>> log,"{}({})\n {}".format(type(exc).__name__, exc, traceback.format_exc())
         print>>log,ModColor.Str("Error parsing filename '{}', see above".format(filename))
         raise ValueError(filename)
コード例 #2
0
    def format_chunk_stats(self, format_string, ncol=8, threshold=None):
        """
        :param format: format string applied to each record
        :param maxcol: maximum number of columns to allocate
        :return:
        """
        nt, nf = self.chunk.shape
        nt_per_col = 1
        nf_per_col = None
        if nf < ncol:
            nt_per_col = ncol // nf
        else:
            nf_per_col = ncol
        # convert stats to list of columns
        output_rows = [[("", False)]]
        for itime in range(nt):
            # start new line every NT_PER_COL-th time chunk
            if itime % nt_per_col == 0:
                output_rows.append([])
            for ifreq in range(nf):
                # start new line every NF_PER_COL-th freq chunk, if frequencies span lines
                if nf_per_col is not None and output_rows[
                        -1] and ifreq % nf_per_col == 0:
                    output_rows.append([])
                statrec = self.chunk[itime, ifreq]
                statrec_dict = {
                    field: statrec[field]
                    for field in self.chunk.dtype.fields
                }
                # new line: prepend chunk label
                if not output_rows[-1]:
                    output_rows[-1].append((statrec.label, False))
                # put it in header as well
                if len(output_rows) == 2:
                    output_rows[0].append((statrec.label, False))
                # check for threshold
                warn = False
                if threshold is not None:
                    for field, value in threshold:
                        if statrec[field] > value:
                            warn = True
                text = format_string.format(**statrec_dict)
                output_rows[-1].append((text, warn))

        # now work out column widths and format
        ncol = max([len(row) for row in output_rows])
        colwidths = [
            max([len(row[icol][0]) for row in output_rows if icol < len(row)])
            for icol in range(ncol)
        ]
        colformat = ["{{:{}}}  ".format(w) for w in colwidths]

        output_rows = [[(colformat[icol].format(col), warn)
                        for icol, (col, warn) in enumerate(row)]
                       for row in output_rows]

        return [
            "".join([(ModColor.Str(col, 'red') if warn else col)
                     for col, warn in row]) for row in output_rows
        ]
コード例 #3
0
ファイル: ifr_gain_machine.py プロジェクト: ratt-ru/CubiCal
 def __init__(self, gmfactory, ifrgain_opts, compute=True):
     """
     Initializes the IFR-based gains machinery.
     
     Args:
         gmfactory:      a GainMachine Factory is used to manage the solution databases
         ifrgain_opts:   dict of options
         compute:        if False, gains are not computed even if options ask them to
     """
     from cubical.main import expand_templated_name
     self.gmfactory = gmfactory
     load_from = expand_templated_name(ifrgain_opts['load-from'])
     save_to = expand_templated_name(ifrgain_opts['save-to'])
     self._ifrgains_per_chan = ifrgain_opts['per-chan']
     self._ifrgain = None
     self._nfreq = gmfactory.grid["freq"]
     nfreq, nant, ncorr = [
         len(gmfactory.grid[axis]) for axis in ("freq", "ant", "corr")
     ]
     if load_from:
         filename = load_from
         print(ModColor.Str(
             "applying baseline-based corrections (BBCs) from {}".format(
                 filename),
             col="green"),
               file=log(0))
         if "//" in filename:
             filename, prefix = filename.rsplit("//", 1)
         else:
             filename, prefix = filename, "BBC"
         parm = param_db.load(filename).get(prefix)
         if parm is None:
             print(ModColor.Str("  no solutions for '{}' in {}".format(
                 prefix, filename)),
                   file=log(0))
         else:
             self._ifrgain = parm.reinterpolate(
                 freq=gmfactory.grid["freq"]).filled()
             if tuple(self._ifrgain.shape) != (nfreq, nant, nant, ncorr,
                                               ncorr):
                 print(ModColor.Str(
                     "  invalid BBC shape {}, will ignore".format(
                         self._ifrgain.shape)),
                       file=log(0))
                 self._ifrgain = None
             else:
                 print("  loaded per-channel BBCs of shape {}".format(
                     filename, self._ifrgain.shape),
                       file=log(0))
                 if not self._ifrgains_per_chan:
                     print("  using one mean value across band",
                           file=log(0))
                     self._ifrgain[np.newaxis,
                                   ...] = self._ifrgain.mean(axis=0)
                 # reset off-diagonal values, if needed
                 if ifrgain_opts["apply-2x2"]:
                     print(ModColor.Str(
                         "  using full 2x2 BBCs. You'd better know what you're doing!",
                         col="green"),
                           file=log(0))
                 else:
                     self._ifrgain[..., (0, 1), (1, 0)] = 1
                     print("  using parallel-hand BBCs only", file=log(0))
     if save_to and compute:
         self._compute_2x2 = ifrgain_opts["compute-2x2"]
         # setup axes for IFR-based gains
         axes = ["freq", "ant1", "ant2", "corr1", "corr2"]
         # define the ifrgain parameter
         self._save_filename = save_to
         parm = gmfactory.define_param(self._save_filename,
                                       "BBC",
                                       1 + 0j,
                                       axes,
                                       interpolation_axes=["freq"])
         self._ifrgains_grid = {
             axis: parm.grid[i]
             for i, axis in enumerate(axes)
         }
         # initialize accumulators for M.D^H and D.D^H terms
         self._mdh_sum = np.ma.zeros(parm.shape,
                                     gmfactory.ctype,
                                     fill_value=0)
         self._ddh_sum = np.ma.zeros(parm.shape,
                                     gmfactory.ctype,
                                     fill_value=0)
         #
         print(
             "will compute & save suggested baseline-based corrections (BBCs) to {}"
             .format(self._save_filename),
             file=log(0))
         print(
             "  (these can optionally be applied in a subsequent CubiCal run)",
             file=log(0))
     else:
         self._ifrgains_grid = None
コード例 #4
0
ファイル: main.py プロジェクト: etrangerlx/CubiCal
def main(debugging=False):
    """
    Main cubical driver function. Reads options, sets up MS and solvers, calls the solver, etc.

    Args:
        debugging (bool, optional):
            If True, run in debugging mode.

    Raises:
        UserInputError:
            If neither --model-lsm nor --model-column were specified.
        UserInputError:
            If no Jones terms are enabled.
        UserInputError:
            If --out-mode is invalid.
        ValueError:
            If unknown Jones type is specified.
        RuntimeError:
            If I/O job on a tile failed.
    """

    # this will be set below if a custom parset is specified on the command line
    custom_parset_file = None
    # "GD" is a global defaults dict, containing options set up from parset + command line
    global GD, enable_pdb

    # keep a list of messages here, until we have a logfile open
    prelog_messages = []

    def prelog_print(level, message):
        prelog_messages.append((level, message))

    try:
        if debugging:
            print("initializing from cubical.last", file=log)
            GD = pickle.load(open("cubical.last"))
            basename = GD["out"]["name"]
            parser = None
        else:
            default_parset = parsets.Parset("%s/DefaultParset.cfg" % os.path.dirname(__file__))

            # if first argument is a filename, treat it as a parset

            if len(sys.argv) > 1 and not sys.argv[1][0].startswith('-'):
                custom_parset_file = sys.argv[1]
                print("reading defaults from {}".format(custom_parset_file), file=log)
                try:
                    parset = parsets.Parset(custom_parset_file)
                except:
                    import traceback
                    traceback.print_exc()
                    raise UserInputError("'{}' must be a valid parset file. Use -h for help.".format(custom_parset_file))
                if not parset.success:
                    raise UserInputError("'{}' must be a valid parset file. Use -h for help.".format(custom_parset_file))
                # update default parameters with values from parset
                default_parset.update_values(parset, other_filename=' in {}'.format(custom_parset_file))

            import cubical
            parser = dynoptparse.DynamicOptionParser(usage='Usage: %prog [parset file] <options>',
                    description="""Questions, bug reports, suggestions: https://github.com/ratt-ru/CubiCal""",
                    version='%prog version {}'.format(cubical.VERSION),
                    defaults=default_parset.value_dict,
                    attributes=default_parset.attr_dict)

            # now read the full input from command line
            # "GD" is a global defaults dict, containing options set up from parset + command line
            GD = parser.read_input()

            # if a single argument is given, it should have been the parset
            if len(parser.get_arguments()) != (1 if custom_parset_file else 0):
                raise UserInputError("Unexpected number of arguments. Use -h for help.")

            # get dirname and basename for all output files
            outdir = expand_templated_name(GD["out"]["dir"]).strip()
            basename = expand_templated_name(GD["out"]["name"]).strip()
            can_overwrite = GD["out"]["overwrite"]
            can_backup = GD["out"]["backup"]

            explicit_basename_path = "/" in basename
            folder_is_ccout  = False

            if explicit_basename_path:
                prelog_print(0, "output basename explicitly set to {}, --out-dir setting ignored".format(basename))
                outdir = os.path.dirname(basename)
            elif outdir == "." or not outdir:
                outdir = None
                prelog_print(0, "using output basename {} in current directory".format(basename))
            else:
                # append implicit .cc-out suffix, unless already there (or ends with .cc-out)
                if not outdir.endswith("/"):
                    if outdir.endswith(".cc-out"):
                        outdir += "/"
                    else:
                        outdir += ".cc-out/"
                folder_is_ccout = outdir.endswith(".cc-out/")
                basename = outdir + basename
                if outdir != "/":
                    outdir = outdir.rstrip("/")
                prelog_print(0, "using output basename {}".format(basename))

            # create directory for output files, if specified, and it doesn't exist
            if outdir and not os.path.exists(outdir):
                prelog_print(0, "creating new output directory {}".format(outdir))
                os.mkdir(outdir)

            # are we going to be overwriting a previous run?
            out_parset = "{}.parset".format(basename)
            if os.path.exists(out_parset):
                prelog_print(0, "{} already exists, possibly from a previous run".format(out_parset))

                if can_backup:
                    if folder_is_ccout:
                        # find non-existing directory name for backup
                        backup_dir = outdir + ".0"
                        N = 0
                        while os.path.exists(backup_dir):
                            N += 1
                            backup_dir = "{}.{}".format(outdir, N)
                        # rename old directory, if we ended up manipulating the directory name
                        os.rename(outdir, backup_dir)
                        os.mkdir(outdir)
                        prelog_print(0, ModColor.Str("backed up existing {} to {}".format(outdir, backup_dir), "blue"))
                    else:
                        prelog_print(0, "refusing to auto-backup output directory, since it is not a .cc-out dir")

                if os.path.exists(out_parset):
                    if can_overwrite:
                        prelog_print(0, "proceeding anyway since --out-overwrite is set")
                    else:
                        if folder_is_ccout:
                            prelog_print(0, "won't proceed without --out-overwrite and/or --out-backup")
                        else:
                            prelog_print(0, "won't proceed without --out-overwrite")
                        raise UserInputError("{} already exists: won't overwrite previous run".format(out_parset))

            GD["out"]["name"] = basename

            # "GD" is a global defaults dict, containing options set up from parset + command line
            pickle.dump(GD, open("cubical.last", "wb"))

            # save parset with all settings
            parser.write_to_parset(out_parset)

        enable_pdb = GD["debug"]["pdb"]

        # now setup logging
        logger.logToFile(basename + ".log", append=GD["log"]["append"])
        logger.enableMemoryLogging(GD["log"]["memory"])
        logger.setBoring(GD["log"]["boring"])
        logger.setGlobalVerbosity(GD["log"]["verbose"])
        logger.setGlobalLogVerbosity(GD["log"]["file-verbose"])

        if not debugging:
            print("started " + " ".join(sys.argv), file=log)

        # dump accumulated messages from before log was open
        for level, message in prelog_messages:
            print(message, file=log(level))
        prelog_messages = []

        # clean up shared memory from any previous runs
        shm_utils.cleanupStaleShm()

        # disable matplotlib's tk backend if we're not going to be showing plots
        if GD['out']['plots'] =='show' or GD['madmax']['plot'] == 'show':
            import pylab
            try:
                pylab.figure()
                pylab.close()
            except Exception as exc:
                import traceback
                print(ModColor.Str("Error initializing matplotlib: {}({})\n {}".format(type(exc).__name__,
                                                                                       exc, traceback.format_exc())), file=log)
                raise UserInputError("matplotlib can't connect to X11. Can't use --out-plots show or --madmax-plot show.")
        else:
            matplotlib.use("Agg")

        # print current options
        if parser is not None:
            parser.print_config(dest=log)

        double_precision = GD["sol"]["precision"] == 64

        # set up RIME

        solver_opts = GD["sol"]
        debug_opts  = GD["debug"]
        out_opts = GD["out"]
        sol_jones = solver_opts["jones"]
        if isinstance(sol_jones, string_types):
            sol_jones = set(sol_jones.split(','))
        jones_opts = [GD[j.lower()] for j in sol_jones]
        # collect list of options from enabled Jones matrices
        if not len(jones_opts):
            raise UserInputError("No Jones terms are enabled")
        print(ModColor.Str("Enabling {}-Jones".format(",".join(sol_jones)), col="green"), file=log)

        have_dd_jones = any([jo['dd-term'] for jo in jones_opts])

        solver.GD = GD

        # set up data handler

        solver_type = GD['out']['mode']
        if solver_type not in solver.SOLVERS:
            raise UserInputError("invalid setting --out-mode {}".format(solver_type))
        solver_mode_name = solver.SOLVERS[solver_type].__name__.replace("_", " ")
        print(ModColor.Str("mode: {}".format(solver_mode_name), col='green'), file=log)
        # these flags are used below to tweak the behaviour of gain machines and model loaders
        apply_only = solver.SOLVERS[solver_type].is_apply_only
        print("solver is apply-only type: {}".format(apply_only), file=log(0))
        load_model = solver.SOLVERS[solver_type].is_model_required
        print("solver requires model: {}".format(load_model), file=log(0))

        if load_model and not GD["model"]["list"]:
            raise UserInputError("--model-list must be specified")

        ms = MSDataHandler(GD["data"]["ms"],
                           GD["data"]["column"],
                           output_column=GD["out"]["column"],
                           output_model_column=GD["out"]["model-column"],
                           output_weight_column=GD["out"]["weight-column"],
                           reinit_output_column=GD["out"]["reinit-column"],
                           taql=GD["sel"]["taql"],
                           fid=GD["sel"]["field"],
                           ddid=GD["sel"]["ddid"],
                           channels=GD["sel"]["chan"],
                           diag=GD["sel"]["diag"],
                           beam_pattern=GD["model"]["beam-pattern"],
                           beam_l_axis=GD["model"]["beam-l-axis"],
                           beam_m_axis=GD["model"]["beam-m-axis"],
                           active_subset=GD["sol"]["subset"],
                           min_baseline=GD["sol"]["min-bl"],
                           max_baseline=GD["sol"]["max-bl"],
                           chunk_freq=GD["data"]["freq-chunk"],
                           rebin_freq=GD["data"]["rebin-freq"],
                           do_load_CASA_kwtables = GD["out"]["casa-gaintables"],
                           feed_rotate_model=GD["model"]["feed-rotate"],
                           pa_rotate_model=GD["model"]["pa-rotate"],
                           pa_rotate_montblanc=GD["montblanc"]["pa-rotate"],
                           derotate_output=GD["out"]["derotate"],
                           )

        solver.metadata = ms.metadata
        # if using dual-corr mode, propagate this into Jones options
        if ms.ncorr == 2:
            for jo in jones_opts:
                jo['diag-only'] = True
                jo['diag-data'] = True
            solver_opts['diag-only'] = True
            solver_opts['diag-data'] = True

        # With a single Jones term, create a gain machine factory based on its type.
        # With multiple Jones, create a ChainMachine factory
        term_iters = solver_opts["term-iters"]
        if type(term_iters) is int:
            term_iters = [term_iters] * len(jones_opts)
            solver_opts["term-iters"] = term_iters
            len(jones_opts) > 1 and log.warn("Multiple gain terms specified, but a recipe of solver sol-term-iters not given. "
                                             "This may indicate user error. We will assume doing the same number of iterations per term and "
                                             "stopping on the last term on the chain.")
        elif type(term_iters) is list and len(term_iters) == 1:
            term_iters = term_iters * len(jones_opts)
            solver_opts["term-iters"] = term_iters
            len(jones_opts) > 1 and log.warn("Multiple gain terms specified, but a recipe of solver sol-term-iters not given. "
                                             "This may indicate user error. We will assume doing the same number of iterations per term and "
                                             "stopping on the last term on the chain.")
        elif type(term_iters) is list and len(term_iters) < len(jones_opts):
            raise ValueError("sol-term-iters is a list, but does not match or exceed the number of gain terms being solved. "
                             "Please either only set a single value to be used or provide a list to construct a iteration recipe")
        elif type(term_iters) is list and len(term_iters) >= len(jones_opts):
            pass # user is executing a recipe
        else:
            raise TypeError("sol-term-iters is neither a list, nor a int. Check your parset")

        if len(jones_opts) == 1:
            jones_opts = jones_opts[0]
            # for just one term, propagate --sol-term-iters, if set, into its max-iter setting
            term_iters = solver_opts["term-iters"]
            if term_iters:
                jones_opts["max-iter"] = term_iters[0] if hasattr(term_iters,'__getitem__') else term_iters
            # create a gain machine factory
            jones_class = machine_types.get_machine_class(jones_opts['type'])
            if jones_class is None:
                raise UserInputError("unknown Jones type '{}'".format(jones_opts['type']))
        elif jones_opts[0]['type'] == "robust-2x2":
            jones_class = jones_chain_robust_machine.JonesChain
        else:
            jones_class = jones_chain_machine.JonesChain

        # init models
        dde_mode = GD["model"]["ddes"]

        if dde_mode == 'always' and not have_dd_jones:
            raise UserInputError("we have '--model-ddes always', but no direction dependent Jones terms enabled")

        # force floats in Montblanc calculations
        mb_opts = GD["montblanc"]
        # mb`_opts['dtype'] = 'float'

        ms.init_models(str(GD["model"]["list"]).split(","),
                       GD["weight"]["column"].split(",") if GD["weight"]["column"] else None,
                       fill_offdiag_weights=GD["weight"]["fill-offdiag"],
                       mb_opts=GD["montblanc"],
                       use_ddes=have_dd_jones and dde_mode != 'never',
                       degrid_opts=GD["degridding"])

        if len(ms.model_directions) < 2 and have_dd_jones and dde_mode == 'auto':
            raise UserInputError("--model-list does not specify directions. "
                    "Have you forgotten a @dE tag perhaps? Rerun with '--model-ddes never' to proceed anyway.")

        if load_model:
            # set up subtraction options
            solver_opts["subtract-model"] = smod = GD["out"]["subtract-model"]
            if smod < 0 or smod >= len(ms.models):
                raise UserInputError("--out-subtract-model {} out of range for {} model(s)".format(smod, len(ms.models)))

            # parse subtraction directions as a slice or list
            subdirs = GD["out"]["subtract-dirs"]
            if type(subdirs) is int:
                subdirs = [subdirs]
            if subdirs:
                if isinstance(subdirs, string_types):
                    try:
                        if ',' in subdirs:
                            subdirs = list(map(int, subdirs.split(",")))
                        else:
                            subdirs = eval("np.s_[{}]".format(subdirs))
                    except:
                        raise UserInputError("invalid --out-subtract-model option '{}'".format(subdirs))
                elif type(subdirs) is not list:
                    raise UserInputError("invalid --out-subtract-dirs option '{}'".format(subdirs))
                # check ranges
                if type(subdirs) is list:
                    out_of_range = [ d for d in subdirs if d < 0 or d >= len(ms.model_directions) ]
                    if out_of_range:
                        raise UserInputError("--out-subtract-dirs {} out of range for {} model direction(s)".format(
                                ",".join(map(str, out_of_range)), len(ms.model_directions)))
                print("subtraction directions set to {}".format(subdirs), file=log(0))
            else:
                subdirs = slice(None)
            solver_opts["subtract-dirs"] = subdirs

        # create gain machine factory
        # TODO: pass in proper antenna and correlation names, rather than number

        grid = dict(ant=ms.antnames, corr=ms.feeds, time=ms.uniq_times, freq=ms.all_freqs)
        solver.gm_factory = jones_class.create_factory(grid=grid,
                                                       apply_only=apply_only,
                                                       double_precision=double_precision,
                                                       global_options=GD, jones_options=jones_opts)
                                                       
        # create IFR-based gain machine. Only compute gains if we're loading a model
        # (i.e. not in load-apply mode)
        solver.ifrgain_machine = ifr_gain_machine.IfrGainMachine(solver.gm_factory, GD["bbc"], compute=load_model)

        solver.legacy_version12_weights = GD["weight"]["legacy-v1-2"]

        single_chunk = GD["data"]["single-chunk"]
        single_tile = GD["data"]["single-tile"]

        # setup worker process properties

        workers.setup_parallelism(GD["dist"]["ncpu"], GD["dist"]["nworker"], GD["dist"]["nthread"],
                                  debugging or single_chunk,
                                  GD["dist"]["pin"], GD["dist"]["pin-io"], GD["dist"]["pin-main"],
                                  ms.use_montblanc, GD["montblanc"]["threads"])

        # set up chunking

        chunk_by = GD["data"]["chunk-by"]
        if isinstance(chunk_by, string_types):
            chunk_by = chunk_by.split(",")
        jump = float(GD["data"]["chunk-by-jump"])

        chunks_per_tile = max(GD["dist"]["min-chunks"], workers.num_workers, 1)
        if GD["dist"]["max-chunks"]:
            chunks_per_tile = max(GD["dist"]["max-chunks"], chunks_per_tile)

        print("defining chunks (time {}, freq {}{})".format(GD["data"]["time-chunk"], GD["data"]["freq-chunk"],
            ", also when {} jumps > {}".format(", ".join(chunk_by), jump) if chunk_by else ""), file=log)

        chunks_per_tile, tile_list = ms.define_chunk(GD["data"]["time-chunk"], GD["data"]["rebin-time"],
                                            GD["data"]["freq-chunk"],
                                            chunk_by=chunk_by, chunk_by_jump=jump,
                                            chunks_per_tile=chunks_per_tile, max_chunks_per_tile=GD["dist"]["max-chunks"])

        # now that we have tiles, define the flagging situation (since this may involve a one-off iteration through the
        # MS to populate the column)
        ms.define_flags(tile_list, flagopts=GD["flags"])

        # single-chunk implies single-tile
        if single_tile >= 0:
            tile_list = tile_list[single_tile:single_tile+1]
            print("--data-single-tile {} set, will process only the one tile".format(single_tile), file=log(0, "blue"))
        elif single_chunk:
            match = re.match("D([0-9]+)T([0-9]+)", single_chunk)
            if not match:
                raise ValueError("invalid setting: --data-single-chunk {}".format(single_chunk))
            ddid_tchunk = int(match.group(1)), int(match.group(2))

            tilemap = { (rc.ddid, rc.tchunk): (tile, rc) for tile in tile_list for rc in tile.rowchunks }
            single_tile_rc = tilemap.get(ddid_tchunk)
            if single_tile_rc:
                tile, rc = single_tile_rc
                tile_list = [tile]
                print("--data-single-chunk {} in {}, rows {}:{}".format(
                    single_chunk, tile.label, min(rc.rows0), max(rc.rows0)+1), file=log(0, "blue"))
            else:
                raise ValueError("--data-single-chunk {}: chunk with this ID not found".format(single_chunk))

        # run the main loop

        t0 = time()

        stats_dict = workers.run_process_loop(ms, tile_list, load_model, single_chunk, solver_type, solver_opts, debug_opts, out_opts)


        print(ModColor.Str("Time taken for {}: {} seconds".format(solver_mode_name, time() - t0), col="green"), file=log)

        # print flagging stats
        print(ModColor.Str("Flagging stats: ",col="green") + " ".join(ms.get_flag_counts()), file=log)

        if not apply_only:
            # now summarize the stats
            print("computing summary statistics", file=log)
            st = SolverStats(stats_dict)
            filename = basename + ".stats.pickle"
            st.save(filename)
            print("saved summary statistics to %s" % filename, file=log)
            print_stats = GD["log"]["stats"]
            if print_stats:
                print("printing some summary statistics below", file=log(0))
                thresholds = []
                for thr in GD["log"]["stats-warn"].split(","):
                    field, value = thr.split(":")
                    thresholds.append((field, float(value)))
                    print("  highlighting {}>{}".format(field, float(value)), file=log(0))
                if print_stats == "all":
                    print_stats = st.get_notrivial_chunk_statfields()
                else:
                    print_stats = print_stats.split("//")
                for stats in print_stats:
                    if stats[0] != "{":
                        stats = "{{{}}}".format(stats)
                    lines = st.format_chunk_stats(stats, threshold=thresholds)
                    print("  summary stats for {}:\n  {}".format(stats, "\n  ".join(lines)), file=log(0))

            if GD["postmortem"]["enable"]:
                # flag based on summary stats
                flag3 = flagging.flag_chisq(st, GD, basename, ms.nddid_actual)

                if flag3 is not None:
                    st.apply_flagcube(flag3)
                    if GD["flags"]["save"] and flag3.any() and not GD["data"]["single-chunk"]:
                        print("regenerating output flags based on post-solution flagging", file=log)
                        flagcol = ms.flag3_to_col(flag3)
                        ms.save_flags(flagcol)

            # make plots
            if GD["out"]["plots"]:
                import cubical.plots
                try:
                    cubical.plots.make_summary_plots(st, ms, GD, basename)
                except Exception as exc:
                    if GD["debug"]["escalate-warnings"]:
                        raise
                    import traceback
                    print(file=ModColor.Str("An error has occurred while making summary plots: {}({})\n {}".format(type(exc).__name__,
                                                                                           exc,
                                                                                           traceback.format_exc())))
                    print(ModColor.Str("This is not fatal, but should be reported (and your plots have gone missing!)"), file=log)

        # make BBC plots
        if solver.ifrgain_machine and solver.ifrgain_machine.is_computing() and GD["bbc"]["plot"] and GD["out"]["plots"]:
            import cubical.plots.ifrgains
            if GD["debug"]["escalate-warnings"]:
                with warnings.catch_warnings():
                    warnings.simplefilter("error", np.ComplexWarning)
                    cubical.plots.ifrgains.make_ifrgain_plots(solver.ifrgain_machine.reload(), ms, GD, basename)
            else:
                try:
                    cubical.plots.ifrgains.make_ifrgain_plots(solver.ifrgain_machine.reload(), ms, GD, basename)
                except Exception as exc:
                    import traceback
                    print(file=ModColor.Str("An error has occurred while making BBC plots: {}({})\n {}".format(type(exc).__name__,
                                                                                           exc,
                                                                                           traceback.format_exc())))
                    print(ModColor.Str("This is not fatal, but should be reported (and your plots have gone missing!)"), file=log)

        ms.close()

        print(ModColor.Str("completed successfully", col="green"), file=log)

    except Exception as exc:
        for level, message in prelog_messages:
            print(message, file=log(level))

        if type(exc) is UserInputError:
            print(ModColor.Str(exc), file=log)
        else:
            import traceback
            print(ModColor.Str("Exiting with exception: {}({})\n {}".format(type(exc).__name__,
                                                                    exc, traceback.format_exc())), file=log)
            if enable_pdb and not type(exc) is UserInputError:
                from cubical.tools import pdb
                exc, value, tb = sys.exc_info()
                pdb.post_mortem(tb)
        sys.exit(2 if type(exc) is UserInputError else 1)
コード例 #5
0
def _solve_gains(gm,
                 stats,
                 madmax,
                 obser_arr,
                 model_arr,
                 flags_arr,
                 sol_opts,
                 label="",
                 compute_residuals=None):
    """
    Main body of the GN/LM method. Handles iterations and convergence tests.

    Args:
        gm (:obj:`~cubical.machines.abstract_machine.MasterMachine`): 
            The gain machine which will be used in the solver loop.
        obser_arr (np.ndarray): 
            Shape (n_mod, n_tim, n_fre, n_ant, n_ant, n_cor, n_cor) array containing observed 
            visibilities. 
        model_arr (np.ndarray): 
            Shape (n_dir, n_mod, n_tim, n_fre, n_ant, n_ant, n_cor, n_cor) array containing model 
            visibilities. 
        flags_arr (np.ndarray): 
            Shape (n_tim, n_fre, n_ant, n_ant) integer array containing flag data.
        sol_opts (dict): 
            Solver options (see [sol] section in DefaultParset.cfg).
        label (str, optional):             
            Label identifying the current chunk (e.g. "D0T1F2").
        compute_residuals (bool, optional): 
            If set, the final residuals will be computed and returned.

    Returns:
        2-element tuple
            
            - resid (np.ndarray)
                The final residuals (if compute_residuals is set), else None.
            - stats (:obj:`~cubical.statistics.SolverStats`)
                An object containing solver statistics.
    """
    chi_interval = sol_opts["chi-int"]
    stall_quorum = sol_opts["stall-quorum"]

    diverging = ""

    # for all the solvers that do not output any weights and for the robust solver when they are no valid solutions
    gm.output_weights = None

    # Estimates the overall noise level and the inverse variance per channel and per antenna as
    # noise varies across the band. This is used to normalize chi^2.

    stats.chunk.noise, inv_var_antchan, inv_var_ant, inv_var_chan = \
                                        stats.estimate_noise(obser_arr, flags_arr)

    # if we have directions in the model, but the gain machine is non-DD, collapse them
    if not gm.dd_term and model_arr.shape[0] > 1:
        model_arr = model_arr.sum(axis=0, keepdims=True)

    # This works out the conditioning of the solution, sets up various chi-sq normalization
    # factors etc, and does any other precomputation required by the current gain machine.

    gm.precompute_attributes(obser_arr, model_arr, flags_arr, inv_var_chan)

    # apply any flags raised in the precompute

    new_flags = flags_arr & ~(FL.MISSING | FL.PRIOR) != 0
    model_arr[:, :, new_flags, :, :] = 0
    obser_arr[:, new_flags, :, :] = 0

    def get_flagging_stats():
        """Returns a string describing per-flagset statistics"""
        fstats = []

        for flag, mask in FL.categories().items():
            n_flag = ((flags_arr & mask) != 0).sum()
            if n_flag:
                fstats.append("{}:{}({:.2%})".format(
                    flag, n_flag, n_flag / float(flags_arr.size)))

        return " ".join(fstats)

    def update_stats(flags, statfields):
        """
        This function updates the solver stats object with a count of valid data points used for chi-sq 
        calculations
        """
        unflagged = (flags == 0)
        # Compute number of terms in each chi-square sum. Shape is (n_tim, n_fre, n_ant).

        nterms = 2 * gm.n_cor * gm.n_cor * np.sum(unflagged, axis=3)

        # Update stats object accordingly.

        for field in statfields:
            getattr(stats.chanant, field + 'n')[...] = np.sum(nterms, axis=0)
            getattr(stats.timeant, field + 'n')[...] = np.sum(nterms, axis=1)
            getattr(stats.timechan, field + 'n')[...] = np.sum(nterms, axis=2)

    update_stats(flags_arr, ('initchi2', 'chi2'))

    # Initialize a residual array.

    resid_shape = [
        gm.n_mod, gm.n_tim, gm.n_fre, gm.n_ant, gm.n_ant, gm.n_cor, gm.n_cor
    ]

    resid_arr = gm.allocate_vis_array(resid_shape, obser_arr.dtype, zeros=True)
    gm.compute_residual(obser_arr, model_arr, resid_arr, require_full=True)
    resid_arr[:, flags_arr != 0] = 0

    # apply MAD flagging
    madmax.set_mode(GD['madmax']['enable'])

    # do mad max flagging, if requested
    thr1, thr2 = madmax.get_mad_thresholds()
    if thr1 or thr2:
        if madmax.beyond_thunderdome(resid_arr, obser_arr, model_arr,
                                     flags_arr, thr1, thr2,
                                     "{} initial".format(label)):
            gm.update_equation_counts(flags_arr != 0)
            stats.chunk.num_mad_flagged = ((flags_arr & FL.MAD) != 0).sum()

    # apply robust flag if robust machine (this uses the madmax flag)
    if hasattr(gm, 'is_robust'):
        if gm.robust_flag_weights:
            gm.robust_flag(flags_arr, model_arr, obser_arr)
            stats.chunk.num_mad_flagged = ((flags_arr & FL.MAD) != 0).sum()

    # In the event that there are no solutions with valid data, this will log some of the
    # flag information and break out of the function.
    stats.chunk.num_solutions = gm.num_solutions
    stats.chunk.num_sol_flagged = gm.num_gain_flags()[0]

    # every chunk stat set above now copied to stats.chunk.field_0
    stats.save_chunk_stats(step=0)

    # raise warnings from priori conditioning, before the loop
    for d in gm.collect_warnings():
        log.write(d["msg"],
                  level=d["level"],
                  print_once=d["raise_once"],
                  verbosity=d["verbosity"],
                  color=d["color"])

    if not gm.has_valid_solutions:
        log.error("{} no solutions: {}; flags {}".format(
            label, gm.conditioning_status_string, get_flagging_stats()))
        return (obser_arr if compute_residuals else None), stats, None

    def compute_chisq(statfield=None, full=True):
        """
        Computes chi-squared statistic based on current residuals and noise estimates.
        Populates the stats object with it.

        Full=True at the beginning and end of a solution, and it is passed to gm.compute_chisq()
        """
        chisq, chisq_per_tf_slot, chisq_tot = gm.compute_chisq(
            resid_arr, inv_var_chan, require_full=full)

        if statfield:
            getattr(stats.chanant, statfield)[...] = np.sum(chisq, axis=0)
            getattr(stats.timeant, statfield)[...] = np.sum(chisq, axis=1)
            getattr(stats.timechan, statfield)[...] = np.sum(chisq, axis=2)

        return chisq_per_tf_slot, chisq_tot

    chi, stats.chunk.chi2u = compute_chisq(statfield='initchi2', full=True)
    stats.chunk.chi2_0 = stats.chunk.chi2u_0 = stats.chunk.chi2u

    # The following provides conditioning information when verbose is set to > 0.
    if log.verbosity() > 0:
        log(1).print("{} chi^2_0 {:.4}; {}; noise {:.3}, flags: {}".format(
            label, stats.chunk.chi2_0, gm.conditioning_status_string,
            float(stats.chunk.noise_0), get_flagging_stats()))

    # Main loop of the NNLS method. Terminates after quorum is reached in either converged or
    # stalled solutions or when the maximum number of iterations is exceeded.

    major_step = 0  # keeps track of "major" solution steps, for purposes of collecting stats

    while not (gm.has_converged) and not (gm.has_stalled):

        num_iter, update_major_step = gm.next_iteration()

        if update_major_step:
            major_step += 1
            stats.save_chunk_stats(step=major_step)

        # This is currently an awkward necessity - if we have a chain of jones terms, we need to
        # make sure that the active term is correct and need to support some sort of decision making
        # for testing convergence. I think doing the iter increment here might be the best choice,
        # with an additional bit of functionality for Jones chains. I suspect I will still need to
        # change the while loop component to be compatible with the idea of partial convergence.
        # Perhaps this should all be done right at the top of the function? A better idea is to let
        # individual machines be aware of their own stalled/converged status, and make those
        # properties more complicated on the chain. This should allow for fairly easy substitution
        # between the various machines.

        gm.compute_update(model_arr, obser_arr)

        # flag solutions. This returns True if any flags have been propagated out to the data.
        if gm.flag_solutions(flags_arr, False):

            update_stats(flags_arr, ('chi2', ))

            # Re-zero the model and data at newly flagged points.
            # TODO: is this needed?
            # TODO: should we perhaps just zero the model per flagged direction, and only flag the data?
            # OMS: probably not: flag propagation is now handled inside the gain machine. If a flag is
            # propagated out to the data, then that slot is gone gone gone and should be zeroe'd everywhere.

            new_flags = flags_arr & ~(FL.MISSING | FL.PRIOR) != 0

            model_arr[:, :, new_flags, :, :] = 0
            obser_arr[:, new_flags, :, :] = 0

            stats.chunk.num_sol_flagged = gm.num_gain_flags()[0]

        # Compute values used in convergence tests. This check implicitly marks flagged gains as
        # converged.

        gm.check_convergence(gm.epsilon)

        stats.chunk.iters = num_iter
        stats.chunk.num_converged = gm.num_converged_solutions
        stats.chunk.frac_converged = gm.num_solutions and gm.num_converged_solutions / float(
            gm.num_solutions)

        # Break out of the solver loop if we find ourselves with no valid solution intervals (e.g. due to gain flagging)
        if not gm.has_valid_solutions:
            break

        # Check residual behaviour after a number of iterations equal to chi_interval. This is
        # expensive, so we do it as infrequently as possible.

        if (num_iter % chi_interval) == 0 or num_iter <= 1:

            old_chi, old_mean_chi = chi, float(stats.chunk.chi2u)

            gm.compute_residual(obser_arr,
                                model_arr,
                                resid_arr,
                                require_full=False)
            resid_arr[:, flags_arr != 0] = 0

            # do mad max flagging, if requested
            thr1, thr2 = madmax.get_mad_thresholds()
            if thr1 or thr2:
                num_mad_flagged_prior = int(stats.chunk.num_mad_flagged)

                if madmax.beyond_thunderdome(
                        resid_arr, obser_arr, model_arr, flags_arr, thr1, thr2,
                        "{} iter {} ({})".format(label, num_iter,
                                                 gm.jones_label)):
                    gm.update_equation_counts(flags_arr != 0)
                    stats.chunk.num_mad_flagged = ((flags_arr & FL.MAD) !=
                                                   0).sum()
                    if stats.chunk.num_mad_flagged != num_mad_flagged_prior:
                        log(2).print("{}: {} new MadMax flags".format(
                            label, stats.chunk.num_mad_flagged -
                            num_mad_flagged_prior))

            chi, stats.chunk.chi2u = compute_chisq(full=False)

            # Check for stalled solutions - solutions for which the residual is no longer improving.
            # Don't do this on a major step (i.e. when going from term to term in a chain), as the
            # reduced chisq (which compute_chisq() returns) can actually jump when going to the next term

            if update_major_step:
                stats.chunk.num_stalled = stats.chunk.num_diverged = 0
            else:
                delta_chi = old_chi - chi
                stats.chunk.num_stalled = np.sum(
                    (delta_chi <= gm.delta_chi * old_chi))
                diverged_tf_slots = delta_chi < -0.1 * old_chi
                stats.chunk.num_diverged = diverged_tf_slots.sum()
                # at first iteration, flag immediate divergers
                if sol_opts[
                        'flag-divergence'] and stats.chunk.num_diverged and num_iter == 1:
                    model_arr[:, :, diverged_tf_slots] = 0
                    obser_arr[:, diverged_tf_slots] = 0

                    # find previously unflagged visibilities that have become flagged due to divergence
                    new_flags = (flags_arr == 0)
                    new_flags[~diverged_tf_slots] = 0

                    flags_arr[diverged_tf_slots] |= FL.DIVERGE

                    num_nf = new_flags.sum()
                    log.warn(
                        "{}: {:.2%} slots diverging, {} new data flags".format(
                            label,
                            diverged_tf_slots.sum() /
                            float(diverged_tf_slots.size), num_nf))

            stats.chunk.frac_stalled = stats.chunk.num_stalled / float(
                chi.size)
            stats.chunk.frac_diverged = stats.chunk.num_diverged / float(
                chi.size)

            gm.has_stalled = (stats.chunk.frac_stalled >= stall_quorum)

            # if gm.has_stalled:
            #     import pdb; pdb.set_trace()

            if log.verbosity() > 1:
                if update_major_step:
                    delta_chi_max = delta_chi_mean = 0.
                else:
                    wh = old_chi != 0
                    delta_chi[wh] /= old_chi[wh]
                    delta_chi_max = delta_chi.max()
                    chi_mean = float(stats.chunk.chi2u)
                    delta_chi_mean = (old_mean_chi - chi_mean
                                      ) / chi_mean if chi_mean != 0 else 0.

                if stats.chunk.num_diverged:
                    diverging = ", " + ModColor.Str(
                        "diverging {:.2%}".format(stats.chunk.frac_diverged),
                        "red")
                else:
                    diverging = ""

                log(2).print(
                    "{} {} chi2 {:.4}, rel delta {:.4} max {:.4}, active {:.2%}{}"
                    .format(label, gm.current_convergence_status_string,
                            stats.chunk.chi2u, delta_chi_mean, delta_chi_max,
                            float(1 - stats.chunk.frac_stalled), diverging))

        # Adding the below lines for the robust solver so that flags should be apply to the weights
        if hasattr(gm, 'is_robust'):
            gm.update_weight_flags(flags_arr)

    # num_valid_solutions will go to 0 if all solution intervals were flagged. If this is not the
    # case, generate residuals etc.

    if gm.has_valid_solutions:
        # Final round of flagging
        flagged = gm.flag_solutions(flags_arr, final=True)
        stats.chunk.num_sol_flagged = gm.num_gain_flags(final=True)[0]
    else:
        flagged = None

    # check this again, because final round of flagging could have killed us
    if gm.has_valid_solutions:
        # Do we need to recompute the final residuals?
        if (sol_opts['last-rites'] or compute_residuals):
            gm.compute_residual(obser_arr,
                                model_arr,
                                resid_arr,
                                require_full=True)
            resid_arr[:, flags_arr != 0] = 0

            # do mad max flagging, if requested
            thr1, thr2 = madmax.get_mad_thresholds()
            if thr1 or thr2:
                if madmax.beyond_thunderdome(
                        resid_arr, obser_arr, model_arr, flags_arr, thr1, thr2,
                        "{} final".format(label)) and sol_opts['last-rites']:
                    gm.update_equation_counts(flags_arr != 0)

        # Re-estimate the noise using the final residuals, if last rites are needed.

        if sol_opts['last-rites']:
            stats.chunk.noise, inv_var_antchan, inv_var_ant, inv_var_chan = \
                                        stats.estimate_noise(resid_arr, flags_arr, residuals=True)
            chi1, stats.chunk.chi2 = compute_chisq(statfield='chi2', full=True)
        else:
            stats.chunk.chi2 = stats.chunk.chi2u

        message = "{} (end solve) {}, stall {:.2%}{}, chi^2 {:.4} -> {:.4}".format(
            label, gm.final_convergence_status_string,
            float(stats.chunk.frac_stalled), diverging,
            float(stats.chunk.chi2_0), stats.chunk.chi2u)

        should_warn = float(stats.chunk.chi2_0) < float(
            stats.chunk.chi2u) or diverging
        if sol_opts['last-rites'] and (should_warn or log.verbosity() > 0):
            message = "{} ({:.4}), noise {:.3} -> {:.3}".format(
                message, float(stats.chunk.chi2), float(stats.chunk.noise_0),
                float(stats.chunk.noise))
        if should_warn:
            message += " Shows signs of divergence. If you see this message often you may have significant RFI present in your data or your solution intervals are too short."
        if should_warn:
            log.warn(message)
        elif log.verbosity() > 0:
            log.info(message)

    # If everything has been flagged, no valid solutions are generated.

    else:
        log.error("{} (end solve) {}: completely flagged?".format(
            label, gm.final_convergence_status_string))

        chi2 = chi2u = 0
        resid_arr = obser_arr

    robust_weights = None
    if hasattr(gm, 'is_robust'):

        # do a last round of robust flag robust flag and save the weights

        if gm.robust_flag_weights and not gm.robust_flag_disable:
            gm.robust_flag(flags_arr, model_arr, obser_arr, final=True)
            stats.chunk.num_mad_flagged = ((flags_arr & FL.MAD) != 0).sum()

        if gm.save_weights:
            newshape = gm.weights.shape[1:-1] + (2, 2)
            robust_weights = np.repeat(gm.weights.real, 4, axis=-1)
            robust_weights = np.reshape(robust_weights, newshape)
            gm.output_weights = robust_weights

    # After the solver loop check for warnings from the solvers
    for d in gm.collect_warnings():
        log.write(d["msg"],
                  level=d["level"],
                  print_once=d["raise_once"],
                  verbosity=d["verbosity"],
                  color=d["color"])

    return (resid_arr if compute_residuals else None), stats, robust_weights
コード例 #6
0
ファイル: flagging.py プロジェクト: sjperkins/CubiCal
def flag_chisq(st, GD, basename, nddid):
    """
    Flags timeslots and channels based on accumulated chi-squared statistics.

    Args:
        st (:obj:`~cubical.statistics.SolverStats`):
            Object containing solver statistics.
        GD (dict):
            Dictionary of global options.
        basename (str):
            Base name for output plots.
        nddid (int):
            Number of data descriptor identifiers.

    Returns:
        np.ndarray:
            Flag cube of shape (n_times, n_ddids, n_chans).
    """

    chi2 = np.ma.masked_array(st.timechan.chi2, st.timechan.chi2 == 0)
    total = (~chi2.mask).sum()
    if not total:
        print >> log, ModColor.Str(
            "no valid solutions anywhere: skipping post-solution flagging.")
        return None

    chi2n = st.timechan.chi2n
    chi2n = np.ma.masked_array(chi2n, chi2n == 0)

    median = np.ma.median(chi2)
    median_np = np.ma.median(chi2n)
    print >> log, "median chi2 value is {:.3} from {} valid t/f slots".format(
        median, total)
    print >> log, "median count per slot is {}".format(median_np)

    chi_median_thresh = GD["postmortem"]["tf-chisq-median"]
    np_median_thresh = GD["postmortem"]["tf-np-median"]
    time_density = GD["postmortem"]["time-density"]
    chan_density = GD["postmortem"]["chan-density"]
    ddid_density = GD["postmortem"]["ddid-density"]

    make_plots = GD["out"]["plots"]
    show_plots = make_plots == "show"

    if make_plots:
        import pylab
        pylab.figure(figsize=(32, 10))
        pylab.subplot(161)
        if chi2.count():
            pylab.imshow(chi2, vmin=0, vmax=5 * median)
        pylab.title("$\chi^2$")
        pylab.colorbar()
        pylab.subplot(162)
        if chi2n.count():
            pylab.imshow(chi2n)
        pylab.title("counts")
        pylab.colorbar()

    flag = (chi2 > chi_median_thresh * median)
    chi2[flag] = np.ma.masked
    nflag = flag.sum()
    print >> log, "{} slots ({:.2%}) flagged on chi2 > {}*median".format(
        nflag, nflag / float(total), chi_median_thresh)

    if make_plots:
        pylab.subplot(163)
        if chi2.count():
            pylab.imshow(chi2)
        pylab.title("$\chi^2$ median flagged")
        pylab.colorbar()

    flag2 = (chi2n < np_median_thresh * median_np)
    n_new = (flag2 & ~flag).sum()
    print >> log, "{} more slots ({:.2%}) flagged on counts < {}*median".format(
        n_new, n_new / float(total), np_median_thresh)
    flag |= flag2

    chi2[flag] = np.ma.masked
    if make_plots:
        pylab.subplot(164)
        if chi2.count():
            pylab.imshow(chi2)
        pylab.title("counts flagged")
        pylab.colorbar()

    nt, nf = flag.shape

    # flag channels with overdense flagging
    freqcount = flag.sum(axis=0)
    freqflags = freqcount > nt * chan_density
    n_new = (freqflags & ~(freqcount == nt)).sum()
    print >> log, "{} more channels flagged on density > {}".format(
        n_new, chan_density)

    # flag timeslots with overdense flagging
    timecount = flag.sum(axis=1)
    timeflags = timecount > nf * time_density
    n_new = (timeflags & ~(timecount == nf)).sum()
    print >> log, "{} more timeslots flagged on density > {}".format(
        n_new, time_density)

    flag = flag | freqflags[np.newaxis, :] | timeflags[:, np.newaxis]
    chi2[flag] = np.ma.masked
    if make_plots:
        pylab.subplot(165)
        if chi2.count():
            pylab.imshow(chi2)
        pylab.title("overdense flagged")
        pylab.colorbar()

    # reshape flag array into time, ddid, channel
    flag3 = flag.reshape((nt, nddid, nf / nddid))

    # flag entire DDIDs with overdense flagging
    maxcount = nt * nf / nddid
    ddidcounts = flag3.sum(axis=(0, 2))
    ddidflags = ddidcounts > maxcount * ddid_density
    n_new = (ddidflags & ~(ddidcounts == maxcount)).sum()
    print >> log, "{} more ddids flagged on density > {}".format(
        n_new, ddid_density)

    flag3 |= ddidflags[np.newaxis, :, np.newaxis]
    chi2[flag] = np.ma.masked
    if make_plots:
        pylab.subplot(166)
        pylab.title("overdense DDID")
        if chi2.count():
            pylab.imshow(chi2)
        pylab.colorbar()
        filename = basename + ".chiflag.png"
        pylab.savefig(filename, DPI=plots.DPI)
        print >> log, "saved chi-sq flagging plot to " + filename
        if show_plots:
            pylab.show()

    return flag3
コード例 #7
0
def main(debugging=False):
    """
    Main cubical driver function. Reads options, sets up MS and solvers, calls the solver, etc.

    Args:
        debugging (bool, optional):
            If True, run in debugging mode.

    Raises:
        UserInputError:
            If neither --model-lsm nor --model-column were specified.
        UserInputError:
            If no Jones terms are enabled.
        UserInputError:
            If --out-mode is invalid.
        ValueError:
            If unknown Jones type is specified.
        RuntimeError:
            If I/O job on a tile failed.
    """

    # this will be set below if a custom parset is specified on the command line
    custom_parset_file = None
    # "GD" is a global defaults dict, containing options set up from parset + command line
    global GD, enable_pdb

    try:
        if debugging:
            print >> log, "initializing from cubical.last"
            GD = cPickle.load(open("cubical.last"))
            basename = GD["out"]["name"]
            parser = None
        else:
            default_parset = parsets.Parset("%s/DefaultParset.cfg" %
                                            os.path.dirname(__file__))

            # if first argument is a filename, treat it as a parset

            if len(sys.argv) > 1 and not sys.argv[1][0].startswith('-'):
                custom_parset_file = sys.argv[1]
                print >> log, "reading defaults from {}".format(
                    custom_parset_file)
                try:
                    parset = parsets.Parset(custom_parset_file)
                except:
                    import traceback
                    traceback.print_exc()
                    raise UserInputError(
                        "'{}' must be a valid parset file. Use -h for help.".
                        format(custom_parset_file))
                if not parset.success:
                    raise UserInputError(
                        "'{}' must be a valid parset file. Use -h for help.".
                        format(custom_parset_file))
                # update default parameters with values from parset
                default_parset.update_values(
                    parset, other_filename=' in {}'.format(custom_parset_file))

            import cubical
            parser = dynoptparse.DynamicOptionParser(
                usage='Usage: %prog [parset file] <options>',
                description=
                """Questions, bug reports, suggestions: https://github.com/ratt-ru/CubiCal""",
                version='%prog version {}'.format(cubical.VERSION),
                defaults=default_parset.value_dict,
                attributes=default_parset.attr_dict)

            # now read the full input from command line
            # "GD" is a global defaults dict, containing options set up from parset + command line
            GD = parser.read_input()

            # if a single argument is given, it should have been the parset
            if len(parser.get_arguments()) != (1 if custom_parset_file else 0):
                raise UserInputError(
                    "Unexpected number of arguments. Use -h for help.")

            # "GD" is a global defaults dict, containing options set up from parset + command line
            cPickle.dump(GD, open("cubical.last", "w"))

            # get basename for all output files
            basename = GD["out"]["name"]
            if not basename:
                basename = "out"

            # create directory for output files, if it doesn't exist
            dirname = os.path.dirname(basename)
            if not os.path.exists(dirname) and not dirname == "":
                os.mkdir(dirname)

            # save parset with all settings. We refuse to clobber a parset with itself
            # (so e.g. "gocubical test.parset --Section-Option foo" does not overwrite test.parset)
            save_parset = basename + ".parset"
            if custom_parset_file and os.path.exists(custom_parset_file) and os.path.exists(save_parset) and \
                    os.path.samefile(save_parset, custom_parset_file):
                basename = "~" + basename
                save_parset = basename + ".parset"
                print >> log, ModColor.Str(
                    "Your --Output-Name would overwrite its own parset. Using %s instead."
                    % basename)
            parser.write_to_parset(save_parset)

        enable_pdb = GD["debug"]["pdb"]
        # clean up shared memory from any previous runs
        shm_utils.cleanupStaleShm()

        # now setup logging
        logger.logToFile(basename + ".log", append=GD["log"]["append"])
        logger.enableMemoryLogging(GD["log"]["memory"])
        logger.setBoring(GD["log"]["boring"])
        logger.setGlobalVerbosity(GD["log"]["verbose"])
        logger.setGlobalLogVerbosity(GD["log"]["file-verbose"])

        if not debugging:
            print >> log, "started " + " ".join(sys.argv)

        # disable matplotlib's tk backend if we're not going to be showing plots
        if GD['out']['plots-show']:
            import pylab
            try:
                pylab.figure()
            except Exception, exc:
                import traceback
                print >> log, ModColor.Str(
                    "Error initializing matplotlib: {}({})\n {}".format(
                        type(exc).__name__, exc, traceback.format_exc()))
                raise UserInputError(
                    "matplotlib can't connect to X11. Suggest disabling --out-plots-show."
                )
        else:
コード例 #8
0
        double_precision = GD["sol"]["precision"] == 64

        # set up RIME

        solver_opts = GD["sol"]
        debug_opts = GD["debug"]
        sol_jones = solver_opts["jones"]
        if type(sol_jones) is str:
            sol_jones = set(sol_jones.split(','))
        jones_opts = [GD[j.lower()] for j in sol_jones]
        # collect list of options from enabled Jones matrices
        if not len(jones_opts):
            raise UserInputError("No Jones terms are enabled")
        print >> log, ModColor.Str("Enabling {}-Jones".format(
            ",".join(sol_jones)),
                                   col="green")

        have_dd_jones = any([jo['dd-term'] for jo in jones_opts])

        # TODO: in this case data_handler can be told to only load diagonal elements. Save memory!
        # top-level diag-diag enforced across jones terms
        if solver_opts['diag-diag']:
            for jo in jones_opts:
                jo['diag-diag'] = True
        else:
            solver_opts['diag-diag'] = all(
                [jo['diag-diag'] for jo in jones_opts])

        # set up data handler
コード例 #9
0
ファイル: pickled_db.py プロジェクト: ratt-ru/CubiCal
    def _load(self, filename):
        """
        Loads database from file. This will create arrays corresponding to the stored parameter
        shapes.

        Args:
            filename (str):
                Name of file to load.
        """

        self.mode = "load"
        self.filename = filename

        db = self._Unpickler(filename)
        print("reading {} in {} mode".format(self.filename, db.mode),
              file=log(0))
        self.metadata = db.metadata
        for key, value in self.metadata.items():
            if key != "mode":
                print("  metadata '{}': {}".format(key, value), file=log(1))

        # now load differently depending on mode
        # in consolidated mode, just unpickle the parameter objects
        if db.mode == PickledDatabase.MODE_CONSOLIDATED:
            self._parameters = next(db)
            for parm in self._parameters.values():
                print("  read {} of shape {}".format(
                    parm.name, 'x'.join(map(str, parm.shape))),
                      file=log(1))
            return

        # otherwise we're in fragmented mode
        if db.mode != PickledDatabase.MODE_FRAGMENTED:
            raise IOError("{}: invalid mode".format(self.filename,
                                                    self.metadata.mode))

        # in fragmented mode? try to read the desc file
        descfile = filename + '.skel'
        self._parameters = None
        if not os.path.exists(descfile):
            print(ModColor.Str(
                "{} does not exist, will try to rebuild".format(descfile)),
                  file=log(0))
        elif os.path.getmtime(descfile) < os.path.getmtime(self.filename):
            print(ModColor.Str(
                "{} older than database: will try to rebuild".format(
                    descfile)),
                  file=log(0))
        elif os.path.getmtime(descfile) < os.path.getmtime(__file__):
            print(ModColor.Str(
                "{} older than this code: will try to rebuild".format(
                    descfile)),
                  file=log(0))
        else:
            try:
                with open(descfile, 'rb') as pf:
                    self._parameters = pickle.load(pf)
            except:
                traceback.print_exc()
                print(ModColor.Str(
                    "error loading {}, will try to rebuild".format(descfile)),
                      file=log(0))
        # rebuild the skeletons, if they weren't loaded
        if self._parameters is None:
            self._parameters = {}
            for item in db:
                if isinstance(item, Parameter):
                    self._parameters[item.name] = item
                elif isinstance(item, _ParmSegment):
                    self._parameters[item.name]._update_shape(
                        item.array.shape, item.grid)
                else:
                    raise IOError("{}: unexpected entry of type '{}'".format(
                        self.filename, type(item)))
            self._save_desc()

        # initialize arrays
        for parm in self._parameters.values():
            parm._init_arrays()

        # go over all slices to paste them into the arrays
        db = self._Unpickler(filename)
        for item in db:
            if type(item) is Parameter:
                pass
            elif type(item) is _ParmSegment:
                parm = self._parameters.get(item.name)
                if parm is None:
                    raise IOError("{}: no parm found for {}'".format(
                        filename, item.name))
                parm._paste_slice(item)
            else:
                raise IOError("{}: unknown item type '{}'".format(
                    filename, type(item)))

        # ok, now arrays and flags each contain a full-sized array. Break it up into slices.
        for parm in self._parameters.values():
            parm._finalize_arrays()
コード例 #10
0
ファイル: solver.py プロジェクト: saopicc/CubiCal
def run_solver(solver_type, itile, chunk_key, sol_opts, debug_opts):
    """
    Initialises a gain machine and invokes the solver for the current chunk.

    Args:
        solver_type (str):
            Specifies type of solver to use.
        itile (int):
            Index of current Tile object.
        chunk_key (str):
            Label identifying the current chunk (e.g. "D0T1F2").
        sol_opts (dict):
            Solver options (see [sol] section in DefaultParset.cfg).

    Returns:
        :obj:`~cubical.statistics.SolverStats`:
            An object containing solver statistics.

    Raises:
        RuntimeError:
            If gain factory has not been initialised.
    """
    import cubical.workers
    cubical.workers._init_worker()

    label = None
    try:
        tile = Tile.tile_list[itile]
        label = chunk_key
        solver = SOLVERS[solver_type]
        # initialize the gain machine for this chunk

        if gm_factory is None:
            raise RuntimeError("Gain machine factory has not been initialized")

        # Get chunk data from tile.

        # need to know which kernel to use to allocate visibility and flag arrays
        kernel = gm_factory.get_kernel()

        obser_arr, model_arr, flags_arr, weight_arr = tile.get_chunk_cubes(
            chunk_key,
            allocator=kernel.allocate_vis_array,
            flag_allocator=kernel.allocate_flag_array)

        chunk_ts, chunk_fs, _, freq_slice = tile.get_chunk_tfs(chunk_key)

        # apply IFR-based gains, if any
        ifrgain_machine.apply(obser_arr, freq_slice)

        # create subdict in shared dict for solutions etc.
        soldict = tile.create_solutions_chunk_dict(chunk_key)

        # create VisDataManager for this chunk

        vdm = _VisDataManager(obser_arr, model_arr, flags_arr, weight_arr,
                              freq_slice)

        n_dir, n_mod = model_arr.shape[0:2] if model_arr is not None else (1,
                                                                           1)

        # create GainMachine
        vdm.gm = gm_factory.create_machine(vdm.weighted_obser, n_dir, n_mod,
                                           chunk_ts, chunk_fs, label)

        # Invoke solver method
        if debug_opts['stop-before-solver']:
            import pdb
            pdb.set_trace()

        corr_vis, stats = solver(vdm, soldict, label, sol_opts)

        # Panic if amplitude has gone crazy

        if debug_opts['panic-amplitude']:
            if corr_vis is not None:
                unflagged = flags_arr == 0
                if unflagged.any() and abs(corr_vis[unflagged, :, :]).max(
                ) > debug_opts['panic-amplitude']:
                    raise RuntimeError(
                        "excessive amplitude in chunk {}".format(label))

        # Copy results back into tile.

        tile.set_chunk_cubes(
            corr_vis, flags_arr if
            (stats and stats.chunk.num_sol_flagged) else None, chunk_key)

        # Ask the gain machine to store its solutions in the shared dict.
        gm_factory.export_solutions(vdm.gm, soldict)

        return stats

    except Exception, exc:
        print >> log, ModColor.Str(
            "Solver for tile {} chunk {} failed with exception: {}".format(
                itile, label, exc))
        print >> log, traceback.format_exc()
        raise
コード例 #11
0
ファイル: solver.py プロジェクト: saopicc/CubiCal
def _solve_gains(gm,
                 obser_arr,
                 model_arr,
                 flags_arr,
                 sol_opts,
                 label="",
                 compute_residuals=None):
    """
    Main body of the GN/LM method. Handles iterations and convergence tests.

    Args:
        gm (:obj:`~cubical.machines.abstract_machine.MasterMachine`): 
            The gain machine which will be used in the solver loop.
        obser_arr (np.ndarray): 
            Shape (n_mod, n_tim, n_fre, n_ant, n_ant, n_cor, n_cor) array containing observed 
            visibilities. 
        model_arr (np.ndarray): 
            Shape (n_dir, n_mod, n_tim, n_fre, n_ant, n_ant, n_cor, n_cor) array containing model 
            visibilities. 
        flags_arr (np.ndarray): 
            Shape (n_tim, n_fre, n_ant, n_ant) integer array containing flag data.
        sol_opts (dict): 
            Solver options (see [sol] section in DefaultParset.cfg).
        label (str, optional):             
            Label identifying the current chunk (e.g. "D0T1F2").
        compute_residuals (bool, optional): 
            If set, the final residuals will be computed and returned.

    Returns:
        2-element tuple
            
            - resid (np.ndarray)
                The final residuals (if compute_residuals is set), else None.
            - stats (:obj:`~cubical.statistics.SolverStats`)
                An object containing solver statistics.
    """
    min_delta_g = sol_opts["delta-g"]
    chi_tol = sol_opts["delta-chi"]
    chi_interval = sol_opts["chi-int"]
    stall_quorum = sol_opts["stall-quorum"]

    # Initialise stat object.

    stats = SolverStats(obser_arr)
    stats.chunk.label = label

    n_stall = 0
    frac_stall = 0
    n_original_flags = (flags_arr & ~(FL.PRIOR | FL.MISSING) != 0).sum()

    # initialize iteration counter

    num_iter = 0

    # Estimates the overall noise level and the inverse variance per channel and per antenna as
    # noise varies across the band. This is used to normalize chi^2.

    stats.chunk.init_noise, inv_var_antchan, inv_var_ant, inv_var_chan = \
                                                        stats.estimate_noise(obser_arr, flags_arr)

    # if we have directions in the model, but the gain machine is non-DD, collapse them
    if not gm.dd_term and model_arr.shape[0] > 1:
        model_arr = model_arr.sum(axis=0, keepdims=True)

    # This works out the conditioning of the solution, sets up various chi-sq normalization
    # factors etc, and does any other precomputation required by the current gain machine.

    gm.precompute_attributes(model_arr, flags_arr, inv_var_chan)

    def get_flagging_stats():
        """Returns a string describing per-flagset statistics"""
        fstats = []

        for flag, mask in FL.categories().iteritems():
            n_flag = ((flags_arr & mask) != 0).sum()
            if n_flag:
                fstats.append("{}:{}({:.2%})".format(
                    flag, n_flag, n_flag / float(flags_arr.size)))

        return " ".join(fstats)

    def update_stats(flags, statfields):
        """
        This function updates the solver stats object with a count of valid data points used for chi-sq 
        calculations
        """
        unflagged = (flags == 0)
        # Compute number of terms in each chi-square sum. Shape is (n_tim, n_fre, n_ant).

        nterms = 2 * gm.n_cor * gm.n_cor * np.sum(unflagged, axis=3)

        # Update stats object accordingly.

        for field in statfields:
            getattr(stats.chanant, field + 'n')[...] = np.sum(nterms, axis=0)
            getattr(stats.timeant, field + 'n')[...] = np.sum(nterms, axis=1)
            getattr(stats.timechan, field + 'n')[...] = np.sum(nterms, axis=2)

    update_stats(flags_arr, ('initchi2', 'chi2'))

    # In the event that there are no solutions with valid data, this will log some of the
    # flag information and break out of the function.

    if not gm.has_valid_solutions:
        stats.chunk.num_sol_flagged, _ = gm.num_gain_flags()

        print >> log, ModColor.Str("{} no solutions: {}; flags {}".format(
            label, gm.conditioning_status_string, get_flagging_stats()))
        return (obser_arr if compute_residuals else None), stats

    # Initialize a residual array.

    resid_shape = [
        gm.n_mod, gm.n_tim, gm.n_fre, gm.n_ant, gm.n_ant, gm.n_cor, gm.n_cor
    ]

    resid_arr = gm.cykernel.allocate_vis_array(resid_shape,
                                               obser_arr.dtype,
                                               zeros=True)
    gm.compute_residual(obser_arr, model_arr, resid_arr)
    resid_arr[:, flags_arr != 0] = 0

    # This flag is set to True when we have an up-to-date residual in resid_arr.

    have_residuals = True

    def compute_chisq(statfield=None):
        """
        Computes chi-squared statistic based on current residuals and noise estimates.
        Populates the stats object with it.
        """
        chisq, chisq_per_tf_slot, chisq_tot = gm.compute_chisq(
            resid_arr, inv_var_chan)

        if statfield:
            getattr(stats.chanant, statfield)[...] = np.sum(chisq, axis=0)
            getattr(stats.timeant, statfield)[...] = np.sum(chisq, axis=1)
            getattr(stats.timechan, statfield)[...] = np.sum(chisq, axis=2)

        return chisq_per_tf_slot, chisq_tot

    chi, mean_chi = compute_chisq(statfield='initchi2')
    stats.chunk.init_chi2 = mean_chi

    # The following provides conditioning information when verbose is set to > 0.
    if log.verbosity() > 0:

        print >> log, "{} chi^2_0 {:.4}; {}; noise {:.3}, flags: {}".format(
            label, mean_chi, gm.conditioning_status_string,
            float(stats.chunk.init_noise), get_flagging_stats())

    # Main loop of the NNLS method. Terminates after quorum is reached in either converged or
    # stalled solutions or when the maximum number of iterations is exceeded.

    while not (gm.has_converged) and not (gm.has_stalled):

        num_iter = gm.next_iteration()

        # This is currently an awkward necessity - if we have a chain of jones terms, we need to
        # make sure that the active term is correct and need to support some sort of decision making
        # for testing convergence. I think doing the iter increment here might be the best choice,
        # with an additional bit of functionality for Jones chains. I suspect I will still need to
        # change the while loop component to be compatible with the idea of partial convergence.
        # Perhaps this should all be done right at the top of the function? A better idea is to let
        # individual machines be aware of their own stalled/converged status, and make those
        # properties more complicated on the chain. This should allow for fairly easy substitution
        # between the various machines.

        gm.compute_update(model_arr, obser_arr)

        # flag solutions. This returns True if any flags have been propagated out to the data.
        if gm.flag_solutions(flags_arr, False):

            update_stats(flags_arr, ('chi2', ))

            # Re-zero the model and data at newly flagged points.
            # TODO: is this needed?
            # TODO: should we perhaps just zero the model per flagged direction, and only flag the data?
            # OMS: probably not: flag propagation is now handled inside the gain machine. If a flag is
            # propagated out to the data, then that slot is gone gone gone and should be zeroe'd everywhere.

            new_flags = flags_arr & ~(FL.MISSING | FL.PRIOR) != 0
            model_arr[:, :, new_flags, :, :] = 0
            obser_arr[:, new_flags, :, :] = 0

            # Break out of the solver loop if we find ourselves with no valid solution intervals.

            if not gm.has_valid_solutions:
                break

        # print>>log,"{} {} {}".format(de.gains[1,5,2,5], de.posterior_gain_error[1,5,2,5], de.posterior_gain_error[1].mean())
        #
        have_residuals = False

        # Compute values used in convergence tests. This check implicitly marks flagged gains as
        # converged.

        gm.check_convergence(min_delta_g)

        # Check residual behaviour after a number of iterations equal to chi_interval. This is
        # expensive, so we do it as infrequently as possible.

        if (num_iter % chi_interval) == 0:

            old_chi, old_mean_chi = chi, mean_chi

            gm.compute_residual(obser_arr, model_arr, resid_arr)
            resid_arr[:, flags_arr != 0] = 0

            chi, mean_chi = compute_chisq()

            have_residuals = True

            # Check for stalled solutions - solutions for which the residual is no longer improving.

            n_stall = float(np.sum(((old_chi - chi) < chi_tol * old_chi)))
            frac_stall = n_stall / chi.size

            gm.has_stalled = (frac_stall >= stall_quorum)

            if log.verbosity() > 1:

                delta_chi = (old_mean_chi - mean_chi) / old_mean_chi

                print >> log(2), (
                    "{} {} chi2 {:.4}, delta {:.4}, stall {:.2%}").format(
                        label, gm.current_convergence_status_string, mean_chi,
                        delta_chi, frac_stall)

    # num_valid_solutions will go to 0 if all solution intervals were flagged. If this is not the
    # case, generate residuals etc.

    if gm.has_valid_solutions:
        # Final round of flagging
        flagged = gm.flag_solutions(flags_arr, True)

    # check this again, because final round of flagging could have killed us
    if gm.has_valid_solutions:
        # Do we need to recompute the final residuals?
        if (sol_opts['last-rites']
                or compute_residuals) and (not have_residuals or flagged):
            gm.compute_residual(obser_arr, model_arr, resid_arr)
            resid_arr[:, flags_arr != 0] = 0
            if sol_opts['last-rites']:
                # Recompute chi-squared based on original noise statistics.
                chi, mean_chi = compute_chisq(statfield='chi2')

        # Re-estimate the noise using the final residuals, if last rites are needed.

        if sol_opts['last-rites']:
            stats.chunk.noise, inv_var_antchan, inv_var_ant, inv_var_chan = \
                                        stats.estimate_noise(resid_arr, flags_arr, residuals=True)
            chi1, mean_chi1 = compute_chisq(statfield='chi2')

        stats.chunk.chi2 = mean_chi

        message = "{} {}, stall {:.2%}, chi^2 {:.4} -> {:.4}".format(
            label, gm.final_convergence_status_string, frac_stall,
            float(stats.chunk.init_chi2), mean_chi)

        if sol_opts['last-rites']:

            message = "{} ({:.4}), noise {:.3} -> {:.3}".format(
                message, float(mean_chi1), float(stats.chunk.init_noise),
                float(stats.chunk.noise))

        print >> log, message

    # If everything has been flagged, no valid solutions are generated.

    else:

        print >> log(0, "red"), "{} {}: completely flagged".format(
            label, gm.final_convergence_status_string)

        stats.chunk.chi2 = 0
        resid_arr = obser_arr

    stats.chunk.iters = num_iter
    stats.chunk.num_converged = gm.num_converged_solutions
    stats.chunk.num_stalled = n_stall

    # copy out flags, if we raised any
    stats.chunk.num_sol_flagged, _ = gm.num_gain_flags()
    if stats.chunk.num_sol_flagged:
        # also for up message with flagging stats
        fstats = ""
        for flagname, mask in FL.categories().iteritems():
            if mask != FL.MISSING:
                n_flag, n_tot = gm.num_gain_flags(mask)
                if n_flag:
                    fstats += "{}:{}({:.2%}) ".format(flagname, n_flag,
                                                      n_flag / float(n_tot))
        n_new_flags = (flags_arr & ~(FL.PRIOR | FL.MISSING) !=
                       0).sum() - n_original_flags
        print >> log, ModColor.Str(
            "{} solver flags raised: {}-> {:.2%} data flags".format(
                label, fstats, n_new_flags / float(flags_arr.size)))

    return (resid_arr if compute_residuals else None), stats
コード例 #12
0
            parser.print_config(dest=log)

        double_precision = GD["sol"]["precision"] == 64

        # set up RIME

        solver_opts = GD["sol"]
        debug_opts  = GD["debug"]
        sol_jones = solver_opts["jones"]
        if type(sol_jones) is str:
            sol_jones = set(sol_jones.split(','))
        jones_opts = [GD[j.lower()] for j in sol_jones]
        # collect list of options from enabled Jones matrices
        if not len(jones_opts):
            raise UserInputError("No Jones terms are enabled")
        print>> log, ModColor.Str("Enabling {}-Jones".format(",".join(sol_jones)), col="green")

        have_dd_jones = any([jo['dd-term'] for jo in jones_opts])

        # TODO: in this case data_handler can be told to only load diagonal elements. Save memory!
        # top-level diag-diag enforced across jones terms
        if solver_opts['diag-diag']:
            for jo in jones_opts:
                jo['diag-diag'] = True
        else:
            solver_opts['diag-diag'] = all([jo['diag-diag'] for jo in jones_opts])

        # set up data handler

        solver_type = GD['out']['mode']
        if solver_type not in solver.SOLVERS:
コード例 #13
0
    def flag_weights(self):
        """Trying flagging visiblities with very very low weights and see if this improves the solution
        like mad max flagger

        wstd: the weights standard deviation
        """

        if self.flaground and not self.robust_flag_disable:

            if self._final_flaground:
                _pre_or_post = "after-solving"   
            else:
                _pre_or_post = "before solving"

            _nvis = self.new_flags.size
            nflag0 = np.sum(self.new_flags!=0)
            wfrac0 = nflag0/_nvis

            # This correction factor is two ensure that we only flag when v is low.
            # The 1.26 factor is just makes it work somehow 
            _v_corr = 1.26/self.v #2*self.npol/self.v if self.v < 5 else self.npol/self.v

            wlow =  _v_corr*(self.v + self.npol)/(self.v + self.sigma_thresh)

            # print("rb-2x2 {} : {} iters: wlow is  {:.3} while 1/v is {:.3}".format(self.label, self.iters, wlow, 1/self.v), file=log(2))

            self.weight_flags = np.where((self.weights< wlow) & (self.weights!=0)) # wlow

            # import pdb; pdb.set_trace()

            if len(self.weight_flags[0])>0:

                self.weights[self.weight_flags] = 0
                self.new_flags[self.weight_flags[1:-1]] |= FL.MAD
                self.residuals[self.weight_flags[:-1]] = 0

                nflag = np.sum(self.new_flags!=0)

                any_new = nflag-nflag0

                if any_new:
                    wfrac = any_new/_nvis
                    
                    print(ModColor.Str("rb-2x2 {} : {} flag round {} : number of weight flags {} ({:.4%}), prior flags {} ({:.4%})".format(self.label, _pre_or_post, self._count+1, any_new, wfrac, nflag0, wfrac0), "blue"), file=log(2))
                
                self.any_new = True

                self._count += 1 
                
                return False if nflag ==_nvis else True

            else:
                if self._count==0:
                    self.any_new = False 
                return True
        
        else:
            self.any_new = False
            self.weight_flags = None 
            
            return True
コード例 #14
0
    def compute_covinv(self):
        
        """
        This functions computes the 4x4 covariance matrix of the residuals visibilities, 
        and it approximtes it inverse. I self.cov_type is set to 1, the covariance maxtrix is 
        assumed to be the Identity matrix as in the Robust-t paper.

        Args:
            residuals (np.array) : Array containing the residuals.
                Shape is n_dir, n_tim, n_fre, n_ant, a_ant, n_cor, n_cor

        Returns:
            covinv (np.array) : Shape is ncor*n_cor x ncor*n_cor (4x4)
            Array containing the inverse covariance matrix

        """

        if self.cov_type == "identity":

            covinv = np.eye(4, dtype=self.dtype)
          
        else:

            Nvis = self.Nvis/2. #only half of the visibilties are used for covariance computation

            ompstd = np.zeros((4,4), dtype=self.dtype)

            self.kernel_robust.compute_cov(self.residuals, ompstd, self.weights)

            # removing the offdiagonal correlations
            std = np.diagonal(ompstd/Nvis) + self.eps**2 # To avoid division by zero

            if np.any(std>self.cov_thresh):
                self.fixed_v = True
                
                if self.flaground:
                    print(ModColor.Str("rb-2x2 {} : flag round {}: Warning Covariance too high probably because of RFI will fixed v to 2 and cov to 1".format(self.label, self._count+1), "red"), file=log(2))
                else:
                    print(ModColor.Str("rb-2x2 {} : {} iters: Warning Covariance too high probably because of RFI will fixed v to 2 and cov to 1".format(self.label, self.iters), "red"), file=log(2))

            else:
                #---scaling the variance in this case improves the robust solver performance----------#
                self.fixed_v = False
                if self.cov_scale and not self.flaground:
                    std /= self.cov_scale 

            if self.iters % 5 == 0 or self.iters == 1:
                if self.flaground:
                    print("rb-2x2 {} : flag round {}: covariance diagonal : [{}]".format(self.label, self._count+1, ", ".join('{:.2g}'.format(x) for x in std)), file=log(2))
                    # print("rb-2x2 {} : flag round {}: covariance diagonal : [{}]".format(self.label, self._count+1, ", ".join('{:.2g}'.format(x) for x in std2)), file=log(2))
                else:
                    print("rb-2x2 {} : {} iters: covariance diagonal : [{}]".format(self.label, self.iters, ", ".join('{:.2g}'.format(x) for x in std)), file=log(2))
                    # print("rb-2x2 {} : {} iters: covariance diagonal : [{}]".format(self.label, self.iters, ", ".join('{:.2g}'.format(x) for x in std2)), file=log(2))


            # can we disable the flagging the solver to avoid flagging unmodelled sources
            # UMS: my thought here is that if data is unmodelled sources rather than RFI xx and yy covariance should be very close
            # so the solver should disable the flagging in this case
            xx_close_to_yy = 0.8 <= np.abs(std[0])/np.abs(std[0]) <= 1.2
            cov_low = np.average([std[0], std[3]]).real < 2e-2

            if xx_close_to_yy and cov_low and self.flaground and self._count==0:
                self.robust_flag_disable = True
                print(ModColor.Str("rb-2x2 {} : flag round {}: Warning: the covariance is low and the xx and yy variances are very close. Flagging will be disable".format(self.label, self._count+1), "red"), file=log(2))


            covinv = np.eye(4, dtype=self.dtype)

            if self.fixed_v:
                covinv *= 1/self.cov_thresh
            else:
                if self.cov_type == "hybrid":
                    if np.max(np.abs(std)) < 1:
                        covinv[np.diag_indices(4)]= 1/std
                
                elif self.cov_type == "compute":
                    covinv[np.diag_indices(4)]= 1/std
                else:
                    raise RuntimeError("unknown robust-cov setting")
        
        if self.npol == 2:
            covinv[(1,2), (1,2)] = 0


        self.covinv = covinv