Beispiel #1
0
    def _update_equation_counts(self, unflagged):
        """Sets up equation counters based on flagging information. Overrides base version to compute
        additional stuff"""
        MasterMachine._update_equation_counts(self, unflagged)

        self.eqs_per_interval = self.interval_sum(self.eqs_per_tf_slot)

        ndir = self.n_dir - len(self.fix_directions) if self.dd_term else 1
        self.num_unknowns = self.dof_per_antenna * self.n_ant * ndir

        # The following determines the number of valid (unflagged) time/frequency slots and the number
        # of valid solution intervals.

        self.valid_intervals = self.eqs_per_interval > self.num_unknowns
        self.num_valid_intervals = self.valid_intervals.sum()
        self.n_valid_sols = self.num_valid_intervals * self.n_dir

        if self.num_valid_intervals:
            # Adjust chi-sq normalisation based on DoF count: MasterMachine computes chi-sq normalization
            # as 1/N_eq, we want to compute it as the reduced chi-square statistic, 1/(N_eq-N_dof)
            # This results in a per-interval correction factor

            with np.errstate(invalid='ignore', divide='ignore'):
                corrfact = self.eqs_per_interval.astype(float) / (
                    self.eqs_per_interval - self.num_unknowns)
            corrfact[~self.valid_intervals] = 0

            self._chisq_tf_norm_factor *= self.unpack_intervals(corrfact)
            self._chisq_norm_factor *= corrfact.sum(
            ) / self.num_valid_intervals
Beispiel #2
0
 def precompute_attributes(self, model_arr, flags_arr, inv_var_chan):
     """Precomputes various stats before starting a solution"""
     MasterMachine.precompute_attributes(self, model_arr, flags_arr,
                                         inv_var_chan)
     for term in self.jones_terms:
         if term.solvable:
             term.precompute_attributes(model_arr, flags_arr, inv_var_chan)
Beispiel #3
0
    def next_iteration(self):
        """
        Updates the iteration count on the relevant element of the Jones chain. It will also handle 
        updating the active Jones term. Ultimately, this should handle any complicated 
        convergence/term switching functionality.
        """

        self.last_active_index = self.active_index
        major_step = False

        if self.active_term.has_converged or self.active_term.has_stalled:
            print("term {} {} ({} iters): {}".format(
                self.active_term.jones_label,
                "converged" if self.active_term.has_converged else "stalled",
                self.active_term.iters,
                self.active_term.final_convergence_status_string),
                  file=log(1))
            self._convergence_states.append(
                self.active_term.final_convergence_status_string)
            self._convergence_states_finalized = True
            self._next_chain_term()
            major_step = True

        self.active_term.next_iteration()

        return MasterMachine.next_iteration(self)[0], major_step
