Exemplo n.º 1
0
    def _translate_single_opinion(self, opinion, timestamp):

        flag = None
        if opinion.decision == "bad":
            # Translate from opinions LSD to flags start- and finish-time.
            start_time = csd_to_unix(opinion.lsd)
            finish_time = csd_to_unix(opinion.lsd + 1)

            flag = DataFlag.create_flag(
                "vote",
                start_time,
                finish_time,
                opinion.freq,
                opinion.instrument,
                opinion.inputs,
            )
        client, _ = DataFlagClient.get_or_create(client_name=__name__,
                                                 client_version=__version__)
        vote = DataFlagVote.create(
            time=timestamp,
            mode=self.mode,
            client=client,
            flag=flag,
            revision=self.revision,
            lsd=opinion.lsd,
        )
        DataFlagVoteOpinion.create(vote=vote, opinion=opinion)
        return flag
Exemplo n.º 2
0
    def is_daytime(self, key, csd):

        src = self[key] if isinstance(key, basestring) else key

        is_daytime = 0

        src_ra, src_dec = ephemeris.object_coords(src.body, date=ephemeris.csd_to_unix(csd), deg=True)

        transit_start = ephemeris.csd_to_unix(csd + (src_ra - src.window) / 360.0)
        transit_end = ephemeris.csd_to_unix(csd + (src_ra + src.window) / 360.0)

        solar_rise = ephemeris.solar_rising(transit_start - 24.0*3600.0, end_time=transit_end)

        for rr in solar_rise:

            ss = ephemeris.solar_setting(rr)[0]

            rrex = rr + self._extend_night
            ssex = ss - self._extend_night

            if ((transit_start <= ssex) and (rrex <= transit_end)):

                is_daytime += 1

                tt = ephemeris.solar_transit(rr)[0]
                if (transit_start <= tt) and (tt <= transit_end):
                    is_daytime += 1

                break

        return is_daytime
Exemplo n.º 3
0
    def _csds_available_data(self, csds, filenames_online,
                             filenames_that_exist):
        """
        Return the subset of csds in `csds` for whom all files are online.

        `filenames_online` and `filenames_that_exist` are a list of tuples
        (start_time, finish_time)

        All 3 lists should be sorted.
        """
        csds_available = []

        for csd in csds:
            start_time = ephemeris.csd_to_unix(csd)
            end_time = ephemeris.csd_to_unix(csd + 1)

            # online - list of filenames that are online between start_time and end_time
            # index_online, the final index in which data was located
            online, index_online = self._files_in_timespan(
                start_time, end_time, filenames_online)
            exists, index_exists = self._files_in_timespan(
                start_time, end_time, filenames_that_exist)

            if (len(online) == len(exists)) and (len(online) != 0):
                csds_available.append(csd)

            # The final file in the span may contain more than one sidereal day
            index_online = max(index_online - 1, 0)
            index_exists = max(index_exists - 1, 0)

            filenames_online = filenames_online[index_online:]
            filenames_that_exist = filenames_that_exist[index_exists:]

        return csds_available
Exemplo n.º 4
0
 def _flags_mask(self, index_map_ra):
     if self._cache_flags:
         flag_time_spans = get_flags_cached(self.flags,
                                            self._cache_reset_time)
     else:
         flag_time_spans = get_flags(
             self.flags,
             csd_to_unix(self.lsd.lsd),
             csd_to_unix(self.lsd.lsd + 1),
         )
     csd_arr = self.lsd.lsd + index_map_ra / 360.0
     flag_mask = np.zeros_like(csd_arr, dtype=np.bool)
     for type_, ca, cb in flag_time_spans:
         flag_mask[(csd_arr > unix_to_csd(ca))
                   & (csd_arr < unix_to_csd(cb))] = True
     return flag_mask[:, np.newaxis]
Exemplo n.º 5
0
    def process(self, sstream):
        """Calculate the mean(median) over the sidereal day.

        Parameters
        ----------
        sstream : andata.CorrData or containers.SiderealStream
            Timestream or sidereal stream.

        Returns
        -------
        mustream : same as sstream
            Sidereal stream containing only the mean(median) value.
        """
        from .flagging import daytime_flag, transit_flag

        # Make sure we are distributed over frequency
        sstream.redistribute("freq")

        # Extract lsd
        lsd = sstream.attrs[
            "lsd"] if "lsd" in sstream.attrs else sstream.attrs["csd"]
        lsd_list = lsd if hasattr(lsd, "__iter__") else [lsd]

        # Calculate the right ascension, method differs depending on input container
        if "ra" in sstream.index_map:
            ra = sstream.ra
            timestamp = {
                dd: ephemeris.csd_to_unix(dd + ra / 360.0)
                for dd in lsd_list
            }
            flag_quiet = np.ones(ra.size, dtype=np.bool)

        elif "time" in sstream.index_map:

            ra = ephemeris.lsa(sstream.time)
            timestamp = {lsd: sstream.time}
            flag_quiet = np.fix(ephemeris.unix_to_csd(sstream.time)) == lsd

        else:
            raise RuntimeError("Format of `sstream` argument is unknown.")

        # If requested, determine "quiet" region of sky.
        # In the case of a SiderealStack, there will be multiple LSDs and the
        # mask will be the logical AND of the mask from each individual LSDs.
        if self.mask_day:
            for dd, time_dd in timestamp.items():
                # Mask daytime data
                flag_quiet &= ~daytime_flag(time_dd)

        if self.mask_sources:
            for dd, time_dd in timestamp.items():
                # Mask data near bright source transits
                for body in self.body:
                    flag_quiet &= ~transit_flag(
                        body, time_dd, nsigma=self.nsigma)

        if self.mask_ra:
            # Only use data within user specified ranges of RA
            mask_ra = np.zeros(ra.size, dtype=np.bool)
            for ra_range in self.mask_ra:
                self.log.info("Using data between RA = [%0.2f, %0.2f] deg" %
                              tuple(ra_range))
                mask_ra |= (ra >= ra_range[0]) & (ra <= ra_range[1])
            flag_quiet &= mask_ra

        # Create output container
        newra = np.mean(ra[flag_quiet], keepdims=True)
        mustream = containers.SiderealStream(
            ra=newra,
            axes_from=sstream,
            attrs_from=sstream,
            distributed=True,
            comm=sstream.comm,
        )
        mustream.redistribute("freq")
        mustream.attrs["statistic"] = self._name_of_statistic

        # Dereference visibilities
        all_vis = sstream.vis[:].view(np.ndarray)
        mu_vis = mustream.vis[:].view(np.ndarray)

        # Combine the visibility weights with the quiet flag
        all_weight = sstream.weight[:].view(np.ndarray) * flag_quiet.astype(
            np.float32)
        if not self.inverse_variance:
            all_weight = (all_weight > 0.0).astype(np.float32)

        # Only include freqs/baselines where enough data is actually present
        frac_present = all_weight.sum(axis=-1) / flag_quiet.sum(axis=-1)
        all_weight *= (frac_present > self.missing_threshold)[..., np.newaxis]

        num_freq_missing_local = int(
            (frac_present < self.missing_threshold).all(axis=1).sum())
        num_freq_missing = self.comm.allreduce(num_freq_missing_local,
                                               op=MPI.SUM)

        self.log.info(
            "Cannot estimate a sidereal mean for "
            f"{100.0 * num_freq_missing / len(mustream.freq):.2f}% of all frequencies."
        )

        # Save the total number of nonzero samples as the weight dataset of the output
        # container
        mustream.weight[:] = np.sum(all_weight, axis=-1, keepdims=True)

        # If requested, compute median (requires loop over frequencies and baselines)
        if self.median:
            mu_vis[..., 0].real = weighted_median(all_vis.real.copy(),
                                                  all_weight)
            mu_vis[..., 0].imag = weighted_median(all_vis.imag.copy(),
                                                  all_weight)

            # Where all the weights are zero explicitly set the median to zero
            missing = ~(all_weight.any(axis=-1))
            mu_vis[missing, 0] = 0.0

        else:
            # Otherwise calculate the mean
            mu_vis[:] = np.sum(all_weight * all_vis, axis=-1, keepdims=True)
            mu_vis[:] *= tools.invert_no_zero(mustream.weight[:])

        # Return sidereal stream containing the mean value
        return mustream
Exemplo n.º 6
0
    def view(self):
        if self.lsd is None:
            return panel.pane.Markdown("No data selected.")
        try:
            if self.intercylinder_only:
                name = "ringmap_intercyl"
            else:
                name = "ringmap"
            container = self.data.load_file(self.revision, self.lsd, name)
        except DataError as err:
            return panel.pane.Markdown(
                f"Error: {str(err)}. Please report this problem."
            )

        # Index map for ra (x-axis)
        index_map_ra = container.index_map["ra"]
        axis_name_ra = "RA [degrees]"

        # Index map for sin(ZA)/sin(theta) (y-axis)
        index_map_el = container.index_map["el"]
        axis_name_el = "sin(\u03B8)"

        # Apply data selections
        sel_beam = np.where(container.index_map["beam"] == self.beam)[0]
        sel_freq = np.where(
            [f[0] for f in container.index_map["freq"]] == self.frequency
        )[0]
        if self.polarization == self.mean_pol_text:
            sel_pol = np.where(
                (container.index_map["pol"] == "XX")
                | (container.index_map["pol"] == "YY")
            )[0]
            rmap = np.squeeze(container.map[sel_beam, sel_pol, sel_freq])
            rmap = np.nanmean(rmap, axis=0)
        else:
            sel_pol = np.where(container.index_map["pol"] == self.polarization)[0]
            rmap = np.squeeze(container.map[sel_beam, sel_pol, sel_freq])

        if self.flag_mask:
            rmap = np.where(self._flags_mask(container.index_map["ra"]), np.nan, rmap)

        if self.weight_mask:
            try:
                rms = np.squeeze(container.rms[sel_pol, sel_freq])
            except IndexError:
                logger.error(
                    f"rms dataset of ringmap file for rev {self.revision} lsd "
                    f"{self.lsd} is missing [{sel_pol}, {sel_freq}] (polarization, "
                    f"frequency). rms has shape {container.rms.shape}"
                )
                self.weight_mask = False
            else:
                rmap = np.where(self._weights_mask(rms), np.nan, rmap)

        # Set flagged data to nan
        rmap = np.where(rmap == 0, np.nan, rmap)

        if self.crosstalk_removal:
            # The mean of an all-nan slice (masked?) is nan. We don't need a warning about that.
            with warnings.catch_warnings():
                warnings.filterwarnings("ignore", r"All-NaN slice encountered")
                rmap -= np.nanmedian(rmap, axis=0)

        if self.template_subtraction:
            try:
                rm_stack = self.data.load_file_from_path(
                    self._stack_path, ccontainers.RingMap
                )
            except DataError as err:
                return panel.pane.Markdown(
                    f"Error: {str(err)}. Please report this problem."
                )

            # The stack file has all polarizations, so we can't reuse sel_pol
            if self.polarization == self.mean_pol_text:
                stack_sel_pol = np.where(
                    (rm_stack.index_map["pol"] == "XX")
                    | (rm_stack.index_map["pol"] == "YY")
                )[0]
            else:
                stack_sel_pol = np.where(
                    rm_stack.index_map["pol"] == self.polarization
                )[0]

            try:
                rm_stack = np.squeeze(rm_stack.map[sel_beam, stack_sel_pol, sel_freq])
            except IndexError as err:
                logger.error(
                    f"map dataset of ringmap stack file "
                    f"is missing [{sel_beam}, {stack_sel_pol}, {sel_freq}] (beam, polarization, "
                    f"frequency). map has shape {rm_stack.map.shape}:\n{err}"
                )
                self.template_subtraction = False
            else:
                if self.polarization == self.mean_pol_text:
                    rm_stack = np.nanmean(rm_stack, axis=0)

                # FIXME: this is a hack. remove when rinmap stack file fixed.
                rmap -= rm_stack.reshape(rm_stack.shape[0], -1, 2).mean(axis=-1)

        if self.transpose:
            rmap = rmap.T
            index_x = index_map_ra
            index_y = index_map_el
            axis_names = [axis_name_ra, axis_name_el]
            xlim, ylim = self.ylim, self.xlim
        else:
            index_x = index_map_el
            index_y = index_map_ra
            axis_names = [axis_name_el, axis_name_ra]
            xlim, ylim = self.xlim, self.ylim

        img = hv.Image(
            (index_x, index_y, rmap),
            datatype=["image", "grid"],
            kdims=axis_names,
        ).opts(
            clim=self.colormap_range,
            logz=self.logarithmic_colorscale,
            cmap=process_cmap("inferno", provider="matplotlib"),
            colorbar=True,
            xlim=xlim,
            ylim=ylim,
        )

        if self.serverside_rendering is not None:
            # set colormap
            cmap_inferno = copy.copy(matplotlib_cm.get_cmap("inferno"))
            cmap_inferno.set_under("black")
            cmap_inferno.set_bad("lightgray")

            # Set z-axis normalization (other possible values are 'eq_hist', 'cbrt').
            if self.logarithmic_colorscale:
                normalization = "log"
            else:
                normalization = "linear"

            # datashade/rasterize the image
            img = self.serverside_rendering(
                img,
                cmap=cmap_inferno,
                precompute=True,
                x_range=xlim,
                y_range=ylim,
                normalization=normalization,
            )

        if self.mark_moon:
            # Put a ring around the location of the moon if it transits on this day
            eph = skyfield_wrapper.ephemeris

            # Start and end times of the CSD
            st = csd_to_unix(self.lsd.lsd)
            et = csd_to_unix(self.lsd.lsd + 1)

            moon_time, moon_dec = chime.transit_times(
                eph["moon"], st, et, return_dec=True
            )

            if len(moon_time):
                lunar_transit = unix_to_csd(moon_time[0])
                lunar_dec = moon_dec[0]
                lunar_ra = (lunar_transit % 1) * 360.0
                lunar_za = np.sin(np.radians(lunar_dec - 49.0))
                if self.transpose:
                    img *= hv.Ellipse(lunar_ra, lunar_za, (5.5, 0.15))
                else:
                    img *= hv.Ellipse(lunar_za, lunar_ra, (0.04, 21))

        if self.mark_day_time:
            # Calculate the sun rise/set times on this sidereal day

            # Start and end times of the CSD
            start_time = csd_to_unix(self.lsd.lsd)
            end_time = csd_to_unix(self.lsd.lsd + 1)

            times, rises = chime.rise_set_times(
                skyfield_wrapper.ephemeris["sun"],
                start_time,
                end_time,
                diameter=-10,
            )
            sun_rise = 0
            sun_set = 0
            for t, r in zip(times, rises):
                if r:
                    sun_rise = (unix_to_csd(t) % 1) * 360
                else:
                    sun_set = (unix_to_csd(t) % 1) * 360

            # Highlight the day time data
            opts = {
                "color": "grey",
                "alpha": 0.5,
                "line_width": 1,
                "line_color": "black",
                "line_dash": "dashed",
            }
            if self.transpose:
                if sun_rise < sun_set:
                    img *= hv.VSpan(sun_rise, sun_set).opts(**opts)
                else:
                    img *= hv.VSpan(self.ylim[0], sun_set).opts(**opts)
                    img *= hv.VSpan(sun_rise, self.ylim[1]).opts(**opts)

            else:
                if sun_rise < sun_set:
                    img *= hv.HSpan(sun_rise, sun_set).opts(**opts)
                else:
                    img *= hv.HSpan(self.ylim[0], sun_set).opts(**opts)
                    img *= hv.HSpan(sun_rise, self.ylim[1]).opts(**opts)

        img.opts(
            # Fix height, but make width responsive
            height=self.height,
            responsive=True,
            shared_axes=True,
            bgcolor="lightgray",
        )

        return panel.Row(img, width_policy="max")
