Example #1
0
    def _finalize_transit(self):
        """Concatenate grouped time streams for the currrent transit."""

        # Find where transit starts and ends
        if len(self.tstreams) == 0 or self.cur_transit is None:
            self.log.info("Did not find any transits.")
            return None
        self.log.debug(
            "Finalising transit for {}...".format(
                ephem.unix_to_datetime(self.cur_transit)
            )
        )
        all_t = np.concatenate([ts.time for ts in self.tstreams])
        start_ind = int(np.argmin(np.abs(all_t - self.start_t)))
        stop_ind = int(np.argmin(np.abs(all_t - self.end_t)))

        # Save list of filenames
        filenames = [ts.attrs["filename"] for ts in self.tstreams]

        dt = self.tstreams[0].time[1] - self.tstreams[0].time[0]
        if dt <= 0:
            self.log.warning(
                "Time steps are not positive definite: dt={:.3f}".format(dt)
                + " Skipping."
            )
            ts = None
        if stop_ind - start_ind > int(self.min_span / 360.0 * SIDEREAL_DAY_SEC / dt):
            if len(self.tstreams) > 1:
                # Concatenate timestreams
                ts = tod.concatenate(self.tstreams, start=start_ind, stop=stop_ind)
            else:
                ts = self.tstreams[0]
            _, dec = ephem.object_coords(
                self.src, all_t[0], deg=True, obs=self.observer
            )
            ts.attrs["dec"] = dec
            ts.attrs["source_name"] = self.source
            ts.attrs["transit_time"] = self.cur_transit
            ts.attrs["observation_id"] = self.obs_id
            ts.attrs["tag"] = "{}_{:0>4d}_{}".format(
                self.source,
                self.obs_id,
                ephem.unix_to_datetime(self.cur_transit).strftime("%Y%m%dT%H%M%S"),
            )
            ts.attrs["archivefiles"] = filenames
        else:
            self.log.info("Transit too short. Skipping.")
            ts = None

        self.tstreams = []
        self.cur_transit = None

        return ts