Beispiel #4
0
    def __init__(self, label, data_arr, ndir, nmod, times, frequencies,
                 chunk_label, jones_options):
        """
        Initialises a chain of complex 2x2 gain machines.
        
        Args:
            label (str):
                Label identifying the Jones term.
            data_arr (np.ndarray): 
                Shape (n_mod, n_tim, n_fre, n_ant, n_ant, n_cor, n_cor) array containing observed 
                visibilities. 
            ndir (int):
                Number of directions.
            nmod (nmod):
                Number of models.
            times (np.ndarray):
                Times for the data being processed.
            frequencies (np.ndarray):
                Frequencies for the data being processsed.
            jones_options (dict): 
                Dictionary of options pertaining to the chain. 
        """
        from cubical.main import UserInputError
        # This instantiates the number of complex 2x2 elements in our chain. Each element is a
        # gain machine in its own right - the purpose of this machine is to manage these machines
        # and do the relevant fiddling between parameter updates. When combining DD terms with
        # DI terms, we need to be initialise the DI terms using only one direction - we do this with
        # slicing rather than summation as it is slightly faster.
        self.jones_terms = []
        self.num_left_di_terms = 0  # how many DI terms are there at the left of the chain
        seen_dd_term = False

        for iterm, term_opts in enumerate(jones_options['chain']):
            jones_class = machine_types.get_machine_class(term_opts['type'])
            if jones_class is None:
                raise UserInputError("unknown Jones class '{}'".format(
                    term_opts['type']))
            if not issubclass(jones_class, Complex2x2Gains) and not issubclass(
                    jones_class, ComplexW2x2Gains) and term_opts['solvable']:
                raise UserInputError(
                    "only complex-2x2 or robust-2x2 terms can be made solvable in a Jones chain"
                )
            term = jones_class(term_opts["label"], data_arr, ndir, nmod, times,
                               frequencies, chunk_label, term_opts)
            self.jones_terms.append(term)
            if term.dd_term:
                seen_dd_term = True
            elif not seen_dd_term:
                self.num_left_di_terms = iterm

        MasterMachine.__init__(self, label, data_arr, ndir, nmod, times,
                               frequencies, chunk_label, jones_options)

        self.chain = cubical.kernels.import_kernel("chain")
        # kernel used for compute_residuals and such
        self.kernel = Complex2x2Gains.get_full_kernel(
            jones_options, diag_gains=self.is_diagonal)

        self.n_dir, self.n_mod = ndir, nmod
        _, self.n_tim, self.n_fre, self.n_ant, self.n_ant, self.n_cor, self.n_cor = data_arr.shape

        self.n_terms = len(self.jones_terms)
        # make list of number of iterations per solvable term
        # If not specified, just use the maxiter setting of each term
        # note that this list is updated as we converge, so make a copy
        term_iters = jones_options['sol']['term-iters']
        if not term_iters:
            self.term_iters = [
                term.maxiter for term in self.jones_terms if term.solvable
            ]
        elif type(term_iters) is int:
            self.term_iters = [term_iters]
        elif isinstance(term_iters, (list, tuple)):
            self.term_iters = list(term_iters)
        else:
            raise UserInputError(
                "invalid term-iters={} setting".format(term_iters))

        self.solvable = bool(self.term_iters) and any(
            [term.solvable for term in self.jones_terms])

        # setup first solvable term in chain
        self.active_index = None

        # this list accumulates the per-term convergence status strings
        self._convergence_states = []
        # True when the last active term has had its convergence status queried
        self._convergence_states_finalized = False

        self.cached_model_arr = self._r = self._m = None
Beispiel #5
0
 def next_iteration(self):
     np.copyto(self.old_gains, self.gains)
     return MasterMachine.next_iteration(self)