Exemplo n.º 7
0
def main(config_file=None, logging_params=DEFAULT_LOGGING):

    # Load config
    config = DEFAULTS.deepcopy()
    if config_file is not None:
        config.merge(NameSpace(load_yaml_config(config_file)))

    # Setup logging
    log.setup_logging(logging_params)
    logger = log.get_logger(__name__)

    ## Load data for flagging
    # Load fpga restarts
    time_fpga_restart = []
    if config.fpga_restart_file is not None:

        with open(config.fpga_restart_file, 'r') as handler:
            for line in handler:
                time_fpga_restart.append(
                    ephemeris.datetime_to_unix(
                        ephemeris.timestr_to_datetime(line.split('_')[0])))

    time_fpga_restart = np.array(time_fpga_restart)

    # Load housekeeping flag
    if config.housekeeping_file is not None:
        ftemp = TempData.from_acq_h5(config.housekeeping_file,
                                     datasets=["time_flag"])
    else:
        ftemp = None

    # Load jump data
    if config.jump_file is not None:
        with h5py.File(config.jump_file, 'r') as handler:
            jump_time = handler["time"][:]
            jump_size = handler["jump_size"][:]
    else:
        jump_time = None
        jump_size = None

    # Load rain data
    if config.rain_file is not None:
        with h5py.File(config.rain_file, 'r') as handler:
            rain_ranges = handler["time_range_conservative"][:]
    else:
        rain_ranges = []

    # Load data flags
    data_flags = {}
    if config.data_flags:
        finder.connect_database()
        flag_types = finder.DataFlagType.select()
        possible_data_flags = []
        for ft in flag_types:
            possible_data_flags.append(ft.name)
            if ft.name in config.data_flags:
                new_data_flags = finder.DataFlag.select().where(
                    finder.DataFlag.type == ft)
                data_flags[ft.name] = list(new_data_flags)

    # Set desired range of time
    start_time = (ephemeris.datetime_to_unix(
        datetime.datetime(
            *config.start_date)) if config.start_date is not None else None)
    end_time = (ephemeris.datetime_to_unix(datetime.datetime(
        *config.end_date)) if config.end_date is not None else None)

    ## Find gain files
    files = {}
    for src in config.sources:
        files[src] = sorted(
            glob.glob(
                os.path.join(config.directory, src.lower(),
                             "%s_%s_lsd_*.h5" % (
                                 config.prefix,
                                 src.lower(),
                             ))))
    csd = {}
    for src in config.sources:
        csd[src] = np.array(
            [int(os.path.splitext(ff)[0][-4:]) for ff in files[src]])

    for src in config.sources:
        logger.info("%s:  %d files" % (src, len(csd[src])))

    ## Remove files that occur during flag
    csd_flag = {}
    for src in config.sources:

        body = ephemeris.source_dictionary[src]

        csd_flag[src] = np.ones(csd[src].size, dtype=np.bool)

        for ii, cc in enumerate(csd[src][:]):

            ttrans = ephemeris.transit_times(body,
                                             ephemeris.csd_to_unix(cc))[0]

            if (start_time is not None) and (ttrans < start_time):
                csd_flag[src][ii] = False
                continue

            if (end_time is not None) and (ttrans > end_time):
                csd_flag[src][ii] = False
                continue

            # If requested, remove daytime transits
            if not config.include_daytime.get(
                    src, config.include_daytime.default) and daytime_flag(
                        ttrans)[0]:
                logger.info("%s CSD %d:  daytime transit" % (src, cc))
                csd_flag[src][ii] = False
                continue

            # Remove transits during HKP drop out
            if ftemp is not None:
                itemp = np.flatnonzero(
                    (ftemp.time[:] >= (ttrans - config.transit_window))
                    & (ftemp.time[:] <= (ttrans + config.transit_window)))
                tempflg = ftemp['time_flag'][itemp]
                if (tempflg.size == 0) or ((np.sum(tempflg, dtype=np.float32) /
                                            float(tempflg.size)) < 0.50):
                    logger.info("%s CSD %d:  no housekeeping" % (src, cc))
                    csd_flag[src][ii] = False
                    continue

            # Remove transits near jumps
            if jump_time is not None:
                njump = np.sum((jump_size > config.min_jump_size)
                               & (jump_time > (ttrans - config.jump_window))
                               & (jump_time < ttrans))
                if njump > config.max_njump:
                    logger.info("%s CSD %d:  %d jumps before" %
                                (src, cc, njump))
                    csd_flag[src][ii] = False
                    continue

            # Remove transits near rain
            for rng in rain_ranges:
                if (((ttrans - config.transit_window) <= rng[1])
                        and ((ttrans + config.transit_window) >= rng[0])):

                    logger.info("%s CSD %d:  during rain" % (src, cc))
                    csd_flag[src][ii] = False
                    break

            # Remove transits during data flag
            for name, flag_list in data_flags.items():

                if csd_flag[src][ii]:

                    for flg in flag_list:

                        if (((ttrans - config.transit_window) <=
                             flg.finish_time)
                                and ((ttrans + config.transit_window) >=
                                     flg.start_time)):

                            logger.info("%s CSD %d:  %s flag" %
                                        (src, cc, name))
                            csd_flag[src][ii] = False
                            break

    # Print number of files left after flagging
    for src in config.sources:
        logger.info("%s:  %d files (after flagging)" %
                    (src, np.sum(csd_flag[src])))

    ## Construct pair wise differences
    npair = len(config.diff_pair)
    shift = [nd * 24.0 * 3600.0 for nd in config.nday_shift]

    calmap = []
    calpair = []

    for (tsrc, csrc), sh in zip(config.diff_pair, shift):

        body_test = ephemeris.source_dictionary[tsrc]
        body_cal = ephemeris.source_dictionary[csrc]

        for ii, cc in enumerate(csd[tsrc]):

            if csd_flag[tsrc][ii]:

                test_transit = ephemeris.transit_times(
                    body_test, ephemeris.csd_to_unix(cc))[0]
                cal_transit = ephemeris.transit_times(body_cal,
                                                      test_transit + sh)[0]
                cal_csd = int(np.fix(ephemeris.unix_to_csd(cal_transit)))

                ttrans = np.sort([test_transit, cal_transit])

                if cal_csd in csd[csrc]:
                    jj = list(csd[csrc]).index(cal_csd)

                    if csd_flag[csrc][jj] and not np.any(
                        (time_fpga_restart >= ttrans[0])
                            & (time_fpga_restart <= ttrans[1])):
                        calmap.append([ii, jj])
                        calpair.append([tsrc, csrc])

    calmap = np.array(calmap)
    calpair = np.array(calpair)

    ntransit = calmap.shape[0]

    logger.info("%d total transit pairs" % ntransit)
    for ii in range(ntransit):

        t1 = ephemeris.transit_times(
            ephemeris.source_dictionary[calpair[ii, 0]],
            ephemeris.csd_to_unix(csd[calpair[ii, 0]][calmap[ii, 0]]))[0]
        t2 = ephemeris.transit_times(
            ephemeris.source_dictionary[calpair[ii, 1]],
            ephemeris.csd_to_unix(csd[calpair[ii, 1]][calmap[ii, 1]]))[0]

        logger.info("%s (%d) - %s (%d):  %0.1f hr" %
                    (calpair[ii, 0], csd_flag[calpair[ii, 0]][calmap[ii, 0]],
                     calpair[ii, 1], csd_flag[calpair[ii, 1]][calmap[ii, 1]],
                     (t1 - t2) / 3600.0))

    # Determine unique diff pairs
    diff_name = np.array(['%s/%s' % tuple(cp) for cp in calpair])
    uniq_diff, lbl_diff, cnt_diff = np.unique(diff_name,
                                              return_inverse=True,
                                              return_counts=True)
    ndiff = uniq_diff.size

    for ud, udcnt in zip(uniq_diff, cnt_diff):
        logger.info("%s:  %d transit pairs" % (ud, udcnt))

    ## Load gains
    inputmap = tools.get_correlator_inputs(datetime.datetime.utcnow(),
                                           correlator='chime')
    ninput = len(inputmap)
    nfreq = 1024

    # Set up gain arrays
    gain = np.zeros((2, nfreq, ninput, ntransit), dtype=np.complex64)
    weight = np.zeros((2, nfreq, ninput, ntransit), dtype=np.float32)
    input_sort = np.zeros((2, ninput, ntransit), dtype=np.int)

    kcsd = np.zeros((2, ntransit), dtype=np.float32)
    timestamp = np.zeros((2, ntransit), dtype=np.float64)
    is_daytime = np.zeros((2, ntransit), dtype=np.bool)

    for tt in range(ntransit):

        for kk, (src, ind) in enumerate(zip(calpair[tt], calmap[tt])):

            body = ephemeris.source_dictionary[src]
            filename = files[src][ind]

            logger.info("%s:  %s" % (src, filename))

            temp = containers.StaticGainData.from_file(filename)

            freq = temp.freq[:]
            inputs = temp.input[:]

            isort = reorder_inputs(inputmap, inputs)
            inputs = inputs[isort]

            gain[kk, :, :, tt] = temp.gain[:, isort]
            weight[kk, :, :, tt] = temp.weight[:, isort]
            input_sort[kk, :, tt] = isort

            kcsd[kk, tt] = temp.attrs['lsd']
            timestamp[kk, tt] = ephemeris.transit_times(
                body, ephemeris.csd_to_unix(kcsd[kk, tt]))[0]
            is_daytime[kk, tt] = daytime_flag(timestamp[kk, tt])[0]

            if np.any(isort != np.arange(isort.size)):
                logger.info("Input ordering has changed: %s" %
                            ephemeris.unix_to_datetime(
                                timestamp[kk, tt]).strftime("%Y-%m-%d"))

        logger.info("")

    inputs = np.array([(inp.id, inp.input_sn) for inp in inputmap],
                      dtype=[('chan_id', 'u2'), ('correlator_input', 'S32')])

    ## Load input flags
    inpflg = np.ones((2, ninput, ntransit), dtype=np.bool)

    min_flag_time = np.min(timestamp) - 7.0 * 24.0 * 60.0 * 60.0
    max_flag_time = np.max(timestamp) + 7.0 * 24.0 * 60.0 * 60.0

    flaginput_files = sorted(
        glob.glob(
            os.path.join(config.flaginput_dir, "*" + config.flaginput_suffix,
                         "*.h5")))

    if flaginput_files:
        logger.info("Found %d flaginput files." % len(flaginput_files))
        tmp = andata.FlagInputData.from_acq_h5(flaginput_files, datasets=())
        start, stop = [
            int(yy) for yy in np.percentile(
                np.flatnonzero((tmp.time[:] >= min_flag_time)
                               & (tmp.time[:] <= max_flag_time)), [0, 100])
        ]

        cont = andata.FlagInputData.from_acq_h5(flaginput_files,
                                                start=start,
                                                stop=stop,
                                                datasets=['flag'])

        for kk in range(2):
            inpflg[kk, :, :] = cont.resample('flag',
                                             timestamp[kk],
                                             transpose=True)

            logger.info("Flaginput time offsets in minutes (pair %d):" % kk)
            logger.info(
                str(
                    np.fix((cont.time[cont.search_update_time(timestamp[kk])] -
                            timestamp[kk]) / 60.0).astype(np.int)))

    # Sort flags so they are in same order
    for tt in range(ntransit):
        for kk in range(2):
            inpflg[kk, :, tt] = inpflg[kk, input_sort[kk, :, tt], tt]

    # Do not apply input flag to phase reference
    for ii in config.index_phase_ref:
        inpflg[:, ii, :] = True

    ## Flag out gains with high uncertainty and frequencies with large fraction of data flagged
    frac_err = tools.invert_no_zero(np.sqrt(weight) * np.abs(gain))

    flag = np.all((weight > 0.0) & (np.abs(gain) > 0.0) &
                  (frac_err < config.max_uncertainty),
                  axis=0)

    freq_flag = ((np.sum(flag, axis=(1, 2), dtype=np.float32) /
                  float(np.prod(flag.shape[1:]))) > config.freq_threshold)

    if config.apply_rfi_mask:
        freq_flag &= np.logical_not(rfi.frequency_mask(freq))

    flag = flag & freq_flag[:, np.newaxis, np.newaxis]

    good_freq = np.flatnonzero(freq_flag)

    logger.info("Number good frequencies %d" % good_freq.size)

    ## Generate flags with more conservative cuts on frequency
    c_flag = flag & np.all(frac_err < config.conservative.max_uncertainty,
                           axis=0)

    c_freq_flag = ((np.sum(c_flag, axis=(1, 2), dtype=np.float32) /
                    float(np.prod(c_flag.shape[1:]))) >
                   config.conservative.freq_threshold)

    if config.conservative.apply_rfi_mask:
        c_freq_flag &= np.logical_not(rfi.frequency_mask(freq))

    c_flag = c_flag & c_freq_flag[:, np.newaxis, np.newaxis]

    c_good_freq = np.flatnonzero(c_freq_flag)

    logger.info("Number good frequencies (conservative thresholds) %d" %
                c_good_freq.size)

    ## Apply input flags
    flag &= np.all(inpflg[:, np.newaxis, :, :], axis=0)

    ## Update flags based on beam flag
    if config.beam_flag_file is not None:

        dbeam = andata.BaseData.from_acq_h5(config.beam_flag_file)

        db_csd = np.floor(ephemeris.unix_to_csd(dbeam.index_map['time'][:]))

        for ii, name in enumerate(config.beam_flag_datasets):
            logger.info("Applying %s beam flag." % name)
            if not ii:
                db_flag = dbeam.flags[name][:]
            else:
                db_flag &= dbeam.flags[name][:]

        cnt = 0
        for ii, dbc in enumerate(db_csd):

            this_csd = np.flatnonzero(np.any(kcsd == dbc, axis=0))

            if this_csd.size > 0:

                logger.info("Beam flag for %d matches %s." %
                            (dbc, str(kcsd[:, this_csd])))

                flag[:, :, this_csd] &= db_flag[np.newaxis, :, ii, np.newaxis]

                cnt += 1

        logger.info("Applied %0.1f percent of the beam flags" %
                    (100.0 * cnt / float(db_csd.size), ))

    ## Flag inputs with large amount of missing data
    input_frac_flagged = (
        np.sum(flag[good_freq, :, :], axis=(0, 2), dtype=np.float32) /
        float(good_freq.size * ntransit))
    input_flag = input_frac_flagged > config.input_threshold

    for ii in config.index_phase_ref:
        logger.info("Phase reference %d has %0.3f fraction of data flagged." %
                    (ii, input_frac_flagged[ii]))
        input_flag[ii] = True

    good_input = np.flatnonzero(input_flag)

    flag = flag & input_flag[np.newaxis, :, np.newaxis]

    logger.info("Number good inputs %d" % good_input.size)

    ## Calibrate
    gaincal = gain[0] * tools.invert_no_zero(gain[1])

    frac_err_cal = np.sqrt(frac_err[0]**2 + frac_err[1]**2)

    count = np.sum(flag, axis=-1, dtype=np.int)
    stat_flag = count > config.min_num_transit

    ## Calculate phase
    amp = np.abs(gaincal)
    phi = np.angle(gaincal)

    ## Calculate polarisation groups
    pol_dict = {'E': 'X', 'S': 'Y'}
    cyl_dict = {2: 'A', 3: 'B', 4: 'C', 5: 'D'}

    if config.group_by_cyl:
        group_id = [
            (inp.pol,
             inp.cyl) if tools.is_chime(inp) and (ii in good_input) else None
            for ii, inp in enumerate(inputmap)
        ]
    else:
        group_id = [
            inp.pol if tools.is_chime(inp) and (ii in good_input) else None
            for ii, inp in enumerate(inputmap)
        ]

    ugroup_id = sorted([uidd for uidd in set(group_id) if uidd is not None])
    ngroup = len(ugroup_id)

    group_list_noref = [
        np.array([
            gg for gg, gid in enumerate(group_id)
            if (gid == ugid) and gg not in config.index_phase_ref
        ]) for ugid in ugroup_id
    ]

    group_list = [
        np.array([gg for gg, gid in enumerate(group_id) if gid == ugid])
        for ugid in ugroup_id
    ]

    if config.group_by_cyl:
        group_str = [
            "%s-%s" % (pol_dict[pol], cyl_dict[cyl]) for pol, cyl in ugroup_id
        ]
    else:
        group_str = [pol_dict[pol] for pol in ugroup_id]

    index_phase_ref = []
    for gstr, igroup in zip(group_str, group_list):
        candidate = [ii for ii in config.index_phase_ref if ii in igroup]
        if len(candidate) != 1:
            index_phase_ref.append(None)
        else:
            index_phase_ref.append(candidate[0])

    logger.info(
        "Phase reference: %s" %
        ', '.join(['%s = %s' % tpl
                   for tpl in zip(group_str, index_phase_ref)]))

    ## Apply thermal correction to amplitude
    if config.amp_thermal.enabled:

        logger.info("Applying thermal correction.")

        # Load the temperatures
        tdata = TempData.from_acq_h5(config.amp_thermal.filename)

        index = tdata.search_sensors(config.amp_thermal.sensor)[0]

        temp = tdata.datasets[config.amp_thermal.field][index]
        temp_func = scipy.interpolate.interp1d(tdata.time, temp,
                                               **config.amp_thermal.interp)

        itemp = temp_func(timestamp)
        dtemp = itemp[0] - itemp[1]

        flag_func = scipy.interpolate.interp1d(
            tdata.time, tdata.datasets['flag'][index].astype(np.float32),
            **config.amp_thermal.interp)

        dtemp_flag = np.all(flag_func(timestamp) == 1.0, axis=0)

        flag &= dtemp_flag[np.newaxis, np.newaxis, :]

        for gstr, igroup in zip(group_str, group_list):
            pstr = gstr[0]
            thermal_coeff = np.polyval(config.amp_thermal.coeff[pstr], freq)
            gthermal = 1.0 + thermal_coeff[:, np.newaxis, np.newaxis] * dtemp[
                np.newaxis, np.newaxis, :]

            amp[:, igroup, :] *= tools.invert_no_zero(gthermal)

    ## Compute common mode
    if config.subtract_common_mode_before:
        logger.info("Calculating common mode amplitude and phase.")
        cmn_amp, flag_cmn_amp = compute_common_mode(amp,
                                                    flag,
                                                    group_list_noref,
                                                    median=False)
        cmn_phi, flag_cmn_phi = compute_common_mode(phi,
                                                    flag,
                                                    group_list_noref,
                                                    median=False)

        # Subtract common mode (from phase only)
        logger.info("Subtracting common mode phase.")
        group_flag = np.zeros((ngroup, ninput), dtype=np.bool)
        for gg, igroup in enumerate(group_list):
            group_flag[gg, igroup] = True
            phi[:,
                igroup, :] = phi[:, igroup, :] - cmn_phi[:, gg, np.newaxis, :]

            for iref in index_phase_ref:
                if (iref is not None) and (iref in igroup):
                    flag[:, iref, :] = flag_cmn_phi[:, gg, :]

    ## If requested, determine and subtract a delay template
    if config.fit_delay_before:
        logger.info("Fitting delay template.")
        omega = timing.FREQ_TO_OMEGA * freq

        tau, tau_flag, _ = construct_delay_template(
            omega,
            phi,
            c_flag & flag,
            min_num_freq_for_delay_fit=config.min_num_freq_for_delay_fit)

        # Compute residuals
        logger.info("Subtracting delay template.")
        phi = phi - tau[np.newaxis, :, :] * omega[:, np.newaxis, np.newaxis]

    ## Normalize by median over time
    logger.info("Calculating median amplitude and phase.")
    med_amp = np.zeros((nfreq, ninput, ndiff), dtype=amp.dtype)
    med_phi = np.zeros((nfreq, ninput, ndiff), dtype=phi.dtype)

    count_by_diff = np.zeros((nfreq, ninput, ndiff), dtype=np.int)
    stat_flag_by_diff = np.zeros((nfreq, ninput, ndiff), dtype=np.bool)

    def weighted_mean(yy, ww, axis=-1):
        return np.sum(ww * yy, axis=axis) * tools.invert_no_zero(
            np.sum(ww, axis=axis))

    for dd in range(ndiff):

        this_diff = np.flatnonzero(lbl_diff == dd)

        this_flag = flag[:, :, this_diff]

        this_amp = amp[:, :, this_diff]
        this_amp_err = this_amp * frac_err_cal[:, :,
                                               this_diff] * this_flag.astype(
                                                   np.float32)

        this_phi = phi[:, :, this_diff]
        this_phi_err = frac_err_cal[:, :, this_diff] * this_flag.astype(
            np.float32)

        count_by_diff[:, :, dd] = np.sum(this_flag, axis=-1, dtype=np.int)
        stat_flag_by_diff[:, :,
                          dd] = count_by_diff[:, :,
                                              dd] > config.min_num_transit

        if config.weighted_mean == 2:
            logger.info("Calculating inverse variance weighted mean.")
            med_amp[:, :,
                    dd] = weighted_mean(this_amp,
                                        tools.invert_no_zero(this_amp_err**2),
                                        axis=-1)
            med_phi[:, :,
                    dd] = weighted_mean(this_phi,
                                        tools.invert_no_zero(this_phi_err**2),
                                        axis=-1)

        elif config.weighted_mean == 1:
            logger.info("Calculating uniform weighted mean.")
            med_amp[:, :, dd] = weighted_mean(this_amp,
                                              this_flag.astype(np.float32),
                                              axis=-1)
            med_phi[:, :, dd] = weighted_mean(this_phi,
                                              this_flag.astype(np.float32),
                                              axis=-1)

        else:
            logger.info("Calculating median value.")
            for ff in range(nfreq):
                for ii in range(ninput):
                    if np.any(this_flag[ff, ii, :]):
                        med_amp[ff, ii, dd] = wq.median(
                            this_amp[ff, ii, :],
                            this_flag[ff, ii, :].astype(np.float32))
                        med_phi[ff, ii, dd] = wq.median(
                            this_phi[ff, ii, :],
                            this_flag[ff, ii, :].astype(np.float32))

    damp = np.zeros_like(amp)
    dphi = np.zeros_like(phi)
    for dd in range(ndiff):
        this_diff = np.flatnonzero(lbl_diff == dd)
        damp[:, :, this_diff] = amp[:, :, this_diff] * tools.invert_no_zero(
            med_amp[:, :, dd, np.newaxis]) - 1.0
        dphi[:, :,
             this_diff] = phi[:, :, this_diff] - med_phi[:, :, dd, np.newaxis]

    # Compute common mode
    if not config.subtract_common_mode_before:
        logger.info("Calculating common mode amplitude and phase.")
        cmn_amp, flag_cmn_amp = compute_common_mode(damp,
                                                    flag,
                                                    group_list_noref,
                                                    median=True)
        cmn_phi, flag_cmn_phi = compute_common_mode(dphi,
                                                    flag,
                                                    group_list_noref,
                                                    median=True)

        # Subtract common mode (from phase only)
        logger.info("Subtracting common mode phase.")
        group_flag = np.zeros((ngroup, ninput), dtype=np.bool)
        for gg, igroup in enumerate(group_list):
            group_flag[gg, igroup] = True
            dphi[:, igroup, :] = dphi[:, igroup, :] - cmn_phi[:, gg,
                                                              np.newaxis, :]

            for iref in index_phase_ref:
                if (iref is not None) and (iref in igroup):
                    flag[:, iref, :] = flag_cmn_phi[:, gg, :]

    ## Compute RMS
    logger.info("Calculating RMS of amplitude and phase.")
    mad_amp = np.zeros((nfreq, ninput), dtype=amp.dtype)
    std_amp = np.zeros((nfreq, ninput), dtype=amp.dtype)

    mad_phi = np.zeros((nfreq, ninput), dtype=phi.dtype)
    std_phi = np.zeros((nfreq, ninput), dtype=phi.dtype)

    mad_amp_by_diff = np.zeros((nfreq, ninput, ndiff), dtype=amp.dtype)
    std_amp_by_diff = np.zeros((nfreq, ninput, ndiff), dtype=amp.dtype)

    mad_phi_by_diff = np.zeros((nfreq, ninput, ndiff), dtype=phi.dtype)
    std_phi_by_diff = np.zeros((nfreq, ninput, ndiff), dtype=phi.dtype)

    for ff in range(nfreq):
        for ii in range(ninput):
            this_flag = flag[ff, ii, :]
            if np.any(this_flag):
                std_amp[ff, ii] = np.std(damp[ff, ii, this_flag])
                std_phi[ff, ii] = np.std(dphi[ff, ii, this_flag])

                mad_amp[ff, ii] = 1.48625 * wq.median(
                    np.abs(damp[ff, ii, :]), this_flag.astype(np.float32))
                mad_phi[ff, ii] = 1.48625 * wq.median(
                    np.abs(dphi[ff, ii, :]), this_flag.astype(np.float32))

                for dd in range(ndiff):
                    this_diff = this_flag & (lbl_diff == dd)
                    if np.any(this_diff):

                        std_amp_by_diff[ff, ii, dd] = np.std(damp[ff, ii,
                                                                  this_diff])
                        std_phi_by_diff[ff, ii, dd] = np.std(dphi[ff, ii,
                                                                  this_diff])

                        mad_amp_by_diff[ff, ii, dd] = 1.48625 * wq.median(
                            np.abs(damp[ff, ii, :]),
                            this_diff.astype(np.float32))
                        mad_phi_by_diff[ff, ii, dd] = 1.48625 * wq.median(
                            np.abs(dphi[ff, ii, :]),
                            this_diff.astype(np.float32))

    ## Construct delay template
    if not config.fit_delay_before:
        logger.info("Fitting delay template.")
        omega = timing.FREQ_TO_OMEGA * freq

        tau, tau_flag, _ = construct_delay_template(
            omega,
            dphi,
            c_flag & flag,
            min_num_freq_for_delay_fit=config.min_num_freq_for_delay_fit)

        # Compute residuals
        logger.info("Subtracting delay template from phase.")
        resid = (dphi - tau[np.newaxis, :, :] *
                 omega[:, np.newaxis, np.newaxis]) * flag.astype(np.float32)

    else:
        resid = dphi

    tau_count = np.sum(tau_flag, axis=-1, dtype=np.int)
    tau_stat_flag = tau_count > config.min_num_transit

    tau_count_by_diff = np.zeros((ninput, ndiff), dtype=np.int)
    tau_stat_flag_by_diff = np.zeros((ninput, ndiff), dtype=np.bool)
    for dd in range(ndiff):
        this_diff = np.flatnonzero(lbl_diff == dd)
        tau_count_by_diff[:, dd] = np.sum(tau_flag[:, this_diff],
                                          axis=-1,
                                          dtype=np.int)
        tau_stat_flag_by_diff[:,
                              dd] = tau_count_by_diff[:,
                                                      dd] > config.min_num_transit

    ## Calculate statistics of residuals
    std_resid = np.zeros((nfreq, ninput), dtype=phi.dtype)
    mad_resid = np.zeros((nfreq, ninput), dtype=phi.dtype)

    std_resid_by_diff = np.zeros((nfreq, ninput, ndiff), dtype=phi.dtype)
    mad_resid_by_diff = np.zeros((nfreq, ninput, ndiff), dtype=phi.dtype)

    for ff in range(nfreq):
        for ii in range(ninput):
            this_flag = flag[ff, ii, :]
            if np.any(this_flag):
                std_resid[ff, ii] = np.std(resid[ff, ii, this_flag])
                mad_resid[ff, ii] = 1.48625 * wq.median(
                    np.abs(resid[ff, ii, :]), this_flag.astype(np.float32))

                for dd in range(ndiff):
                    this_diff = this_flag & (lbl_diff == dd)
                    if np.any(this_diff):
                        std_resid_by_diff[ff, ii,
                                          dd] = np.std(resid[ff, ii,
                                                             this_diff])
                        mad_resid_by_diff[ff, ii, dd] = 1.48625 * wq.median(
                            np.abs(resid[ff, ii, :]),
                            this_diff.astype(np.float32))

    ## Calculate statistics of delay template
    mad_tau = np.zeros((ninput, ), dtype=phi.dtype)
    std_tau = np.zeros((ninput, ), dtype=phi.dtype)

    mad_tau_by_diff = np.zeros((ninput, ndiff), dtype=phi.dtype)
    std_tau_by_diff = np.zeros((ninput, ndiff), dtype=phi.dtype)

    for ii in range(ninput):
        this_flag = tau_flag[ii]
        if np.any(this_flag):
            std_tau[ii] = np.std(tau[ii, this_flag])
            mad_tau[ii] = 1.48625 * wq.median(np.abs(tau[ii]),
                                              this_flag.astype(np.float32))

            for dd in range(ndiff):
                this_diff = this_flag & (lbl_diff == dd)
                if np.any(this_diff):
                    std_tau_by_diff[ii, dd] = np.std(tau[ii, this_diff])
                    mad_tau_by_diff[ii, dd] = 1.48625 * wq.median(
                        np.abs(tau[ii]), this_diff.astype(np.float32))

    ## Define output
    res = {
        "timestamp": {
            "data": timestamp,
            "axis": ["div", "time"]
        },
        "is_daytime": {
            "data": is_daytime,
            "axis": ["div", "time"]
        },
        "csd": {
            "data": kcsd,
            "axis": ["div", "time"]
        },
        "pair_map": {
            "data": lbl_diff,
            "axis": ["time"]
        },
        "pair_count": {
            "data": cnt_diff,
            "axis": ["pair"]
        },
        "gain": {
            "data": gaincal,
            "axis": ["freq", "input", "time"]
        },
        "frac_err": {
            "data": frac_err_cal,
            "axis": ["freq", "input", "time"]
        },
        "flags/gain": {
            "data": flag,
            "axis": ["freq", "input", "time"],
            "flag": True
        },
        "flags/gain_conservative": {
            "data": c_flag,
            "axis": ["freq", "input", "time"],
            "flag": True
        },
        "flags/count": {
            "data": count,
            "axis": ["freq", "input"],
            "flag": True
        },
        "flags/stat": {
            "data": stat_flag,
            "axis": ["freq", "input"],
            "flag": True
        },
        "flags/count_by_pair": {
            "data": count_by_diff,
            "axis": ["freq", "input", "pair"],
            "flag": True
        },
        "flags/stat_by_pair": {
            "data": stat_flag_by_diff,
            "axis": ["freq", "input", "pair"],
            "flag": True
        },
        "med_amp": {
            "data": med_amp,
            "axis": ["freq", "input", "pair"]
        },
        "med_phi": {
            "data": med_phi,
            "axis": ["freq", "input", "pair"]
        },
        "flags/group_flag": {
            "data": group_flag,
            "axis": ["group", "input"],
            "flag": True
        },
        "cmn_amp": {
            "data": cmn_amp,
            "axis": ["freq", "group", "time"]
        },
        "cmn_phi": {
            "data": cmn_phi,
            "axis": ["freq", "group", "time"]
        },
        "amp": {
            "data": damp,
            "axis": ["freq", "input", "time"]
        },
        "phi": {
            "data": dphi,
            "axis": ["freq", "input", "time"]
        },
        "std_amp": {
            "data": std_amp,
            "axis": ["freq", "input"]
        },
        "std_amp_by_pair": {
            "data": std_amp_by_diff,
            "axis": ["freq", "input", "pair"]
        },
        "mad_amp": {
            "data": mad_amp,
            "axis": ["freq", "input"]
        },
        "mad_amp_by_pair": {
            "data": mad_amp_by_diff,
            "axis": ["freq", "input", "pair"]
        },
        "std_phi": {
            "data": std_phi,
            "axis": ["freq", "input"]
        },
        "std_phi_by_pair": {
            "data": std_phi_by_diff,
            "axis": ["freq", "input", "pair"]
        },
        "mad_phi": {
            "data": mad_phi,
            "axis": ["freq", "input"]
        },
        "mad_phi_by_pair": {
            "data": mad_phi_by_diff,
            "axis": ["freq", "input", "pair"]
        },
        "tau": {
            "data": tau,
            "axis": ["input", "time"]
        },
        "flags/tau": {
            "data": tau_flag,
            "axis": ["input", "time"],
            "flag": True
        },
        "flags/tau_count": {
            "data": tau_count,
            "axis": ["input"],
            "flag": True
        },
        "flags/tau_stat": {
            "data": tau_stat_flag,
            "axis": ["input"],
            "flag": True
        },
        "flags/tau_count_by_pair": {
            "data": tau_count_by_diff,
            "axis": ["input", "pair"],
            "flag": True
        },
        "flags/tau_stat_by_pair": {
            "data": tau_stat_flag_by_diff,
            "axis": ["input", "pair"],
            "flag": True
        },
        "std_tau": {
            "data": std_tau,
            "axis": ["input"]
        },
        "std_tau_by_pair": {
            "data": std_tau_by_diff,
            "axis": ["input", "pair"]
        },
        "mad_tau": {
            "data": mad_tau,
            "axis": ["input"]
        },
        "mad_tau_by_pair": {
            "data": mad_tau_by_diff,
            "axis": ["input", "pair"]
        },
        "resid_phi": {
            "data": resid,
            "axis": ["freq", "input", "time"]
        },
        "std_resid_phi": {
            "data": std_resid,
            "axis": ["freq", "input"]
        },
        "std_resid_phi_by_pair": {
            "data": std_resid_by_diff,
            "axis": ["freq", "input", "pair"]
        },
        "mad_resid_phi": {
            "data": mad_resid,
            "axis": ["freq", "input"]
        },
        "mad_resid_phi_by_pair": {
            "data": mad_resid_by_diff,
            "axis": ["freq", "input", "pair"]
        },
    }

    ## Create the output container
    logger.info("Creating StabilityData container.")
    data = StabilityData()

    data.create_index_map(
        "div", np.array(["numerator", "denominator"], dtype=np.string_))
    data.create_index_map("pair", np.array(uniq_diff, dtype=np.string_))
    data.create_index_map("group", np.array(group_str, dtype=np.string_))

    data.create_index_map("freq", freq)
    data.create_index_map("input", inputs)
    data.create_index_map("time", timestamp[0, :])

    logger.info("Writing datsets to container.")
    for name, dct in res.iteritems():
        is_flag = dct.get('flag', False)
        if is_flag:
            dset = data.create_flag(name.split('/')[-1], data=dct['data'])
        else:
            dset = data.create_dataset(name, data=dct['data'])

        dset.attrs['axis'] = np.array(dct['axis'], dtype=np.string_)

    data.attrs['phase_ref'] = np.array(
        [iref for iref in index_phase_ref if iref is not None])

    # Determine the output filename and save results
    start_time, end_time = ephemeris.unix_to_datetime(
        np.percentile(timestamp, [0, 100]))
    tfmt = "%Y%m%d"
    night_str = 'night_' if not np.any(is_daytime) else ''
    output_file = os.path.join(
        config.output_dir, "%s_%s_%sraw_stability_data.h5" %
        (start_time.strftime(tfmt), end_time.strftime(tfmt), night_str))

    logger.info("Saving results to %s." % output_file)
    data.save(output_file)