Example #2
0
def _cut_non_chime(data, visi, chan_array, inputs=None):
    """
    Remove non CHIME channels (noise injection, RFI antenna,
    26m, etc...) from visibility. Also remove channels marked
    as powered-off in layout DB.
    """

    # Map of channels to corr. inputs:
    input_map = data.input
    tmstp = data.index_map["time"]["ctime"]  # time stamp
    # Datetime halfway through data:
    half_time = ch_eph.unix_to_datetime(tmstp[int(len(tmstp) // 2)])
    # Get information on correlator inputs, if not already supplied
    if inputs is None:
        inputs = tools.get_correlator_inputs(half_time)
    # Reorder inputs to have sema order as input map (and data)
    inputs = tools.reorder_correlator_inputs(input_map, inputs)
    # Get noise source channel index:

    # Test if inputs are attached to CHIME antenna and powered on:
    pwds = tools.is_chime_on(inputs)

    for ii in range(len(inputs)):
        #        if ( (not tools.is_chime(inputs[ii]))
        if (not pwds[ii]) and (ii in chan_array):
            # Remove non-CHIME-on channels from visibility matrix...
            idx = np.where(chan_array == ii)[0][0]  # index of channel
            visi = np.delete(visi, idx, axis=0)
            # ...and from product array:
            chan_array = np.delete(chan_array, idx, axis=0)

    return visi, chan_array
def fs_from_file(filename, frq, src,  
                 del_t=900, transposed=True, subtract_avg=False):

    f = h5py.File(filename, 'r')

    times = f['index_map']['time'].value['ctime'] + 10.6

    src_trans = eph.transit_times(src, times[0])

    # try to account for differential arrival time from cylinder rotation. 

    del_phi = (src._dec - np.radians(eph.CHIMELATITUDE)) * np.sin(np.radians(1.988))
    del_phi *= (24 * 3600.0) / (2 * np.pi)

    # Adjust the transit time accordingly                                                                                   
    src_trans += del_phi

    # Select +- del_t of transit, accounting for the mispointing      
    t_range = np.where((times < src_trans + del_t) & (times > src_trans - del_t))[0]

    times = times[t_range[0]:t_range[-1]]#[offp::2] test

    print "Time range:", times[0], times[-1]

    print "\n...... This data is from %s starting at RA: %f ...... \n" \
        % (eph.unix_to_datetime(times[0]), eph.transit_RA(times[0]))


    if transposed is True:
        v = f['vis'][frq[0]:frq[-1]+1, :]
        v = v[..., t_range[0]:t_range[-1]]
        vis = v['r'] + 1j * v['i']

        del v

    # Read in time and freq slice if data has not yet been transposed
    if transposed is False:
         v = f['vis'][t_range[0]:t_range[-1], frq[0]:frq[-1]+1, :]
         vis = v['r'][:] + 1j * v['i'][:]
         del v
         vis = np.transpose(vis, (1, 2, 0))

    inp = gen_inp()[0]

    # Remove offset from galaxy                                                                                
    if subtract_avg is True:
        vis -= 0.5 * (vis[..., 0] + vis[..., -1])[..., np.newaxis]

    freq_MHZ = 800.0 - np.array(frq) / 1024.0 * 400.
    print len(inp)

    baddies = np.where(np.isnan(tools.get_feed_positions(inp)[:, 0]))[0]

    # Fringestop to location of "src"

    data_fs = tools.fringestop_pathfinder(vis, eph.transit_RA(times), freq_MHZ, inp, src)
#    data_fs = fringestop_pathfinder(vis, eph.transit_RA(times), freq_MHZ, inp, src)


    return data_fs
def print_timestamp(fn):
    src_dict = {'CasA': eph.CasA, 'TauA': eph.TauA, 'CygA': eph.CygA, 'VirA': eph.VirA,
            '1929': 292.0, '0329': 54.0}

    try:
        f = h5py.File(fn, 'r')
    except IOError:
        print "File couldn't be opened"
        return 

    times = f['index_map']['time'].value['ctime'][:]

    print "-------------------------------------"
    print "start time %s in PT\n" % (eph.unix_to_datetime(times[0]))

    print "RA range %f : %f" % (np.round(eph.transit_RA(times[0]), 1), 
                                np.round(eph.transit_RA(times[-1]),1))
    print "-------------------------------------"

    srcz = []
    for src in src_dict:
        if eph.transit_times(src_dict[src], times[0])[0] < times[-1]:
            print "%s is in this file" % src
            srcz.append(src)

    return srcz
Example #5
0
def search_extended_timing_solutions(timing_files, timestamp):

    # Load the timing correction
    nfiles = len(timing_files)
    tstart = np.zeros(nfiles, dtype=np.float32)
    tstop = np.zeros(nfiles, dtype=np.float32)
    all_tcorr = []

    for ff, filename in enumerate(timing_files):

        kwargs = {}
        with h5py.File(filename, 'r') as handler:

            for key in ['tau', 'avg_phase', 'noise_source', 'time']:

                kwargs[key] = handler[key][:]

        tcorr = timing.TimingCorrection(**kwargs)

        all_tcorr.append(tcorr)
        tstart[ff] = tcorr.time[0]
        tstop[ff] = tcorr.time[-1]

    # Map timestamp to a timing correction object
    imatch = np.flatnonzero((timestamp >= tstart) & (timestamp <= tstop))

    if imatch.size > 1:
        ValueError("Timing corrections overlap!")
    elif imatch.size < 1:
        ValueError("No timing correction for transit on %s (CSD %d)" %
                   (ephemeris.unix_to_datetime(timestamp).strftime("%Y-%m-%d"),
                    ephemeris.unix_to_csd(timestamp)))

    return all_tcorr[imatch[0]]
Example #6
0
    def _transit_bounds(self):
        """Find the start and end times of this transit.

        Compares the desired HA span to the start and end times of the observation
        recorded in the database. Also gets the observation ID."""

        # subtract half a day from start time to ensure we don't get following day
        self.start_t = self.cur_transit - self.ha_span / 360.0 / 2.0 * SIDEREAL_DAY_SEC
        self.end_t = self.cur_transit + self.ha_span / 360.0 / 2.0 * SIDEREAL_DAY_SEC

        # get bounds of observation from database
        this_run = [
            r
            for r in self.db_runs
            if r[1][0] < self.cur_transit and r[1][1] > self.cur_transit
        ]
        if len(this_run) == 0:
            self.log.warning(
                "Could not find source transit in holography database for {}.".format(
                    ephem.unix_to_datetime(self.cur_transit)
                )
            )
            # skip this file
            self.cur_transit = None
        else:
            self.start_t = max(self.start_t, this_run[0][1][0])
            self.end_t = min(self.end_t, this_run[0][1][1])
            self.obs_id = this_run[0][0]
Example #7
0
def print_timestamp(fn):
    src_dict = {
        'CasA': eph.CasA,
        'TauA': eph.TauA,
        'CygA': eph.CygA,
        'VirA': eph.VirA,
        '1929': 292.0,
        '0329': 54.0
    }

    try:
        f = h5py.File(fn, 'r')
    except IOError:
        print "File couldn't be opened"
        return

    times = f['index_map']['time'].value['ctime'][:]

    print "-------------------------------------"
    print "start time %s in PT\n" % (eph.unix_to_datetime(times[0]))

    print "RA range %f : %f" % (np.round(eph.transit_RA(
        times[0]), 1), np.round(eph.transit_RA(times[-1]), 1))
    print "-------------------------------------"

    srcz = []
    for src in src_dict:
        if eph.transit_times(src_dict[src], times[0])[0] < times[-1]:
            print "%s is in this file" % src
            srcz.append(src)

    return srcz
Example #8
0
    def next(self, intervals):
        """Filter input files to exclude those already processed.

        Parameters
        ----------
        intervals: list of chimedb.data_index.DataInterval
            List intervals to filter.

        Returns
        -------
        files: list of str
            List of files to be processed.
        """

        self.log.info("Starting next for task %s" % self.__class__.__name__)

        self.comm.Barrier()

        files = []
        for fi in intervals:
            start, end = fi[1]
            # find holography observation that overlaps this set
            this_obs = [
                o
                for o in self.hol_obs
                if (o.start_time >= start and o.start_time <= end)
                or (o.finish_time >= start and o.finish_time <= end)
                or (o.start_time <= start and o.finish_time >= end)
            ]

            if len(this_obs) == 0:
                self.log.warning(
                    "Could not find source transit in holography database for {}.".format(
                        ephem.unix_to_datetime(start)
                    )
                )
            elif this_obs[0].id in self.proc_transits:
                self.log.warning(
                    "Already processed transit for {}. Skipping.".format(
                        ephem.unix_to_datetime(start)
                    )
                )
            else:
                files += fi[0]

        self.log.info("Leaving next for task %s" % self.__class__.__name__)
        return files
Example #9
0
    def set_acq_list(self):
        """This method sets four attributes.  The first two attributes
        are 'night_finder' and 'night_acq_list', which are the
        finder object and list of acquisitions that
        contain all night time data between self.t1 and self.t2.
        The second two attributes are 'finder' and 'acq_list',
        which are the finder object and list of acquisitions
        that contain all data beween self.t1 and self.t2 with the
        sunrise, sun transit, and sunset removed.
        """

        # Create a Finder object and focus on time range
        f = finder.Finder(node_spoof=_DEFAULT_NODE_SPOOF)
        f.filter_acqs((data_index.ArchiveInst.name == "pathfinder"))
        f.only_corr()
        f.set_time_range(self.t1, self.t2)

        # Create a list of acquisitions that only contain data collected at night
        f_night = copy.deepcopy(f)
        f_night.exclude_daytime()

        self.night_finder = f_night
        self.night_acq_list = f_night.get_results()

        # Create a list of acquisitions that flag out sunrise, sun transit, and sunset
        mm = ephemeris.unix_to_datetime(self.t1).month
        dd = ephemeris.unix_to_datetime(self.t1).day
        mm = mm + float(dd) / 30.0

        fct = 3.0
        tol1 = (np.arctan(
            (mm - 3.0) * fct) + np.pi / 2.0) * 10500.0 / np.pi + 1500.0
        tol2 = (np.pi / 2.0 - np.arctan(
            (mm - 11.0) * fct)) * 10500.0 / np.pi + 1500.0
        ttol = np.minimum(tol1, tol2)

        fct = 5.0
        tol1 = (np.arctan(
            (mm - 4.0) * fct) + np.pi / 2.0) * 2100.0 / np.pi + 6000.0
        tol2 = (np.pi / 2.0 - np.arctan(
            (mm - 10.0) * fct)) * 2100.0 / np.pi + 6000.0
        rstol = np.minimum(tol1, tol2)

        f.exclude_sun(time_delta=ttol, time_delta_rise_set=rstol)

        self.finder = f
        self.acq_list = f.get_results()
Example #10
0
def main():
    assert len(sys.argv) == 2, "src"

    basedir = '/mnt/gong/archive/'

    if run_ph_solver is True:
        while True:
            dir_recent = reorder_dir(basedir)[0]
  
            src_search = sys.argv[1]
            print dir_recent

            # Search the most recent directory for the source "src_search"
            trans_file = find_trans(basedir + dir_recent, src_search)

            if trans_file is None:
                print "Trans file doesn't exist in this directory, sleeping for 1hr"
                time.sleep(3600)
                continue

            tstring = make_outfile_name(trans_file)

            outfile = './solutions/' + tstring + src_search + '2.hdf5'
            
            # might be useful to just make sure gains are swapped
#            os.system('python solver_script.py %s %s -solve_gains 0 -gen_pkls 1' \
#                          % (trans_file, src_search))

            if os.path.exists(outfile):
                print "%s already exists, taking a nap. Will check again in 1 hour." % outfile
                time.sleep(3600)
                continue

            print "feed this", trans_file

            snm = time.time()

            os.system('python solver_script.py %s %s -solve_gains 1 -gen_pkls 1' \
                          % (trans_file, src_search))
            print "Going to sleep for a whole damn day"
            print eph.unix_to_datetime(time.time() - 8 * 3600)

            time.sleep(12 * 3600)
Example #11
0
def main():
    assert len(sys.argv) == 2, "src"

    basedir = '/mnt/gong/archive/'

    if run_ph_solver is True:
        while True:
            dir_recent = reorder_dir(basedir)[0]

            src_search = sys.argv[1]
            print dir_recent

            # Search the most recent directory for the source "src_search"
            trans_file = find_trans(basedir + dir_recent, src_search)

            if trans_file is None:
                print "Trans file doesn't exist in this directory, sleeping for 1hr"
                time.sleep(3600)
                continue

            tstring = make_outfile_name(trans_file)

            outfile = './solutions/' + tstring + src_search + '2.hdf5'

            # might be useful to just make sure gains are swapped
            #            os.system('python solver_script.py %s %s -solve_gains 0 -gen_pkls 1' \
            #                          % (trans_file, src_search))

            if os.path.exists(outfile):
                print "%s already exists, taking a nap. Will check again in 1 hour." % outfile
                time.sleep(3600)
                continue

            print "feed this", trans_file

            snm = time.time()

            os.system('python solver_script.py %s %s -solve_gains 1 -gen_pkls 1' \
                          % (trans_file, src_search))
            print "Going to sleep for a whole damn day"
            print eph.unix_to_datetime(time.time() - 8 * 3600)

            time.sleep(12 * 3600)
Example #12
0
    def process_finish(self):
        """Normalise the stack and return the result.

        Includes the sample variance over transits within the stack.

        Returns
        -------
        stack: draco.core.containers.TrackBeam
            Stacked transits.
        """
        # Divide by norm to get average transit
        inv_norm = invert_no_zero(self.norm)
        self.stack.beam[:] *= inv_norm
        self.stack.weight[:] = invert_no_zero(self.stack.weight[:]) * self.norm**2

        self.variance = self.variance * inv_norm - np.abs(self.stack.beam[:]) ** 2
        self.pseudo_variance = self.pseudo_variance * inv_norm - self.stack.beam[:] ** 2

        # Calculate the covariance between the real and imaginary component
        # from the accumulated variance and psuedo-variance
        self.stack.sample_variance[0] = 0.5 * (
            self.variance + self.pseudo_variance.real
        )
        self.stack.sample_variance[1] = 0.5 * self.pseudo_variance.imag
        self.stack.sample_variance[2] = 0.5 * (
            self.variance - self.pseudo_variance.real
        )

        # Create tag
        time_range = np.percentile(self.stack.attrs["transit_time"], [0, 100])
        self.stack.attrs["tag"] = "{}_{}_to_{}".format(
            self.stack.attrs["source_name"],
            ephem.unix_to_datetime(time_range[0]).strftime("%Y%m%dT%H%M%S"),
            ephem.unix_to_datetime(time_range[1]).strftime("%Y%m%dT%H%M%S"),
        )

        return self.stack
Example #13
0
    def set_metadata(self, tms, input_map):
        """Sets self.corr_inputs, self.pwds, self.pstns, self.p1_idx, self.p2_idx"""
        from ch_util import tools

        # Get CHIME ON channels:
        half_time = ephemeris.unix_to_datetime(tms[int(len(tms) // 2)])
        corr_inputs = tools.get_correlator_inputs(half_time)
        self.corr_inputs = tools.reorder_correlator_inputs(
            input_map, corr_inputs)
        pwds = tools.is_chime_on(
            self.corr_inputs)  # Which inputs are CHIME ON antennas
        self.pwds = np.array(pwds, dtype=bool)
        # Get cylinders and polarizations
        self.pstns, self.p1_idx, self.p2_idx = self.get_pos_pol(
            self.corr_inputs, self.pwds)
Example #14
0
    def get_prod_sel(self, data):
        """ """
        from ch_util import tools

        input_map = data.input
        tms = data.time
        half_time = ephemeris.unix_to_datetime(tms[int(len(tms) // 2)])
        corr_inputs = tools.get_correlator_inputs(half_time)
        corr_inputs = tools.reorder_correlator_inputs(input_map, corr_inputs)
        pwds = tools.is_chime_on(
            corr_inputs)  # Which inputs are CHIME ON antennas

        wchp1, wchp2, echp1, echp2 = self.get_cyl_pol(corr_inputs, pwds)

        # Ensure base channels are CHIME and ON
        while not pwds[np.where(input_map["chan_id"] == self.bswp1)[0][0]]:
            self.bswp1 += 1
        while not pwds[np.where(input_map["chan_id"] == self.bswp2)[0][0]]:
            self.bswp2 += 1
        while not pwds[np.where(input_map["chan_id"] == self.bsep1)[0][0]]:
            self.bsep1 += 1
        while not pwds[np.where(input_map["chan_id"] == self.bsep2)[0][0]]:
            self.bsep2 += 1

        prod_sel = []
        for (ii, prod) in enumerate(data.prod):
            add_prod = False
            add_prod = add_prod or (
                (prod[0] == self.bswp1 and prod[1] in echp1) or
                (prod[1] == self.bswp1 and prod[0] in echp1))
            add_prod = add_prod or (
                (prod[0] == self.bswp2 and prod[1] in echp2) or
                (prod[1] == self.bswp2 and prod[0] in echp2))
            add_prod = add_prod or (
                (prod[0] == self.bsep1 and prod[1] in wchp1) or
                (prod[1] == self.bsep1 and prod[0] in wchp1))
            add_prod = add_prod or (
                (prod[0] == self.bsep2 and prod[1] in wchp2) or
                (prod[1] == self.bsep2 and prod[0] in wchp2))

            if add_prod:
                prod_sel.append(ii)

        prod_sel.sort()

        return prod_sel, pwds
Example #15
0
    def next(self, ts):
        """Generate an input description from the timestream passed in.

        Parameters
        ----------
        ts : andata.CorrData
            Timestream container.

        Returns
        -------
        inputs : list of :class:`CorrInput`
            A list of describing the inputs as they are in the file.
        """

        # Fetch from the cache if we can
        if self.cache and self._cached_inputs:
            self.log.debug("Using cached inputs.")
            return self._cached_inputs

        inputs = None

        if mpiutil.rank0:

            # Get the datetime of the middle of the file
            time = ephemeris.unix_to_datetime(0.5 * (ts.time[0] + ts.time[-1]))
            inputs = tools.get_correlator_inputs(time)

            inputs = tools.reorder_correlator_inputs(ts.index_map["input"], inputs)

        # Broadcast input description to all ranks
        inputs = mpiutil.world.bcast(inputs, root=0)

        # Save into the cache for the next iteration
        if self.cache:
            self._cached_inputs = inputs

        # Make sure all nodes have container before return
        mpiutil.world.Barrier()

        return inputs
def find_transit(time0='Now', src=eph.CasA):
    import time
    
    if time0 == 'Now':
        time0 = time.time()

    # Get reference datetime from unixtime
    dt_now = eph.unix_to_datetime(time0)
    dt_now = dt_now.isoformat()

    # Only use relevant characters in datetime string
    dt_str = dt_now.replace("-", "")[:7]
    dirnm = '/mnt/gong/archive/' + dt_str

    filelist = glob.glob(dirnm + '*/*h5')

    # Step through each file and find that day's transit
    for ff in filelist:
        
        try:
            andataReader = andata.Reader(ff)
            acqtimes = andataReader.time
            trans_time = eph.transit_times(src, time0)
            
            #print ff
#            print eph.transit_RA(trans_time), eph.transit_RA(acqtimes[0])
            del andataReader
            
            if np.abs(acqtimes - trans_time[0]).min() < 1000.0:
#                print "On ", eph.unix_to_datetime(trans_time[0])
#                print "foundit in %s \n" % ff
                return ff

                break

        except (KeyError, ValueError, IOError):
            pass

    return None
Example #17
0
def find_transit(time0='Now', src=eph.CasA):
    import time

    if time0 == 'Now':
        time0 = time.time()

    # Get reference datetime from unixtime
    dt_now = eph.unix_to_datetime(time0)
    dt_now = dt_now.isoformat()

    # Only use relevant characters in datetime string
    dt_str = dt_now.replace("-", "")[:7]
    dirnm = '/mnt/gong/archive/' + dt_str

    filelist = glob.glob(dirnm + '*/*h5')

    # Step through each file and find that day's transit
    for ff in filelist:

        try:
            andataReader = andata.Reader(ff)
            acqtimes = andataReader.time
            trans_time = eph.transit_times(src, time0)

            #print ff
            #            print eph.transit_RA(trans_time), eph.transit_RA(acqtimes[0])
            del andataReader

            if np.abs(acqtimes - trans_time[0]).min() < 1000.0:
                #                print "On ", eph.unix_to_datetime(trans_time[0])
                #                print "foundit in %s \n" % ff
                return ff

                break

        except (KeyError, ValueError, IOError):
            pass

    return None
Example #18
0
    def process(self, filelist):
        """Generate timing correction from an input list of files.

        Parameters
        ----------
        filelist : list of files

        Returns
        -------
        tcorr : ch_util.timing.TimingCorrection
            Timing correction derived from noise source data.
        """
        # Determine the acquisition
        new_acq = np.unique(
            [os.path.basename(os.path.dirname(ff)) for ff in filelist])
        if new_acq.size > 1:
            raise RuntimeError(
                "Cannot process multiple acquisitions.  Received %d." %
                new_acq.size)
        else:
            new_acq = new_acq[0]

        # If this is a new acquisition, then ensure the
        # static phase and amplitude are recalculated
        if new_acq != self.current_acq:
            self.current_acq = new_acq

            for key in self._datasets_fixed_for_acq:
                self.kwargs[key] = None

        # Process the chimetiming data
        self.log.info("Processing %d files from %s." %
                      (len(filelist), self.current_acq))

        tcorr = timing.TimingData.from_acq_h5(
            filelist,
            only_correction=True,
            distributed=self.comm.size > 1,
            comm=self.comm,
            **self.kwargs,
        )

        self.log.info("Finished processing %d files from %s." %
                      (len(filelist), self.current_acq))

        # Save the static phase and amplitude to be used on subsequent iterations
        # within this acquisition
        for key in self._datasets_fixed_for_acq:
            if key in tcorr.datasets:
                self.kwargs[key] = tcorr.datasets[key][:]
            elif key in tcorr.flags:
                self.kwargs[key] = tcorr.flags[key][:]
            else:
                msg = "Dataset %s could not be found in timing correction object." % key
                raise RuntimeError(msg)

        # Save the names of the files used to construct the correction
        tcorr.attrs["archive_files"] = np.array(filelist)

        # Create a tag indicating the range of time processed
        tfmt = "%Y%m%dT%H%M%SZ"
        start_time = ephemeris.unix_to_datetime(tcorr.time[0]).strftime(tfmt)
        end_time = ephemeris.unix_to_datetime(tcorr.time[-1]).strftime(tfmt)
        tag = [start_time, "to", end_time]
        if self.output_suffix:
            tag.append(self.output_suffix)

        tcorr.attrs["tag"] = "_".join(tag)

        # Return timing correction
        return tcorr
Example #19
0
 def from_lsd(cls, lsd: int):
     unix = csd_to_unix(lsd)
     lsd = int(lsd)
     date = unix_to_datetime(unix).date()
     day = cls(lsd, date)
     return day
Example #20
0
    def process(self, tstream):
        """Apply the timing correction to the input timestream.

        Parameters
        ----------
        tstream : andata.CorrData, containers.TimeStream, or containers.SiderealStream
            Apply the timing correction to the visibilities stored in this container.

        Returns
        -------
        tstream_corr : same as `tstream`
            Timestream with corrected visibilities.
        """
        # Determine times
        if "time" in tstream.index_map:
            timestamp = tstream.time
        else:
            csd = (tstream.attrs["lsd"]
                   if "lsd" in tstream.attrs else tstream.attrs["csd"])
            timestamp = ephemeris.csd_to_unix(csd + tstream.ra / 360.0)

        # Extract local frequencies
        tstream.redistribute("freq")

        nfreq = tstream.vis.local_shape[0]
        sfreq = tstream.vis.local_offset[0]
        efreq = sfreq + nfreq

        freq = tstream.freq[sfreq:efreq]

        # If requested, extract the input flags
        input_flags = tstream.input_flags[:] if self.use_input_flags else None

        # Check needs_timing_correction flags time ranges to see if input timestream
        # falls within the interval
        needs_timing_correction = False
        for flag in self.flags:
            if (timestamp[0] >= flag["start_time"]
                    and timestamp[0] <= flag["finish_time"]):
                if timestamp[-1] >= flag["finish_time"]:
                    raise PipelineRuntimeError(
                        f"Data covering {timestamp[0]} to {timestamp[-1]} partially overlaps "
                        f"needs_timing_correction DataFlag covering {ephemeris.unix_to_datetime(flag['start_time']).strftime('%Y%m%dT%H%M%SZ')} "
                        f"to {ephemeris.unix_to_datetime(flag['finish_time']).strftime('%Y%m%dT%H%M%SZ')}."
                    )
                else:
                    self.log.info(
                        f"Data covering {timestamp[0]} to {timestamp[-1]} flagged by "
                        f"needs_timing_correction DataFlag covering {ephemeris.unix_to_datetime(flag['start_time']).strftime('%Y%m%dT%H%M%SZ')} "
                        f"to {ephemeris.unix_to_datetime(flag['finish_time']).strftime('%Y%m%dT%H%M%SZ')}. Timing correction will be applied."
                    )
                    needs_timing_correction = True
                    break

        if not needs_timing_correction:
            self.log.info(
                f"Data in span {ephemeris.unix_to_datetime(timestamp[0]).strftime('%Y%m%dT%H%M%SZ')} to {ephemeris.unix_to_datetime(timestamp[-1]).strftime('%Y%m%dT%H%M%SZ')} does not need timing correction"
            )
            return tstream

        # Find the right timing correction
        for tcorr in self.tcorr:
            if timestamp[0] >= tcorr.time[0] and timestamp[-1] <= tcorr.time[
                    -1]:
                break
        else:
            msg = (
                "Could not find timing correction file covering "
                "range of timestream data (%s to %s)." % tuple(
                    ephemeris.unix_to_datetime([timestamp[0], timestamp[-1]])))

            if self.pass_if_missing:
                self.log.warning(msg + " Doing nothing.")
                return tstream

            raise RuntimeError(msg)

        self.log.info("Using correction file %s" % tcorr.attrs["tag"])

        # If requested, reference the timing correct with respect to source transit time
        if self.refer_to_transit:
            # First check for transit_time attribute in file
            ttrans = tstream.attrs.get("transit_time", None)
            if ttrans is None:
                source = tstream.attrs["source_name"]
                ttrans = ephemeris.transit_times(
                    ephemeris.source_dictionary[source],
                    tstream.time[0],
                    tstream.time[-1],
                )
                if ttrans.size != 1:
                    raise RuntimeError(
                        "Found %d transits of %s in timestream.  "
                        "Require single transit." % (ttrans.size, source))
                else:
                    ttrans = ttrans[0]

            self.log.info(
                "Referencing timing correction to %s (RA=%0.1f deg)." % (
                    ephemeris.unix_to_datetime(ttrans).strftime(
                        "%Y%m%dT%H%M%SZ"),
                    ephemeris.lsa(ttrans),
                ))

            tcorr.set_global_reference_time(ttrans,
                                            window=self.transit_window,
                                            interpolate=True,
                                            interp="linear")

        # Apply the timing correction
        tcorr.apply_timing_correction(tstream,
                                      time=timestamp,
                                      freq=freq,
                                      input_flags=input_flags,
                                      copy=False)

        return tstream
Example #21
0
def svd_panels(data,
               results,
               rank,
               cmap=COLORMAP,
               vrng=None,
               nticks=20,
               cyl_size=256,
               input_range=None,
               input_map=None,
               corr_order=False,
               zoom_date=None,
               timezone='Canada/Pacific',
               phys_unit=True):

    inputs = data.index_map['input'][:]
    ninput = inputs.size

    timestamp = data.time[:]
    ntime = timestamp.size

    good_input = results['good_input']
    good_transit = results['good_transit']

    if (input_map is not None) and corr_order:
        cord = np.array([inp.corr_order for inp in input_map])
        cord = cord[good_input]
        isort = np.argsort(cord)
        cttl = "Correlator Order"
    else:
        cord = good_input
        isort = np.argsort(cord)
        cttl = "Feed Number"

    if input_range is None:
        input_range = [0, ninput]

    tfmt = "%b-%d"
    tz = pytz.timezone(timezone)

    skip = ntime // nticks
    xtck = np.arange(0, ntime, skip, dtype=np.int)
    xtcklbl = [
        tt.astimezone(tz).strftime(tfmt)
        for tt in ephemeris.unix_to_datetime(timestamp[::skip])
    ]

    ytck = np.arange(ninput // cyl_size + 1) * cyl_size

    if zoom_date is None:
        zoom_date = []
    else:
        zoom_date = [list(ephemeris.datetime_to_unix(zd)) for zd in zoom_date]

    zoom_date = [[timestamp[0], timestamp[-1]]] + zoom_date
    nzoom = len(zoom_date)

    gs = gridspec.GridSpec(2 + nzoom // 2,
                           2,
                           height_ratios=[3, 2] + [2] * (nzoom // 2),
                           hspace=0.30)

    this_rank = np.outer(results['S'][rank] * results['U'][:, rank],
                         results['VH'][rank, :])

    var = results['S']**2
    exp_var = 100.0 * var / np.sum(var)
    exp_var = exp_var[rank]

    if vrng is None:
        vrng = np.nanpercentile(this_rank, [2, 98])

    y = np.full((ninput, ntime), np.nan, dtype=np.float32)
    for ii, gi in enumerate(cord):
        y[gi, good_transit] = this_rank[:, ii]

    plt.subplot(gs[0, :])

    img = plt.imshow(y,
                     aspect='auto',
                     origin='lower',
                     interpolation='nearest',
                     extent=(0, ntime, 0, ninput),
                     vmin=vrng[0],
                     vmax=vrng[1],
                     cmap=cmap)

    cbar = plt.colorbar(img)
    cbar.ax.set_ylabel(r'$\tau$ [picosec]')

    plt.ylabel(cttl)
    plt.title("k = %d | Explains %0.1f%% of Variance" % (rank, exp_var))

    plt.xticks(xtck, xtcklbl, rotation=70.0)
    plt.yticks(ytck)
    plt.ylim(input_range)

    if phys_unit:
        scale = 1.48625 * np.median(
            np.abs(results['S'][rank] * results['U'][:, rank]))
        lbl = r"v$_{%d}$  x MED$(|\sigma_{%d} u_{%d}|)$ [picosec]" % (
            rank, rank, rank)
    else:
        scale = 1.0
        lbl = r"v$_{%d}$" % rank

    plt.subplot(gs[1, 0])
    plt.plot(cord,
             scale * results['VH'][rank, :],
             color='k',
             marker='.',
             linestyle='None')
    plt.grid()
    plt.xticks(ytck, rotation=70)
    plt.xlim(input_range)
    plt.xlabel(cttl)
    plt.ylabel(lbl)

    scolors = ['b', 'r', 'forestgreen', 'magenta', 'orange']
    list_good = list(results['good_transit'])

    if phys_unit:
        scale = 1.48625 * np.median(np.abs(results['VH'][rank, :]))
        lbl = r"$\sigma_{%d}$ u$_{%d}$  x MED$(|v_{%d}|)$ [picosec]" % (
            rank, rank, rank)
    else:
        scale = 1.0
        lbl = r"$\sigma_{%d}$ u$_{%d}$" % (rank, rank)

    classification = np.char.add(np.char.add(data['calibrator'][:], '/'),
                                 data['source'][:])
    usource, ind_source = np.unique(classification, return_inverse=True)
    nsource = usource.size

    for zz, (start_time, end_time) in enumerate(zoom_date):

        row = (zz + 1) // 2 + 1
        col = (zz + 1) % 2

        plt.subplot(gs[row, col])

        this_zoom = np.flatnonzero((data.time >= start_time)
                                   & (data.time <= end_time))
        nsamples = this_zoom.size

        skip = nsamples // (nticks // 2)

        xtck = np.arange(0, nsamples, skip, dtype=np.int)
        ztfmt = "%b-%d %H:%M" if zz > 0 else tfmt
        xtcklbl = np.array([
            tt.astimezone(tz).strftime(ztfmt)
            for tt in ephemeris.unix_to_datetime(data.time[this_zoom[xtck]])
        ])

        for ss, src in enumerate(usource):

            this_source = np.flatnonzero(ind_source[this_zoom] == ss)
            if this_source.size == 0:
                continue

            this_source = np.array(
                [ts for ts in this_source if this_zoom[ts] in list_good])
            this_source_good = np.array(
                [list_good.index(ind) for ind in this_zoom[this_source]])

            plt.plot(this_source,
                     scale * results['S'][rank] *
                     results['U'][this_source_good, rank],
                     color=scolors[ss],
                     marker='.',
                     markersize=4,
                     linestyle='None',
                     label=src)

        plt.xticks(xtck, xtcklbl, rotation=70.0)
        plt.xlim(0, nsamples)

        plt.grid()
        plt.ylabel(lbl)
        plt.legend(prop={'size': 10})
Example #22
0
def main(config_file=None, logging_params=DEFAULT_LOGGING):

    # Load config
    config = DEFAULTS.deepcopy()
    if config_file is not None:
        print(config_file)
        config.merge(NameSpace(load_yaml_config(config_file)))

    # Setup logging
    log.setup_logging(logging_params)
    logger = log.get_logger(__name__)

    timer = Timer(logger)

    # Load data
    sfile = config.data.filename if os.path.isabs(
        config.data.filename) else os.path.join(config.directory,
                                                config.data.filename)
    sdata = StabilityData.from_file(sfile)

    ninput, ntime = sdata['tau'].shape

    # Load temperature data
    tfile = (config.temperature.filename
             if os.path.isabs(config.temperature.filename) else os.path.join(
                 config.directory, config.temperature.filename))

    tkeys = ['flag', 'data_flag', 'outlier']
    if config.temperature.load:
        tkeys += config.temperature.load

    tdata = TempData.from_acq_h5(tfile, datasets=tkeys)

    # Query layout database
    inputmap = tools.get_correlator_inputs(ephemeris.unix_to_datetime(
        np.median(sdata.time[:])),
                                           correlator='chime')

    good_input = np.flatnonzero(np.any(sdata['flags']['tau'][:], axis=-1))
    pol = sutil.get_pol(sdata, inputmap)
    npol = len(pol)

    mezz_index, crate_index = sutil.get_mezz_and_crate(sdata, inputmap)

    if config.mezz_ref.enable:
        phase_ref = [
            ipol[mezz_index[ipol] == iref]
            for ipol, iref in zip(pol, config.mezz_ref.mezz)
        ]
    else:
        phase_ref = config.data.phase_ref

    # Load timing
    if config.timing.enable:

        # Extract filenames from config
        timing_files = [
            tf if os.path.isabs(tf) else os.path.join(config.directory, tf)
            for tf in config.timing.files
        ]
        timing_files_hpf = [
            os.path.join(os.path.dirname(tf), 'hpf', os.path.basename(tf))
            for tf in timing_files
        ]
        timing_files_lpf = [
            os.path.join(os.path.dirname(tf), 'lpf', os.path.basename(tf))
            for tf in timing_files
        ]

        # If requested, add the timing data back into the delay data
        if config.timing.add.enable:

            timer.start("Adding timing data to delay measurements.")

            ns_tau, _, ns_flag, ns_inputs = sutil.get_timing_correction(
                sdata, timing_files, **config.timing.add.kwargs)

            index = timing.map_input_to_noise_source(sdata.index_map['input'],
                                                     ns_inputs)

            timing_tau = ns_tau[index, :]
            timing_flag = ns_flag[index, :]
            for ipol, iref in zip(pol, config.data.phase_ref):
                timing_tau[ipol, :] = timing_tau[ipol, :] - timing_tau[
                    iref, np.newaxis, :]
                timing_flag[ipol, :] = timing_flag[ipol, :] & timing_flag[
                    iref, np.newaxis, :]

            sdata['tau'][:] = sdata['tau'][:] + timing_tau
            sdata['flags']['tau'][:] = sdata['flags']['tau'][:] & timing_flag

            timer.stop()

        # Extract the dependent variables from the timing dataset
        timer.start("Calculating timing dependence.")

        if config.timing.sep_delay:
            logger.info("Fitting HPF and LPF timing correction separately.")
            files = timing_files_hpf
            files2 = timing_files_lpf
        else:
            files2 = None
            if config.timing.hpf_delay:
                logger.info("Using HPF timing correction for delay.")
                files = timing_files_hpf
            elif config.timing.lpf_delay:
                logger.info("Using LPF timing correction for delay.")
                files = timing_files_lpf
            else:
                logger.info("Using full timing correction for delay.")
                files = timing_files

        kwargs = {}
        if config.timing.lpf_amp:
            logger.info("Using LPF timing correction for amplitude.")
            kwargs['afiles'] = timing_files_lpf
        elif config.timing.hpf_amp:
            logger.info("Using HPF timing correction for amplitude.")
            kwargs['afiles'] = timing_files_hpf
        else:
            logger.info("Using full timing correction for amplitude.")
            kwargs['afiles'] = timing_files

        for key in ['ns_ref', 'inter_cmn', 'fit_amp', 'ref_amp', 'cmn_amp']:
            if key in config.timing:
                kwargs[key] = config.timing[key]

        xtiming, xtiming_flag, xtiming_group = sutil.timing_dependence(
            sdata, files, inputmap, **kwargs)

        if files2 is not None:
            logger.info("Calculating second timing dependence.")
            kwargs['fit_amp'] = False
            xtiming2, xtiming2_flag, xtiming2_group = sutil.timing_dependence(
                sdata, files2, inputmap, **kwargs)

            xtiming = np.concatenate((xtiming, xtiming2), axis=-1)
            xtiming_flag = np.concatenate((xtiming_flag, xtiming2_flag),
                                          axis=-1)
            xtiming_group = np.concatenate((xtiming_group, xtiming2_group),
                                           axis=-1)

        timer.stop()

    else:
        xtiming = None
        xtiming_flag = None
        xtiming_group = None

    # Reference delay data to mezzanine
    if config.mezz_ref.enable:

        timer.start("Referencing delay measurements to mezzanine.")

        for ipol, iref in zip(pol, config.mezz_ref.mezz):

            this_mezz = ipol[mezz_index[ipol] == iref]

            wmezz = sdata['flags']['tau'][this_mezz, :].astype(np.float32)

            norm = np.sum(wmezz, axis=0)

            taut_mezz = np.sum(wmezz * sdata['tau'][this_mezz, :],
                               axis=0) * tools.invert_no_zero(norm)
            flagt_mezz = norm > 0.0

            sdata['tau'][
                ipol, :] = sdata['tau'][ipol, :] - taut_mezz[np.newaxis, :]
            sdata['flags']['tau'][ipol, :] = sdata['flags']['tau'][
                ipol, :] & flagt_mezz[np.newaxis, :]

        timer.stop()

    # Load cable monitor
    if config.cable_monitor.enable:

        timer.start("Calculating cable monitor dependence.")

        cbl = timing.TimingCorrection.from_acq_h5(
            config.cable_monitor.filename)

        kwargs = {'include_diff': config.cable_monitor.include_diff}

        xcable, xcable_flag, xcable_group = sutil.cable_monitor_dependence(
            sdata, cbl, **kwargs)

        timer.stop()

    else:
        xcable = None
        xcable_flag = None
        xcable_group = None

    # Load NS distance
    if config.ns_distance.enable:

        timer.start("Calculating NS distance dependence.")

        kwargs = {}
        kwargs['phase_ref'] = phase_ref

        for key in [
                'sensor', 'temp_field', 'sep_cyl', 'sep_feed',
                'include_offset', 'include_ha'
        ]:
            if key in config.ns_distance:
                kwargs[key] = config.ns_distance[key]

        if config.ns_distance.use_cable_monitor:
            kwargs['is_cable_monitor'] = True
            kwargs['use_alpha'] = config.ns_distance.use_alpha
            nsx = timing.TimingCorrection.from_acq_h5(
                config.cable_monitor.filename)
        else:
            kwargs['is_cable_monitor'] = False
            nsx = tdata

        xdist, xdist_flag, xdist_group = sutil.ns_distance_dependence(
            sdata, nsx, inputmap, **kwargs)

        if (config.ns_distance.deriv
                is not None) and (config.ns_distance.deriv > 0):

            for dd in range(1, config.ns_distance.deriv + 1):

                d_xdist, d_xdist_flag, d_xdist_group = sutil.ns_distance_dependence(
                    sdata, tdata, inputmap, deriv=dd, **kwargs)

                tind = np.atleast_1d(1)
                xdist = np.concatenate((xdist, d_xdist[:, :, tind]), axis=-1)
                xdist_flag = xnp.concatenate(
                    (xdist_flag, d_xdist_flag[:, :, tind]), axis=-1)
                xdist_group = np.concatenate(
                    (xdist_group, d_xdist_group[:, tind]), axis=-1)

        timer.stop()

    else:
        xdist = None
        xdist_flag = None
        xdist_group = None

    # Load temperatures
    if config.temperature.enable:

        timer.start("Calculating temperature dependence.")

        xtemp, xtemp_flag, xtemp_group, xtemp_name = sutil.temperature_dependence(
            sdata,
            tdata,
            config.temperature.sensor,
            field=config.temperature.temp_field,
            inputmap=inputmap,
            phase_ref=phase_ref,
            check_hut=config.temperature.check_hut)

        if (config.temperature.deriv
                is not None) and (config.temperature.deriv > 0):

            for dd in range(1, config.temperature.deriv + 1):

                d_xtemp, d_xtemp_flag, d_xtemp_group, d_xtemp_name = sutil.temperature_dependence(
                    sdata,
                    tdata,
                    config.temperature.sensor,
                    field=config.temperature.temp_field,
                    deriv=dd,
                    inputmap=inputmap,
                    phase_ref=phase_ref,
                    check_hut=config.temperature.check_hut)

                xtemp = np.concatenate((xtemp, d_xtemp), axis=-1)
                xtemp_flag = xnp.concatenate((xtemp_flag, d_xtemp_flag),
                                             axis=-1)
                xtemp_group = np.concatenate((xtemp_group, d_xtemp_group),
                                             axis=-1)
                xtemp_name += d_xtemp_name

        timer.stop()

    else:
        xtemp = None
        xtemp_flag = None
        xtemp_group = None
        xtemp_name = None

    # Combine into single feature matrix
    x, coeff_name = _concatenate(xdist,
                                 xtemp,
                                 xcable,
                                 xtiming,
                                 name_xtemp=xtemp_name)

    x_group, _ = _concatenate(xdist_group, xtemp_group, xcable_group,
                              xtiming_group)

    x_flag, _ = _concatenate(xdist_flag, xtemp_flag, xcable_flag, xtiming_flag)
    x_flag = np.all(x_flag, axis=-1) & sdata.flags['tau'][:]

    nfeature = x.shape[-1]

    logger.info("Fitting %d features." % nfeature)

    # Save data
    if config.preliminary_save.enable:

        if config.preliminary_save.filename is not None:
            ofile = (config.preliminary_save.filename if os.path.isabs(
                config.preliminary_save.filename) else os.path.join(
                    config.directory, config.preliminary_save.filename))
        else:
            ofile = os.path.splitext(
                sfile)[0] + '_%s.h5' % config.preliminary_save.suffix

        sdata.save(ofile, mode='w')

    # Subtract mean
    if config.mean_subtract:
        timer.start("Subtracting mean value.")

        tau, mu_tau, mu_tau_flag = sutil.mean_subtract(sdata,
                                                       sdata['tau'][:],
                                                       x_flag,
                                                       use_calibrator=True)

        mu_x = np.zeros(mu_tau.shape + (nfeature, ), dtype=x.dtype)
        mu_x_flag = np.zeros(mu_tau.shape + (nfeature, ), dtype=np.bool)
        x_no_mu = x.copy()
        for ff in range(nfeature):
            x_no_mu[..., ff], mu_x[...,
                                   ff], mu_x_flag[...,
                                                  ff] = sutil.mean_subtract(
                                                      sdata,
                                                      x[:, :, ff],
                                                      x_flag,
                                                      use_calibrator=True)
        timer.stop()

    else:
        x_no_mu = x.copy()
        tau = sdata['tau'][:].copy()

    # Calculate unique days
    csd_uniq, bmap = np.unique(sdata['csd'][:], return_inverse=True)
    ncsd = csd_uniq.size

    # Prepare unique sources
    classification = np.char.add(np.char.add(sdata['calibrator'][:], '/'),
                                 sdata['source'][:])

    # If requested, load existing coefficients
    if config.coeff is not None:
        coeff = andata.BaseData.from_acq_h5(config.coeff)
        evaluate_only = True
    else:
        evaluate_only = False

    # If requested, set up boot strapping
    if config.bootstrap.enable:

        nboot = config.bootstrap.number
        nchoices = ncsd if config.bootstrap.by_transit else ntime
        nsample = int(config.bootstrap.fraction * nchoices)

        bindex = np.zeros((nboot, nsample), dtype=np.int)
        for roll in range(nboot):
            bindex[roll, :] = np.sort(
                np.random.choice(nchoices,
                                 size=nsample,
                                 replace=config.bootstrap.replace))

    else:

        nboot = 1
        bindex = np.arange(ntime, dtype=np.int)[np.newaxis, :]

    # Prepare output
    if config.output.directory is not None:
        output_dir = config.output.directory
    else:
        output_dir = config.data.directory

    if config.output.suffix is not None:
        output_suffix = config.output.suffix
    else:
        output_suffix = os.path.splitext(os.path.basename(
            config.data.filename))[0]

    # Perform joint fit
    for bb, bind in enumerate(bindex):

        if config.bootstrap.enable and config.bootstrap.by_transit:
            tind = np.concatenate(
                tuple([np.flatnonzero(bmap == ii) for ii in bind]))
        else:
            tind = bind

        ntime = tind.size

        if config.jackknife.enable:
            start = int(
                config.jackknife.start * ncsd
            ) if config.jackknife.start <= 1.0 else config.jackknife.start
            end = int(
                config.jackknife.end *
                ncsd) if config.jackknife.end <= 1.0 else config.jackknife.end

            time_flag_fit = (bmap >= start) & (bmap < end)

            if config.jackknife.restrict_stat:
                time_flag_stat = np.logical_not(time_flag_fit)
            else:
                time_flag_stat = np.ones(ntime, dtype=np.bool)

        else:
            time_flag_fit = np.ones(ntime, dtype=np.bool)
            time_flag_stat = np.ones(ntime, dtype=np.bool)

        logger.info(
            "Fitting data between %s (CSD %d) and %s (CSD %d)" %
            (ephemeris.unix_to_datetime(np.min(
                sdata.time[tind[time_flag_fit]])).strftime("%Y-%m-%d"),
             np.min(sdata['csd'][:][tind[time_flag_fit]]),
             ephemeris.unix_to_datetime(np.max(
                 sdata.time[tind[time_flag_fit]])).strftime("%Y-%m-%d"),
             np.max(sdata['csd'][:][tind[time_flag_fit]])))

        logger.info(
            "Calculating statistics from data between %s (CSD %d) and %s (CSD %d)"
            % (ephemeris.unix_to_datetime(
                np.min(sdata.time[tind[time_flag_stat]])).strftime("%Y-%m-%d"),
               np.min(sdata['csd'][:][tind[time_flag_stat]]),
               ephemeris.unix_to_datetime(
                   np.max(
                       sdata.time[tind[time_flag_stat]])).strftime("%Y-%m-%d"),
               np.max(sdata['csd'][:][tind[time_flag_stat]])))

        if evaluate_only:
            timer.start("Evaluating coefficients provided.")
            fitter = sutil.JointTempEvaluation(
                x_no_mu[:, tind, :],
                tau[:, tind],
                coeff['coeff'][:],
                flag=x_flag[:, tind],
                coeff_name=coeff.index_map['feature'][:],
                feature_name=coeff_name,
                intercept=coeff['intercept'][:],
                intercept_name=coeff.index_map['classification'][:],
                classification=classification[tind])
            timer.stop()

        else:
            timer.start("Setting up fit.  Bootstrap %d of %d." %
                        (bb + 1, nboot))

            fitter = sutil.JointTempRegression(
                x_no_mu[:, tind, :],
                tau[:, tind],
                x_group,
                flag=x_flag[:, tind],
                classification=classification[tind],
                coeff_name=coeff_name)
            timer.stop()

            timer.start("Performing fit.  Bootstrap %d of %d." %
                        (bb + 1, nboot))
            fitter.fit_temp(time_flag=time_flag_fit, **config.fit_options)
            timer.stop()

        # If bootstrapping, append counter to filename
        if config.bootstrap.enable:
            output_suffix_bb = output_suffix + "_bootstrap_%04d" % (
                config.bootstrap.index_start + bb, )

            with open(
                    os.path.join(output_dir,
                                 "bootstrap_index_%s.json" % output_suffix_bb),
                    'w') as jhandler:
                json.dump({
                    "bind": bind.tolist(),
                    "tind": tind.tolist()
                }, jhandler)

        else:
            output_suffix_bb = output_suffix

        # Save statistics to file
        if config.output.stat:

            # If requested, break the model up into its various components for calculating statistics
            stat_key = ['data', 'model', 'resid']
            if config.refine_model.enable:
                stat_add = fitter.refine_model(config.refine_model.include)
                stat_key += stat_add

            # Redefine axes
            bdata = StabilityData()
            for dset in ["source", "csd", "calibrator", "calibrator_time"]:
                bdata.create_dataset(dset, data=sdata[dset][tind])

            bdata.create_index_map("time", sdata.index_map["time"][tind])
            bdata.create_index_map("input", sdata.index_map["input"][:])
            bdata.attrs["calibrator"] = sdata.attrs.get("calibrator", "CYG_A")

            # Calculate statistics
            stat = {}
            for statistic in ['std', 'mad']:
                for attr in stat_key:
                    for ref, ref_common in zip(['mezz', 'cmn'], [False, True]):
                        stat[(statistic, attr, ref)] = sutil.short_long_stat(
                            bdata,
                            getattr(fitter, attr),
                            fitter._flag & time_flag_stat[np.newaxis, :],
                            stat=statistic,
                            ref_common=ref_common,
                            pol=pol)

            output_filename = os.path.join(output_dir,
                                           "stat_%s.h5" % output_suffix_bb)

            write_stat(bdata, stat, fitter, output_filename)

        # Save coefficients to file
        if config.output.coeff:
            output_filename = os.path.join(output_dir,
                                           "coeff_%s.h5" % output_suffix_bb)

            write_coeff(sdata, fitter, output_filename)

        # Save residuals to file
        if config.output.resid:
            output_filename = os.path.join(output_dir,
                                           "resid_%s.h5" % output_suffix_bb)

            write_resid(sdata, fitter, output_filename)

        del fitter
        gc.collect()
Example #23
0
        def check_for_duplicates(t, src, start_tol, ignore_src_mismatch=False):
            """
            Check for duplicate holography observations, comparing the given
            observation to the existing database

            Inputs
            ------
            t: Skyfield Time object
                beginning time of observation
            src: HolographySource
                target source
            start_tol: float
                Tolerance in seconds within which to search for duplicates
            ignore_src_mismatch: bool (default: False)
                If True, consider observations a match if the time matches
                but the source does not

            Outputs
            -------
            If a duplicate is found: :py:class:`HolographyObservation` object for the
            existing entry in the database

            If no duplicate is found: None
            """
            ts = ephemeris.skyfield_wrapper.timescale

            unixt = ephemeris.ensure_unix(t)

            dup_found = False

            existing_db_entry = cls.select().where(
                cls.start_time.between(unixt - start_tol, unixt + start_tol))
            if len(existing_db_entry) > 0:
                if len(existing_db_entry) > 1:
                    print("Multiple entries found.")
                for entry in existing_db_entry:
                    tt = ts.utc(ephemeris.unix_to_datetime(entry.start_time))
                    # LST = GST + east longitude
                    ttlst = np.mod(tt.gmst + DRAO_lon, 24.0)

                    # Check if source name matches. If not, print a warning
                    # but proceed anyway.
                    if src.name.upper() == entry.source.name.upper():
                        dup_found = True
                        if verbose:
                            print("Observation is already in database.")
                    else:
                        if ignore_src_mismatch:
                            dup_found = True
                        print(
                            "** Observation at same time but with different " +
                            "sources in database: ",
                            src.name,
                            entry.source.name,
                            tt.utc_datetime().isoformat(),
                        )
                        # if the observations match in start time and source,
                        # call them the same observation. Not the most strict
                        # check possible.

                    if dup_found:
                        tf = ts.utc(
                            ephemeris.unix_to_datetime(entry.finish_time))
                        print("Tried to add  :  {} {}; LST={:.3f}".format(
                            src.name,
                            t.utc_datetime().strftime(DATE_FMT_STR), ttlst))
                        print("Existing entry:  {} {}; LST={:.3f}".format(
                            entry.source.name,
                            tt.utc_datetime().strftime(DATE_FMT_STR),
                            ttlst,
                        ))
            if dup_found:
                return existing_db_entry
            else:
                return None
Example #24
0
def main(config_file=None, logging_params=DEFAULT_LOGGING):

    # Setup logging
    log.setup_logging(logging_params)
    mlog = log.get_logger(__name__)

    # Set config
    config = DEFAULTS.deepcopy()
    if config_file is not None:
        config.merge(NameSpace(load_yaml_config(config_file)))

    # Set niceness
    current_niceness = os.nice(0)
    os.nice(config.niceness - current_niceness)
    mlog.info('Changing process niceness from %d to %d.  Confirm:  %d' %
              (current_niceness, config.niceness, os.nice(0)))

    # Find acquisition files
    acq_files = sorted(glob(os.path.join(config.data_dir, config.acq, "*.h5")))
    nfiles = len(acq_files)

    # Determine time range of each file
    findex = []
    tindex = []
    for ii, filename in enumerate(acq_files):
        subdata = andata.CorrData.from_acq_h5(filename, datasets=())

        findex += [ii] * subdata.ntime
        tindex += range(subdata.ntime)

    findex = np.array(findex)
    tindex = np.array(tindex)

    # Determine transits within these files
    transits = []

    data = andata.CorrData.from_acq_h5(acq_files, datasets=())

    solar_rise = ephemeris.solar_rising(data.time[0] - 24.0 * 3600.0,
                                        end_time=data.time[-1])

    for rr in solar_rise:

        ss = ephemeris.solar_setting(rr)[0]

        solar_flag = np.flatnonzero((data.time >= rr) & (data.time <= ss))

        if solar_flag.size > 0:

            solar_flag = solar_flag[::config.downsample]

            tval = data.time[solar_flag]

            this_findex = findex[solar_flag]
            this_tindex = tindex[solar_flag]

            file_list, tindices = [], []

            for ii in range(nfiles):

                this_file = np.flatnonzero(this_findex == ii)

                if this_file.size > 0:

                    file_list.append(acq_files[ii])
                    tindices.append(this_tindex[this_file])

            date = ephemeris.unix_to_datetime(rr).strftime('%Y%m%dT%H%M%SZ')
            transits.append((date, tval, file_list, tindices))

    # Create file prefix and suffix
    prefix = []

    prefix.append("redundant_calibration")

    if config.output_prefix is not None:
        prefix.append(config.output_prefix)

    prefix = '_'.join(prefix)

    suffix = []

    if config.include_auto:
        suffix.append("wauto")
    else:
        suffix.append("noauto")

    if config.include_intracyl:
        suffix.append("wintra")
    else:
        suffix.append("nointra")

    if config.fix_degen:
        suffix.append("fixed_degen")
    else:
        suffix.append("degen")

    suffix = '_'.join(suffix)

    # Loop over solar transits
    for date, timestamps, files, time_indices in transits:

        nfiles = len(files)

        mlog.info("%s (%d files) " % (date, nfiles))

        output_file = os.path.join(config.output_dir,
                                   "%s_SUN_%s_%s.h5" % (prefix, date, suffix))

        mlog.info("Saving to:  %s" % output_file)

        # Get info about this set of files
        data = andata.CorrData.from_acq_h5(files,
                                           datasets=['flags/inputs'],
                                           apply_gain=False,
                                           renormalize=False)

        coord = sun_coord(timestamps, deg=True)

        fstart = config.freq_start if config.freq_start is not None else 0
        fstop = config.freq_stop if config.freq_stop is not None else data.freq.size
        freq_index = range(fstart, fstop)

        freq = data.freq[freq_index]

        ntime = timestamps.size
        nfreq = freq.size

        # Determind bad inputs
        if config.bad_input_file is None or not os.path.isfile(
                config.bad_input_file):
            bad_input = np.flatnonzero(
                ~np.all(data.flags['inputs'][:], axis=-1))
        else:
            with open(config.bad_input_file, 'r') as handler:
                bad_input = pickle.load(handler)

        mlog.info("%d inputs flagged as bad." % bad_input.size)

        nant = data.ninput

        # Determine polarization product maps
        dbinputs = tools.get_correlator_inputs(ephemeris.unix_to_datetime(
            timestamps[0]),
                                               correlator='chime')

        dbinputs = tools.reorder_correlator_inputs(data.input, dbinputs)

        feedpos = tools.get_feed_positions(dbinputs)

        prod = defaultdict(list)
        dist = defaultdict(list)

        for pp, this_prod in enumerate(data.prod):

            aa, bb = this_prod
            inp_aa = dbinputs[aa]
            inp_bb = dbinputs[bb]

            if (aa in bad_input) or (bb in bad_input):
                continue

            if not tools.is_chime(inp_aa) or not tools.is_chime(inp_bb):
                continue

            if not config.include_intracyl and (inp_aa.cyl == inp_bb.cyl):
                continue

            if not config.include_auto and (aa == bb):
                continue

            this_dist = list(feedpos[aa, :] - feedpos[bb, :])

            if tools.is_array_x(inp_aa) and tools.is_array_x(inp_bb):
                key = 'XX'

            elif tools.is_array_y(inp_aa) and tools.is_array_y(inp_bb):
                key = 'YY'

            elif not config.include_crosspol:
                continue

            elif tools.is_array_x(inp_aa) and tools.is_array_y(inp_bb):
                key = 'XY'

            elif tools.is_array_y(inp_aa) and tools.is_array_x(inp_bb):
                key = 'YX'

            else:
                raise RuntimeError("CHIME feeds not polarized.")

            prod[key].append(pp)
            dist[key].append(this_dist)

        polstr = sorted(prod.keys())
        polcnt = 0
        pol_sky_id = []
        bmap = {}
        for key in polstr:
            prod[key] = np.array(prod[key])
            dist[key] = np.array(dist[key])

            p_bmap, p_ubaseline = generate_mapping(dist[key])
            nubase = p_ubaseline.shape[0]

            bmap[key] = p_bmap + polcnt

            if polcnt > 0:

                ubaseline = np.concatenate((ubaseline, p_ubaseline), axis=0)
                pol_sky_id += [key] * nubase

            else:

                ubaseline = p_ubaseline.copy()
                pol_sky_id = [key] * nubase

            polcnt += nubase
            mlog.info("%d unique baselines" % polcnt)

        nsky = ubaseline.shape[0]

        # Create arrays to hold the results
        ores = {}
        ores['freq'] = freq
        ores['input'] = data.input
        ores['time'] = timestamps
        ores['coord'] = coord
        ores['pol'] = np.array(pol_sky_id)
        ores['baseline'] = ubaseline

        # Create array to hold gain results
        ores['gain'] = np.zeros((nfreq, nant, ntime), dtype=np.complex)
        ores['sky'] = np.zeros((nfreq, nsky, ntime), dtype=np.complex)
        ores['err'] = np.zeros((nfreq, nant + nsky, ntime, 2), dtype=np.float)

        # Loop over polarisations
        for key in polstr:

            reverse_map = bmap[key]
            p_prod = prod[key]

            isort = np.argsort(reverse_map)

            p_prod = p_prod[isort]

            p_ant1 = data.prod['input_a'][p_prod]
            p_ant2 = data.prod['input_b'][p_prod]
            p_vismap = reverse_map[isort]

            # Find the redundant groups
            tmp = np.where(np.diff(p_vismap) != 0)[0]
            edges = np.zeros(2 + tmp.size, dtype='int')
            edges[0] = 0
            edges[1:-1] = tmp + 1
            edges[-1] = p_vismap.size

            kept_base = np.unique(p_vismap)

            # Determine the unique antennas
            kept_ants = np.unique(np.concatenate([p_ant1, p_ant2]))
            antmap = np.zeros(kept_ants.max() + 1, dtype='int') - 1

            p_nant = kept_ants.size
            for i in range(p_nant):
                antmap[kept_ants[i]] = i

            p_ant1_use = antmap[p_ant1].copy()
            p_ant2_use = antmap[p_ant2].copy()

            # Create matrix
            p_nvis = p_prod.size
            nred = edges.size - 1

            npar = p_nant + nred

            A = np.zeros((p_nvis, npar), dtype=np.float32)
            B = np.zeros((p_nvis, npar), dtype=np.float32)

            for kk in range(p_nant):

                flag_ant1 = p_ant1_use == kk
                if np.any(flag_ant1):
                    A[flag_ant1, kk] = 1.0
                    B[flag_ant1, kk] = 1.0

                flag_ant2 = p_ant2_use == kk
                if np.any(flag_ant2):
                    A[flag_ant2, kk] = 1.0
                    B[flag_ant2, kk] = -1.0

            for ee in range(nred):

                A[edges[ee]:edges[ee + 1], p_nant + ee] = 1.0

                B[edges[ee]:edges[ee + 1], p_nant + ee] = 1.0

            # Add equations to break degeneracy
            if config.fix_degen:
                A = np.concatenate((A, np.zeros((1, npar), dtype=np.float32)))
                A[-1, 0:p_nant] = 1.0

                B = np.concatenate((B, np.zeros((3, npar), dtype=np.float32)))
                B[-3, 0:p_nant] = 1.0
                B[-2, 0:p_nant] = feedpos[kept_ants, 0]
                B[-1, 0:p_nant] = feedpos[kept_ants, 1]

            # Loop over frequencies
            for ff, find in enumerate(freq_index):

                mlog.info("Freq %d of %d.  %0.2f MHz." %
                          (ff + 1, nfreq, freq[ff]))

                cnt = 0

                # Loop over files
                for ii, (filename, tind) in enumerate(zip(files,
                                                          time_indices)):

                    ntind = len(tind)
                    mlog.info("Processing file %s (%d time samples)" %
                              (filename, ntind))

                    # Compute noise weight
                    with h5py.File(filename, 'r') as hf:
                        wnoise = np.median(hf['flags/vis_weight'][find, :, :],
                                           axis=-1)

                    # Loop over times
                    for tt in tind:

                        t0 = time.time()

                        mlog.info("Time %d of %d.  %d index of current file." %
                                  (cnt + 1, ntime, tt))

                        # Load visibilities
                        with h5py.File(filename, 'r') as hf:

                            snap = hf['vis'][find, :, tt]
                            wsnap = wnoise * (
                                (hf['flags/vis_weight'][find, :, tt] > 0.0) &
                                (np.abs(snap) > 0.0)).astype(np.float32)

                        # Extract relevant products for this polarization
                        snap = snap[p_prod]
                        wsnap = wsnap[p_prod]

                        # Turn into amplitude and phase, avoiding NaN
                        mask = (wsnap > 0.0)

                        amp = np.where(mask, np.log(np.abs(snap)), 0.0)
                        phi = np.where(mask, np.angle(snap), 0.0)

                        # Deal with phase wrapping
                        for aa, bb in zip(edges[:-1], edges[1:]):
                            dphi = phi[aa:bb] - np.sort(phi[aa:bb])[int(
                                (bb - aa) / 2)]
                            phi[aa:bb] += (2.0 * np.pi * (dphi < -np.pi) -
                                           2.0 * np.pi * (dphi > np.pi))

                        # Add elements to fix degeneracy
                        if config.fix_degen:
                            amp = np.concatenate((amp, np.zeros(1)))
                            phi = np.concatenate((phi, np.zeros(3)))

                        # Determine noise matrix
                        inv_diagC = wsnap * np.abs(snap)**2 * 2.0

                        if config.fix_degen:
                            inv_diagC = np.concatenate((inv_diagC, np.ones(1)))

                        # Amplitude estimate and covariance
                        amp_param_cov = np.linalg.inv(
                            np.dot(A.T, inv_diagC[:, np.newaxis] * A))
                        amp_param = np.dot(amp_param_cov,
                                           np.dot(A.T, inv_diagC * amp))

                        # Phase estimate and covariance
                        if config.fix_degen:
                            inv_diagC = np.concatenate((inv_diagC, np.ones(2)))

                        phi_param_cov = np.linalg.inv(
                            np.dot(B.T, inv_diagC[:, np.newaxis] * B))
                        phi_param = np.dot(phi_param_cov,
                                           np.dot(B.T, inv_diagC * phi))

                        # Save to large array
                        ores['gain'][ff, kept_ants,
                                     cnt] = np.exp(amp_param[0:p_nant] +
                                                   1.0J * phi_param[0:p_nant])

                        ores['sky'][ff, kept_base,
                                    cnt] = np.exp(amp_param[p_nant:] +
                                                  1.0J * phi_param[p_nant:])

                        ores['err'][ff, kept_ants, cnt,
                                    0] = np.diag(amp_param_cov[0:p_nant,
                                                               0:p_nant])
                        ores['err'][ff, nant + kept_base, cnt,
                                    0] = np.diag(amp_param_cov[p_nant:,
                                                               p_nant:])

                        ores['err'][ff, kept_ants, cnt,
                                    1] = np.diag(phi_param_cov[0:p_nant,
                                                               0:p_nant])
                        ores['err'][ff, nant + kept_base, cnt,
                                    1] = np.diag(phi_param_cov[p_nant:,
                                                               p_nant:])

                        # Increment time counter
                        cnt += 1

                        # Print time elapsed
                        mlog.info("Took %0.1f seconds." % (time.time() - t0, ))

        # Save to pickle file
        with h5py.File(output_file, 'w') as handler:

            handler.attrs['date'] = date

            for key, val in ores.iteritems():
                handler.create_dataset(key, data=val)
Example #25
0
def main(config_file=None, logging_params=DEFAULT_LOGGING):

    # Setup logging
    log.setup_logging(logging_params)
    mlog = log.get_logger(__name__)

    # Set config
    config = DEFAULTS.deepcopy()
    if config_file is not None:
        config.merge(NameSpace(load_yaml_config(config_file)))

    # Create transit tracker
    source_list = FluxCatalog.sort(
    ) if not config.source_list else config.source_list

    cal_list = [
        name for name, obj in FluxCatalog.iteritems()
        if (obj.dec >= config.min_dec) and (
            obj.predict_flux(config.freq_nominal) >= config.min_flux) and (
                name in source_list)
    ]

    if not cal_list:
        raise RuntimeError("No calibrators found.")

    # Sort list by flux at nominal frequency
    cal_list.sort(
        key=lambda name: FluxCatalog[name].predict_flux(config.freq_nominal))

    # Add to transit tracker
    transit_tracker = containers.TransitTrackerOffline(
        nsigma=config.nsigma_source, extend_night=config.extend_night)
    for name in cal_list:
        transit_tracker[name] = FluxCatalog[name].skyfield

    mlog.info("Initializing offline point source processing.")

    search_time = config.start_time or 0

    # Find all calibration files
    all_files = sorted(
        glob.glob(
            os.path.join(config.acq_dir,
                         '*' + config.correlator + config.acq_suffix, '*.h5')))
    if not all_files:
        return

    # Remove files whose last modified time is before the time of the most recent update
    all_files = [
        ff for ff in all_files if (os.path.getmtime(ff) > search_time)
    ]
    if not all_files:
        return

    # Remove files that are currently locked
    all_files = [
        ff for ff in all_files
        if not os.path.isfile(os.path.splitext(ff)[0] + '.lock')
    ]
    if not all_files:
        return

    # Add files to transit tracker
    for ff in all_files:
        transit_tracker.add_file(ff)

    # Extract point source transits ready for analysis
    all_transits = transit_tracker.get_transits()

    # Create dictionary to hold results
    h5_psrc_fit = {}
    inputmap = None

    # Loop over transits
    for transit in all_transits:

        src, csd, is_day, files, start, stop = transit

        # Discard any point sources with unusual csd value
        if (csd < config.min_csd) or (csd > config.max_csd):
            continue

        # Discard any point sources transiting during the day
        if is_day > config.process_daytime:
            continue

        mlog.info(
            'Processing %s transit on CSD %d (%d files, %d time samples)' %
            (src, csd, len(files), stop - start + 1))

        # Load inputmap
        if inputmap is None:
            if config.inputmap is None:
                inputmap = tools.get_correlator_inputs(
                    ephemeris.unix_to_datetime(ephemeris.csd_to_unix(csd)),
                    correlator=config.correlator)
            else:
                with open(config.inputmap, 'r') as handler:
                    inputmap = pickle.load(handler)

        # Grab the timing correction for this transit
        tcorr = None
        if config.apply_timing:

            if config.timing_glob is not None:

                mlog.info(
                    "Loading timing correction from extended timing solutions."
                )

                timing_files = sorted(glob.glob(config.timing_glob))

                if timing_files:

                    try:
                        tcorr = search_extended_timing_solutions(
                            timing_files, ephemeris.csd_to_unix(csd))

                    except Exception as e:
                        mlog.error(
                            'search_extended_timing_solutions failed with error: %s'
                            % e)

                    else:
                        mlog.info(str(tcorr))

            if tcorr is None:

                mlog.info(
                    "Loading timing correction from chimetiming acquisitions.")

                try:
                    tcorr = timing.load_timing_correction(
                        files,
                        start=start,
                        stop=stop,
                        window=config.timing_window,
                        instrument=config.correlator)
                except Exception as e:
                    mlog.error(
                        'timing.load_timing_correction failed with error: %s' %
                        e)
                    mlog.warning(
                        'No timing correction applied to %s transit on CSD %d.'
                        % (src, csd))
                else:
                    mlog.info(str(tcorr))

        # Call the main routine to process data
        try:
            outdct = offline_cal.offline_point_source_calibration(
                files,
                src,
                start=start,
                stop=stop,
                inputmap=inputmap,
                tcorr=tcorr,
                logging_params=logging_params,
                **config.analysis.as_dict())

        except Exception as e:
            msg = 'offline_cal.offline_point_source_calibration failed with error:  %s' % e
            mlog.error(msg)
            continue
            #raise RuntimeError(msg)

        # Find existing gain files for this particular point source
        if src not in h5_psrc_fit:

            output_files = find_files(config, psrc=src)
            if output_files is not None:
                output_files = output_files[-1]
                mlog.info('Writing %s transit on CSD %d to existing file %s.' %
                          (src, csd, output_files))

            h5_psrc_fit[src] = containers.PointSourceWriter(
                src,
                output_file=output_files,
                output_dir=config.output_dir,
                output_suffix=point_source_name_to_file_suffix(src),
                instrument=config.correlator,
                max_file_size=config.max_file_size,
                max_num=config.max_num_time,
                memory_size=0)

        # Associate this gain calibration to the transit time
        this_time = ephemeris.transit_times(FluxCatalog[src].skyfield,
                                            ephemeris.csd_to_unix(csd))[0]

        outdct['csd'] = csd
        outdct['is_daytime'] = is_day
        outdct['acquisition'] = os.path.basename(os.path.dirname(files[0]))

        # Write to output file
        mlog.info('Writing to disk results from %s transit on CSD %d.' %
                  (src, csd))
        h5_psrc_fit[src].write(this_time, **outdct)

        # Dump an individual file for this point source transit
        mlog.info('Dumping to disk single file for %s transit on CSD %d.' %
                  (src, csd))
        dump_dir = os.path.join(config.output_dir, 'point_source_gains')
        containers.mkdir(dump_dir)

        dump_file = os.path.join(dump_dir, '%s_csd_%d.h5' % (src.lower(), csd))
        h5_psrc_fit[src].dump(dump_file,
                              datasets=[
                                  'csd', 'acquisition', 'is_daytime', 'gain',
                                  'weight', 'timing', 'model'
                              ])

        mlog.info('Finished analysis of %s transit on CSD %d.' % (src, csd))
Example #26
0
def main(config_file=None, logging_params=DEFAULT_LOGGING):

    # Load config
    config = DEFAULTS.deepcopy()
    if config_file is not None:
        config.merge(NameSpace(load_yaml_config(config_file)))

    # Setup logging
    log.setup_logging(logging_params)
    logger = log.get_logger(__name__)

    ## Load data for flagging
    # Load fpga restarts
    time_fpga_restart = []
    if config.fpga_restart_file is not None:

        with open(config.fpga_restart_file, 'r') as handler:
            for line in handler:
                time_fpga_restart.append(
                    ephemeris.datetime_to_unix(
                        ephemeris.timestr_to_datetime(line.split('_')[0])))

    time_fpga_restart = np.array(time_fpga_restart)

    # Load housekeeping flag
    if config.housekeeping_file is not None:
        ftemp = TempData.from_acq_h5(config.housekeeping_file,
                                     datasets=["time_flag"])
    else:
        ftemp = None

    # Load jump data
    if config.jump_file is not None:
        with h5py.File(config.jump_file, 'r') as handler:
            jump_time = handler["time"][:]
            jump_size = handler["jump_size"][:]
    else:
        jump_time = None
        jump_size = None

    # Load rain data
    if config.rain_file is not None:
        with h5py.File(config.rain_file, 'r') as handler:
            rain_ranges = handler["time_range_conservative"][:]
    else:
        rain_ranges = []

    # Load data flags
    data_flags = {}
    if config.data_flags:
        finder.connect_database()
        flag_types = finder.DataFlagType.select()
        possible_data_flags = []
        for ft in flag_types:
            possible_data_flags.append(ft.name)
            if ft.name in config.data_flags:
                new_data_flags = finder.DataFlag.select().where(
                    finder.DataFlag.type == ft)
                data_flags[ft.name] = list(new_data_flags)

    # Set desired range of time
    start_time = (ephemeris.datetime_to_unix(
        datetime.datetime(
            *config.start_date)) if config.start_date is not None else None)
    end_time = (ephemeris.datetime_to_unix(datetime.datetime(
        *config.end_date)) if config.end_date is not None else None)

    ## Find gain files
    files = {}
    for src in config.sources:
        files[src] = sorted(
            glob.glob(
                os.path.join(config.directory, src.lower(),
                             "%s_%s_lsd_*.h5" % (
                                 config.prefix,
                                 src.lower(),
                             ))))
    csd = {}
    for src in config.sources:
        csd[src] = np.array(
            [int(os.path.splitext(ff)[0][-4:]) for ff in files[src]])

    for src in config.sources:
        logger.info("%s:  %d files" % (src, len(csd[src])))

    ## Remove files that occur during flag
    csd_flag = {}
    for src in config.sources:

        body = ephemeris.source_dictionary[src]

        csd_flag[src] = np.ones(csd[src].size, dtype=np.bool)

        for ii, cc in enumerate(csd[src][:]):

            ttrans = ephemeris.transit_times(body,
                                             ephemeris.csd_to_unix(cc))[0]

            if (start_time is not None) and (ttrans < start_time):
                csd_flag[src][ii] = False
                continue

            if (end_time is not None) and (ttrans > end_time):
                csd_flag[src][ii] = False
                continue

            # If requested, remove daytime transits
            if not config.include_daytime.get(
                    src, config.include_daytime.default) and daytime_flag(
                        ttrans)[0]:
                logger.info("%s CSD %d:  daytime transit" % (src, cc))
                csd_flag[src][ii] = False
                continue

            # Remove transits during HKP drop out
            if ftemp is not None:
                itemp = np.flatnonzero(
                    (ftemp.time[:] >= (ttrans - config.transit_window))
                    & (ftemp.time[:] <= (ttrans + config.transit_window)))
                tempflg = ftemp['time_flag'][itemp]
                if (tempflg.size == 0) or ((np.sum(tempflg, dtype=np.float32) /
                                            float(tempflg.size)) < 0.50):
                    logger.info("%s CSD %d:  no housekeeping" % (src, cc))
                    csd_flag[src][ii] = False
                    continue

            # Remove transits near jumps
            if jump_time is not None:
                njump = np.sum((jump_size > config.min_jump_size)
                               & (jump_time > (ttrans - config.jump_window))
                               & (jump_time < ttrans))
                if njump > config.max_njump:
                    logger.info("%s CSD %d:  %d jumps before" %
                                (src, cc, njump))
                    csd_flag[src][ii] = False
                    continue

            # Remove transits near rain
            for rng in rain_ranges:
                if (((ttrans - config.transit_window) <= rng[1])
                        and ((ttrans + config.transit_window) >= rng[0])):

                    logger.info("%s CSD %d:  during rain" % (src, cc))
                    csd_flag[src][ii] = False
                    break

            # Remove transits during data flag
            for name, flag_list in data_flags.items():

                if csd_flag[src][ii]:

                    for flg in flag_list:

                        if (((ttrans - config.transit_window) <=
                             flg.finish_time)
                                and ((ttrans + config.transit_window) >=
                                     flg.start_time)):

                            logger.info("%s CSD %d:  %s flag" %
                                        (src, cc, name))
                            csd_flag[src][ii] = False
                            break

    # Print number of files left after flagging
    for src in config.sources:
        logger.info("%s:  %d files (after flagging)" %
                    (src, np.sum(csd_flag[src])))

    ## Construct pair wise differences
    npair = len(config.diff_pair)
    shift = [nd * 24.0 * 3600.0 for nd in config.nday_shift]

    calmap = []
    calpair = []

    for (tsrc, csrc), sh in zip(config.diff_pair, shift):

        body_test = ephemeris.source_dictionary[tsrc]
        body_cal = ephemeris.source_dictionary[csrc]

        for ii, cc in enumerate(csd[tsrc]):

            if csd_flag[tsrc][ii]:

                test_transit = ephemeris.transit_times(
                    body_test, ephemeris.csd_to_unix(cc))[0]
                cal_transit = ephemeris.transit_times(body_cal,
                                                      test_transit + sh)[0]
                cal_csd = int(np.fix(ephemeris.unix_to_csd(cal_transit)))

                ttrans = np.sort([test_transit, cal_transit])

                if cal_csd in csd[csrc]:
                    jj = list(csd[csrc]).index(cal_csd)

                    if csd_flag[csrc][jj] and not np.any(
                        (time_fpga_restart >= ttrans[0])
                            & (time_fpga_restart <= ttrans[1])):
                        calmap.append([ii, jj])
                        calpair.append([tsrc, csrc])

    calmap = np.array(calmap)
    calpair = np.array(calpair)

    ntransit = calmap.shape[0]

    logger.info("%d total transit pairs" % ntransit)
    for ii in range(ntransit):

        t1 = ephemeris.transit_times(
            ephemeris.source_dictionary[calpair[ii, 0]],
            ephemeris.csd_to_unix(csd[calpair[ii, 0]][calmap[ii, 0]]))[0]
        t2 = ephemeris.transit_times(
            ephemeris.source_dictionary[calpair[ii, 1]],
            ephemeris.csd_to_unix(csd[calpair[ii, 1]][calmap[ii, 1]]))[0]

        logger.info("%s (%d) - %s (%d):  %0.1f hr" %
                    (calpair[ii, 0], csd_flag[calpair[ii, 0]][calmap[ii, 0]],
                     calpair[ii, 1], csd_flag[calpair[ii, 1]][calmap[ii, 1]],
                     (t1 - t2) / 3600.0))

    # Determine unique diff pairs
    diff_name = np.array(['%s/%s' % tuple(cp) for cp in calpair])
    uniq_diff, lbl_diff, cnt_diff = np.unique(diff_name,
                                              return_inverse=True,
                                              return_counts=True)
    ndiff = uniq_diff.size

    for ud, udcnt in zip(uniq_diff, cnt_diff):
        logger.info("%s:  %d transit pairs" % (ud, udcnt))

    ## Load gains
    inputmap = tools.get_correlator_inputs(datetime.datetime.utcnow(),
                                           correlator='chime')
    ninput = len(inputmap)
    nfreq = 1024

    # Set up gain arrays
    gain = np.zeros((2, nfreq, ninput, ntransit), dtype=np.complex64)
    weight = np.zeros((2, nfreq, ninput, ntransit), dtype=np.float32)
    input_sort = np.zeros((2, ninput, ntransit), dtype=np.int)

    kcsd = np.zeros((2, ntransit), dtype=np.float32)
    timestamp = np.zeros((2, ntransit), dtype=np.float64)
    is_daytime = np.zeros((2, ntransit), dtype=np.bool)

    for tt in range(ntransit):

        for kk, (src, ind) in enumerate(zip(calpair[tt], calmap[tt])):

            body = ephemeris.source_dictionary[src]
            filename = files[src][ind]

            logger.info("%s:  %s" % (src, filename))

            temp = containers.StaticGainData.from_file(filename)

            freq = temp.freq[:]
            inputs = temp.input[:]

            isort = reorder_inputs(inputmap, inputs)
            inputs = inputs[isort]

            gain[kk, :, :, tt] = temp.gain[:, isort]
            weight[kk, :, :, tt] = temp.weight[:, isort]
            input_sort[kk, :, tt] = isort

            kcsd[kk, tt] = temp.attrs['lsd']
            timestamp[kk, tt] = ephemeris.transit_times(
                body, ephemeris.csd_to_unix(kcsd[kk, tt]))[0]
            is_daytime[kk, tt] = daytime_flag(timestamp[kk, tt])[0]

            if np.any(isort != np.arange(isort.size)):
                logger.info("Input ordering has changed: %s" %
                            ephemeris.unix_to_datetime(
                                timestamp[kk, tt]).strftime("%Y-%m-%d"))

        logger.info("")

    inputs = np.array([(inp.id, inp.input_sn) for inp in inputmap],
                      dtype=[('chan_id', 'u2'), ('correlator_input', 'S32')])

    ## Load input flags
    inpflg = np.ones((2, ninput, ntransit), dtype=np.bool)

    min_flag_time = np.min(timestamp) - 7.0 * 24.0 * 60.0 * 60.0
    max_flag_time = np.max(timestamp) + 7.0 * 24.0 * 60.0 * 60.0

    flaginput_files = sorted(
        glob.glob(
            os.path.join(config.flaginput_dir, "*" + config.flaginput_suffix,
                         "*.h5")))

    if flaginput_files:
        logger.info("Found %d flaginput files." % len(flaginput_files))
        tmp = andata.FlagInputData.from_acq_h5(flaginput_files, datasets=())
        start, stop = [
            int(yy) for yy in np.percentile(
                np.flatnonzero((tmp.time[:] >= min_flag_time)
                               & (tmp.time[:] <= max_flag_time)), [0, 100])
        ]

        cont = andata.FlagInputData.from_acq_h5(flaginput_files,
                                                start=start,
                                                stop=stop,
                                                datasets=['flag'])

        for kk in range(2):
            inpflg[kk, :, :] = cont.resample('flag',
                                             timestamp[kk],
                                             transpose=True)

            logger.info("Flaginput time offsets in minutes (pair %d):" % kk)
            logger.info(
                str(
                    np.fix((cont.time[cont.search_update_time(timestamp[kk])] -
                            timestamp[kk]) / 60.0).astype(np.int)))

    # Sort flags so they are in same order
    for tt in range(ntransit):
        for kk in range(2):
            inpflg[kk, :, tt] = inpflg[kk, input_sort[kk, :, tt], tt]

    # Do not apply input flag to phase reference
    for ii in config.index_phase_ref:
        inpflg[:, ii, :] = True

    ## Flag out gains with high uncertainty and frequencies with large fraction of data flagged
    frac_err = tools.invert_no_zero(np.sqrt(weight) * np.abs(gain))

    flag = np.all((weight > 0.0) & (np.abs(gain) > 0.0) &
                  (frac_err < config.max_uncertainty),
                  axis=0)

    freq_flag = ((np.sum(flag, axis=(1, 2), dtype=np.float32) /
                  float(np.prod(flag.shape[1:]))) > config.freq_threshold)

    if config.apply_rfi_mask:
        freq_flag &= np.logical_not(rfi.frequency_mask(freq))

    flag = flag & freq_flag[:, np.newaxis, np.newaxis]

    good_freq = np.flatnonzero(freq_flag)

    logger.info("Number good frequencies %d" % good_freq.size)

    ## Generate flags with more conservative cuts on frequency
    c_flag = flag & np.all(frac_err < config.conservative.max_uncertainty,
                           axis=0)

    c_freq_flag = ((np.sum(c_flag, axis=(1, 2), dtype=np.float32) /
                    float(np.prod(c_flag.shape[1:]))) >
                   config.conservative.freq_threshold)

    if config.conservative.apply_rfi_mask:
        c_freq_flag &= np.logical_not(rfi.frequency_mask(freq))

    c_flag = c_flag & c_freq_flag[:, np.newaxis, np.newaxis]

    c_good_freq = np.flatnonzero(c_freq_flag)

    logger.info("Number good frequencies (conservative thresholds) %d" %
                c_good_freq.size)

    ## Apply input flags
    flag &= np.all(inpflg[:, np.newaxis, :, :], axis=0)

    ## Update flags based on beam flag
    if config.beam_flag_file is not None:

        dbeam = andata.BaseData.from_acq_h5(config.beam_flag_file)

        db_csd = np.floor(ephemeris.unix_to_csd(dbeam.index_map['time'][:]))

        for ii, name in enumerate(config.beam_flag_datasets):
            logger.info("Applying %s beam flag." % name)
            if not ii:
                db_flag = dbeam.flags[name][:]
            else:
                db_flag &= dbeam.flags[name][:]

        cnt = 0
        for ii, dbc in enumerate(db_csd):

            this_csd = np.flatnonzero(np.any(kcsd == dbc, axis=0))

            if this_csd.size > 0:

                logger.info("Beam flag for %d matches %s." %
                            (dbc, str(kcsd[:, this_csd])))

                flag[:, :, this_csd] &= db_flag[np.newaxis, :, ii, np.newaxis]

                cnt += 1

        logger.info("Applied %0.1f percent of the beam flags" %
                    (100.0 * cnt / float(db_csd.size), ))

    ## Flag inputs with large amount of missing data
    input_frac_flagged = (
        np.sum(flag[good_freq, :, :], axis=(0, 2), dtype=np.float32) /
        float(good_freq.size * ntransit))
    input_flag = input_frac_flagged > config.input_threshold

    for ii in config.index_phase_ref:
        logger.info("Phase reference %d has %0.3f fraction of data flagged." %
                    (ii, input_frac_flagged[ii]))
        input_flag[ii] = True

    good_input = np.flatnonzero(input_flag)

    flag = flag & input_flag[np.newaxis, :, np.newaxis]

    logger.info("Number good inputs %d" % good_input.size)

    ## Calibrate
    gaincal = gain[0] * tools.invert_no_zero(gain[1])

    frac_err_cal = np.sqrt(frac_err[0]**2 + frac_err[1]**2)

    count = np.sum(flag, axis=-1, dtype=np.int)
    stat_flag = count > config.min_num_transit

    ## Calculate phase
    amp = np.abs(gaincal)
    phi = np.angle(gaincal)

    ## Calculate polarisation groups
    pol_dict = {'E': 'X', 'S': 'Y'}
    cyl_dict = {2: 'A', 3: 'B', 4: 'C', 5: 'D'}

    if config.group_by_cyl:
        group_id = [
            (inp.pol,
             inp.cyl) if tools.is_chime(inp) and (ii in good_input) else None
            for ii, inp in enumerate(inputmap)
        ]
    else:
        group_id = [
            inp.pol if tools.is_chime(inp) and (ii in good_input) else None
            for ii, inp in enumerate(inputmap)
        ]

    ugroup_id = sorted([uidd for uidd in set(group_id) if uidd is not None])
    ngroup = len(ugroup_id)

    group_list_noref = [
        np.array([
            gg for gg, gid in enumerate(group_id)
            if (gid == ugid) and gg not in config.index_phase_ref
        ]) for ugid in ugroup_id
    ]

    group_list = [
        np.array([gg for gg, gid in enumerate(group_id) if gid == ugid])
        for ugid in ugroup_id
    ]

    if config.group_by_cyl:
        group_str = [
            "%s-%s" % (pol_dict[pol], cyl_dict[cyl]) for pol, cyl in ugroup_id
        ]
    else:
        group_str = [pol_dict[pol] for pol in ugroup_id]

    index_phase_ref = []
    for gstr, igroup in zip(group_str, group_list):
        candidate = [ii for ii in config.index_phase_ref if ii in igroup]
        if len(candidate) != 1:
            index_phase_ref.append(None)
        else:
            index_phase_ref.append(candidate[0])

    logger.info(
        "Phase reference: %s" %
        ', '.join(['%s = %s' % tpl
                   for tpl in zip(group_str, index_phase_ref)]))

    ## Apply thermal correction to amplitude
    if config.amp_thermal.enabled:

        logger.info("Applying thermal correction.")

        # Load the temperatures
        tdata = TempData.from_acq_h5(config.amp_thermal.filename)

        index = tdata.search_sensors(config.amp_thermal.sensor)[0]

        temp = tdata.datasets[config.amp_thermal.field][index]
        temp_func = scipy.interpolate.interp1d(tdata.time, temp,
                                               **config.amp_thermal.interp)

        itemp = temp_func(timestamp)
        dtemp = itemp[0] - itemp[1]

        flag_func = scipy.interpolate.interp1d(
            tdata.time, tdata.datasets['flag'][index].astype(np.float32),
            **config.amp_thermal.interp)

        dtemp_flag = np.all(flag_func(timestamp) == 1.0, axis=0)

        flag &= dtemp_flag[np.newaxis, np.newaxis, :]

        for gstr, igroup in zip(group_str, group_list):
            pstr = gstr[0]
            thermal_coeff = np.polyval(config.amp_thermal.coeff[pstr], freq)
            gthermal = 1.0 + thermal_coeff[:, np.newaxis, np.newaxis] * dtemp[
                np.newaxis, np.newaxis, :]

            amp[:, igroup, :] *= tools.invert_no_zero(gthermal)

    ## Compute common mode
    if config.subtract_common_mode_before:
        logger.info("Calculating common mode amplitude and phase.")
        cmn_amp, flag_cmn_amp = compute_common_mode(amp,
                                                    flag,
                                                    group_list_noref,
                                                    median=False)
        cmn_phi, flag_cmn_phi = compute_common_mode(phi,
                                                    flag,
                                                    group_list_noref,
                                                    median=False)

        # Subtract common mode (from phase only)
        logger.info("Subtracting common mode phase.")
        group_flag = np.zeros((ngroup, ninput), dtype=np.bool)
        for gg, igroup in enumerate(group_list):
            group_flag[gg, igroup] = True
            phi[:,
                igroup, :] = phi[:, igroup, :] - cmn_phi[:, gg, np.newaxis, :]

            for iref in index_phase_ref:
                if (iref is not None) and (iref in igroup):
                    flag[:, iref, :] = flag_cmn_phi[:, gg, :]

    ## If requested, determine and subtract a delay template
    if config.fit_delay_before:
        logger.info("Fitting delay template.")
        omega = timing.FREQ_TO_OMEGA * freq

        tau, tau_flag, _ = construct_delay_template(
            omega,
            phi,
            c_flag & flag,
            min_num_freq_for_delay_fit=config.min_num_freq_for_delay_fit)

        # Compute residuals
        logger.info("Subtracting delay template.")
        phi = phi - tau[np.newaxis, :, :] * omega[:, np.newaxis, np.newaxis]

    ## Normalize by median over time
    logger.info("Calculating median amplitude and phase.")
    med_amp = np.zeros((nfreq, ninput, ndiff), dtype=amp.dtype)
    med_phi = np.zeros((nfreq, ninput, ndiff), dtype=phi.dtype)

    count_by_diff = np.zeros((nfreq, ninput, ndiff), dtype=np.int)
    stat_flag_by_diff = np.zeros((nfreq, ninput, ndiff), dtype=np.bool)

    def weighted_mean(yy, ww, axis=-1):
        return np.sum(ww * yy, axis=axis) * tools.invert_no_zero(
            np.sum(ww, axis=axis))

    for dd in range(ndiff):

        this_diff = np.flatnonzero(lbl_diff == dd)

        this_flag = flag[:, :, this_diff]

        this_amp = amp[:, :, this_diff]
        this_amp_err = this_amp * frac_err_cal[:, :,
                                               this_diff] * this_flag.astype(
                                                   np.float32)

        this_phi = phi[:, :, this_diff]
        this_phi_err = frac_err_cal[:, :, this_diff] * this_flag.astype(
            np.float32)

        count_by_diff[:, :, dd] = np.sum(this_flag, axis=-1, dtype=np.int)
        stat_flag_by_diff[:, :,
                          dd] = count_by_diff[:, :,
                                              dd] > config.min_num_transit

        if config.weighted_mean == 2:
            logger.info("Calculating inverse variance weighted mean.")
            med_amp[:, :,
                    dd] = weighted_mean(this_amp,
                                        tools.invert_no_zero(this_amp_err**2),
                                        axis=-1)
            med_phi[:, :,
                    dd] = weighted_mean(this_phi,
                                        tools.invert_no_zero(this_phi_err**2),
                                        axis=-1)

        elif config.weighted_mean == 1:
            logger.info("Calculating uniform weighted mean.")
            med_amp[:, :, dd] = weighted_mean(this_amp,
                                              this_flag.astype(np.float32),
                                              axis=-1)
            med_phi[:, :, dd] = weighted_mean(this_phi,
                                              this_flag.astype(np.float32),
                                              axis=-1)

        else:
            logger.info("Calculating median value.")
            for ff in range(nfreq):
                for ii in range(ninput):
                    if np.any(this_flag[ff, ii, :]):
                        med_amp[ff, ii, dd] = wq.median(
                            this_amp[ff, ii, :],
                            this_flag[ff, ii, :].astype(np.float32))
                        med_phi[ff, ii, dd] = wq.median(
                            this_phi[ff, ii, :],
                            this_flag[ff, ii, :].astype(np.float32))

    damp = np.zeros_like(amp)
    dphi = np.zeros_like(phi)
    for dd in range(ndiff):
        this_diff = np.flatnonzero(lbl_diff == dd)
        damp[:, :, this_diff] = amp[:, :, this_diff] * tools.invert_no_zero(
            med_amp[:, :, dd, np.newaxis]) - 1.0
        dphi[:, :,
             this_diff] = phi[:, :, this_diff] - med_phi[:, :, dd, np.newaxis]

    # Compute common mode
    if not config.subtract_common_mode_before:
        logger.info("Calculating common mode amplitude and phase.")
        cmn_amp, flag_cmn_amp = compute_common_mode(damp,
                                                    flag,
                                                    group_list_noref,
                                                    median=True)
        cmn_phi, flag_cmn_phi = compute_common_mode(dphi,
                                                    flag,
                                                    group_list_noref,
                                                    median=True)

        # Subtract common mode (from phase only)
        logger.info("Subtracting common mode phase.")
        group_flag = np.zeros((ngroup, ninput), dtype=np.bool)
        for gg, igroup in enumerate(group_list):
            group_flag[gg, igroup] = True
            dphi[:, igroup, :] = dphi[:, igroup, :] - cmn_phi[:, gg,
                                                              np.newaxis, :]

            for iref in index_phase_ref:
                if (iref is not None) and (iref in igroup):
                    flag[:, iref, :] = flag_cmn_phi[:, gg, :]

    ## Compute RMS
    logger.info("Calculating RMS of amplitude and phase.")
    mad_amp = np.zeros((nfreq, ninput), dtype=amp.dtype)
    std_amp = np.zeros((nfreq, ninput), dtype=amp.dtype)

    mad_phi = np.zeros((nfreq, ninput), dtype=phi.dtype)
    std_phi = np.zeros((nfreq, ninput), dtype=phi.dtype)

    mad_amp_by_diff = np.zeros((nfreq, ninput, ndiff), dtype=amp.dtype)
    std_amp_by_diff = np.zeros((nfreq, ninput, ndiff), dtype=amp.dtype)

    mad_phi_by_diff = np.zeros((nfreq, ninput, ndiff), dtype=phi.dtype)
    std_phi_by_diff = np.zeros((nfreq, ninput, ndiff), dtype=phi.dtype)

    for ff in range(nfreq):
        for ii in range(ninput):
            this_flag = flag[ff, ii, :]
            if np.any(this_flag):
                std_amp[ff, ii] = np.std(damp[ff, ii, this_flag])
                std_phi[ff, ii] = np.std(dphi[ff, ii, this_flag])

                mad_amp[ff, ii] = 1.48625 * wq.median(
                    np.abs(damp[ff, ii, :]), this_flag.astype(np.float32))
                mad_phi[ff, ii] = 1.48625 * wq.median(
                    np.abs(dphi[ff, ii, :]), this_flag.astype(np.float32))

                for dd in range(ndiff):
                    this_diff = this_flag & (lbl_diff == dd)
                    if np.any(this_diff):

                        std_amp_by_diff[ff, ii, dd] = np.std(damp[ff, ii,
                                                                  this_diff])
                        std_phi_by_diff[ff, ii, dd] = np.std(dphi[ff, ii,
                                                                  this_diff])

                        mad_amp_by_diff[ff, ii, dd] = 1.48625 * wq.median(
                            np.abs(damp[ff, ii, :]),
                            this_diff.astype(np.float32))
                        mad_phi_by_diff[ff, ii, dd] = 1.48625 * wq.median(
                            np.abs(dphi[ff, ii, :]),
                            this_diff.astype(np.float32))

    ## Construct delay template
    if not config.fit_delay_before:
        logger.info("Fitting delay template.")
        omega = timing.FREQ_TO_OMEGA * freq

        tau, tau_flag, _ = construct_delay_template(
            omega,
            dphi,
            c_flag & flag,
            min_num_freq_for_delay_fit=config.min_num_freq_for_delay_fit)

        # Compute residuals
        logger.info("Subtracting delay template from phase.")
        resid = (dphi - tau[np.newaxis, :, :] *
                 omega[:, np.newaxis, np.newaxis]) * flag.astype(np.float32)

    else:
        resid = dphi

    tau_count = np.sum(tau_flag, axis=-1, dtype=np.int)
    tau_stat_flag = tau_count > config.min_num_transit

    tau_count_by_diff = np.zeros((ninput, ndiff), dtype=np.int)
    tau_stat_flag_by_diff = np.zeros((ninput, ndiff), dtype=np.bool)
    for dd in range(ndiff):
        this_diff = np.flatnonzero(lbl_diff == dd)
        tau_count_by_diff[:, dd] = np.sum(tau_flag[:, this_diff],
                                          axis=-1,
                                          dtype=np.int)
        tau_stat_flag_by_diff[:,
                              dd] = tau_count_by_diff[:,
                                                      dd] > config.min_num_transit

    ## Calculate statistics of residuals
    std_resid = np.zeros((nfreq, ninput), dtype=phi.dtype)
    mad_resid = np.zeros((nfreq, ninput), dtype=phi.dtype)

    std_resid_by_diff = np.zeros((nfreq, ninput, ndiff), dtype=phi.dtype)
    mad_resid_by_diff = np.zeros((nfreq, ninput, ndiff), dtype=phi.dtype)

    for ff in range(nfreq):
        for ii in range(ninput):
            this_flag = flag[ff, ii, :]
            if np.any(this_flag):
                std_resid[ff, ii] = np.std(resid[ff, ii, this_flag])
                mad_resid[ff, ii] = 1.48625 * wq.median(
                    np.abs(resid[ff, ii, :]), this_flag.astype(np.float32))

                for dd in range(ndiff):
                    this_diff = this_flag & (lbl_diff == dd)
                    if np.any(this_diff):
                        std_resid_by_diff[ff, ii,
                                          dd] = np.std(resid[ff, ii,
                                                             this_diff])
                        mad_resid_by_diff[ff, ii, dd] = 1.48625 * wq.median(
                            np.abs(resid[ff, ii, :]),
                            this_diff.astype(np.float32))

    ## Calculate statistics of delay template
    mad_tau = np.zeros((ninput, ), dtype=phi.dtype)
    std_tau = np.zeros((ninput, ), dtype=phi.dtype)

    mad_tau_by_diff = np.zeros((ninput, ndiff), dtype=phi.dtype)
    std_tau_by_diff = np.zeros((ninput, ndiff), dtype=phi.dtype)

    for ii in range(ninput):
        this_flag = tau_flag[ii]
        if np.any(this_flag):
            std_tau[ii] = np.std(tau[ii, this_flag])
            mad_tau[ii] = 1.48625 * wq.median(np.abs(tau[ii]),
                                              this_flag.astype(np.float32))

            for dd in range(ndiff):
                this_diff = this_flag & (lbl_diff == dd)
                if np.any(this_diff):
                    std_tau_by_diff[ii, dd] = np.std(tau[ii, this_diff])
                    mad_tau_by_diff[ii, dd] = 1.48625 * wq.median(
                        np.abs(tau[ii]), this_diff.astype(np.float32))

    ## Define output
    res = {
        "timestamp": {
            "data": timestamp,
            "axis": ["div", "time"]
        },
        "is_daytime": {
            "data": is_daytime,
            "axis": ["div", "time"]
        },
        "csd": {
            "data": kcsd,
            "axis": ["div", "time"]
        },
        "pair_map": {
            "data": lbl_diff,
            "axis": ["time"]
        },
        "pair_count": {
            "data": cnt_diff,
            "axis": ["pair"]
        },
        "gain": {
            "data": gaincal,
            "axis": ["freq", "input", "time"]
        },
        "frac_err": {
            "data": frac_err_cal,
            "axis": ["freq", "input", "time"]
        },
        "flags/gain": {
            "data": flag,
            "axis": ["freq", "input", "time"],
            "flag": True
        },
        "flags/gain_conservative": {
            "data": c_flag,
            "axis": ["freq", "input", "time"],
            "flag": True
        },
        "flags/count": {
            "data": count,
            "axis": ["freq", "input"],
            "flag": True
        },
        "flags/stat": {
            "data": stat_flag,
            "axis": ["freq", "input"],
            "flag": True
        },
        "flags/count_by_pair": {
            "data": count_by_diff,
            "axis": ["freq", "input", "pair"],
            "flag": True
        },
        "flags/stat_by_pair": {
            "data": stat_flag_by_diff,
            "axis": ["freq", "input", "pair"],
            "flag": True
        },
        "med_amp": {
            "data": med_amp,
            "axis": ["freq", "input", "pair"]
        },
        "med_phi": {
            "data": med_phi,
            "axis": ["freq", "input", "pair"]
        },
        "flags/group_flag": {
            "data": group_flag,
            "axis": ["group", "input"],
            "flag": True
        },
        "cmn_amp": {
            "data": cmn_amp,
            "axis": ["freq", "group", "time"]
        },
        "cmn_phi": {
            "data": cmn_phi,
            "axis": ["freq", "group", "time"]
        },
        "amp": {
            "data": damp,
            "axis": ["freq", "input", "time"]
        },
        "phi": {
            "data": dphi,
            "axis": ["freq", "input", "time"]
        },
        "std_amp": {
            "data": std_amp,
            "axis": ["freq", "input"]
        },
        "std_amp_by_pair": {
            "data": std_amp_by_diff,
            "axis": ["freq", "input", "pair"]
        },
        "mad_amp": {
            "data": mad_amp,
            "axis": ["freq", "input"]
        },
        "mad_amp_by_pair": {
            "data": mad_amp_by_diff,
            "axis": ["freq", "input", "pair"]
        },
        "std_phi": {
            "data": std_phi,
            "axis": ["freq", "input"]
        },
        "std_phi_by_pair": {
            "data": std_phi_by_diff,
            "axis": ["freq", "input", "pair"]
        },
        "mad_phi": {
            "data": mad_phi,
            "axis": ["freq", "input"]
        },
        "mad_phi_by_pair": {
            "data": mad_phi_by_diff,
            "axis": ["freq", "input", "pair"]
        },
        "tau": {
            "data": tau,
            "axis": ["input", "time"]
        },
        "flags/tau": {
            "data": tau_flag,
            "axis": ["input", "time"],
            "flag": True
        },
        "flags/tau_count": {
            "data": tau_count,
            "axis": ["input"],
            "flag": True
        },
        "flags/tau_stat": {
            "data": tau_stat_flag,
            "axis": ["input"],
            "flag": True
        },
        "flags/tau_count_by_pair": {
            "data": tau_count_by_diff,
            "axis": ["input", "pair"],
            "flag": True
        },
        "flags/tau_stat_by_pair": {
            "data": tau_stat_flag_by_diff,
            "axis": ["input", "pair"],
            "flag": True
        },
        "std_tau": {
            "data": std_tau,
            "axis": ["input"]
        },
        "std_tau_by_pair": {
            "data": std_tau_by_diff,
            "axis": ["input", "pair"]
        },
        "mad_tau": {
            "data": mad_tau,
            "axis": ["input"]
        },
        "mad_tau_by_pair": {
            "data": mad_tau_by_diff,
            "axis": ["input", "pair"]
        },
        "resid_phi": {
            "data": resid,
            "axis": ["freq", "input", "time"]
        },
        "std_resid_phi": {
            "data": std_resid,
            "axis": ["freq", "input"]
        },
        "std_resid_phi_by_pair": {
            "data": std_resid_by_diff,
            "axis": ["freq", "input", "pair"]
        },
        "mad_resid_phi": {
            "data": mad_resid,
            "axis": ["freq", "input"]
        },
        "mad_resid_phi_by_pair": {
            "data": mad_resid_by_diff,
            "axis": ["freq", "input", "pair"]
        },
    }

    ## Create the output container
    logger.info("Creating StabilityData container.")
    data = StabilityData()

    data.create_index_map(
        "div", np.array(["numerator", "denominator"], dtype=np.string_))
    data.create_index_map("pair", np.array(uniq_diff, dtype=np.string_))
    data.create_index_map("group", np.array(group_str, dtype=np.string_))

    data.create_index_map("freq", freq)
    data.create_index_map("input", inputs)
    data.create_index_map("time", timestamp[0, :])

    logger.info("Writing datsets to container.")
    for name, dct in res.iteritems():
        is_flag = dct.get('flag', False)
        if is_flag:
            dset = data.create_flag(name.split('/')[-1], data=dct['data'])
        else:
            dset = data.create_dataset(name, data=dct['data'])

        dset.attrs['axis'] = np.array(dct['axis'], dtype=np.string_)

    data.attrs['phase_ref'] = np.array(
        [iref for iref in index_phase_ref if iref is not None])

    # Determine the output filename and save results
    start_time, end_time = ephemeris.unix_to_datetime(
        np.percentile(timestamp, [0, 100]))
    tfmt = "%Y%m%d"
    night_str = 'night_' if not np.any(is_daytime) else ''
    output_file = os.path.join(
        config.output_dir, "%s_%s_%sraw_stability_data.h5" %
        (start_time.strftime(tfmt), end_time.strftime(tfmt), night_str))

    logger.info("Saving results to %s." % output_file)
    data.save(output_file)
def solve_ps_transit(filename, corrs, feeds, inp, 
          src, nfreq=1024, transposed=False, nfeed=128):
    """ Function that fringestops time slice 
    where point source is in the beam, takes 
    all correlations for a given polarization, and then 
    eigendecomposes the correlation matrix freq by freq
    after removing the fpga phases. It will also 
    plot intermediate steps to verify the phase solution.

    Parameters
    ----------
    filename : np.str
         Full-path filename 
    corrs : list
         List of correlations to use in solver
    feeds : list
         List of feeds to use
    inp   : 
         Correlator inputs (output of ch_util.tools.get_correlator_inputs)
    src   : ephem.FixedBody
         Source to calibrate off of. e.g. ch_util.ephemeris.TauA
    
    Returns
    -------
    Gains : np.array
         Complex gain array (nfreq, nfeed) 
    """

    nsplit = 32 # Number of freq chunks to divide nfreq into
    del_t = 800

    f = h5py.File(filename, 'r')

    # Add half an integration time to each. Hack. 
    times = f['index_map']['time'].value['ctime'] + 10.50
    src_trans = eph.transit_times(src, times[0])
    
    # try to account for differential arrival time from 
    # cylinder rotation. 
    del_phi = (src._dec - np.radians(eph.CHIMELATITUDE)) \
                 * np.sin(np.radians(1.988))
    del_phi *= (24 * 3600.0) / (2 * np.pi)

    # Adjust the transit time accordingly
    src_trans += del_phi

    # Select +- del_t of transit, accounting for the mispointing 
    t_range = np.where((times < src_trans + 
                  del_t) & (times > src_trans - del_t))[0]
 
    print "\n...... This data is from %s starting at RA: %f ...... \n" \
        % (eph.unix_to_datetime(times[0]), eph.transit_RA(times[0]))

    assert (len(t_range) > 0), "Source is not in this acq"

    # Create gains array to fill in solution
    Gains = np.zeros([nfreq, nfeed], np.complex128)
    
    print "Starting the solver"
    
    times = times[t_range[0]:t_range[-1]]
    
    k=0
    
    # Start at a strong freq channel that can be plotted
    # and from which we can find the noise source on-sample
    for i in range(12, nsplit) + range(0, 12):

        k+=1

        # Divides the arrays up into nfreq / nsplit freq chunks and solves those
        frq = range(i * nfreq // nsplit, (i+1) * nfreq // nsplit)
        
        print "      %d:%d \n" % (frq[0], frq[-1])

        # Read in time and freq slice if data has already been transposed
        if transposed is True:
            v = f['vis'][frq[0]:frq[-1]+1, corrs, :]
            v = v[..., t_range[0]:t_range[-1]]
            vis = v['r'] + 1j * v['i']

            if k==1:
                autos = auto_corrs(nfeed)
                offp = (abs(vis[:, autos, 0::2]).mean() > \
                        (abs(vis[:, autos, 1::2]).mean())).astype(int)

                times = times[offp::2]
            
            vis = vis[..., offp::2]

            gg = f['gain_coeff'][frq[0]:frq[-1]+1, 
                    feeds, t_range[0]:t_range[-1]][..., offp::2]

            gain_coeff = gg['r'] + 1j * gg['i']
            
            del gg
            


        # Read in time and freq slice if data has not yet been transposed
        if transposed is False:
            print "TRANSPOSED V OF CODE DOESN'T WORK YET!"
            v = f['vis'][t_range[0]:t_range[-1]:2, frq[0]:frq[-1]+1, corrs]
            vis = v['r'][:] + 1j * v['i'][:]
            del v

            gg = f['gain_coeff'][0, frq[0]:frq[-1]+1, feeds]
            gain_coeff = gg['r'][:] + 1j * gg['i'][:]

            vis = vis[..., offp::2]

            vis = np.transpose(vis, (1, 2, 0))


        # Remove fpga gains from data
        vis = remove_fpga_gains(vis, gain_coeff, nfeed=nfeed, triu=False)

        # Remove offset from galaxy
        vis -= 0.5 * (vis[..., 0] + vis[..., -1])[..., np.newaxis]
   
        # Get physical freq for fringestopper
        freq_MHZ = 800.0 - np.array(frq) / 1024.0 * 400.
    
        baddies = np.where(np.isnan(tools.get_feed_positions(inp)[:, 0]))[0]
        a, b, c = select_corrs(baddies, nfeed=128)

        vis[:, a + b] = 0.0

        # Fringestop to location of "src"
        data_fs = tools.fringestop_pathfinder(vis, eph.transit_RA(times), freq_MHZ, inp, src)

        del vis

        dr, sol_arr = solve_gain(data_fs)

        # Find index of point source transit
        drlist = np.argmax(dr, axis=-1)
        
        # If multiple freq channels are zerod, the trans_pix
        # will end up being 0. This is bad, so ensure that 
        # you are only looking for non-zero transit pixels.
        drlist = [x for x in drlist if x != 0]
        trans_pix = np.argmax(np.bincount(drlist))

        assert trans_pix != 0.0

        Gains[frq] = sol_arr[..., trans_pix-3:trans_pix+4].mean(-1)

        zz = h5py.File('data' + str(i) + '.hdf5','w')
        zz.create_dataset('data', data=dr)
        zz.close()

        print "%f, %d Nans out of %d" % (np.isnan(sol_arr).sum(), np.isnan(Gains[frq]).sum(), np.isnan(Gains[frq]).sum())
        print trans_pix, sol_arr[..., trans_pix-3:trans_pix+4].mean(-1).sum(), sol_arr.mean(-1).sum()

        # Plot up post-fs phases to see if everything has been fixed
        if frq[0] == 12 * nsplit:
            print "======================"
            print "   Plotting up freq: %d" % frq[0]
            print "======================"
            img_nm = './phs_plots/dfs' + np.str(frq[17]) + np.str(np.int(time.time())) +'.png'
            img_nmcorr = './phs_plots/dfs' + np.str(frq[17]) + np.str(np.int(time.time())) +'corr.png'

            plt_gains(data_fs, 0, img_name=img_nm, bad_chans=baddies)
            dfs_corr = correct_dfs(data_fs, np.angle(Gains[frq])[..., np.newaxis], nfeed=128)

            plt_gains(dfs_corr, 0, img_name=img_nmcorr, bad_chans=baddies)

            del dfs_corr

        del data_fs, a

    return Gains
Example #28
0
def fs_from_file(filename,
                 frq,
                 src,
                 del_t=900,
                 transposed=True,
                 subtract_avg=False):

    f = h5py.File(filename, 'r')

    times = f['index_map']['time'].value['ctime'] + 10.6

    src_trans = eph.transit_times(src, times[0])

    # try to account for differential arrival time from cylinder rotation.

    del_phi = (src._dec - np.radians(eph.CHIMELATITUDE)) * np.sin(
        np.radians(1.988))
    del_phi *= (24 * 3600.0) / (2 * np.pi)

    # Adjust the transit time accordingly
    src_trans += del_phi

    # Select +- del_t of transit, accounting for the mispointing
    t_range = np.where((times < src_trans + del_t)
                       & (times > src_trans - del_t))[0]

    times = times[t_range[0]:t_range[-1]]  #[offp::2] test

    print "Time range:", times[0], times[-1]

    print "\n...... This data is from %s starting at RA: %f ...... \n" \
        % (eph.unix_to_datetime(times[0]), eph.transit_RA(times[0]))

    if transposed is True:
        v = f['vis'][frq[0]:frq[-1] + 1, :]
        v = v[..., t_range[0]:t_range[-1]]
        vis = v['r'] + 1j * v['i']

        del v

    # Read in time and freq slice if data has not yet been transposed
    if transposed is False:
        v = f['vis'][t_range[0]:t_range[-1], frq[0]:frq[-1] + 1, :]
        vis = v['r'][:] + 1j * v['i'][:]
        del v
        vis = np.transpose(vis, (1, 2, 0))

    inp = gen_inp()[0]

    # Remove offset from galaxy
    if subtract_avg is True:
        vis -= 0.5 * (vis[..., 0] + vis[..., -1])[..., np.newaxis]

    freq_MHZ = 800.0 - np.array(frq) / 1024.0 * 400.
    print len(inp)

    baddies = np.where(np.isnan(tools.get_feed_positions(inp)[:, 0]))[0]

    # Fringestop to location of "src"

    data_fs = tools.fringestop_pathfinder(vis, eph.transit_RA(times), freq_MHZ,
                                          inp, src)
    #    data_fs = fringestop_pathfinder(vis, eph.transit_RA(times), freq_MHZ, inp, src)

    return data_fs
Example #29
0
def solve_ps_transit(filename,
                     corrs,
                     feeds,
                     inp,
                     src,
                     nfreq=1024,
                     transposed=False,
                     nfeed=128):
    """ Function that fringestops time slice 
    where point source is in the beam, takes 
    all correlations for a given polarization, and then 
    eigendecomposes the correlation matrix freq by freq
    after removing the fpga phases. It will also 
    plot intermediate steps to verify the phase solution.

    Parameters
    ----------
    filename : np.str
         Full-path filename 
    corrs : list
         List of correlations to use in solver
    feeds : list
         List of feeds to use
    inp   : 
         Correlator inputs (output of ch_util.tools.get_correlator_inputs)
    src   : ephem.FixedBody
         Source to calibrate off of. e.g. ch_util.ephemeris.TauA
    
    Returns
    -------
    Gains : np.array
         Complex gain array (nfreq, nfeed) 
    """

    nsplit = 32  # Number of freq chunks to divide nfreq into
    del_t = 800

    f = h5py.File(filename, 'r')

    # Add half an integration time to each. Hack.
    times = f['index_map']['time'].value['ctime'] + 10.50
    src_trans = eph.transit_times(src, times[0])

    # try to account for differential arrival time from
    # cylinder rotation.
    del_phi = (src._dec - np.radians(eph.CHIMELATITUDE)) \
                 * np.sin(np.radians(1.988))
    del_phi *= (24 * 3600.0) / (2 * np.pi)

    # Adjust the transit time accordingly
    src_trans += del_phi

    # Select +- del_t of transit, accounting for the mispointing
    t_range = np.where((times < src_trans + del_t)
                       & (times > src_trans - del_t))[0]

    print "\n...... This data is from %s starting at RA: %f ...... \n" \
        % (eph.unix_to_datetime(times[0]), eph.transit_RA(times[0]))

    assert (len(t_range) > 0), "Source is not in this acq"

    # Create gains array to fill in solution
    Gains = np.zeros([nfreq, nfeed], np.complex128)

    print "Starting the solver"

    times = times[t_range[0]:t_range[-1]]

    k = 0

    # Start at a strong freq channel that can be plotted
    # and from which we can find the noise source on-sample
    for i in range(12, nsplit) + range(0, 12):

        k += 1

        # Divides the arrays up into nfreq / nsplit freq chunks and solves those
        frq = range(i * nfreq // nsplit, (i + 1) * nfreq // nsplit)

        print "      %d:%d \n" % (frq[0], frq[-1])

        # Read in time and freq slice if data has already been transposed
        if transposed is True:
            v = f['vis'][frq[0]:frq[-1] + 1, corrs, :]
            v = v[..., t_range[0]:t_range[-1]]
            vis = v['r'] + 1j * v['i']

            if k == 1:
                autos = auto_corrs(nfeed)
                offp = (abs(vis[:, autos, 0::2]).mean() > \
                        (abs(vis[:, autos, 1::2]).mean())).astype(int)

                times = times[offp::2]

            vis = vis[..., offp::2]

            gg = f['gain_coeff'][frq[0]:frq[-1] + 1, feeds,
                                 t_range[0]:t_range[-1]][..., offp::2]

            gain_coeff = gg['r'] + 1j * gg['i']

            del gg

        # Read in time and freq slice if data has not yet been transposed
        if transposed is False:
            print "TRANSPOSED V OF CODE DOESN'T WORK YET!"
            v = f['vis'][t_range[0]:t_range[-1]:2, frq[0]:frq[-1] + 1, corrs]
            vis = v['r'][:] + 1j * v['i'][:]
            del v

            gg = f['gain_coeff'][0, frq[0]:frq[-1] + 1, feeds]
            gain_coeff = gg['r'][:] + 1j * gg['i'][:]

            vis = vis[..., offp::2]

            vis = np.transpose(vis, (1, 2, 0))

        # Remove fpga gains from data
        vis = remove_fpga_gains(vis, gain_coeff, nfeed=nfeed, triu=False)

        # Remove offset from galaxy
        vis -= 0.5 * (vis[..., 0] + vis[..., -1])[..., np.newaxis]

        # Get physical freq for fringestopper
        freq_MHZ = 800.0 - np.array(frq) / 1024.0 * 400.

        baddies = np.where(np.isnan(tools.get_feed_positions(inp)[:, 0]))[0]
        a, b, c = select_corrs(baddies, nfeed=128)

        vis[:, a + b] = 0.0

        # Fringestop to location of "src"
        data_fs = tools.fringestop_pathfinder(vis, eph.transit_RA(times),
                                              freq_MHZ, inp, src)

        del vis

        dr, sol_arr = solve_gain(data_fs)

        # Find index of point source transit
        drlist = np.argmax(dr, axis=-1)

        # If multiple freq channels are zerod, the trans_pix
        # will end up being 0. This is bad, so ensure that
        # you are only looking for non-zero transit pixels.
        drlist = [x for x in drlist if x != 0]
        trans_pix = np.argmax(np.bincount(drlist))

        assert trans_pix != 0.0

        Gains[frq] = sol_arr[..., trans_pix - 3:trans_pix + 4].mean(-1)

        zz = h5py.File('data' + str(i) + '.hdf5', 'w')
        zz.create_dataset('data', data=dr)
        zz.close()

        print "%f, %d Nans out of %d" % (np.isnan(sol_arr).sum(),
                                         np.isnan(Gains[frq]).sum(),
                                         np.isnan(Gains[frq]).sum())
        print trans_pix, sol_arr[..., trans_pix - 3:trans_pix +
                                 4].mean(-1).sum(), sol_arr.mean(-1).sum()

        # Plot up post-fs phases to see if everything has been fixed
        if frq[0] == 12 * nsplit:
            print "======================"
            print "   Plotting up freq: %d" % frq[0]
            print "======================"
            img_nm = './phs_plots/dfs' + np.str(frq[17]) + np.str(
                np.int(time.time())) + '.png'
            img_nmcorr = './phs_plots/dfs' + np.str(frq[17]) + np.str(
                np.int(time.time())) + 'corr.png'

            plt_gains(data_fs, 0, img_name=img_nm, bad_chans=baddies)
            dfs_corr = correct_dfs(data_fs,
                                   np.angle(Gains[frq])[..., np.newaxis],
                                   nfeed=128)

            plt_gains(dfs_corr, 0, img_name=img_nmcorr, bad_chans=baddies)

            del dfs_corr

        del data_fs, a

    return Gains
Example #30
0
def main(config_file=None, logging_params=DEFAULT_LOGGING):

    # Setup logging
    log.setup_logging(logging_params)
    mlog = log.get_logger(__name__)

    # Set config
    config = DEFAULTS.deepcopy()
    if config_file is not None:
        config.merge(NameSpace(load_yaml_config(config_file)))

    # Set niceness
    current_niceness = os.nice(0)
    os.nice(config.niceness - current_niceness)
    mlog.info('Changing process niceness from %d to %d.  Confirm:  %d' %
                  (current_niceness, config.niceness, os.nice(0)))

    # Create output suffix
    output_suffix = config.output_suffix if config.output_suffix is not None else "jumps"

    # Calculate the wavelet transform for the following scales
    nwin = 2 * config.max_scale + 1
    nhwin = nwin // 2

    if config.log_scale:
        mlog.info("Using log scale.")
        scale = np.logspace(np.log10(config.min_scale), np.log10(nwin), num=config.num_points, dtype=np.int)
    else:
        mlog.info("Using linear scale.")
        scale = np.arange(config.min_scale, nwin, dtype=np.int)

    # Loop over acquisitions
    for acq in config.acq:

        # Find acquisition files
        all_data_files = sorted(glob(os.path.join(config.data_dir, acq, "*.h5")))
        nfiles = len(all_data_files)

        if nfiles == 0:
            continue

        mlog.info("Now processing acquisition %s (%d files)" % (acq, nfiles))

        # Determine list of feeds to examine
        dset = ['flags/inputs'] if config.use_input_flag else ()

        rdr = andata.CorrData.from_acq_h5(all_data_files, datasets=dset,
                                          apply_gain=False, renormalize=False)

        inputmap = tools.get_correlator_inputs(ephemeris.unix_to_datetime(rdr.time[0]),
                                               correlator='chime')

        # Extract good inputs
        if config.use_input_flag:
            ifeed = np.flatnonzero((np.sum(rdr.flags['inputs'][:], axis=-1, dtype=np.int) /
                                     float(rdr.flags['inputs'].shape[-1])) > config.input_threshold)
        else:
            ifeed = np.array([ii for ii, inp in enumerate(inputmap) if tools.is_chime(inp)])

        ninp = len(ifeed)

        mlog.info("Processing %d feeds." % ninp)

        # Create list of candidates
        cfreq, cinput, ctime, cindex = [], [], [], []
        jump_flag, jump_time, jump_auto = [], [], []
        ncandidate = 0

        # Determine number of files to process at once
        if config.max_num_file is None:
            chunk_size = nfiles
        else:
            chunk_size = min(config.max_num_file, nfiles)

        # Loop over chunks of files
        for chnk, data_files in enumerate(chunks(all_data_files, chunk_size)):

            mlog.info("Now processing chunk %d (%d files)" % (chnk, len(data_files)))

            # Deteremine selections along the various axes
            rdr = andata.CorrData.from_acq_h5(data_files, datasets=())

            auto_sel = np.array([ii for ii, pp in enumerate(rdr.prod) if pp[0] == pp[1]])
            auto_sel = andata._convert_to_slice(auto_sel)

            if config.time_start is None:
                ind_start = 0
            else:
                time_start = ephemeris.datetime_to_unix(datetime.datetime(*config.time_start))
                ind_start = int(np.argmin(np.abs(rdr.time - time_start)))

            if config.time_stop is None:
                ind_stop = rdr.ntime
            else:
                time_stop = ephemeris.datetime_to_unix(datetime.datetime(*config.time_stop))
                ind_stop = int(np.argmin(np.abs(rdr.time - time_stop)))

            if config.freq_physical is not None:

                if hasattr(config.freq_physical, '__iter__'):
                    freq_physical = config.freq_physical
                else:
                    freq_physical = [config.freq_physical]

                freq_sel = [np.argmin(np.abs(ff - rdr.freq)) for ff in freq_physical]
                freq_sel = andata._convert_to_slice(freq_sel)

            else:
                fstart = config.freq_start if config.freq_start is not None else 0
                fstop = config.freq_stop if config.freq_stop is not None else rdr.freq.size
                freq_sel = slice(fstart, fstop)

            # Load autocorrelations
            t0 = time.time()
            data = andata.CorrData.from_acq_h5(data_files, datasets=['vis'], start=ind_start, stop=ind_stop,
                                                           freq_sel=freq_sel, prod_sel=auto_sel,
                                                           apply_gain=False, renormalize=False)

            mlog.info("Took %0.1f seconds to load autocorrelations." % (time.time() - t0,))

            # If first chunk, save the frequencies that are being used
            if not chnk:
                all_freq = data.freq.copy()

            # If requested do not consider data during day or near bright source transits
            flag_quiet = np.ones(data.ntime, dtype=np.bool)
            if config.ignore_sun:
                flag_quiet &= ~transit_flag('sun', data.time, freq=np.min(data.freq), pol='X', nsig=1.0)

            if config.only_quiet:
                flag_quiet &= ~daytime_flag(data.time)
                for ss in ["CYG_A", "CAS_A", "TAU_A", "VIR_A"]:
                    flag_quiet &= ~transit_flag(ss, data.time, freq=np.min(data.freq), pol='X', nsig=1.0)

            # Loop over frequencies
            for ff, freq in enumerate(data.freq):

                print_cnt = 0
                mlog.info("FREQ %d (%0.2f MHz)" % (ff, freq))

                auto = data.vis[ff, :, :].real

                fractional_auto = auto * tools.invert_no_zero(np.median(auto, axis=-1, keepdims=True)) - 1.0

                # Loop over inputs
                for ii in ifeed:

                    print_cnt += 1
                    do_print = not (print_cnt % 100)

                    if do_print:
                        mlog.info("INPUT %d" % ii)
                    t0 = time.time()

                    signal = fractional_auto[ii, :]

                    # Perform wavelet transform
                    coef, freqs = pywt.cwt(signal, scale, config.wavelet_name)

                    if do_print:
                        mlog.info("Took %0.1f seconds to perform wavelet transform." % (time.time() - t0,))
                    t0 = time.time()

                    # Find local modulus maxima
                    flg_mod_max, mod_max = mod_max_finder(scale, coef, threshold=config.thresh, search_span=config.search_span)

                    if do_print:
                        mlog.info("Took %0.1f seconds to find modulus maxima." % (time.time() - t0,))
                    t0 = time.time()

                    # Find persisent modulus maxima across scales
                    candidates, cmm, pdrift, start, stop, lbl = finger_finder(scale, flg_mod_max, mod_max,
                                                                              istart=max(config.min_rise - config.min_scale, 0),
                                                                              do_fill=False)

                    if do_print:
                        mlog.info("Took %0.1f seconds to find fingers." % (time.time() - t0,))
                    t0 = time.time()

                    if candidates is None:
                        continue

                    # Cut bad candidates
                    index_good_candidates = np.flatnonzero((scale[stop] >= config.max_scale) &
                                                            flag_quiet[candidates[start, np.arange(start.size)]] &
                                                            (pdrift <= config.psigma_max))

                    ngood = index_good_candidates.size

                    if ngood == 0:
                        continue

                    mlog.info("Input %d has %d jumps" % (ii, ngood))

                    # Add remaining candidates to list
                    ncandidate += ngood

                    cfreq += [freq] * ngood
                    cinput += [ii] * ngood

                    for igc in index_good_candidates:

                        icenter = candidates[start[igc], igc]

                        cindex.append(icenter)
                        ctime.append(data.time[icenter])

                        aa = max(0, icenter - nhwin)
                        bb = min(data.ntime, icenter + nhwin + 1)

                        ncut = bb - aa

                        temp_var = np.zeros(nwin, dtype=np.bool)
                        temp_var[0:ncut] = True
                        jump_flag.append(temp_var)

                        temp_var = np.zeros(nwin, dtype=data.time.dtype)
                        temp_var[0:ncut] = data.time[aa:bb].copy()
                        jump_time.append(temp_var)

                        temp_var = np.zeros(nwin, dtype=auto.dtype)
                        temp_var[0:ncut] = auto[ii, aa:bb].copy()
                        jump_auto.append(temp_var)


            # Garbage collect
            del data
            gc.collect()

        # If we found any jumps, write them to a file.
        if ncandidate > 0:

            output_file = os.path.join(config.output_dir, "%s_%s.h5" % (acq, output_suffix))

            mlog.info("Writing %d jumps to: %s" % (ncandidate, output_file))

            # Write to output file
            with h5py.File(output_file, 'w') as handler:

                handler.attrs['files'] = all_data_files
                handler.attrs['chan_id'] = ifeed
                handler.attrs['freq'] = all_freq

                index_map = handler.create_group('index_map')
                index_map.create_dataset('jump', data=np.arange(ncandidate))
                index_map.create_dataset('window', data=np.arange(nwin))

                ax = np.array(['jump'])

                dset = handler.create_dataset('freq', data=np.array(cfreq))
                dset.attrs['axis'] = ax

                dset = handler.create_dataset('input', data=np.array(cinput))
                dset.attrs['axis'] = ax

                dset = handler.create_dataset('time', data=np.array(ctime))
                dset.attrs['axis'] = ax

                dset = handler.create_dataset('time_index', data=np.array(cindex))
                dset.attrs['axis'] = ax


                ax = np.array(['jump', 'window'])

                dset = handler.create_dataset('jump_flag', data=np.array(jump_flag))
                dset.attrs['axis'] = ax

                dset = handler.create_dataset('jump_time', data=np.array(jump_time))
                dset.attrs['axis'] = ax

                dset = handler.create_dataset('jump_auto', data=np.array(jump_auto))
                dset.attrs['axis'] = ax

        else:
            mlog.info("No jumps found for %s acquisition." % acq)
Example #31
0
    def process(self, sstream, inputmap, inputmask):
        """Determine calibration from a timestream.

        Parameters
        ----------
        sstream : andata.CorrData or containers.SiderealStream
            Timestream collected during the day.
        inputmap : list of :class:`CorrInput`
            A list describing the inputs as they are in the file.
        inputmask : containers.CorrInputMask
            Mask indicating which correlator inputs to use in the
            eigenvalue decomposition.

        Returns
        -------
        suntrans : containers.SunTransit
            Response to the sun.
        """

        from operator import itemgetter
        from itertools import groupby
        from .calibration import _extract_diagonal, solve_gain

        # Ensure that we are distributed over frequency
        sstream.redistribute("freq")

        # Find the local frequencies
        nfreq = sstream.vis.local_shape[0]
        sfreq = sstream.vis.local_offset[0]
        efreq = sfreq + nfreq

        # Get the local frequency axis
        freq = sstream.freq["centre"][sfreq:efreq]
        wv = 3e2 / freq

        # Get times
        if hasattr(sstream, "time"):
            time = sstream.time
            ra = ephemeris.transit_RA(time)
        else:
            ra = sstream.index_map["ra"][:]
            csd = (sstream.attrs["lsd"]
                   if "lsd" in sstream.attrs else sstream.attrs["csd"])
            csd = csd + ra / 360.0
            time = ephemeris.csd_to_unix(csd)

        # Only examine data between sunrise and sunset
        time_flag = np.zeros(len(time), dtype=np.bool)
        rise = ephemeris.solar_rising(time[0] - 24.0 * 3600.0,
                                      end_time=time[-1])
        for rr in rise:
            ss = ephemeris.solar_setting(rr)[0]
            time_flag |= (time >= rr) & (time <= ss)

        if not np.any(time_flag):
            self.log.debug(
                "No daytime data between %s and %s.",
                ephemeris.unix_to_datetime(time[0]).strftime("%b %d %H:%M"),
                ephemeris.unix_to_datetime(time[-1]).strftime("%b %d %H:%M"),
            )
            return None

        # Convert boolean flag to slices
        time_index = np.where(time_flag)[0]

        time_slice = []
        ntime = 0
        for key, group in groupby(
                enumerate(time_index),
                lambda index_item: index_item[0] - index_item[1]):
            group = list(map(itemgetter(1), group))
            ngroup = len(group)
            time_slice.append(
                (slice(group[0], group[-1] + 1), slice(ntime, ntime + ngroup)))
            ntime += ngroup

        time = np.concatenate([time[slc[0]] for slc in time_slice])
        ra = np.concatenate([ra[slc[0]] for slc in time_slice])

        # Get ra, dec, alt of sun
        sun_pos = np.array([
            ra_dec_of(ephemeris.skyfield_wrapper.ephemeris["sun"], t)
            for t in time
        ])

        # Convert from ra to hour angle
        sun_pos[:, 0] = np.radians(ra) - sun_pos[:, 0]

        # Determine good inputs
        nfeed = len(inputmap)
        good_input = np.arange(
            nfeed, dtype=np.int)[inputmask.datasets["input_mask"][:]]

        # Use input map to figure out which are the X and Y feeds
        xfeeds = np.array([
            idx for idx, inp in enumerate(inputmap)
            if tools.is_chime_x(inp) and (idx in good_input)
        ])
        yfeeds = np.array([
            idx for idx, inp in enumerate(inputmap)
            if tools.is_chime_y(inp) and (idx in good_input)
        ])

        self.log.debug(
            "Performing sun calibration with %d/%d good feeds (%d xpol, %d ypol).",
            len(good_input),
            nfeed,
            len(xfeeds),
            len(yfeeds),
        )

        # Construct baseline vector for each visibility
        feed_pos = tools.get_feed_positions(inputmap)
        vis_pos = np.array([
            feed_pos[ii] - feed_pos[ij]
            for ii, ij in sstream.index_map["prod"][:]
        ])
        vis_pos = np.where(np.isnan(vis_pos), np.zeros_like(vis_pos), vis_pos)

        u = (vis_pos[np.newaxis, :, 0] / wv[:, np.newaxis])[:, :, np.newaxis]
        v = (vis_pos[np.newaxis, :, 1] / wv[:, np.newaxis])[:, :, np.newaxis]

        # Create container to hold results of fit
        suntrans = containers.SunTransit(time=time,
                                         pol_x=xfeeds,
                                         pol_y=yfeeds,
                                         axes_from=sstream)
        for key in suntrans.datasets.keys():
            suntrans.datasets[key][:] = 0.0

        # Set coordinates
        suntrans.coord[:] = sun_pos

        # Loop over time slices
        for slc_in, slc_out in time_slice:

            # Extract visibility slice
            vis_slice = sstream.vis[..., slc_in].copy()

            ha = (sun_pos[slc_out, 0])[np.newaxis, np.newaxis, :]
            dec = (sun_pos[slc_out, 1])[np.newaxis, np.newaxis, :]

            # Extract the diagonal (to be used for weighting)
            norm = (_extract_diagonal(vis_slice, axis=1).real)**0.5
            norm = tools.invert_no_zero(norm)

            # Fringestop
            if self.fringestop:
                vis_slice *= tools.fringestop_phase(
                    ha, np.radians(ephemeris.CHIMELATITUDE), dec, u, v)

            # Solve for the point source response of each set of polarisations
            ev_x, resp_x, err_resp_x = solve_gain(vis_slice,
                                                  feeds=xfeeds,
                                                  norm=norm[:, xfeeds])
            ev_y, resp_y, err_resp_y = solve_gain(vis_slice,
                                                  feeds=yfeeds,
                                                  norm=norm[:, yfeeds])

            # Save to container
            suntrans.evalue_x[..., slc_out] = ev_x
            suntrans.evalue_y[..., slc_out] = ev_y

            suntrans.response[:, xfeeds, slc_out] = resp_x
            suntrans.response[:, yfeeds, slc_out] = resp_y

            suntrans.response_error[:, xfeeds, slc_out] = err_resp_x
            suntrans.response_error[:, yfeeds, slc_out] = err_resp_y

        # If requested, fit a model to the primary beam of the sun transit
        if self.model_fit:

            # Estimate peak RA
            i_transit = np.argmin(np.abs(sun_pos[:, 0]))

            body = ephemeris.skyfield_wrapper.ephemeris["sun"]
            obs = ephemeris._get_chime()
            obs.date = ephemeris.unix_to_ephem_time(time[i_transit])
            body.compute(obs)

            peak_ra = ephemeris.peak_RA(body)
            dra = ra - peak_ra
            dra = np.abs(dra - (dra > np.pi) * 2.0 * np.pi)[np.newaxis,
                                                            np.newaxis, :]

            # Estimate FWHM
            sig_x = cal_utils.guess_fwhm(freq,
                                         pol="X",
                                         dec=body.dec,
                                         sigma=True)[:, np.newaxis, np.newaxis]
            sig_y = cal_utils.guess_fwhm(freq,
                                         pol="Y",
                                         dec=body.dec,
                                         sigma=True)[:, np.newaxis, np.newaxis]

            # Only fit ra values above the specified dynamic range threshold
            fit_flag = np.zeros([nfreq, nfeed, ntime], dtype=np.bool)
            fit_flag[:, xfeeds, :] = dra < (self.nsig * sig_x)
            fit_flag[:, yfeeds, :] = dra < (self.nsig * sig_y)

            # Fit model for the complex response of each feed to the point source
            param, param_cov = cal_utils.fit_point_source_transit(
                ra,
                suntrans.response[:],
                suntrans.response_error[:],
                flag=fit_flag)

            # Save to container
            suntrans.add_dataset("flag")
            suntrans.flag[:] = fit_flag

            suntrans.add_dataset("parameter")
            suntrans.parameter[:] = param

            suntrans.add_dataset("parameter_cov")
            suntrans.parameter_cov[:] = param_cov

        # Update attributes
        units = "sqrt(" + sstream.vis.attrs.get("units",
                                                "correlator-units") + ")"
        suntrans.response.attrs["units"] = units
        suntrans.response_error.attrs["units"] = units

        suntrans.attrs["source"] = "Sun"

        # Return sun transit
        return suntrans
Example #32
0
    def create_from_dict(
        cls,
        dict,
        notes=None,
        start_tol=60.0,
        dryrun=True,
        replace_dup=False,
        verbose=False,
    ):
        """
        Create a holography database entry from a dictionary

        This routine checks for duplicates and overwrites duplicates if and
        only if `replace_dup = True`

        Parameters
        ----------
        dict : dict
            src : :py:class:`HolographySource`
                A HolographySource object for the source
            start_time
                Start time as a Skyfield Time object
            finish_time
                Finish time as a Skyfield Time object
        """
        DATE_FMT_STR = "%Y-%m-%d %H:%M:%S %Z"

        def check_for_duplicates(t, src, start_tol, ignore_src_mismatch=False):
            """
            Check for duplicate holography observations, comparing the given
            observation to the existing database

            Inputs
            ------
            t: Skyfield Time object
                beginning time of observation
            src: HolographySource
                target source
            start_tol: float
                Tolerance in seconds within which to search for duplicates
            ignore_src_mismatch: bool (default: False)
                If True, consider observations a match if the time matches
                but the source does not

            Outputs
            -------
            If a duplicate is found: :py:class:`HolographyObservation` object for the
            existing entry in the database

            If no duplicate is found: None
            """
            ts = ephemeris.skyfield_wrapper.timescale

            unixt = ephemeris.ensure_unix(t)

            dup_found = False

            existing_db_entry = cls.select().where(
                cls.start_time.between(unixt - start_tol, unixt + start_tol))
            if len(existing_db_entry) > 0:
                if len(existing_db_entry) > 1:
                    print("Multiple entries found.")
                for entry in existing_db_entry:
                    tt = ts.utc(ephemeris.unix_to_datetime(entry.start_time))
                    # LST = GST + east longitude
                    ttlst = np.mod(tt.gmst + DRAO_lon, 24.0)

                    # Check if source name matches. If not, print a warning
                    # but proceed anyway.
                    if src.name.upper() == entry.source.name.upper():
                        dup_found = True
                        if verbose:
                            print("Observation is already in database.")
                    else:
                        if ignore_src_mismatch:
                            dup_found = True
                        print(
                            "** Observation at same time but with different " +
                            "sources in database: ",
                            src.name,
                            entry.source.name,
                            tt.utc_datetime().isoformat(),
                        )
                        # if the observations match in start time and source,
                        # call them the same observation. Not the most strict
                        # check possible.

                    if dup_found:
                        tf = ts.utc(
                            ephemeris.unix_to_datetime(entry.finish_time))
                        print("Tried to add  :  {} {}; LST={:.3f}".format(
                            src.name,
                            t.utc_datetime().strftime(DATE_FMT_STR), ttlst))
                        print("Existing entry:  {} {}; LST={:.3f}".format(
                            entry.source.name,
                            tt.utc_datetime().strftime(DATE_FMT_STR),
                            ttlst,
                        ))
            if dup_found:
                return existing_db_entry
            else:
                return None

        # DRAO longitude in hours
        DRAO_lon = ephemeris.chime.longitude * 24.0 / 360.0

        if verbose:
            print(" ")
        addtodb = True

        dup_entries = check_for_duplicates(dict["start_time"], dict["src"],
                                           start_tol)

        if dup_entries is not None:
            if replace_dup:
                if not dryrun:
                    for entry in dup_entries:
                        cls.delete_instance(entry)
                        if verbose:
                            print(
                                "Deleted observation from database and replacing."
                            )
                elif verbose:
                    print(
                        "Would have deleted observation and replaced (dry run)."
                    )
                addtodb = True
            else:
                addtodb = False
                for entry in dup_entries:
                    print("Not replacing duplicate {} observation {}".format(
                        entry.source.name,
                        ephemeris.unix_to_datetime(
                            entry.start_time).strftime(DATE_FMT_STR),
                    ))

        # we've appended this observation to obslist.
        # Now add to the database, if we're supposed to.
        if addtodb:
            string = "Adding to database: {} {} to {}"
            print(
                string.format(
                    dict["src"].name,
                    dict["start_time"].utc_datetime().strftime(DATE_FMT_STR),
                    dict["finish_time"].utc_datetime().strftime(DATE_FMT_STR),
                ))
            if dryrun:
                print("Dry run; doing nothing")
            else:
                cls.create(
                    source=dict["src"],
                    start_time=ephemeris.ensure_unix(dict["start_time"]),
                    finish_time=ephemeris.ensure_unix(dict["finish_time"]),
                    quality_flag=dict["quality_flag"],
                    notes=notes,
                )
Example #33
0
def main(config_file=None, logging_params=DEFAULT_LOGGING):

    # Setup logging
    log.setup_logging(logging_params)
    mlog = log.get_logger(__name__)

    # Set config
    config = DEFAULTS.deepcopy()
    if config_file is not None:
        config.merge(NameSpace(load_yaml_config(config_file)))

    # Set niceness
    current_niceness = os.nice(0)
    os.nice(config.niceness - current_niceness)
    mlog.info('Changing process niceness from %d to %d.  Confirm:  %d' %
              (current_niceness, config.niceness, os.nice(0)))

    # Find acquisition files
    acq_files = sorted(glob(os.path.join(config.data_dir, config.acq, "*.h5")))
    nfiles = len(acq_files)

    # Determine time range of each file
    findex = []
    tindex = []
    for ii, filename in enumerate(acq_files):
        subdata = andata.CorrData.from_acq_h5(filename, datasets=())

        findex += [ii] * subdata.ntime
        tindex += range(subdata.ntime)

    findex = np.array(findex)
    tindex = np.array(tindex)

    # Determine transits within these files
    transits = []

    data = andata.CorrData.from_acq_h5(acq_files, datasets=())

    solar_rise = ephemeris.solar_rising(data.time[0] - 24.0 * 3600.0,
                                        end_time=data.time[-1])

    for rr in solar_rise:

        ss = ephemeris.solar_setting(rr)[0]

        solar_flag = np.flatnonzero((data.time >= rr) & (data.time <= ss))

        if solar_flag.size > 0:

            solar_flag = solar_flag[::config.downsample]

            tval = data.time[solar_flag]

            this_findex = findex[solar_flag]
            this_tindex = tindex[solar_flag]

            file_list, tindices = [], []

            for ii in range(nfiles):

                this_file = np.flatnonzero(this_findex == ii)

                if this_file.size > 0:

                    file_list.append(acq_files[ii])
                    tindices.append(this_tindex[this_file])

            date = ephemeris.unix_to_datetime(rr).strftime('%Y%m%dT%H%M%SZ')
            transits.append((date, tval, file_list, tindices))

    # Specify some parameters for algorithm
    N = 2048

    noffset = len(config.offsets)

    if config.sep_pol:
        rank = 1
        cross_pol = False
        pol = np.array(['S', 'E'])
        pol_s = np.array(
            [rr + 256 * xx for xx in range(0, 8, 2) for rr in range(256)])
        pol_e = np.array(
            [rr + 256 * xx for xx in range(1, 8, 2) for rr in range(256)])
        prod_ss = []
        prod_ee = []
    else:
        rank = 8
        cross_pol = config.cross_pol
        pol = np.array(['all'])

    npol = pol.size

    # Create file prefix and suffix
    prefix = []

    prefix.append("gain_solutions")

    if config.output_prefix is not None:
        prefix.append(config.output_prefix)

    prefix = '_'.join(prefix)

    suffix = []

    suffix.append("pol_%s" % '_'.join(pol))

    suffix.append("niter_%d" % config.niter)

    if cross_pol:
        suffix.append("zerocross")
    else:
        suffix.append("keepcross")

    if config.normalize:
        suffix.append("normed")
    else:
        suffix.append("notnormed")

    suffix = '_'.join(suffix)

    # Loop over solar transits
    for date, timestamps, files, time_indices in transits:

        nfiles = len(files)

        mlog.info("%s (%d files) " % (date, nfiles))

        output_file = os.path.join(
            config.output_dir, "%s_SUN_%s_%s.pickle" % (prefix, date, suffix))

        mlog.info("Saving to:  %s" % output_file)

        # Get info about this set of files
        data = andata.CorrData.from_acq_h5(files, datasets=['flags/inputs'])

        prod = data.prod

        coord = sun_coord(timestamps, deg=True)

        fstart = config.freq_start if config.freq_start is not None else 0
        fstop = config.freq_stop if config.freq_stop is not None else data.freq.size
        freq_index = range(fstart, fstop)

        freq = data.freq[freq_index]

        ntime = timestamps.size
        nfreq = freq.size

        # Determind bad inputs
        if config.bad_input_file is None or not os.path.isfile(
                config.bad_input_file):
            bad_input = np.flatnonzero(
                ~np.all(data.flags['inputs'][:], axis=-1))
        else:
            with open(config.bad_input_file, 'r') as handler:
                bad_input = pickle.load(handler)

        mlog.info("%d inputs flagged as bad." % bad_input.size)
        bad_prod = np.array([
            ii for ii, pp in enumerate(prod)
            if (pp[0] in bad_input) or (pp[1] in bad_input)
        ])

        # Create arrays to hold the results
        ores = {}
        ores['date'] = date
        ores['coord'] = coord
        ores['time'] = timestamps
        ores['freq'] = freq
        ores['offsets'] = config.offsets
        ores['pol'] = pol

        ores['evalue'] = np.zeros((noffset, nfreq, ntime, N), dtype=np.float32)
        ores['resp'] = np.zeros((noffset, nfreq, ntime, N, config.neigen),
                                dtype=np.complex64)
        ores['resp_err'] = np.zeros((noffset, nfreq, ntime, N, config.neigen),
                                    dtype=np.float32)

        # Loop over frequencies
        for ff, find in enumerate(freq_index):

            mlog.info("Freq %d of %d.  %0.2f MHz." % (ff + 1, nfreq, freq[ff]))

            cnt = 0

            # Loop over files
            for ii, (filename, tind) in enumerate(zip(files, time_indices)):

                ntind = len(tind)
                mlog.info("Processing file %s (%d time samples)" %
                          (filename, ntind))

                # Loop over times
                for tt in tind:

                    t0 = time.time()

                    mlog.info("Time %d of %d.  %d index of current file." %
                              (cnt + 1, ntime, tt))

                    # Load visibilities
                    with h5py.File(filename, 'r') as hf:

                        vis = hf['vis'][find, :, tt]

                    # Set bad products equal to zero
                    vis[bad_prod] = 0.0

                    # Different code if we are separating polarisations
                    if config.sep_pol:

                        if not any(prod_ss):

                            for pind, pp in enumerate(prod):
                                if (pp[0] in pol_s) and (pp[1] in pol_s):
                                    prod_ss.append(pind)

                                elif (pp[0] in pol_e) and (pp[1] in pol_e):
                                    prod_ee.append(pind)

                            prod_ss = np.array(prod_ss)
                            prod_ee = np.array(prod_ee)

                            mlog.info("Product sizes: %d, %d" %
                                      (prod_ss.size, prod_ee.size))

                        # Loop over polarisations
                        for pp, (input_pol,
                                 prod_pol) in enumerate([(pol_s, prod_ss),
                                                         (pol_e, prod_ee)]):

                            visp = vis[prod_pol]

                            mlog.info("pol %s, visibility size:  %d" %
                                      (pol[pp], visp.size))

                            # Loop over offsets
                            for oo, off in enumerate(config.offsets):

                                mlog.info(
                                    "pol %s, rank %d, niter %d, offset %d, cross_pol %s, neigen %d"
                                    % (pol[pp], rank, config.niter, off,
                                       cross_pol, config.neigen))

                                ev, rr, rre = solve_gain(
                                    visp,
                                    cutoff=off,
                                    cross_pol=cross_pol,
                                    normalize=config.normalize,
                                    rank=rank,
                                    niter=config.niter,
                                    neigen=config.neigen)

                                ores['evalue'][oo, ff, cnt, input_pol] = ev
                                ores['resp'][oo, ff, cnt, input_pol, :] = rr
                                ores['resp_err'][oo, ff, cnt,
                                                 input_pol, :] = rre

                    else:

                        # Loop over offsets
                        for oo, off in enumerate(config.offsets):

                            mlog.info(
                                "rank %d, niter %d, offset %d, cross_pol %s, neigen %d"
                                % (rank, config.niter, off, cross_pol,
                                   config.neigen))

                            ev, rr, rre = solve_gain(
                                vis,
                                cutoff=off,
                                cross_pol=cross_pol,
                                normalize=config.normalize,
                                rank=rank,
                                niter=config.niter,
                                neigen=config.neigen)

                            ores['evalue'][oo, ff, cnt, :] = ev
                            ores['resp'][oo, ff, cnt, :, :] = rr
                            ores['resp_err'][oo, ff, cnt, :, :] = rre

                    # Increment time counter
                    cnt += 1

                    # Print time elapsed
                    mlog.info("Took %0.1f seconds." % (time.time() - t0, ))

        # Save to pickle file
        with open(output_file, 'w') as handle:

            pickle.dump(ores, handle)
Example #34
0
def _create_plot(visi, tmstp, cut_tmstp, sky, popt, test_chans, good_gains,
                 good_noise, good_fit):
    """Creates plot of the visibilities and the fits
    with labels for those that fail the tests
    """
    import matplotlib

    matplotlib.use("PDF")
    import matplotlib.pyplot as plt
    import time

    # Visibilities to plot:
    visi1 = visi  # Raw data
    tmstp1 = tmstp  # Raw data
    visi2 = np.array([[
        sky.fit_func(tt, popt[ii][0], popt[ii][1])
        for tt in range(len(cut_tmstp))
    ] for ii in range(len(popt))])
    tmstp2 = cut_tmstp

    # For title, use start time stamp:
    title = "Good channels result for {0}".format(
        ch_eph.unix_to_datetime(tmstp1[0]).date())

    # I need to know the slot for each channel:
    def get_slot(channel):
        slot_array = [4, 2, 16, 14, 3, 1, 15, 13, 8, 6, 12, 10, 7, 5, 11, 9]
        return slot_array[int(channel) // 16]

    fig = plt.figure(figsize=(8, 64))
    fig.suptitle(title, fontsize=16)

    if (tmstp1[-1] - tmstp1[0]) / (24.0 * 3600.0) > 3.0:
        # Days since starting time
        # Notice: same starting time for both
        time_pl1 = (tmstp1 - tmstp1[0]) / (3600 * 24)
        time_pl2 = (tmstp2 - tmstp1[0]) / (3600 * 24)
        time_unit = "days"
    else:
        # Hours since starting time
        time_pl1 = (tmstp1 - tmstp1[0]) / (3600)
        time_pl2 = (tmstp2 - tmstp1[0]) / (3600)
        time_unit = "hours"

    for ii in range(len(visi1)):
        chan = test_chans[ii]

        # Determine position in subplot:
        if chan < 64:
            pos = chan * 4 + 1
        elif chan < 128:
            pos = (chan - 64) * 4 + 2
        elif chan < 192:
            pos = (chan - 128) * 4 + 3
        elif chan < 256:
            pos = (chan - 192) * 4 + 4

        # Create subplot:
        plt.subplot(64, 4, pos)

        lab = ""
        # Or print standard label:
        if good_gains is not None:
            if not good_gains[ii]:
                lab = lab + "bad gains | "
        if good_noise is not None:
            if not good_noise[ii]:
                lab = lab + "noisy | "
        if not good_fit[ii]:
            lab = lab + "bad fit"

        if lab != "":
            plt.plot([], [], "1.0", label=lab)
            plt.legend(loc="best", prop={"size": 6})

        trace_pl1 = visi1[ii, :].real
        plt.plot(time_pl1, trace_pl1, "b-")

        trace_pl2 = visi2[ii, :].real
        plt.plot(time_pl2, trace_pl2, "r-")

        tm_brd = (time_pl1[-1] - time_pl1[0]) / 10.0
        plt.xlim(time_pl1[0] - tm_brd, time_pl1[-1] + tm_brd)

        # Determine limits of plots:
        med = np.median(trace_pl1)
        mad = np.median([abs(entry - med) for entry in trace_pl1])
        plt.ylim(med - 7.0 * mad, med + 7.0 * mad)

        # labels:
        plt.ylabel("Ch{0} (Sl.{1})".format(chan, get_slot(chan)), fontsize=8)

        # Hide numbering:
        frame = plt.gca()
        frame.axes.get_yaxis().set_ticks([])
        if (chan != 63) and (chan != 127) and (chan != 191) and (chan != 255):
            # Remove x-axis, except on bottom plots:
            frame.axes.get_xaxis().set_ticks([])
        else:
            # Change size of numbers in x axis:
            frame.tick_params(axis="both", which="major", labelsize=10)
            frame.tick_params(axis="both", which="minor", labelsize=8)
            if chan == 127:
                # Put x-labels on bottom plots:
                if time_unit == "days":
                    plt.xlabel(
                        "Time (days since {0} UTC)".format(
                            ch_eph.unix_to_datetime(tmstp1[0])),
                        fontsize=10,
                    )
                else:
                    plt.xlabel(
                        "Time (hours since {0} UTC)".format(
                            ch_eph.unix_to_datetime(tmstp1[0])),
                        fontsize=10,
                    )

        if chan == 0:
            plt.title("West cyl. P1(N-S)", fontsize=12)
        elif chan == 64:
            plt.title("West cyl. P2(E-W)", fontsize=12)
        elif chan == 128:
            plt.title("East cyl. P1(N-S)", fontsize=12)
        elif chan == 192:
            plt.title("East cyl. P2(E-W)", fontsize=12)

    filename = "plot_fit_{0}.pdf".format(int(time.time()))
    plt.savefig(filename)
    plt.close()
    print("Finished creating plot. File name: {0}".format(filename))