Beispiel #6
0
    def precompute_attributes(self, model_arr, flags_arr, inv_var_chan):
        """Precomputes various stats before starting a solution"""
        unflagged = MasterMachine.precompute_attributes(
            self, model_arr, flags_arr, inv_var_chan)

        # Pre-flag gain solution intervals that are completely flagged in the input data
        # (i.e. MISSING|PRIOR). This has shape (n_timint, n_freint, n_ant).

        missing_intervals = self.interval_and(
            (flags_arr & (FL.MISSING | FL.PRIOR) != 0).all(axis=-1))

        self.missing_gain_fraction = missing_intervals.sum() / float(
            missing_intervals.size)

        # convert the intervals array to gain shape, and apply flags
        self.gflags[:,
                    self._interval_to_gainres(missing_intervals)] = FL.MISSING

        # number of data points per time/frequency/antenna
        numeq_tfa = unflagged.sum(axis=-1)

        # compute error estimates per direction, antenna, and interval
        with np.errstate(invalid='ignore', divide='ignore'):
            sigmasq = 1 / inv_var_chan  # squared noise per channel. Could be infinite if no data
            # collapse direction axis, if not directional
            if not self.dd_term:
                model_arr = model_arr.sum(axis=0, keepdims=True)
            # mean |model|^2 per direction+TFA
            modelsq = (model_arr*np.conj(model_arr)).real.sum(axis=(1,-1,-2,-3)) / \
                      (self.n_mod*self.n_cor*self.n_cor*numeq_tfa)
            modelsq[:, numeq_tfa == 0] = 0
            # inverse SNR^2 per direction+TFA
            inv_snr2 = sigmasq[np.newaxis, np.newaxis, :, np.newaxis] / modelsq
            inv_snr2[:, numeq_tfa == 0] = 0
            # take the mean SNR^-2 over each interval
            # numeq_tfa becomes number of points per interval, antenna
            numeq_tfa = self.interval_sum(numeq_tfa)
            inv_snr2_int = self.interval_sum(inv_snr2,
                                             1) / numeq_tfa[np.newaxis, ...]
            inv_snr2_int[:, numeq_tfa == 0] = 0
            # convert that into a gain error per direction,interval,antenna
            self.prior_gain_error = np.sqrt(
                inv_snr2_int /
                (self.eqs_per_interval - self.num_unknowns)[np.newaxis, :, :,
                                                            np.newaxis])

        self.prior_gain_error[:, ~self.valid_intervals, :] = 0
        # reset to 0 for fixed directions
        if self.dd_term:
            self.prior_gain_error[self.fix_directions, ...] = 0

        # flag gains on max error
        bad_gain_intervals = self.prior_gain_error > self.max_gain_error  # dir,time,freq,ant
        if bad_gain_intervals.any():
            # (n_dir,) array showing how many were flagged per direction
            self._n_flagged_on_max_error = bad_gain_intervals.sum(axis=(1, 2,
                                                                        3))
            # raised corresponding gain flags
            self.gflags[self._interval_to_gainres(bad_gain_intervals,
                                                  1)] |= FL.LOWSNR
            self.prior_gain_error[bad_gain_intervals] = 0
            # flag intervals where all directions are bad, and propagate that out into flags
            bad_intervals = bad_gain_intervals.all(axis=0)
            if bad_intervals.any():
                bad_slots = self.unpack_intervals(bad_intervals)
                flags_arr[bad_slots, ...] |= FL.LOWSNR
                unflagged[bad_slots, ...] = False
                self._update_equation_counts(unflagged)
        else:
            self._n_flagged_on_max_error = None

        self._n_flagged_on_max_posterior_error = None
        self.flagged = self.gflags != 0
        self.n_flagged = self.flagged.sum()

        return unflagged