Exemplo n.º 8
0
    def process(self, sstream, suntrans, inputmap):
        """Clean the sun.

        Parameters
        ----------
        sstream : containers.SiderealStream
            Sidereal stream.
        suntrans : containers.SolarTransit
            Response to the sun.
        inputmap : list of :class:`CorrInput`
            A list describing the inputs as they are in the file.

        Returns
        -------
        mstream : containers.SiderealStream
            Sidereal stream with sun removed
        """

        sstream.redistribute("freq")
        suntrans.redistribute("freq")

        # Determine time mapping
        if hasattr(sstream, "time"):
            stime = sstream.time[:]
        else:
            ra = sstream.index_map["ra"][:]
            csd = (sstream.attrs["lsd"]
                   if "lsd" in sstream.attrs else sstream.attrs["csd"])
            stime = ephemeris.csd_to_unix(csd + ra / 360.0)

        # Extract gain array
        gtime = suntrans.time[:]
        gain = suntrans.response[:].view(np.ndarray)

        ninput = gain.shape[1]

        # Determine product map
        prod_map = sstream.index_map["prod"][:]
        nprod = prod_map.size

        if nprod != (ninput * (ninput + 1) // 2):
            raise Exception(
                "Number of inputs does not match the number of products.")

        feed_list = [(inputmap[ii], inputmap[jj]) for ii, jj in prod_map]

        # Determine polarisation for each visibility
        same_pol = np.zeros(nprod, dtype=np.bool)
        for pp, (ii, jj) in enumerate(feed_list):
            if tools.is_chime(ii) and tools.is_chime(jj):
                same_pol[pp] = tools.is_chime_y(ii) == tools.is_chime_y(jj)

        # Match ra
        match = np.array([np.argmin(np.abs(gt - stime)) for gt in gtime])

        # Loop over frequencies and products
        for lfi, fi in sstream.vis[:].enumerate(0):

            for pp in range(nprod):

                if same_pol[pp]:

                    ii, jj = prod_map[pp]

                    # Fetch the gains
                    gi = gain[lfi, ii, :]
                    gj = gain[lfi, jj, :].conj()

                    # Subtract the gains
                    sstream.vis[fi, pp, match] -= gi * gj

        # Return the clean sidereal stream
        return sstream
Exemplo n.º 9
0
 def end_time(self):
     return csd_to_unix(self._lsd + 1)
Exemplo n.º 10
0
 def from_lsd(cls, lsd: int):
     unix = csd_to_unix(lsd)
     lsd = int(lsd)
     date = unix_to_datetime(unix).date()
     day = cls(lsd, date)
     return day
Exemplo n.º 11
0
    def _available_files(self, start_csd, end_csd):
        """
        Return chimestack files available in cedar_online between start_csd and end_csd, if all of the files for that period are available online.

        Return an empty list if files between start_csd and end_csd are only partially available online.

        Total file count is verified by checking files that exist everywhere.

        Parameters
        ----------
        start_csd : int
            Start date in sidereal day format
        end_csd : int
            End date in sidereal day format

        Returns
        -------
        list
            List contains the chimestack files available in the timespan, if all of them are available online

        """

        # Connect to databases
        db.connect()

        # Get timestamps in unix format
        # Needed for queries
        start_time = ephemeris.csd_to_unix(start_csd)
        end_time = ephemeris.csd_to_unix(end_csd)

        # We will want to know which files are in chime_online and nearline on cedar
        online_node = di.StorageNode.get(name="cedar_online", active=True)
        chimestack_inst = di.ArchiveInst.get(name="chimestack")

        # TODO if the time range is so small that it’s completely contained within a single file, nothing will be returned
        # have to special-case it by looking for files which start before the start time and end after the end time).

        archive_files = (di.ArchiveFileCopy.select(
            di.CorrFileInfo.start_time,
            di.CorrFileInfo.finish_time,
        ).join(di.ArchiveFile).join(di.ArchiveAcq).switch(di.ArchiveFile).join(
            di.CorrFileInfo))

        # chimestack files available online which include between start and end_time

        files_that_exist = archive_files.where(
            di.ArchiveAcq.inst ==
            chimestack_inst,  # specifically looking for chimestack files
            di.CorrFileInfo.start_time <
            end_time,  # which contain data that includes start time and end time
            di.CorrFileInfo.finish_time >= start_time,
            di.ArchiveFileCopy.has_file == "Y",
        )

        files_online = files_that_exist.where(
            di.ArchiveFileCopy.node == online_node,  # that are online
        )

        filenames_online = sorted([t for t in files_online.tuples()])

        # files_that_exist might contain the same file multiple files
        # if it exists in multiple locations (nearline, online, gossec, etc)
        # we only want to include it once
        filenames_that_exist = sorted(
            list(set(t for t in files_that_exist.tuples())))

        return filenames_online, filenames_that_exist
Exemplo n.º 12
0
    def process(self, tstream):
        """Apply the timing correction to the input timestream.

        Parameters
        ----------
        tstream : andata.CorrData, containers.TimeStream, or containers.SiderealStream
            Apply the timing correction to the visibilities stored in this container.

        Returns
        -------
        tstream_corr : same as `tstream`
            Timestream with corrected visibilities.
        """
        # Determine times
        if "time" in tstream.index_map:
            timestamp = tstream.time
        else:
            csd = (tstream.attrs["lsd"]
                   if "lsd" in tstream.attrs else tstream.attrs["csd"])
            timestamp = ephemeris.csd_to_unix(csd + tstream.ra / 360.0)

        # Extract local frequencies
        tstream.redistribute("freq")

        nfreq = tstream.vis.local_shape[0]
        sfreq = tstream.vis.local_offset[0]
        efreq = sfreq + nfreq

        freq = tstream.freq[sfreq:efreq]

        # If requested, extract the input flags
        input_flags = tstream.input_flags[:] if self.use_input_flags else None

        # Check needs_timing_correction flags time ranges to see if input timestream
        # falls within the interval
        needs_timing_correction = False
        for flag in self.flags:
            if (timestamp[0] >= flag["start_time"]
                    and timestamp[0] <= flag["finish_time"]):
                if timestamp[-1] >= flag["finish_time"]:
                    raise PipelineRuntimeError(
                        f"Data covering {timestamp[0]} to {timestamp[-1]} partially overlaps "
                        f"needs_timing_correction DataFlag covering {ephemeris.unix_to_datetime(flag['start_time']).strftime('%Y%m%dT%H%M%SZ')} "
                        f"to {ephemeris.unix_to_datetime(flag['finish_time']).strftime('%Y%m%dT%H%M%SZ')}."
                    )
                else:
                    self.log.info(
                        f"Data covering {timestamp[0]} to {timestamp[-1]} flagged by "
                        f"needs_timing_correction DataFlag covering {ephemeris.unix_to_datetime(flag['start_time']).strftime('%Y%m%dT%H%M%SZ')} "
                        f"to {ephemeris.unix_to_datetime(flag['finish_time']).strftime('%Y%m%dT%H%M%SZ')}. Timing correction will be applied."
                    )
                    needs_timing_correction = True
                    break

        if not needs_timing_correction:
            self.log.info(
                f"Data in span {ephemeris.unix_to_datetime(timestamp[0]).strftime('%Y%m%dT%H%M%SZ')} to {ephemeris.unix_to_datetime(timestamp[-1]).strftime('%Y%m%dT%H%M%SZ')} does not need timing correction"
            )
            return tstream

        # Find the right timing correction
        for tcorr in self.tcorr:
            if timestamp[0] >= tcorr.time[0] and timestamp[-1] <= tcorr.time[
                    -1]:
                break
        else:
            msg = (
                "Could not find timing correction file covering "
                "range of timestream data (%s to %s)." % tuple(
                    ephemeris.unix_to_datetime([timestamp[0], timestamp[-1]])))

            if self.pass_if_missing:
                self.log.warning(msg + " Doing nothing.")
                return tstream

            raise RuntimeError(msg)

        self.log.info("Using correction file %s" % tcorr.attrs["tag"])

        # If requested, reference the timing correct with respect to source transit time
        if self.refer_to_transit:
            # First check for transit_time attribute in file
            ttrans = tstream.attrs.get("transit_time", None)
            if ttrans is None:
                source = tstream.attrs["source_name"]
                ttrans = ephemeris.transit_times(
                    ephemeris.source_dictionary[source],
                    tstream.time[0],
                    tstream.time[-1],
                )
                if ttrans.size != 1:
                    raise RuntimeError(
                        "Found %d transits of %s in timestream.  "
                        "Require single transit." % (ttrans.size, source))
                else:
                    ttrans = ttrans[0]

            self.log.info(
                "Referencing timing correction to %s (RA=%0.1f deg)." % (
                    ephemeris.unix_to_datetime(ttrans).strftime(
                        "%Y%m%dT%H%M%SZ"),
                    ephemeris.lsa(ttrans),
                ))

            tcorr.set_global_reference_time(ttrans,
                                            window=self.transit_window,
                                            interpolate=True,
                                            interp="linear")

        # Apply the timing correction
        tcorr.apply_timing_correction(tstream,
                                      time=timestamp,
                                      freq=freq,
                                      input_flags=input_flags,
                                      copy=False)

        return tstream
Exemplo n.º 13
0
    def view(self):
        if self.lsd is None:
            return panel.pane.Markdown("No data selected.")
        try:
            sens_container = self.data.load_file(self.revision, self.lsd,
                                                 "sensitivity")
        except DataError as err:
            return panel.pane.Markdown(
                f"Error: {str(err)}. Please report this problem.")

        # Index map for ra (x-axis)
        sens_csd = csd(sens_container.time)
        index_map_ra = (sens_csd - self.lsd.lsd) * 360
        axis_name_ra = "RA [degrees]"

        # Index map for frequency (y-axis)
        index_map_f = np.linspace(800.0, 400.0, 1024, endpoint=False)
        axis_name_f = "Frequency [MHz]"

        # Apply data selections
        if self.polarization == self.mean_pol_text:
            sel_pol = np.where((sens_container.index_map["pol"] == "XX")
                               | (sens_container.index_map["pol"] == "YY"))[0]
            sens = np.squeeze(sens_container.measured[:, sel_pol])
            sens = np.squeeze(np.nanmean(sens, axis=1))
        else:
            sel_pol = np.where(
                sens_container.index_map["pol"] == self.polarization)[0]
            sens = np.squeeze(sens_container.measured[:, sel_pol])

        if self.flag_mask:
            sens = np.where(self._flags_mask(index_map_ra).T, np.nan, sens)

        # Set flagged data to nan
        sens = np.where(sens == 0, np.nan, sens)

        if self.mask_rfi:
            try:
                rfi_container = self.data.load_file(self.revision, self.lsd,
                                                    "rfi")
            except DataError as err:
                return panel.pane.Markdown(
                    f"Error: {str(err)}. Please report this problem.")
            rfi = np.squeeze(rfi_container.mask[:])

            # calculate percentage masked to print later
            rfi_percentage = round(np.count_nonzero(rfi) / rfi.size * 100)

            sens *= np.where(rfi, np.nan, 1)

        if self.divide_by_estimate:
            estimate = np.squeeze(sens_container.radiometer[:, sel_pol])
            if self.polarization == self.mean_pol_text:
                estimate = np.squeeze(np.nanmean(estimate, axis=1))
            estimate = np.where(estimate == 0, np.nan, estimate)
            sens = sens / estimate

        if self.transpose:
            sens = sens.T
            index_x = index_map_f
            index_y = index_map_ra
            axis_names = [axis_name_f, axis_name_ra]
            xlim, ylim = self.ylim, self.xlim
        else:
            index_x = index_map_ra
            index_y = index_map_f
            axis_names = [axis_name_ra, axis_name_f]
            xlim, ylim = self.xlim, self.ylim

        image_opts = {
            "clim": self.colormap_range,
            "logz": self.logarithmic_colorscale,
            "cmap": process_cmap("viridis", provider="matplotlib"),
            "colorbar": True,
            "xticks": [0, 60, 120, 180, 240, 300, 360],
        }
        if self.mask_rfi:
            image_opts["title"] = f"RFI mask: {rfi_percentage}%"

        overlay_opts = {
            "xlim": xlim,
            "ylim": ylim,
        }

        # Fill in missing data
        img = hv_image_with_gaps(index_x,
                                 index_y,
                                 sens,
                                 opts=image_opts,
                                 kdims=axis_names).opts(**overlay_opts)

        if self.serverside_rendering is not None:
            # set colormap
            cmap_inferno = copy.copy(matplotlib_cm.get_cmap("viridis"))

            # Set z-axis normalization (other possible values are 'eq_hist', 'cbrt').
            if self.logarithmic_colorscale:
                normalization = "log"
            else:
                normalization = "linear"

            # datashade/rasterize the image
            img = self.serverside_rendering(
                img,
                cmap=cmap_inferno,
                precompute=True,
                x_range=xlim,
                y_range=ylim,
                normalization=normalization,
                # TODO: set xticks like above
            )

        if self.mark_day_time:
            # Calculate the sun rise/set times on this sidereal day

            # Start and end times of the CSD
            start_time = csd_to_unix(self.lsd.lsd)
            end_time = csd_to_unix(self.lsd.lsd + 1)

            times, rises = chime.rise_set_times(
                skyfield_wrapper.ephemeris["sun"],
                start_time,
                end_time,
                diameter=-10,
            )
            sun_rise = 0
            sun_set = 0
            for t, r in zip(times, rises):
                if r:
                    sun_rise = (unix_to_csd(t) % 1) * 360
                else:
                    sun_set = (unix_to_csd(t) % 1) * 360

            # Highlight the day time data
            opts = {
                "color": "grey",
                "alpha": 0.5,
                "line_width": 1,
                "line_color": "black",
                "line_dash": "dashed",
            }

            span = hv.HSpan if self.transpose else hv.VSpan
            if sun_rise < sun_set:
                img *= span(sun_rise, sun_set).opts(**opts)
            else:
                img *= span(self.xlim[0], sun_set).opts(**opts)
                img *= span(sun_rise, self.xlim[-1]).opts(**opts)

        img.opts(
            # Fix height, but make width responsive
            height=self.height,
            responsive=True,
            bgcolor="lightgray",
            shared_axes=True,
        )

        return panel.Row(img, width_policy="max")
Exemplo n.º 14
0
    def process(self, sstream, inputmap):
        """Clean the sun.

        Parameters
        ----------
        sstream : containers.SiderealStream
            Sidereal stream.

        Returns
        -------
        mstream : containers.SiderealStream
            Sidereal stack with sun projected out.
        """

        sstream.redistribute("freq")

        # Get array of CSDs for each sample
        ra = sstream.index_map["ra"][:]
        csd = sstream.attrs[
            "lsd"] if "lsd" in sstream.attrs else sstream.attrs["csd"]
        csd = csd + ra / 360.0

        nprod = len(sstream.index_map["prod"])

        # Get position of sun at every time sample
        times = ephemeris.csd_to_unix(csd)
        sun_pos = np.array([
            ra_dec_of(ephemeris.skyfield_wrapper.ephemeris["sun"], t)
            for t in times
        ])

        # Get hour angle and dec of sun, in radians
        ha = 2 * np.pi * (ra / 360.0) - sun_pos[:, 0]
        dec = sun_pos[:, 1]
        el = sun_pos[:, 2]

        # Construct baseline vector for each visibility
        feed_pos = tools.get_feed_positions(inputmap)
        vis_pos = np.array([
            feed_pos[ii] - feed_pos[ij]
            for ii, ij in sstream.index_map["prod"][:]
        ])

        feed_list = [(inputmap[fi], inputmap[fj])
                     for fi, fj in sstream.index_map["prod"][:]]

        # Determine polarisation for each visibility
        pol_ind = np.full(len(feed_list), -1, dtype=np.int)
        for ii, (fi, fj) in enumerate(feed_list):
            if tools.is_chime(fi) and tools.is_chime(fj):
                pol_ind[ii] = 2 * tools.is_chime_y(fi) + tools.is_chime_y(fj)

        # Change vis_pos for non-CHIME feeds from NaN to 0.0
        vis_pos[(pol_ind == -1), :] = 0.0

        # Initialise new container
        sscut = sstream.__class__(axes_from=sstream, attrs_from=sstream)
        sscut.redistribute("freq")

        wv = 3e2 / sstream.index_map["freq"]["centre"]

        # Iterate over frequencies and polarisations to null out the sun
        for lfi, fi in sstream.vis[:].enumerate(0):

            # Get the baselines in wavelengths
            u = vis_pos[:, 0] / wv[fi]
            v = vis_pos[:, 1] / wv[fi]

            # Loop over ra to reduce memory usage
            for ri in range(len(ra)):

                # Copy over the visiblities and weights
                vis = sstream.vis[fi, :, ri]
                weight = sstream.weight[fi, :, ri]
                sscut.vis[fi, :, ri] = vis
                sscut.weight[fi, :, ri] = weight

                # Check if sun has set
                if el[ri] > 0.0:

                    # Calculate the phase that the sun would have using the fringestop routine
                    sun_vis = tools.fringestop_phase(
                        ha[ri], np.radians(ephemeris.CHIMELATITUDE), dec[ri],
                        u, v)

                    # Calculate the visibility vector for the sun
                    sun_vis = sun_vis.conj()

                    # Mask out the auto-correlations
                    sun_vis *= np.logical_or(u != 0.0, v != 0.0)

                    # Iterate over polarisations to do projection independently for each.
                    # This is needed because of the different beams for each pol.
                    for pol in range(4):

                        # Mask out other polarisations in the visibility vector
                        sun_vis_pol = sun_vis * (pol_ind == pol)

                        # Calculate various projections
                        vds = (vis * sun_vis_pol.conj() * weight).sum(axis=0)
                        sds = (sun_vis_pol * sun_vis_pol.conj() *
                               weight).sum(axis=0)
                        isds = tools.invert_no_zero(sds)

                        # Subtract sun contribution from visibilities and place in new array
                        sscut.vis[fi, :, ri] -= sun_vis_pol * vds * isds

        # Return the clean sidereal stream
        return sscut
Exemplo n.º 15
0
    def setup(self):
        """Query the database and fetch the files

        Returns
        -------
        files : list
            List of files to load
        """
        files = None

        # Query the database on rank=0 only, and broadcast to everywhere else
        if mpiutil.rank0:

            if self.run_name:
                return self.QueryRun()

            layout.connect_database()

            f = finder.Finder(node_spoof=self.node_spoof)

            f.filter_acqs(di.AcqType.name == self.acqtype)

            if self.instrument is not None:
                f.filter_acqs(di.ArchiveInst.name == self.instrument)

            if self.accept_all_global_flags:
                f.accept_all_global_flags()

            # Use start and end times if set, or try and use the start and end CSDs
            if self.start_time:
                st, et = self.start_time, self.end_time
            elif self.start_csd:
                st = ephemeris.csd_to_unix(self.start_csd)
                et = (
                    ephemeris.csd_to_unix(self.end_csd)
                    if self.end_csd is not None
                    else None
                )

            # Note: include_time_interval includes the specified time interval
            # Using this instead of set_time_range, which only narrows the interval
            # f.include_time_interval(self.start_time, self.end_time)
            f.set_time_range(st, et)

            if self.start_RA and self.end_RA:
                f.include_RA_interval(self.start_RA, self.end_RA)
            elif self.start_RA or self.start_RA:
                self.log.warning(
                    "One but not both of start_RA and end_RA " "are set. Ignoring both."
                )

            f.filter_acqs(di.ArchiveInst.name == self.instrument)

            if self.exclude_daytime:
                f.exclude_daytime()

            if self.exclude_sun:
                f.exclude_sun(
                    time_delta=self.exclude_sun_time_delta,
                    time_delta_rise_set=self.exclude_sun_time_delta_rise_set,
                )

            if self.include_transits:
                time_delta = self.include_transits_time_delta
                ntime_delta = len(time_delta)
                if (ntime_delta > 1) and (ntime_delta < len(self.include_transits)):
                    raise ValueError(
                        "Must specify `time_delta` for each source in "
                        "`include_transits` or provide single value for all sources."
                    )
                for ss, src in enumerate(self.include_transits):
                    tdelta = time_delta[ss % ntime_delta] if ntime_delta > 0 else None
                    bdy = (
                        ephemeris.source_dictionary[src]
                        if isinstance(src, str)
                        else src
                    )
                    f.include_transits(bdy, time_delta=tdelta)

            if self.exclude_transits:
                time_delta = self.exclude_transits_time_delta
                ntime_delta = len(time_delta)
                if (ntime_delta > 1) and (ntime_delta < len(self.exclude_transits)):
                    raise ValueError(
                        "Must specify `time_delta` for each source in "
                        "`exclude_transits` or provide single value for all sources."
                    )
                for ss, src in enumerate(self.exclude_transits):
                    tdelta = time_delta[ss % ntime_delta] if ntime_delta > 0 else None
                    bdy = (
                        ephemeris.source_dictionary[src]
                        if isinstance(src, str)
                        else src
                    )
                    f.exclude_transits(bdy, time_delta=tdelta)

            if self.source_26m:
                f.include_26m_obs(self.source_26m)

            if len(self.exclude_data_flag_types) > 0:
                f.exclude_data_flag_type(self.exclude_data_flag_types)

            results = f.get_results()
            if not self.return_intervals:
                files = [fname for result in results for fname in result[0]]
                files.sort()
            else:
                files = results
                files.sort(key=lambda x: x[1][0])

        files = mpiutil.world.bcast(files, root=0)

        # Make sure all nodes have container before return
        mpiutil.world.Barrier()

        return files
Exemplo n.º 16
0
def offline_point_source_calibration(file_list,
                                     source,
                                     inputmap=None,
                                     start=None,
                                     stop=None,
                                     physical_freq=None,
                                     tcorr=None,
                                     logging_params=DEFAULT_LOGGING,
                                     **kwargs):
    # Load config
    config = DEFAULTS.deepcopy()
    config.merge(NameSpace(kwargs))

    # Setup logging
    log.setup_logging(logging_params)
    mlog = log.get_logger(__name__)

    mlog.info("ephemeris file: %s" % ephemeris.__file__)

    # Set the model to use
    fitter_function = utils.fit_point_source_transit
    model_function = utils.model_point_source_transit

    farg = inspect.getargspec(fitter_function)
    defaults = {
        key: val
        for key, val in zip(farg.args[-len(farg.defaults):], farg.defaults)
    }
    poly_deg_amp = kwargs.get('poly_deg_amp', defaults['poly_deg_amp'])
    poly_deg_phi = kwargs.get('poly_deg_phi', defaults['poly_deg_phi'])
    poly_type = kwargs.get('poly_type', defaults['poly_type'])

    param_name = ([
        '%s_poly_amp_coeff%d' % (poly_type, cc)
        for cc in range(poly_deg_amp + 1)
    ] + [
        '%s_poly_phi_coeff%d' % (poly_type, cc)
        for cc in range(poly_deg_phi + 1)
    ])

    model_kwargs = [('poly_deg_amp', poly_deg_amp),
                    ('poly_deg_phi', poly_deg_phi), ('poly_type', poly_type)]
    model_name = '.'.join(
        [getattr(model_function, key) for key in ['__module__', '__name__']])

    tval = {}

    # Set where to evaluate gain
    ha_eval_str = ['raw_transit']

    if config.multi_sample:
        ha_eval_str += ['transit', 'peak']
        ha_eval = [0.0, None]
        fitslc = slice(1, 3)

    ind_eval = ha_eval_str.index(config.evaluate_gain_at)

    # Determine dimensions
    direction = ['amp', 'phi']
    nparam = len(param_name)
    ngain = len(ha_eval_str)
    ndir = len(direction)

    # Determine frequencies
    data = andata.CorrData.from_acq_h5(file_list,
                                       datasets=(),
                                       start=start,
                                       stop=stop)
    freq = data.freq

    if physical_freq is not None:
        index_freq = np.array(
            [np.argmin(np.abs(ff - freq)) for ff in physical_freq])
        freq_sel = utils.convert_to_slice(index_freq)
        freq = freq[index_freq]
    else:
        index_freq = np.arange(freq.size)
        freq_sel = None

    nfreq = freq.size

    # Compute flux of source
    inv_rt_flux_density = tools.invert_no_zero(
        np.sqrt(FluxCatalog[source].predict_flux(freq)))

    # Read in the eigenvaluess for all frequencies
    data = andata.CorrData.from_acq_h5(file_list,
                                       datasets=['erms', 'eval'],
                                       freq_sel=freq_sel,
                                       start=start,
                                       stop=stop)

    # Determine source coordinates
    this_csd = np.floor(ephemeris.unix_to_csd(np.median(data.time)))
    timestamp0 = ephemeris.transit_times(FluxCatalog[source].skyfield,
                                         ephemeris.csd_to_unix(this_csd))[0]
    src_ra, src_dec = ephemeris.object_coords(FluxCatalog[source].skyfield,
                                              date=timestamp0,
                                              deg=True)

    ra = ephemeris.lsa(data.time)
    ha = ra - src_ra
    ha = ha - (ha > 180.0) * 360.0 + (ha < -180.0) * 360.0
    ha = np.radians(ha)

    itrans = np.argmin(np.abs(ha))

    window = 0.75 * np.max(np.abs(ha))

    off_source = np.abs(ha) > window

    mlog.info("CSD %d" % this_csd)
    mlog.info("Hour angle at transit (%d of %d):  %0.2f deg   " %
              (itrans, len(ha), np.degrees(ha[itrans])))
    mlog.info("Hour angle off source: %0.2f deg" %
              np.median(np.abs(np.degrees(ha[off_source]))))

    src_dec = np.radians(src_dec)
    lat = np.radians(ephemeris.CHIMELATITUDE)

    # Determine division of frequencies
    ninput = data.ninput
    ntime = data.ntime
    nblock_freq = int(np.ceil(nfreq / float(config.nfreq_per_block)))

    # Determine bad inputs
    eps = 10.0 * np.finfo(data['erms'].dtype).eps
    good_freq = np.flatnonzero(np.all(data['erms'][:] > eps, axis=-1))
    ind_sub_freq = good_freq[slice(0, good_freq.size,
                                   max(int(good_freq.size / 10), 1))]

    tmp_data = andata.CorrData.from_acq_h5(file_list,
                                           datasets=['evec'],
                                           freq_sel=ind_sub_freq,
                                           start=start,
                                           stop=stop)
    eps = 10.0 * np.finfo(tmp_data['evec'].dtype).eps
    bad_input = np.flatnonzero(
        np.all(np.abs(tmp_data['evec'][:, 0]) < eps, axis=(0, 2)))

    input_axis = tmp_data.input.copy()

    del tmp_data

    # Query layout database for correlator inputs
    if inputmap is None:
        inputmap = tools.get_correlator_inputs(
            datetime.datetime.utcfromtimestamp(data.time[itrans]),
            correlator='chime')

    inputmap = tools.reorder_correlator_inputs(input_axis, inputmap)

    tools.change_chime_location(rotation=config.telescope_rotation)

    # Determine x and y pol index
    xfeeds = np.array([
        idf for idf, inp in enumerate(inputmap)
        if (idf not in bad_input) and tools.is_array_x(inp)
    ])
    yfeeds = np.array([
        idf for idf, inp in enumerate(inputmap)
        if (idf not in bad_input) and tools.is_array_y(inp)
    ])

    nfeed = xfeeds.size + yfeeds.size

    pol = [yfeeds, xfeeds]
    polstr = ['Y', 'X']
    npol = len(pol)

    neigen = min(max(npol, config.neigen), data['eval'].shape[1])

    phase_ref = config.phase_reference_index
    phase_ref_by_pol = [
        pol[pp].tolist().index(phase_ref[pp]) for pp in range(npol)
    ]

    # Calculate dynamic range
    eval0_off_source = np.median(data['eval'][:, 0, off_source], axis=-1)

    dyn = data['eval'][:, 1, :] * tools.invert_no_zero(
        eval0_off_source[:, np.newaxis])

    # Determine frequencies to mask
    not_rfi = np.ones((nfreq, 1), dtype=np.bool)
    if config.mask_rfi is not None:
        for frng in config.mask_rfi:
            not_rfi[:, 0] &= ((freq < frng[0]) | (freq > frng[1]))

    mlog.info("%0.1f percent of frequencies available after masking RFI." %
              (100.0 * np.sum(not_rfi, dtype=np.float32) / float(nfreq), ))

    #dyn_flg = utils.contiguous_flag(dyn > config.dyn_rng_threshold, centre=itrans)
    if source in config.dyn_rng_threshold:
        dyn_rng_threshold = config.dyn_rng_threshold[source]
    else:
        dyn_rng_threshold = config.dyn_rng_threshold.default

    mlog.info("Dynamic range threshold set to %0.1f." % dyn_rng_threshold)

    dyn_flg = dyn > dyn_rng_threshold

    # Calculate fit flag
    fit_flag = np.zeros((nfreq, npol, ntime), dtype=np.bool)
    for pp in range(npol):

        mlog.info("Dynamic Range Nsample, Pol %d:  %s" % (pp, ','.join([
            "%d" % xx for xx in np.percentile(np.sum(dyn_flg, axis=-1),
                                              [25, 50, 75, 100])
        ])))

        if config.nsigma1 is None:
            fit_flag[:, pp, :] = dyn_flg & not_rfi

        else:

            fit_window = config.nsigma1 * np.radians(
                utils.get_window(freq, pol=polstr[pp], dec=src_dec, deg=True))

            win_flg = np.abs(ha)[np.newaxis, :] <= fit_window[:, np.newaxis]

            fit_flag[:, pp, :] = (dyn_flg & win_flg & not_rfi)

    # Calculate base error
    base_err = data['erms'][:, np.newaxis, :]

    # Check for sign flips
    ref_resp = andata.CorrData.from_acq_h5(file_list,
                                           datasets=['evec'],
                                           input_sel=config.eigen_reference,
                                           freq_sel=freq_sel,
                                           start=start,
                                           stop=stop)['evec'][:, 0:neigen,
                                                              0, :]

    sign0 = 1.0 - 2.0 * (ref_resp.real < 0.0)

    # Check that we have the correct reference feed
    if np.any(np.abs(ref_resp.imag) > 0.0):
        ValueError("Reference feed %d is incorrect." % config.eigen_reference)

    del ref_resp

    # Save index_map
    results = {}
    results['model'] = model_name
    results['param'] = param_name
    results['freq'] = data.index_map['freq'][:]
    results['input'] = input_axis
    results['eval'] = ha_eval_str
    results['dir'] = direction

    for key, val in model_kwargs:
        results[key] = val

    # Initialize numpy arrays to hold results
    if config.return_response:

        results['response'] = np.zeros((nfreq, ninput, ntime),
                                       dtype=np.complex64)
        results['response_err'] = np.zeros((nfreq, ninput, ntime),
                                           dtype=np.float32)
        results['fit_flag'] = fit_flag
        results['ha_axis'] = ha
        results['ra'] = ra

    else:

        results['gain_eval'] = np.zeros((nfreq, ninput, ngain),
                                        dtype=np.complex64)
        results['weight_eval'] = np.zeros((nfreq, ninput, ngain),
                                          dtype=np.float32)
        results['frac_gain_err'] = np.zeros((nfreq, ninput, ngain, ndir),
                                            dtype=np.float32)

        results['parameter'] = np.zeros((nfreq, ninput, nparam),
                                        dtype=np.float32)
        results['parameter_err'] = np.zeros((nfreq, ninput, nparam),
                                            dtype=np.float32)

        results['index_eval'] = np.full((nfreq, ninput), -1, dtype=np.int8)
        results['gain'] = np.zeros((nfreq, ninput), dtype=np.complex64)
        results['weight'] = np.zeros((nfreq, ninput), dtype=np.float32)

        results['ndof'] = np.zeros((nfreq, ninput, ndir), dtype=np.float32)
        results['chisq'] = np.zeros((nfreq, ninput, ndir), dtype=np.float32)

        results['timing'] = np.zeros((nfreq, ninput), dtype=np.complex64)

    # Initialize metric like variables
    results['runtime'] = np.zeros((nblock_freq, 2), dtype=np.float64)

    # Compute distances
    dist = tools.get_feed_positions(inputmap)
    for pp, feeds in enumerate(pol):
        dist[feeds, :] -= dist[phase_ref[pp], np.newaxis, :]

    # Loop over frequency blocks
    for gg in range(nblock_freq):

        mlog.info("Frequency block %d of %d." % (gg, nblock_freq))

        fstart = gg * config.nfreq_per_block
        fstop = min((gg + 1) * config.nfreq_per_block, nfreq)
        findex = np.arange(fstart, fstop)
        ngroup = findex.size

        freq_sel = utils.convert_to_slice(index_freq[findex])

        timeit_start_gg = time.time()

        #
        if config.return_response:
            gstart = start
            gstop = stop

            tslc = slice(0, ntime)

        else:
            good_times = np.flatnonzero(np.any(fit_flag[findex], axis=(0, 1)))

            if good_times.size == 0:
                continue

            gstart = int(np.min(good_times))
            gstop = int(np.max(good_times)) + 1

            tslc = slice(gstart, gstop)

            gstart += start
            gstop += start

        hag = ha[tslc]
        itrans = np.argmin(np.abs(hag))

        # Load eigenvectors.
        nudata = andata.CorrData.from_acq_h5(
            file_list,
            datasets=['evec', 'vis', 'flags/vis_weight'],
            apply_gain=False,
            freq_sel=freq_sel,
            start=gstart,
            stop=gstop)

        # Save time to load data
        results['runtime'][gg, 0] = time.time() - timeit_start_gg
        timeit_start_gg = time.time()

        mlog.info("Time to load (per frequency):  %0.3f sec" %
                  (results['runtime'][gg, 0] / ngroup, ))

        # Loop over polarizations
        for pp, feeds in enumerate(pol):

            # Get timing correction
            if tcorr is not None:
                tgain = tcorr.get_gain(nudata.freq, nudata.input[feeds],
                                       nudata.time)
                tgain *= tgain[:, phase_ref_by_pol[pp], np.newaxis, :].conj()

                tgain_transit = tgain[:, :, itrans].copy()
                tgain *= tgain_transit[:, :, np.newaxis].conj()

            # Create the polarization masking vector
            P = np.zeros((1, ninput, 1), dtype=np.float64)
            P[:, feeds, :] = 1.0

            # Loop over frequencies
            for gff, ff in enumerate(findex):

                flg = fit_flag[ff, pp, tslc]

                if (2 * int(np.sum(flg))) < (nparam +
                                             1) and not config.return_response:
                    continue

                # Normalize by eigenvalue and correct for pi phase flips in process.
                resp = (nudata['evec'][gff, 0:neigen, :, :] *
                        np.sqrt(data['eval'][ff, 0:neigen, np.newaxis, tslc]) *
                        sign0[ff, :, np.newaxis, tslc])

                # Rotate to single-pol response
                # Move time to first axis for the matrix multiplication
                invL = tools.invert_no_zero(
                    np.rollaxis(data['eval'][ff, 0:neigen, np.newaxis, tslc],
                                -1, 0))

                UT = np.rollaxis(resp, -1, 0)
                U = np.swapaxes(UT, -1, -2)

                mu, vp = np.linalg.eigh(np.matmul(UT.conj(), P * U))

                rsign0 = (1.0 - 2.0 * (vp[:, 0, np.newaxis, :].real < 0.0))

                resp = mu[:, np.newaxis, :] * np.matmul(U, rsign0 * vp * invL)

                # Extract feeds of this pol
                # Transpose so that time is back to last axis
                resp = resp[:, feeds, -1].T

                # Compute error on response
                dataflg = ((nudata.weight[gff, feeds, :] > 0.0)
                           & np.isfinite(nudata.weight[gff, feeds, :])).astype(
                               np.float32)

                resp_err = dataflg * base_err[ff, :, tslc] * np.sqrt(
                    nudata.vis[gff, feeds, :].real) * tools.invert_no_zero(
                        np.sqrt(mu[np.newaxis, :, -1]))

                # Reference to specific input
                resp *= np.exp(
                    -1.0J *
                    np.angle(resp[phase_ref_by_pol[pp], np.newaxis, :]))

                # Apply timing correction
                if tcorr is not None:
                    resp *= tgain[gff]

                    results['timing'][ff, feeds] = tgain_transit[gff]

                # Fringestop
                lmbda = scipy.constants.c * 1e-6 / nudata.freq[gff]

                resp *= tools.fringestop_phase(
                    hag[np.newaxis, :], lat, src_dec,
                    dist[feeds, 0, np.newaxis] / lmbda,
                    dist[feeds, 1, np.newaxis] / lmbda)

                # Normalize by source flux
                resp *= inv_rt_flux_density[ff]
                resp_err *= inv_rt_flux_density[ff]

                # If requested, reference phase to the median value
                if config.med_phase_ref:
                    phi0 = np.angle(resp[:, itrans, np.newaxis])
                    resp *= np.exp(-1.0J * phi0)
                    resp *= np.exp(
                        -1.0J *
                        np.median(np.angle(resp), axis=0, keepdims=True))
                    resp *= np.exp(1.0J * phi0)

                # Check if return_response flag was set by user
                if not config.return_response:

                    if config.multi_sample:
                        moving_window = config.nsigma2 and config.nsigma2 * np.radians(
                            utils.get_window(nudata.freq[gff],
                                             pol=polstr[pp],
                                             dec=src_dec,
                                             deg=True))

                    # Loop over inputs
                    for pii, ii in enumerate(feeds):

                        is_good = flg & (np.abs(resp[pii, :]) >
                                         0.0) & (resp_err[pii, :] > 0.0)

                        # Set the intial gains based on raw response at transit
                        if is_good[itrans]:
                            results['gain_eval'][ff, ii,
                                                 0] = tools.invert_no_zero(
                                                     resp[pii, itrans])
                            results['frac_gain_err'][ff, ii, 0, :] = (
                                resp_err[pii, itrans] * tools.invert_no_zero(
                                    np.abs(resp[pii, itrans])))
                            results['weight_eval'][ff, ii, 0] = 0.5 * (
                                np.abs(resp[pii, itrans])**2 *
                                tools.invert_no_zero(resp_err[pii, itrans]))**2

                            results['index_eval'][ff, ii] = 0
                            results['gain'][ff,
                                            ii] = results['gain_eval'][ff, ii,
                                                                       0]
                            results['weight'][ff,
                                              ii] = results['weight_eval'][ff,
                                                                           ii,
                                                                           0]

                        # Exit if not performing multi time sample fit
                        if not config.multi_sample:
                            continue

                        if (2 * int(np.sum(is_good))) < (nparam + 1):
                            continue

                        try:
                            param, param_err, gain, gain_err, ndof, chisq, tval = fitter_function(
                                hag[is_good],
                                resp[pii, is_good],
                                resp_err[pii, is_good],
                                ha_eval,
                                window=moving_window,
                                tval=tval,
                                **config.fit)
                        except Exception as rex:
                            if config.verbose:
                                mlog.info(
                                    "Frequency %0.2f, Feed %d failed with error: %s"
                                    % (nudata.freq[gff], ii, rex))
                            continue

                        # Check for nan
                        wfit = (np.abs(gain) *
                                tools.invert_no_zero(np.abs(gain_err)))**2
                        if np.any(~np.isfinite(np.abs(gain))) or np.any(
                                ~np.isfinite(wfit)):
                            continue

                        # Save to results using the convention that you should *multiply* the visibilites by the gains
                        results['gain_eval'][
                            ff, ii, fitslc] = tools.invert_no_zero(gain)
                        results['frac_gain_err'][ff, ii, fitslc,
                                                 0] = gain_err.real
                        results['frac_gain_err'][ff, ii, fitslc,
                                                 1] = gain_err.imag
                        results['weight_eval'][ff, ii, fitslc] = wfit

                        results['parameter'][ff, ii, :] = param
                        results['parameter_err'][ff, ii, :] = param_err

                        results['ndof'][ff, ii, :] = ndof
                        results['chisq'][ff, ii, :] = chisq

                        # Check if the fit was succesful and update the gain evaluation index appropriately
                        if np.all((chisq / ndof.astype(np.float32)
                                   ) <= config.chisq_per_dof_threshold):
                            results['index_eval'][ff, ii] = ind_eval
                            results['gain'][ff, ii] = results['gain_eval'][
                                ff, ii, ind_eval]
                            results['weight'][ff, ii] = results['weight_eval'][
                                ff, ii, ind_eval]

                else:

                    # Return response only (do not fit model)
                    results['response'][ff, feeds, :] = resp
                    results['response_err'][ff, feeds, :] = resp_err

        # Save time to fit data
        results['runtime'][gg, 1] = time.time() - timeit_start_gg

        mlog.info("Time to fit (per frequency):  %0.3f sec" %
                  (results['runtime'][gg, 1] / ngroup, ))

        # Clean up
        del nudata
        gc.collect()

    # Print total run time
    mlog.info("TOTAL TIME TO LOAD: %0.3f min" %
              (np.sum(results['runtime'][:, 0]) / 60.0, ))
    mlog.info("TOTAL TIME TO FIT:  %0.3f min" %
              (np.sum(results['runtime'][:, 1]) / 60.0, ))

    # Set the best estimate of the gain
    if not config.return_response:

        flag = results['index_eval'] >= 0
        gain = results['gain']

        # Compute amplitude
        amp = np.abs(gain)

        # Hard cutoffs on the amplitude
        med_amp = np.median(amp[flag])
        min_amp = med_amp * config.min_amp_scale_factor
        max_amp = med_amp * config.max_amp_scale_factor

        flag &= ((amp >= min_amp) & (amp <= max_amp))

        # Flag outliers in amplitude for each frequency
        for pp, feeds in enumerate(pol):

            med_amp_by_pol = np.zeros(nfreq, dtype=np.float32)
            sig_amp_by_pol = np.zeros(nfreq, dtype=np.float32)

            for ff in range(nfreq):

                this_flag = flag[ff, feeds]

                if np.any(this_flag):

                    med, slow, shigh = utils.estimate_directional_scale(
                        amp[ff, feeds[this_flag]])
                    lower = med - config.nsigma_outlier * slow
                    upper = med + config.nsigma_outlier * shigh

                    flag[ff, feeds] &= ((amp[ff, feeds] >= lower) &
                                        (amp[ff, feeds] <= upper))

                    med_amp_by_pol[ff] = med
                    sig_amp_by_pol[ff] = 0.5 * (shigh - slow) / np.sqrt(
                        np.sum(this_flag, dtype=np.float32))

            if config.nsigma_med_outlier:

                med_flag = med_amp_by_pol > 0.0

                not_outlier = flag_outliers(med_amp_by_pol,
                                            med_flag,
                                            window=config.window_med_outlier,
                                            nsigma=config.nsigma_med_outlier)
                flag[:, feeds] &= not_outlier[:, np.newaxis]

                mlog.info("Pol %s:  %d frequencies are outliers." %
                          (polstr[pp],
                           np.sum(~not_outlier & med_flag, dtype=np.int)))

        # Determine bad frequencies
        flag_freq = (np.sum(flag, axis=1, dtype=np.float32) /
                     float(ninput)) > config.threshold_good_freq
        good_freq = np.flatnonzero(flag_freq)

        # Determine bad inputs
        fraction_good = np.sum(flag[good_freq, :], axis=0,
                               dtype=np.float32) / float(good_freq.size)
        flag_input = fraction_good > config.threshold_good_input

        # Finalize flag
        flag &= (flag_freq[:, np.newaxis] & flag_input[np.newaxis, :])

        # Interpolate gains
        interp_gain, interp_weight = interpolate_gain(
            freq,
            gain,
            results['weight'],
            flag=flag,
            length_scale=config.interpolation_length_scale,
            mlog=mlog)
        # Save gains to object
        results['flag'] = flag
        results['gain'] = interp_gain
        results['weight'] = interp_weight

    # Return results
    return results
Exemplo n.º 17
0
    def _create_hook(self):
        """Create the revision.

        This tries to determine which days are good and bad, and partitions the
        available good days into the individual stacks.
        """

        days = {}

        core.connect()

        # Go over each revision and construct the set of LSDs we should stack, and save
        # the path to each.
        # NOTE: later entries is `daily_revisions` will override LSDs found in earlier
        # revisions.
        for rev in self.default_params["daily_revisions"]:

            daily_path = (
                self.root_path
                if self.default_params["daily_root"] is None
                else self.default_params["daily_root"]
            )
            daily_rev = daily.DailyProcessing(rev, root_path=daily_path)

            # Get all the bad days in this revision
            revision = df.DataRevision.get(name=rev)
            query = (
                df.DataFlagOpinion.select(df.DataFlagOpinion.lsd)
                .distinct()
                .where(
                    df.DataFlagOpinion.revision == revision,
                    df.DataFlagOpinion.decision == "bad",
                )
            )
            bad_days = [x[0] for x in query.tuples()]

            # Get all the good days
            query = (
                df.DataFlagOpinion.select(df.DataFlagOpinion.lsd)
                .distinct()
                .where(
                    df.DataFlagOpinion.revision == revision,
                    df.DataFlagOpinion.decision == "good",
                )
            )
            good_days = [x[0] for x in query.tuples()]

            for d in daily_rev.ls():

                # Filter out known bad days here
                if (int(d) in bad_days) or (int(d) not in good_days):
                    continue

                # Insert the day and path into the dict, this will replace the entries
                # from prior revisions
                path = daily_rev.base_path / d
                lsd = int(d)
                days[lsd] = path

        lsds = sorted(days)

        # Map each LSD into the quarter it belongs in and find which quarters we have
        # data for
        dates = ctime.unix_to_datetime(ephemeris.csd_to_unix(np.array(lsds)))
        yq = np.array([f"{d.year}q{(d.month - 1) // 3 + 1}" for d in dates])
        quarters = np.unique(yq)

        npart = self.default_params["partitions"]

        lsd_partitions = {}

        # For each quarter divide the LSDs it contains into a number of partitions to
        # give jack knifes
        for quarter in quarters:

            lsds_in_quarter = sorted(np.array(lsds)[yq == quarter])

            # Skip quarters with two few days in them
            if len(lsds_in_quarter) < self.default_params["min_days"] * npart:
                continue

            for i in range(npart):
                lsd_partitions[f"{quarter}p{i}"] = [
                    int(d) for d in lsds_in_quarter[i::npart]
                ]

        # Save the relevant parameters into the revisions configuration
        self.default_params["days"] = {
            int(day): str(path) for day, path in days.items()
        }
        self.default_params["stacks"] = lsd_partitions
Exemplo n.º 18
0
 def start_time(self):
     return csd_to_unix(self._lsd)
Exemplo n.º 19
0
    def parse_ant_logs(cls, logs, return_post_report_params=False):
        """
        Unzip and parse .ANT log file output by nsched for John Galt Telescope
        observations

        Parameters
        ----------
        logs : list of strings
            .ZIP filenames. Each .ZIP archive should include a .ANT file and
            a .POST_REPORT file. This method unzips the archive, uses
            `parse_post_report` to read the .POST_REPORT file and extract
            the CHIME sidereal day corresponding to the DRAO sidereal day,
            and then reads the lines in the .ANT file to obtain the pointing
            history of the Galt Telescope during this observation.

            (The DRAO sidereal day is days since the clock in Ev Sheehan's
            office at DRAO was reset. This clock is typically only reset every
            few years, but it does not correspond to any defined date, so the
            date must be figured out from the .POST_REPORT file, which reports
            both the DRAO sidereal day and the UTC date and time.

            Known reset dates: 2017-11-21, 2019-3-10)

        Returns
        -------

        if output_params == False:
            ant_data: A dictionary consisting of lists containing the LST,
                hour angle, RA, and dec (all as Skyfield Angle objects),
                CHIME sidereal day, and DRAO sidereal day.

        if output_params == True
            output_params: dictionary returned by `parse_post_report`
            and
            ant_data: described above

        Files
        -----
        the .ANT and .POST_REPORT files in the input .zip archive are
        extracted into /tmp/26mlog/<loginname>/
        """

        from skyfield.positionlib import Angle
        from caput import time as ctime

        DRAO_lon = ephemeris.CHIMELONGITUDE * 24.0 / 360.0

        def sidlst_to_csd(sid, lst, sid_ref, t_ref):
            """
            Convert an integer DRAO sidereal day and LST to a float
            CHIME sidereal day

            Parameters
            ----------
            sid : int
                DRAO sidereal day
            lst : float, in hours
                local sidereal time
            sid_ref : int
                DRAO sidereal day at the reference time t_ref
            t_ref : skyfield time object, Julian days
                Reference time

            Returns
            -------
            output : float
                CHIME sidereal day
            """
            csd_ref = int(
                ephemeris.csd(ephemeris.datetime_to_unix(
                    t_ref.utc_datetime())))
            csd = sid - sid_ref + csd_ref
            return csd + lst / ephemeris.SIDEREAL_S / 24.0

        ant_data_list = []
        post_report_list = []

        for log in logs:
            doobs = True

            filename = log.split("/")[-1]
            basedir = "/tmp/26mlog/{}/".format(os.getlogin())
            basename, extension = filename.split(".")
            post_report_file = basename + ".POST_REPORT"
            ant_file = basename + ".ANT"

            if extension == "zip":
                try:
                    zipfile.ZipFile(log).extract(post_report_file,
                                                 path=basedir)
                except:
                    print(
                        "Failed to extract {} into {}. Moving right along...".
                        format(post_report_file, basedir))
                    doobs = False
                try:
                    zipfile.ZipFile(log).extract(ant_file, path=basedir)
                except:
                    print(
                        "Failed to extract {} into {}. Moving right along...".
                        format(ant_file, basedir))
                    doobs = False

            if doobs:
                try:
                    post_report_params = cls.parse_post_report(
                        basedir + post_report_file)

                    with open(os.path.join(basedir, ant_file), "r") as f:
                        lines = [line for line in f]
                        ant_data = {"sid": np.array([])}
                        lsth = []
                        lstm = []
                        lsts = []

                        hah = []
                        ham = []
                        has = []

                        decd = []
                        decm = []
                        decs = []

                        for l in lines:
                            arr = l.split()

                            try:
                                lst_hms = [float(x) for x in arr[2].split(":")]

                                # do last element first: if this is going to
                                # crash because a line in the log is incomplete,
                                # we don't want it to append to any of the lists

                                decs.append(float(arr[8].replace('"', "")))
                                decm.append(float(arr[7].replace("'", "")))
                                decd.append(float(arr[6].replace("D", "")))

                                has.append(float(arr[5].replace("S", "")))
                                ham.append(float(arr[4].replace("M", "")))
                                hah.append(float(arr[3].replace("H", "")))

                                lsts.append(float(lst_hms[2]))
                                lstm.append(float(lst_hms[1]))
                                lsth.append(float(lst_hms[0]))

                                ant_data["sid"] = np.append(
                                    ant_data["sid"], int(arr[1]))
                            except:
                                print("Failed in file {} for line \n{}".format(
                                    ant_file, l))
                                if len(ant_data["sid"]) != len(decs):
                                    print("WARNING: mismatch in list lengths.")

                        ant_data["lst"] = Angle(hours=(lsth, lstm, lsts))

                        ha = Angle(hours=(hah, ham, has))
                        dec = Angle(degrees=(decd, decm, decs))

                        ant_data["ha"] = Angle(
                            radians=ha.radians -
                            ephemeris.galt_pointing_model_ha(ha, dec).radians,
                            preference="hours",
                        )

                        ant_data["dec_cirs"] = Angle(
                            radians=dec.radians -
                            ephemeris.galt_pointing_model_dec(ha, dec).radians,
                            preference="degrees",
                        )

                        ant_data["csd"] = sidlst_to_csd(
                            np.array(ant_data["sid"]),
                            ant_data["lst"].hours,
                            post_report_params["SID"],
                            post_report_params["start_time"],
                        )

                    ant_data["t"] = ephemeris.unix_to_skyfield_time(
                        ephemeris.csd_to_unix(ant_data["csd"]))

                    # Correct RA from equinox to CIRS coords (both in radians)
                    era = np.radians(
                        ctime.unix_to_era(ephemeris.ensure_unix(
                            ant_data["t"])))
                    gast = ant_data["t"].gast * 2 * np.pi / 24.0

                    ant_data["ra_cirs"] = Angle(
                        radians=ant_data["lst"].radians -
                        ant_data["ha"].radians + (era - gast),
                        preference="hours",
                    )

                    obs = ephemeris.Star_cirs(
                        ra=ant_data["ra_cirs"],
                        dec=ant_data["dec_cirs"],
                        epoch=ant_data["t"],
                    )

                    ant_data["ra"] = obs.ra
                    ant_data["dec"] = obs.dec

                    ant_data_list.append(ant_data)
                    post_report_list.append(post_report_params)
                except:
                    print("Parsing {} failed".format(post_report_file))

        if return_post_report_params:
            return post_report_list, ant_data_list
        return ant_data
Exemplo n.º 20
0
def main(config_file=None, logging_params=DEFAULT_LOGGING):

    # Setup logging
    log.setup_logging(logging_params)
    mlog = log.get_logger(__name__)

    # Set config
    config = DEFAULTS.deepcopy()
    if config_file is not None:
        config.merge(NameSpace(load_yaml_config(config_file)))

    # Create transit tracker
    source_list = FluxCatalog.sort(
    ) if not config.source_list else config.source_list

    cal_list = [
        name for name, obj in FluxCatalog.iteritems()
        if (obj.dec >= config.min_dec) and (
            obj.predict_flux(config.freq_nominal) >= config.min_flux) and (
                name in source_list)
    ]

    if not cal_list:
        raise RuntimeError("No calibrators found.")

    # Sort list by flux at nominal frequency
    cal_list.sort(
        key=lambda name: FluxCatalog[name].predict_flux(config.freq_nominal))

    # Add to transit tracker
    transit_tracker = containers.TransitTrackerOffline(
        nsigma=config.nsigma_source, extend_night=config.extend_night)
    for name in cal_list:
        transit_tracker[name] = FluxCatalog[name].skyfield

    mlog.info("Initializing offline point source processing.")

    search_time = config.start_time or 0

    # Find all calibration files
    all_files = sorted(
        glob.glob(
            os.path.join(config.acq_dir,
                         '*' + config.correlator + config.acq_suffix, '*.h5')))
    if not all_files:
        return

    # Remove files whose last modified time is before the time of the most recent update
    all_files = [
        ff for ff in all_files if (os.path.getmtime(ff) > search_time)
    ]
    if not all_files:
        return

    # Remove files that are currently locked
    all_files = [
        ff for ff in all_files
        if not os.path.isfile(os.path.splitext(ff)[0] + '.lock')
    ]
    if not all_files:
        return

    # Add files to transit tracker
    for ff in all_files:
        transit_tracker.add_file(ff)

    # Extract point source transits ready for analysis
    all_transits = transit_tracker.get_transits()

    # Create dictionary to hold results
    h5_psrc_fit = {}
    inputmap = None

    # Loop over transits
    for transit in all_transits:

        src, csd, is_day, files, start, stop = transit

        # Discard any point sources with unusual csd value
        if (csd < config.min_csd) or (csd > config.max_csd):
            continue

        # Discard any point sources transiting during the day
        if is_day > config.process_daytime:
            continue

        mlog.info(
            'Processing %s transit on CSD %d (%d files, %d time samples)' %
            (src, csd, len(files), stop - start + 1))

        # Load inputmap
        if inputmap is None:
            if config.inputmap is None:
                inputmap = tools.get_correlator_inputs(
                    ephemeris.unix_to_datetime(ephemeris.csd_to_unix(csd)),
                    correlator=config.correlator)
            else:
                with open(config.inputmap, 'r') as handler:
                    inputmap = pickle.load(handler)

        # Grab the timing correction for this transit
        tcorr = None
        if config.apply_timing:

            if config.timing_glob is not None:

                mlog.info(
                    "Loading timing correction from extended timing solutions."
                )

                timing_files = sorted(glob.glob(config.timing_glob))

                if timing_files:

                    try:
                        tcorr = search_extended_timing_solutions(
                            timing_files, ephemeris.csd_to_unix(csd))

                    except Exception as e:
                        mlog.error(
                            'search_extended_timing_solutions failed with error: %s'
                            % e)

                    else:
                        mlog.info(str(tcorr))

            if tcorr is None:

                mlog.info(
                    "Loading timing correction from chimetiming acquisitions.")

                try:
                    tcorr = timing.load_timing_correction(
                        files,
                        start=start,
                        stop=stop,
                        window=config.timing_window,
                        instrument=config.correlator)
                except Exception as e:
                    mlog.error(
                        'timing.load_timing_correction failed with error: %s' %
                        e)
                    mlog.warning(
                        'No timing correction applied to %s transit on CSD %d.'
                        % (src, csd))
                else:
                    mlog.info(str(tcorr))

        # Call the main routine to process data
        try:
            outdct = offline_cal.offline_point_source_calibration(
                files,
                src,
                start=start,
                stop=stop,
                inputmap=inputmap,
                tcorr=tcorr,
                logging_params=logging_params,
                **config.analysis.as_dict())

        except Exception as e:
            msg = 'offline_cal.offline_point_source_calibration failed with error:  %s' % e
            mlog.error(msg)
            continue
            #raise RuntimeError(msg)

        # Find existing gain files for this particular point source
        if src not in h5_psrc_fit:

            output_files = find_files(config, psrc=src)
            if output_files is not None:
                output_files = output_files[-1]
                mlog.info('Writing %s transit on CSD %d to existing file %s.' %
                          (src, csd, output_files))

            h5_psrc_fit[src] = containers.PointSourceWriter(
                src,
                output_file=output_files,
                output_dir=config.output_dir,
                output_suffix=point_source_name_to_file_suffix(src),
                instrument=config.correlator,
                max_file_size=config.max_file_size,
                max_num=config.max_num_time,
                memory_size=0)

        # Associate this gain calibration to the transit time
        this_time = ephemeris.transit_times(FluxCatalog[src].skyfield,
                                            ephemeris.csd_to_unix(csd))[0]

        outdct['csd'] = csd
        outdct['is_daytime'] = is_day
        outdct['acquisition'] = os.path.basename(os.path.dirname(files[0]))

        # Write to output file
        mlog.info('Writing to disk results from %s transit on CSD %d.' %
                  (src, csd))
        h5_psrc_fit[src].write(this_time, **outdct)

        # Dump an individual file for this point source transit
        mlog.info('Dumping to disk single file for %s transit on CSD %d.' %
                  (src, csd))
        dump_dir = os.path.join(config.output_dir, 'point_source_gains')
        containers.mkdir(dump_dir)

        dump_file = os.path.join(dump_dir, '%s_csd_%d.h5' % (src.lower(), csd))
        h5_psrc_fit[src].dump(dump_file,
                              datasets=[
                                  'csd', 'acquisition', 'is_daytime', 'gain',
                                  'weight', 'timing', 'model'
                              ])

        mlog.info('Finished analysis of %s transit on CSD %d.' % (src, csd))
Exemplo n.º 21
0
    def process(self, sstream, inputmap, inputmask):
        """Determine calibration from a timestream.

        Parameters
        ----------
        sstream : andata.CorrData or containers.SiderealStream
            Timestream collected during the day.
        inputmap : list of :class:`CorrInput`
            A list describing the inputs as they are in the file.
        inputmask : containers.CorrInputMask
            Mask indicating which correlator inputs to use in the
            eigenvalue decomposition.

        Returns
        -------
        suntrans : containers.SunTransit
            Response to the sun.
        """

        from operator import itemgetter
        from itertools import groupby
        from .calibration import _extract_diagonal, solve_gain

        # Ensure that we are distributed over frequency
        sstream.redistribute("freq")

        # Find the local frequencies
        nfreq = sstream.vis.local_shape[0]
        sfreq = sstream.vis.local_offset[0]
        efreq = sfreq + nfreq

        # Get the local frequency axis
        freq = sstream.freq["centre"][sfreq:efreq]
        wv = 3e2 / freq

        # Get times
        if hasattr(sstream, "time"):
            time = sstream.time
            ra = ephemeris.transit_RA(time)
        else:
            ra = sstream.index_map["ra"][:]
            csd = (sstream.attrs["lsd"]
                   if "lsd" in sstream.attrs else sstream.attrs["csd"])
            csd = csd + ra / 360.0
            time = ephemeris.csd_to_unix(csd)

        # Only examine data between sunrise and sunset
        time_flag = np.zeros(len(time), dtype=np.bool)
        rise = ephemeris.solar_rising(time[0] - 24.0 * 3600.0,
                                      end_time=time[-1])
        for rr in rise:
            ss = ephemeris.solar_setting(rr)[0]
            time_flag |= (time >= rr) & (time <= ss)

        if not np.any(time_flag):
            self.log.debug(
                "No daytime data between %s and %s.",
                ephemeris.unix_to_datetime(time[0]).strftime("%b %d %H:%M"),
                ephemeris.unix_to_datetime(time[-1]).strftime("%b %d %H:%M"),
            )
            return None

        # Convert boolean flag to slices
        time_index = np.where(time_flag)[0]

        time_slice = []
        ntime = 0
        for key, group in groupby(
                enumerate(time_index),
                lambda index_item: index_item[0] - index_item[1]):
            group = list(map(itemgetter(1), group))
            ngroup = len(group)
            time_slice.append(
                (slice(group[0], group[-1] + 1), slice(ntime, ntime + ngroup)))
            ntime += ngroup

        time = np.concatenate([time[slc[0]] for slc in time_slice])
        ra = np.concatenate([ra[slc[0]] for slc in time_slice])

        # Get ra, dec, alt of sun
        sun_pos = np.array([
            ra_dec_of(ephemeris.skyfield_wrapper.ephemeris["sun"], t)
            for t in time
        ])

        # Convert from ra to hour angle
        sun_pos[:, 0] = np.radians(ra) - sun_pos[:, 0]

        # Determine good inputs
        nfeed = len(inputmap)
        good_input = np.arange(
            nfeed, dtype=np.int)[inputmask.datasets["input_mask"][:]]

        # Use input map to figure out which are the X and Y feeds
        xfeeds = np.array([
            idx for idx, inp in enumerate(inputmap)
            if tools.is_chime_x(inp) and (idx in good_input)
        ])
        yfeeds = np.array([
            idx for idx, inp in enumerate(inputmap)
            if tools.is_chime_y(inp) and (idx in good_input)
        ])

        self.log.debug(
            "Performing sun calibration with %d/%d good feeds (%d xpol, %d ypol).",
            len(good_input),
            nfeed,
            len(xfeeds),
            len(yfeeds),
        )

        # Construct baseline vector for each visibility
        feed_pos = tools.get_feed_positions(inputmap)
        vis_pos = np.array([
            feed_pos[ii] - feed_pos[ij]
            for ii, ij in sstream.index_map["prod"][:]
        ])
        vis_pos = np.where(np.isnan(vis_pos), np.zeros_like(vis_pos), vis_pos)

        u = (vis_pos[np.newaxis, :, 0] / wv[:, np.newaxis])[:, :, np.newaxis]
        v = (vis_pos[np.newaxis, :, 1] / wv[:, np.newaxis])[:, :, np.newaxis]

        # Create container to hold results of fit
        suntrans = containers.SunTransit(time=time,
                                         pol_x=xfeeds,
                                         pol_y=yfeeds,
                                         axes_from=sstream)
        for key in suntrans.datasets.keys():
            suntrans.datasets[key][:] = 0.0

        # Set coordinates
        suntrans.coord[:] = sun_pos

        # Loop over time slices
        for slc_in, slc_out in time_slice:

            # Extract visibility slice
            vis_slice = sstream.vis[..., slc_in].copy()

            ha = (sun_pos[slc_out, 0])[np.newaxis, np.newaxis, :]
            dec = (sun_pos[slc_out, 1])[np.newaxis, np.newaxis, :]

            # Extract the diagonal (to be used for weighting)
            norm = (_extract_diagonal(vis_slice, axis=1).real)**0.5
            norm = tools.invert_no_zero(norm)

            # Fringestop
            if self.fringestop:
                vis_slice *= tools.fringestop_phase(
                    ha, np.radians(ephemeris.CHIMELATITUDE), dec, u, v)

            # Solve for the point source response of each set of polarisations
            ev_x, resp_x, err_resp_x = solve_gain(vis_slice,
                                                  feeds=xfeeds,
                                                  norm=norm[:, xfeeds])
            ev_y, resp_y, err_resp_y = solve_gain(vis_slice,
                                                  feeds=yfeeds,
                                                  norm=norm[:, yfeeds])

            # Save to container
            suntrans.evalue_x[..., slc_out] = ev_x
            suntrans.evalue_y[..., slc_out] = ev_y

            suntrans.response[:, xfeeds, slc_out] = resp_x
            suntrans.response[:, yfeeds, slc_out] = resp_y

            suntrans.response_error[:, xfeeds, slc_out] = err_resp_x
            suntrans.response_error[:, yfeeds, slc_out] = err_resp_y

        # If requested, fit a model to the primary beam of the sun transit
        if self.model_fit:

            # Estimate peak RA
            i_transit = np.argmin(np.abs(sun_pos[:, 0]))

            body = ephemeris.skyfield_wrapper.ephemeris["sun"]
            obs = ephemeris._get_chime()
            obs.date = ephemeris.unix_to_ephem_time(time[i_transit])
            body.compute(obs)

            peak_ra = ephemeris.peak_RA(body)
            dra = ra - peak_ra
            dra = np.abs(dra - (dra > np.pi) * 2.0 * np.pi)[np.newaxis,
                                                            np.newaxis, :]

            # Estimate FWHM
            sig_x = cal_utils.guess_fwhm(freq,
                                         pol="X",
                                         dec=body.dec,
                                         sigma=True)[:, np.newaxis, np.newaxis]
            sig_y = cal_utils.guess_fwhm(freq,
                                         pol="Y",
                                         dec=body.dec,
                                         sigma=True)[:, np.newaxis, np.newaxis]

            # Only fit ra values above the specified dynamic range threshold
            fit_flag = np.zeros([nfreq, nfeed, ntime], dtype=np.bool)
            fit_flag[:, xfeeds, :] = dra < (self.nsig * sig_x)
            fit_flag[:, yfeeds, :] = dra < (self.nsig * sig_y)

            # Fit model for the complex response of each feed to the point source
            param, param_cov = cal_utils.fit_point_source_transit(
                ra,
                suntrans.response[:],
                suntrans.response_error[:],
                flag=fit_flag)

            # Save to container
            suntrans.add_dataset("flag")
            suntrans.flag[:] = fit_flag

            suntrans.add_dataset("parameter")
            suntrans.parameter[:] = param

            suntrans.add_dataset("parameter_cov")
            suntrans.parameter_cov[:] = param_cov

        # Update attributes
        units = "sqrt(" + sstream.vis.attrs.get("units",
                                                "correlator-units") + ")"
        suntrans.response.attrs["units"] = units
        suntrans.response_error.attrs["units"] = units

        suntrans.attrs["source"] = "Sun"

        # Return sun transit
        return suntrans