Beispiel #7
0
    def __init__(self, label, data_arr, ndir, nmod, times, frequencies,
                 chunk_label, options, cykernel):
        """
        Initialises a gain machine which supports solution intervals.
        
        Args:
            label (str):
                Label identifying the Jones term.
            data_arr (np.ndarray): 
                Shape (n_mod, n_tim, n_fre, n_ant, n_ant, n_cor, n_cor) array containing observed 
                visibilities. 
            ndir (int):
                Number of directions.
            nmod (nmod):
                Number of models.
            times (np.ndarray):
                Times for the data being processed.
            freqs (np.ndarray):
                Frequencies for the data being processsed.
            options (dict): 
                Dictionary of options. 
        """

        MasterMachine.__init__(self, label, data_arr, ndir, nmod, times,
                               frequencies, chunk_label, options)

        self.cykernel = cykernel

        self.t_int = options["time-int"] or self.n_tim
        self.f_int = options["freq-int"] or self.n_fre
        self.eps = 1e-6

        # Initialise attributes used for computing values over intervals.
        # n_tim and n_fre are the time and frequency dimensions of the data arrays.
        # n_timint and n_freint are the time and frequency dimensions of the gains.

        self.t_bins = range(0, self.n_tim, self.t_int)
        self.f_bins = range(0, self.n_fre, self.f_int)

        self.n_timint = len(self.t_bins)
        self.n_freint = len(self.f_bins)
        self.n_tf_ints = self.n_timint * self.n_freint

        # number of valid solutions
        self.n_valid_sols = self.n_dir * self.n_tf_ints

        # split grids into intervals, and find the centre of gravity of each
        timebins = np.split(times, self.t_bins[1:])
        freqbins = np.split(frequencies, self.f_bins[1:])
        timegrid = np.array([float(x.mean()) for x in timebins])
        freqgrid = np.array([float(x.mean()) for x in freqbins])

        # interval_grid determines the per-interval grid poins
        self.interval_grid = dict(time=timegrid, freq=freqgrid)
        # data_grid determines the full resolution grid
        self.data_grid = dict(time=times, freq=frequencies)

        # compute index from each data point to interval number
        t_ind = np.arange(self.n_tim) // self.t_int
        f_ind = np.arange(self.n_fre) // self.f_int

        self.t_mapping, self.f_mapping = np.meshgrid(t_ind,
                                                     f_ind,
                                                     indexing="ij")

        # Initialise attributes used in convergence testing. n_cnvgd is the number
        # of solutions which have converged.

        self._has_stalled = False
        self.n_cnvgd = 0
        self._frac_cnvgd = 0
        self.iters = 0
        self.min_quorum = options["conv-quorum"]
        self.update_type = options["update-type"]
        self.ref_ant = options["ref-ant"]
        self.fix_directions = options["fix-dirs"] or []
        if type(self.fix_directions) is int:
            self.fix_directions = [self.fix_directions]
        # True if gains are loaded from a DB
        self._gains_loaded = False

        # Construct flag array and populate flagging attributes.
        self.max_gain_error = options["max-prior-error"]
        self.max_post_error = options["max-post-error"]

        self.clip_lower = options["clip-low"]
        self.clip_upper = options["clip-high"]
        self.clip_after = options["clip-after"]

        self.init_gains()
        self.old_gains = self.gains.copy()

        # Gain error estimates. Populated by subclasses, if available
        # Should be array of same shape as the gains
        self.prior_gain_error = None
        self.posterior_gain_error = None

        # buffers for arrays used in internal updates
        self._jh = self._jhr = self._jhj = self._gh = self._r = self._ginv = self._ghinv = None
        self._update = None

        # flag: have gains been updated
        self._gh_update = self._ghinv_update = True
    def precompute_attributes(self, data_arr, model_arr, flags_arr, inv_var_chan):
        """Precomputes various attributes of the machine before starting a solution"""
        unflagged = MasterMachine.precompute_attributes(self, data_arr, model_arr, flags_arr, inv_var_chan)

        ## NB: not sure why I used to apply MISSING|PRIOR here. Surely other input flags must be honoured
        ## (SKIPSOL, NULLDATA, etc.)?
        ### Pre-flag gain solution intervals that are completely flagged in the input data
        ### (i.e. MISSING|PRIOR). This has shape (n_timint, n_freint, n_ant).

        missing_intervals = self.interval_and((flags_arr!=0).all(axis=-1))

        self.missing_gain_fraction = missing_intervals.sum() / float(missing_intervals.size)

        # convert the intervals array to gain shape, and apply flags
        self.gflags[:, self._interval_to_gainres(missing_intervals)] = FL.MISSING

        # number of data points per time/frequency/antenna
        numeq_tfa = unflagged.sum(axis=-1)

        # compute error estimates per direction, antenna, and interval
        if inv_var_chan is not None:
            with np.errstate(invalid='ignore', divide='ignore'):
                # collapse direction axis, if not directional
                if not self.dd_term:
                    model_arr = model_arr.sum(axis=0, keepdims=True)
                # mean |model|^2 per direction+TFA
                modelsq = (model_arr*np.conj(model_arr)).real.sum(axis=(1,-1,-2,-3)) / \
                          (self.n_mod*self.n_cor*self.n_cor*numeq_tfa)
                modelsq[:, numeq_tfa==0] = 0

                sigmasq = 1.0/inv_var_chan                        # squared noise per channel. Could be infinite if no data
                # take the sigma (in quadrature) over each interval
                # divided by quadrature unflagged contributing interferometers per interval
                # this yields var<g> 
                # (numeq_tfa becomes number of unflagged points per interval, antenna)
                numeq_tfa = self.interval_sum(numeq_tfa)
                sigmasq[np.logical_or(np.isnan(sigmasq), np.isinf(sigmasq))] = 0.0
                modelsq[np.logical_or(np.isnan(modelsq), np.isinf(modelsq))] = 0.0
                NSR_int = self.interval_sum(np.ones_like(modelsq) * (sigmasq)[None, None, :, None], 1) / \
                              (self.interval_sum(modelsq, 1) * numeq_tfa)
                # convert that into a gain error per direction,interval,antenna
                self.prior_gain_error = np.sqrt(NSR_int)
                if self.dd_term:
                    self.prior_gain_error[self.fix_directions, ...] = 0

                pge_flag_invalid = np.logical_or(np.isnan(self.prior_gain_error),
                                                 np.isinf(self.prior_gain_error))

                invalid_models = np.logical_or(self.interval_sum(modelsq, 1) == 0,
                                               np.logical_or(np.isnan(self.interval_sum(modelsq, 1)),
                                                             np.isinf(self.interval_sum(modelsq, 1))))
                if np.any(np.all(numeq_tfa == 0, axis=-1)) and log.verbosity() > 1:
                    self.raise_userwarning(
                        logging.CRITICAL,
                        "One or more directions (or its frequency intervals) are already fully flagged.",
                        90, raise_once="prior_fully_flagged_dirs", verbosity=2, color="red")

                if np.any(np.all(invalid_models, axis=-1)) and log.verbosity() > 1:
                    self.raise_userwarning(
                        logging.CRITICAL,
                        "One or more directions (or its frequency intervals) have invalid or 0 models.",
                        90, raise_once="invalid_models", verbosity=2, color="red")

            self.prior_gain_error[:, ~self.valid_intervals, :] = 0
            # reset to 0 for fixed directions
            if self.dd_term:
                self.prior_gain_error[self.fix_directions, ...] = 0

            # flag gains on max error
            self._n_flagged_on_max_error = None
            bad_gain_intervals = pge_flag_invalid
            if self.max_gain_error:
                low_snr = self.prior_gain_error > self.max_gain_error
                if low_snr.all(axis=0).all():
                    msg = "'{0:s}' {1:s} All directions flagged, either due to low SNR. "\
                          "You need to check your tagged directions and your max-prior-error and/or solution intervals. "\
                          "New flags will be raised for this chunk of data".format(
                                self.jones_label, self.chunk_label)
                    self.raise_userwarning(logging.CRITICAL, msg, 70, verbosity=log.verbosity(), color="red")

                else:
                    if low_snr.all(axis=-1).all(axis=-1).all(axis=-1).any(): #all antennas fully flagged of some direction
                        dir_snr = {}
                        for d in range(self.prior_gain_error.shape[0]):
                            percflagged = np.sum(low_snr[d]) * 100.0 / low_snr[d].size
                            if percflagged > self.low_snr_warn and d not in self.fix_directions: dir_snr[d] = percflagged

                        if len(dir_snr) > 0:
                            if log.verbosity() > 2:
                                msg = "Low SNR in one or more directions of gain '{0:s}' chunk '{1:s}':".format(
                                        self.jones_label, self.chunk_label) +\
                                      "\n{0:s}\n".format("\n".join(["\t direction {0:s}: {1:.3f}% gains affected".format(
                                                            str(d), dir_snr[d]) for d in sorted(dir_snr)])) +\
                                      "Check your settings for gain solution intervals and max-prior-error. "
                            else:
                                msg = "'{0:s}' {1:s} Low SNR in directions {2:s}. Increase solution intervals or raise max-prior-error!".format(
                                    self.jones_label, self.chunk_label, ", ".join(map(str, sorted(dir_snr))))
                            self.raise_userwarning(logging.CRITICAL, msg, 50, verbosity=log.verbosity(), color="red")

                    if low_snr.all(axis=0).all(axis=0).all(axis=-1).any():
                        msg = "'{0:s}' {1:s} All time of one or more frequency intervals flagged due to low SNR. "\
                              "You need to check your max-prior-error and/or solution intervals. "\
                              "New flags will be raised for this chunk of data".format(
                                    self.jones_label, self.chunk_label)
                        self.raise_userwarning(logging.WARNING, msg, 70, verbosity=log.verbosity())

                    if low_snr.all(axis=0).all(axis=1).all(axis=-1).any():
                        msg = "'{0:s}' {1:s} All channels of one or more time intervals flagged due to low SNR. "\
                              "You need to check your max-prior-error and/or solution intervals. "\
                              "New flags will be raised for this chunk of data".format(
                                    self.jones_label, self.chunk_label)
                        self.raise_userwarning(logging.WARNING, msg, 70, verbosity=log.verbosity())
                    stationflags = np.argwhere(low_snr.all(axis=0).all(axis=0).all(axis=0)).flatten()
                    if stationflags.size > 0:
                        msg = "'{0:s}' {1:s} Stations {2:s} ({3:d}/{4:d}) fully flagged due to low SNR. "\
                              "These stations may be faulty or your SNR requirements (max-prior-error) are not met. "\
                              "New flags will be raised for this chunk of data".format(
                                    self.jones_label, self.chunk_label, ", ".join(map(str, stationflags)),
                                    np.sum(low_snr.all(axis=0).all(axis=0).all(axis=0)), low_snr.shape[3])
                        self.raise_userwarning(logging.WARNING, msg, 70, verbosity=log.verbosity())


                bad_gain_intervals = np.logical_or(bad_gain_intervals,
                                                   low_snr)    # dir,time,freq,ant

            if bad_gain_intervals.any():
                # (n_dir,) array showing how many were flagged per direction
                self._n_flagged_on_max_error = bad_gain_intervals.sum(axis=(1,2,3))
                # raised corresponding gain flags
                self.gflags[self._interval_to_gainres(bad_gain_intervals,1)] |= FL.LOWSNR
                self.prior_gain_error[bad_gain_intervals] = 0
                # flag intervals where all directions are bad, and propagate that out into flags
                bad_intervals = bad_gain_intervals.all(axis=0)
                if bad_intervals.any():
                    bad_slots = self.unpack_intervals(bad_intervals)
                    flags_arr[bad_slots,...] |= FL.LOWSNR
                    unflagged[bad_slots,...] = False
                    self.update_equation_counts(unflagged)

        self._n_flagged_on_max_posterior_error = None
        self.flagged = self.gflags != 0
        self.n_flagged = self.flagged.sum()

        return unflagged
    def __init__(self, label, data_arr, ndir, nmod, times, frequencies, chunk_label, options):
        """
        Initialises a gain machine which supports solution intervals.
        
        Args:
            label (str):
                Label identifying the Jones term.
            data_arr (np.ndarray): 
                Shape (n_mod, n_tim, n_fre, n_ant, n_ant, n_cor, n_cor) array containing observed 
                visibilities. 
            ndir (int):
                Number of directions.
            nmod (nmod):
                Number of models.
            times (np.ndarray):
                Times for the data being processed.
            freqs (np.ndarray):
                Frequencies for the data being processsed.
            options (dict): 
                Dictionary of options.
            diag_gains (bool):
                If True, gains are diagonal-only. Else gains are full 2x2.
        """

        MasterMachine.__init__(self, label, data_arr, ndir, nmod, times, frequencies,
                               chunk_label, options)

        # select which kernels to use for computing full data
        self.kernel = self.get_full_kernel(options, self.is_diagonal)

        # kernel used in solver is diag-diag in diag mode, else uses full kernel version
        if options.get('diag-data') or options.get('diag-only'):
            self.kernel_solve = cubical.kernels.import_kernel('diagdiag_complex')
        else:
            self.kernel_solve = self.kernel

        log(2).print("{} kernels are {} {}".format(label, self.kernel, self.kernel_solve))

        self.t_int = options["time-int"] or self.n_tim
        self.f_int = options["freq-int"] or self.n_fre
        self.eps = 1e-6

        # Initialise attributes used for computing values over intervals.
        # n_tim and n_fre are the time and frequency dimensions of the data arrays.
        # n_timint and n_freint are the time and frequency dimensions of the gains.

        self.t_bins = list(range(0, self.n_tim, self.t_int))
        self.f_bins = list(range(0, self.n_fre, self.f_int))

        self.n_timint = len(self.t_bins)
        self.n_freint = len(self.f_bins)
        self.n_tf_ints = self.n_timint * self.n_freint

        # number of valid solutions
        self.n_valid_sols = self.n_dir * self.n_tf_ints

        # split grids into intervals, and find the centre of gravity of each
        timebins = np.split(times, self.t_bins[1:])
        freqbins = np.split(frequencies, self.f_bins[1:])
        timegrid = np.array([float(x.mean()) for x in timebins])
        freqgrid = np.array([float(x.mean()) for x in freqbins])

        # interval_grid determines the per-interval grid poins
        self.interval_grid = dict(time=timegrid, freq=freqgrid)
        # data_grid determines the full resolution grid
        self.data_grid = dict(time=times, freq=frequencies)

        # compute index from each data point to interval number
        t_ind = np.arange(self.n_tim)//self.t_int
        f_ind = np.arange(self.n_fre)//self.f_int

        self.t_mapping, self.f_mapping = np.meshgrid(t_ind, f_ind, indexing="ij")

        # Initialise attributes used in convergence testing. n_cnvgd is the number
        # of solutions which have converged.

        self._has_stalled = False
        self.n_cnvgd = 0
        self._frac_cnvgd = 0
        self.iters = 0
        self.min_quorum = options["conv-quorum"]
        self.update_type = options["update-type"]
        self.ref_ant = options["ref-ant"]
        self.fix_directions = options["fix-dirs"] if options["fix-dirs"] is not None and \
                options["fix-dirs"] != "" else []

        if type(self.fix_directions) is int:
            self.fix_directions = [self.fix_directions]
        if type(self.fix_directions) is str and re.match(r"^\W*\d{1,}(\W*,\W*\d{1,})*\W*$", self.fix_directions):
            self.fix_directions = map(int, map(str.strip, ",".split(self.fix_directions)))

        if not (type(self.fix_directions) is list and
                all(map(lambda x: type(x) is int, self.fix_directions))):
            raise ArgumentError("Fix directions must be number or list of numbers")

        # True if gains are loaded from a DB
        self._gains_loaded = False

        # Construct flag array and populate flagging attributes.
        self.max_gain_error = options["max-prior-error"]
        self.max_post_error = options["max-post-error"]
        self.low_snr_warn = options["low-snr-warn"]
        self.high_gain_var_warn = options["high-gain-var-warn"]
        self.clip_lower = options["clip-low"]
        self.clip_upper = options["clip-high"]
        self.clip_after = options["clip-after"]

        self.init_gains()
        self.old_gains = self.gains.copy()

        # Gain error estimates. Populated by subclasses, if available
        # Should be array of same shape as the gains
        self.prior_gain_error = None
        self.posterior_gain_error = None

        # buffers for arrays used in internal updates
        self._jh = self._jhr = self._jhj = self._gh = self._r = self._ginv = self._ghinv = None
        self._update = None

        # flag: have gains been updated
        self._gh_update = self._ghinv_update = True