Example #1
0
def read_measurement_set(file_name):
    tb = table()
    tb.open(file_name)
    columns = tb.colnames()
    columns.remove('DATA')
    datadict = {}
    fulldata = MeasurementSetComponent(file_name, 'DATA')
    realdata = copy.copy(fulldata)
    realdata._data = fulldata._data.real
    imagdata = copy.copy(fulldata)
    imagdata._data = fulldata._data.imag
    datadict['REAL_DATA'] = realdata
    datadict['IMAG_DATA'] = imagdata
    datashape = datadict['REAL_DATA'].data.shape
    for cn in columns:
        try:
            datadict[cn] = MeasurementSetComponent(file_name, cn, shape=datashape)
        except RuntimeError:
            # data do not exist or can't be read for some reason
            pass
        except ValueError:
            print(f"FAILED to load {cn}")
    result = CASAData(**datadict)
    tb.close()
    return result
Example #2
0
def write_flag(msfile, elevation_limit, elevation, baseline_dict):
    """ flag data if below user-specified elevation limit """
    tb = casatools.table()
    qa = casatools.quanta()
    me = casatools.measures()
    tb.open(msfile, nomodify=True)
    flag = tb.getcol('FLAG').T
    tb.close()

    tb.open(msfile + '/ANTENNA', nomodify=True)
    station_names = tb.getcol('NAME')
    pos = tb.getcol('POSITION')
    mount = tb.getcol('MOUNT')
    Nant = pos.shape[0]
    tb.close()
    for a0 in range(Nant):
        for a1 in range(Nant):
            if a1 > a0:
                flag_mask = np.invert(
                    ((elevation[a1] > elevation_limit)
                     & (elevation[a0] > elevation_limit)) > 0)
                #print(flag_mask.reshape((flag[:,:,baseline_dict[(a0, a1)]].shape, 1, 1)))
                print(flag_mask.reshape((flag_mask.shape[0], 1, 1)).shape)
                flag[baseline_dict[(a0, a1)]] = flag_mask.reshape(
                    (flag_mask.shape[0], 1, 1))
    if ("JB" in station_names) & ("M2" in station_names):
        flagdata(vis=msfile, mode='manual', antenna="JB&M2")
    tb.open(msfile, nomodify=False)
    tb.putcol("FLAG", flag.T)
    tb.close()
Example #3
0
def load_antenna_delays(ant_delay_table, nant, npol=2):
    """Load antenna delays from a CASA calibration table.

    Parameters
    ----------
    ant_delay_table : str
        The full path to the calibration table.
    nant : int
        The number of antennas.
    npol : int
        The number of polarizations.

    Returns
    -------
    ndarray
        The relative delay per baseline in nanoseconds. Baselines are in
        anti-casa order. Dimensions (nbaselines, npol).
    """
    error = 0
    tb = cc.table()
    error += not tb.open(ant_delay_table)
    antenna_delays = tb.getcol('FPARAM')
    npol = antenna_delays.shape[0]
    antenna_delays = antenna_delays.reshape(npol, -1, nant)
    error += not tb.close()

    bl_delays = np.zeros(((nant * (nant + 1)) // 2, npol))
    idx = 0
    for i in np.arange(nant):
        for j in np.arange(i + 1):
            #j-i or i-j ?
            bl_delays[idx, :] = antenna_delays[:, 0, j] - antenna_delays[:, 0,
                                                                         i]

    return bl_delays
Example #4
0
    def data(self):
        if not hasattr(self, '_data'):
            tb = table()
            tb.open(self.filename)
            self._data = tb.getcol(self.colname)
            tb.close()

        return self._data
Example #5
0
def get_mosaic_centre(ms_name, return_string=True,
                      field_name='M33'):
    '''
    Assuming a fully sampled mosaic, take the median
    as the phase centre.
    '''

    try:
        # CASA 6
        import casatools
        # iatool = casatools.image()
        tb = casatools.table()
    except ImportError:
        try:
            from taskinit import tbtool
            # iatool = iatool()
            tb = tbtool()
        except ImportError:
            raise ImportError("Could not import CASA (casac).")

    tb.open(ms_name + "/FIELD")
    ptgs = tb.getcol("PHASE_DIR").squeeze()

    ras, decs = (ptgs * u.rad).to(u.deg)

    if field_name is not None:

        field_names = tb.getcol('NAME')

        valids = np.array([True if field_name in name else False
                           for name in field_names])

        ras = ras[valids]
        decs = decs[valids]

        if ras.size == 0:
            raise ValueError("No fields with given sourceid.")

    tb.close()

    med_ra = np.median(ras)
    med_dec = np.median(decs)

    if return_string:

        med_ptg = SkyCoord(med_ra, med_dec, frame='icrs')

        ptg_str = "ICRS "
        ptg_str += med_ptg.to_string('hmsdms')

        # tclean was rejecting this b/c of a string type
        # change? Anyways this seems to fix it.
        ptg_str = str(ptg_str)

        return ptg_str

    return med_ra, med_dec
Example #6
0
    def __init__(self, filename, ia_kwargs={}):

        try:
            import casatools
            self.iatool = casatools.image
            tb = casatools.table()
        except ImportError:
            try:
                from taskinit import iatool, tbtool
                self.iatool = iatool
                tb = tbtool()
            except ImportError:
                raise ImportError(
                    "Could not import CASA (casac) and therefore cannot read CASA .image files"
                )

        self.ia_kwargs = ia_kwargs

        self.filename = filename

        self._cache = {}

        log.debug("Creating ArrayLikeCasa object")

        # try to trick CASA into destroying the ia object
        def getshape():
            ia = self.iatool()
            # use the ia tool to get the file contents
            try:
                ia.open(self.filename, cache=False)
            except AssertionError as ex:
                if 'must be of cReqPath type' in str(ex):
                    raise IOError("File {0} not found.  Error was: {1}".format(
                        self.filename, str(ex)))
                else:
                    raise ex

            self.shape = tuple(ia.shape()[::-1])
            self.dtype = np.dtype(ia.pixeltype())

            ia.done()
            ia.close()

        getshape()

        self.ndim = len(self.shape)

        tb.open(self.filename)
        dminfo = tb.getdminfo()
        tb.done()

        # unclear if this is always the correct callspec!!!
        # (transpose requires this be backwards)
        self.chunksize = dminfo['*1']['SPEC']['DEFAULTTILESHAPE'][::-1]

        log.debug("Finished with initialization of ArrayLikeCasa object")
Example #7
0
def make_baseline_dictionary(msfile):
    tb = casatools.table()
    qa = casatools.quanta()
    me = casatools.measures()
    tb.open(msfile, nomodify=True)
    A0 = tb.getcol('ANTENNA1')
    A1 = tb.getcol("ANTENNA2")
    ant_unique = np.unique(np.hstack((A0, A1)))
    tb.close()
    return dict([((x, y), np.where((A0 == x) & (A1 == y))[0])
                 for x in ant_unique for y in ant_unique if y > x])
Example #8
0
def vel_to_chan(msfile, field, obsid, spw, restfreq, vel):
    """
    Identifies the channel(s) corresponding to input LSRK velocities.
    Useful for choosing which channels to split out or flag if a line is
    expected to be present.

    Args:
        msfile (string): name of measurement set
        field (string): field name
        spw (int): Spectral window number
        obsid (int): Observation ID corresponding to the selected
            spectral window
        restfreq (float): Rest frequency [Hz]
        vel (float or array of floats): input velocity in LSRK frame
            km/s]


    Returns:
        (array) or (int) channel number most closely corresponding to
            input LSRK velocity
    """

    tb = table()
    mstool = ms()
    # open the file
    tb.open(msfile + "/SPECTRAL_WINDOW")
    chanfreqs = tb.getcol("CHAN_FREQ", startrow=spw, nrow=1)
    tb.close()
    tb.open(msfile + "/FIELD")
    fieldnames = tb.getcol("NAME")
    tb.close()
    tb.open(msfile + "/OBSERVATION")
    obstime = np.squeeze(tb.getcol("TIME_RANGE", startrow=obsid, nrow=1))[0]
    tb.close()
    nchan = len(chanfreqs)
    mstool.open(msfile)
    lsrkfreqs = mstool.cvelfreqs(
        spwids=[spw],
        mode="channel",
        nchan=nchan,
        obstime=str(obstime) + "s",
        start=0,
        outframe="LSRK",
    )
    # convert to LSRK velocities [km/s]
    chanvelocities = (restfreq - lsrkfreqs) / restfreq * cc_kms
    mstool.close()
    if type(vel) == np.ndarray:
        outchans = np.zeros_like(vel)
        for i in range(len(vel)):
            outchans[i] = np.argmin(np.abs(chanvelocities - vel[i]))
        return outchans
    else:
        return np.argmin(np.abs(chanvelocities - vel))
Example #9
0
    def __init__(self, filename, colname, shape=None):
        self.filename = filename
        self.colname = colname

        # gotta be a better way, right?
        if not hasattr(self, '_data'):
            tb = table()
            tb.open(self.filename)
            self._data = tb.getcol(self.colname)
            if shape is not None:
                self._data = np.broadcast_to(self._data, shape)
            tb.close()
Example #10
0
def parallacticAngle(msfile, times):
    #measure = pm.measures()
    #tab = pt.table(msfile, readonly=True,ack=False)
    #field_tab = pt.table(tab.getkeyword('FIELD'),ack=False)
    #direction = np.squeeze(field_tab.getcol('PHASE_DIR'))
    tb = casatools.table()
    qa = casatools.quanta()
    me = casatools.measures()
    time_unique = times

    tb.open(msfile + '/FIELD', nomodify=True)
    direction = np.squeeze(tb.getcol('PHASE_DIR'))
    tb.close()

    tb.open(msfile + '/ANTENNA', nomodify=True)
    station_names = tb.getcol('NAME')
    pos = tb.getcol('POSITION').T
    mount = tb.getcol('MOUNT')
    Nant = pos.shape[0]
    N = range(Nant)
    nbl = (Nant * (Nant - 1)) / 2
    tb.close()

    ra = qa.quantity(direction[0], 'rad')
    dec = qa.quantity(direction[1], 'rad')
    pointing = me.direction('j2000', ra, dec)
    start_time = me.epoch('utc', qa.quantity(time_unique[0], 's'))
    me.doframe(start_time)

    parallactic_ant_matrix = np.zeros((Nant, time_unique.shape[0]))

    def antenna_para(antenna):
        x = qa.quantity(pos[antenna, 0], 'm')
        y = qa.quantity(pos[antenna, 1], 'm')
        z = qa.quantity(pos[antenna, 2], 'm')
        position = me.position('wgs84', x, y, z)
        me.doframe(position)
        sec2rad = 2 * np.pi / (24 * 3600.)
        hour_angle = me.measure(pointing, 'HADEC')['m0']['value'] +\
            (time_unique-time_unique.min()) * sec2rad
        earth_radius = 6371000.0
        latitude = np.arcsin(pos[antenna, 2] / earth_radius)
        return np.arctan2(
            np.sin(hour_angle) * np.cos(latitude),
            (np.cos(direction[1]) * np.sin(latitude) -
             np.cos(hour_angle) * np.cos(latitude) * np.sin(direction[1])))

    for i in range(Nant):
        if mount[i] == 'EQUATORIAL':
            parallactic_ant_matrix[i] = np.zeros(time_unique.shape)
        else:
            parallactic_ant_matrix[i] = antenna_para(i) * (180. / np.pi)
    return parallactic_ant_matrix
Example #11
0
    def get_ms_info(self):
        """Recovers all the metadata from the MS
        """
        tb = casatools.table()
        # ms = casatools.ms()
        tb.open(self.msfile)
        # getting the required keywords, otherwise I need to open the MS again
        tb_ant = tb.getkeyword('ANTENNA')
        tb_spw = tb.getkeyword('SPECTRAL_WINDOW')
        # tb_pol = tb.getkeyword('POLARIZATION')
        tb_src = tb.getkeyword('FIELD')
        tb.open(tb_ant.replace('Table: ', ''))
        self._antennas = list(tb.getcol('NAME'))
        tb.open(tb_spw.replace('Table: ', ''))
        self._nsubbands = len(tb.getcol('TOTAL_BANDWIDTH'))
        self._bandwidth = np.sum(tb.getcol('TOTAL_BANDWIDTH')) * u.Hz
        self._nchannels = int(tb.getcol('NUM_CHAN')[0])
        self._frequency = np.mean(
            tb.getcol('CHAN_FREQ').reshape(
                self.nchannels * self.nsubbands)) * u.Hz
        tb.open(tb_src.replace('Table: ', ''))
        tmp_src = {}
        src_coords = tb.getcol('PHASE_DIR')
        for i, src_name in enumerate(tb.getcol('NAME')):
            if src_name in self.targets:
                src_type = SourceType.target
            elif src_name in self.phase_calibrators:
                src_type = SourceType.calibrator
            elif src_name in self.bandpass_calibrators:
                src_type = SourceType.bandpass
            else:
                src_type = SourceType.other

            tmp_src[src_name] = Source(name=src_name,
                                       coordinates=coord.SkyCoord(
                                           *src_coords[:, 0, i], unit=u.rad),
                                       source_type=src_type)

        self._sources = tmp_src

        # Save checks...
        for source_group in (self.targets, self.phase_calibrators,
                             self.bandpass_calibrators):
            for a_source in source_group:
                assert a_source in self.sources, f"The source {a_source} defined in the input file was not observed. " \
                                                 f"The available sources are {', '.join(list(self.sources.keys()))}."

        for an_antenna in self.refants:
            assert an_antenna in self.antennas, f"The ref. antenna {an_antenna} defined in the input file is not " \
                                                f"in the array for {self.project_name}. " \
                                                f"The available antennas are {', '.join(self.antennas)}."
Example #12
0
def get_spw_global_fringe(caltable: str) -> int:
    """Returns the subband (spw) where the solutions from a global fringe run have been stored.
    Note that when combining spw in fringe, it will write the solutions in the first subband with
    data (so in general spw = 0; but not always).
    This function reads the generated calibration table and returns the number of the spw where
    the solutions where actually stored.
    """
    tb = casatools.table()
    tb.open(caltable)
    tb.open(tb.getkeyword('SPECTRAL_WINDOW').replace('Table: ', ''))
    the_spw = int(tb.getcol('MEAS_FREQ_REF')[0])
    tb.close()
    tb.close()
    return the_spw
Example #13
0
def match_to_antenna_nos(evn_SEFD, msfile):
    tb = casatools.table()
    qa = casatools.quanta()
    me = casatools.measures()
    evn_SEFD_2 = {}
    evn_diams = {}
    tb.open('%s/ANTENNA' % msfile)
    x = tb.getcol('NAME')
    tb.close()
    print(evn_SEFD)
    for i, j in enumerate(x):
        evn_SEFD_2[i] = evn_SEFD[j][0]
        evn_diams[i] = evn_SEFD[j][1]
    return evn_SEFD_2, evn_diams
def test_getdesc(tmp_path, filename):

    casa_filename = str(tmp_path / 'casa.image')

    make_casa_testimage(filename, casa_filename)

    tb = table()
    tb.open(casa_filename)
    desc_reference = tb.getdesc()
    tb.close()

    desc_actual = getdesc(casa_filename)

    assert pformat(desc_actual) == pformat(desc_reference)
Example #15
0
def add_hera_obs_pos():
    """Adding HERA observatory position (at some point as PAPER_SA)

    Only needed for the older versions of CASA
    """
    obstablename = os.path.dirname(casadata.__file__) + \
                   '/__data__/geodetic/Observatories/'
    tbl = table()
    tbl.open(obstablename, nomodify=False)
    if not (tbl.getcol('Name') == 'HERA').any():
        paperi = (tbl.getcol('Name') == 'PAPER_SA').nonzero()[0]
        tbl.copyrows(obstablename, startrowin=paperi, startrowout=-1, nrow=1)
        tbl.putcell('Name', tbl.nrows()-1, 'HERA')
        tbl.close()
def test_generic_table_read(tmp_path):

    # NOTE: for now, this doesn't check that we can read the data - just
    # the metadata about the table.

    filename_fits = str(tmp_path / 'generic.fits')
    filename_casa = str(tmp_path / 'generic.image')

    t = Table()
    t['short'] = np.arange(3, dtype=np.int16)
    t['ushort'] = np.arange(3, dtype=np.uint16)
    t['int'] = np.arange(3, dtype=np.int32)
    t['uint'] = np.arange(3, dtype=np.uint32)
    t['float'] = np.arange(3, dtype=np.float32)
    t['double'] = np.arange(3, dtype=np.float64)
    t['complex'] = np.array([1 + 2j, 3.3 + 8.2j, -1.2 - 4.2j],
                            dtype=np.complex64)
    t['dcomplex'] = np.array([3.33 + 4.22j, 3.3 + 8.2j, -1.2 - 4.2j],
                             dtype=np.complex128)
    t['str'] = np.array(['reading', 'casa', 'images'])

    # Repeat this at the end to make sure we correctly finished reading
    # the complex column metadata
    t['int2'] = np.arange(3, dtype=np.int32)

    t.write(filename_fits)

    tb = table()
    tb.fromfits(filename_casa, filename_fits)
    tb.close()

    # Use the arrays in the table to also generate keywords of various types
    keywords = {'scalars': {}, 'arrays': {}}
    for name in t.colnames:
        keywords['scalars']['s_' + name] = t[name][0]
        keywords['arrays']['a_' + name] = t[name]

    tb.open(filename_casa)
    tb.putkeywords(keywords)
    tb.flush()
    tb.close()

    desc_actual = getdesc(filename_casa)

    tb.open(filename_casa)
    desc_reference = tb.getdesc()
    tb.close()

    assert pformat(desc_actual) == pformat(desc_reference)
Example #17
0
def flag_quack_integrations(myvis, num_ints=2.5):

    tb = table()

    tb.open(myvis)
    int_time = tb.getcol('INTERVAL')[0]
    tb.close()

    this_quackinterval = num_ints * int_time

    flagdata(vis=myvis,
             flagbackup=False,
             mode='quack',
             quackmode='beg',
             quackincrement=False,
             quackinterval=this_quackinterval)
Example #18
0
def add_pt_src(msfile, pt_flux):
    tb = casatools.table()
    qa = casatools.quanta()
    me = casatools.measures()
    cl = casatools.componentlist()
    tb.open(msfile + '/SOURCE')
    direc = tb.getcol('DIRECTION')
    direc = direc.T[0]
    tb.close()
    print('J2000 %srad %srad' % (direc[0], direc[1]))
    cl.addcomponent(flux=pt_flux,
                    fluxunit='Jy',
                    shape='point',
                    dir='J2000 %srad %srad' % (direc[0], direc[1]))
    os.system('rm -r %s.cl' % msfile)
    cl.rename('%s.cl' % msfile)
    cl.close()
    ft(vis=msfile, complist='%s.cl' % msfile, usescratch=True)
    #uvsub(vis=msfile,reverse=True)
    os.system('rm -r %s.cl' % msfile)
Example #19
0
def calc_pb_corr(msfile, diam_ants, single_freq):
    tb = casatools.table()
    qa = casatools.quanta()
    me = casatools.measures()

    arcmin_off = float(msfile.split('_')[1]) + (
        float(msfile.split('_')[2].split('.ms')[0]) / 60.)

    tb.open('%s/SPECTRAL_WINDOW' % msfile)
    nspw = len(tb.getcol("MEAS_FREQ_REF"))
    chan_freqs = tb.getcol('CHAN_FREQ').T
    if single_freq != False:
        freq = np.mean(chan_freqs)
        chan_freqs = [freq] * nspw
    diams_ants2 = {}
    for i in diams_ants.keys():
        pb_freq = {}
        for j in range(nspw):
            pb_freq[str(j)] = np.sqrt(
                calc_hpbw(x=arcmin_off, diam=diam_ants[i], freq=chan_freqs[j]))
        print(pb_freq)
        diams_ants2[i] = pb_freq

    datacolumn = 'MODEL_DATA'
    tb.open('%s' % msfile, nomodify=True)
    data = tb.getcol('%s' % datacolumn)
    antenna1 = tb.getcol('ANTENNA1')
    antenna2 = tb.getcol('ANTENNA2')
    spw_id = tb.getcol('DATA_DESC_ID')
    tint = np.average(tb.getcol('EXPOSURE'))
    tb.close()
    for i in range(len(antenna1)):
        amps, phase = R2P(data[:, :, i])
        amps = amps * (diams_ants2[antenna1[i]][str(spw_id[i])] *
                       diams_ants2[antenna2[i]][str(spw_id[i])])
        data[:, :, i] = P2R(amps, phase)
    tb.open('%s' % msfile, nomodify=False)
    tb.putcol('%s' % datacolumn, data)
    tb.close()
    return
Example #20
0
def copy_pols(msfile, antenna, pol, newpol):
    tb = casatools.table()
    tb.open(msfile + '/POLARIZATION')
    pol_code = tb.getcol('CORR_TYPE')
    pol = np.where(pol_code == pol)[0][0]
    newpol = np.where(pol_code == newpol)[0][0]

    tb.open(msfile, nomodify=False)
    ram_restrict = 100000
    ranger = list(range(0, tb.nrows(), ram_restrict))

    for j in progressbar(ranger, '', 50):
        if j == ranger[-1]:
            ram_restrict = tb.nrows() % ram_restrict
        gain = tb.getcol('DATA', startrow=j, nrow=ram_restrict, rowincr=1)
        ant1 = tb.getcol('ANTENNA1', startrow=j, nrow=ram_restrict, rowincr=1)
        ant2 = tb.getcol('ANTENNA2', startrow=j, nrow=ram_restrict, rowincr=1)
        gain[newpol, :,
             ((ant1 == ant) | (ant2 == ant))] = gain[pol, :, ((ant1 == ant) |
                                                              (ant2 == ant))]
        tb.putcol('DATA', gain, startrow=j, nrow=ram_restrict, rowincr=1)
    tb.close()
Example #21
0
def add_noise(msfile, datacolumn, evn_SEFD, adjust_time=1.0):
    tb = casatools.table()
    qa = casatools.quanta()
    me = casatools.measures()
    tb.open('%s' % msfile, nomodify=True)
    data = tb.getcol('%s' % datacolumn)
    if datacolumn == 'CORRECTED_DATA':
        weightnames = 'WEIGHT'
    elif datacolumn == 'DATA':
        weightnames = 'SIGMA'
    else:
        raise TypeError
    weights = tb.getcol('%s' % weightnames)
    antenna1 = tb.getcol('ANTENNA1')
    antenna2 = tb.getcol('ANTENNA2')
    tint = np.average(tb.getcol('EXPOSURE'))
    tb.close()

    if adjust_time != 1.0:
        tint = tint * adjust_time

    tb.open('%s/SPECTRAL_WINDOW' % msfile, nomodify=True)
    chan_width = np.average(tb.getcol('CHAN_WIDTH'))
    print(chan_width, tint)
    tb.close()
    for i in range(len(antenna1)):
        sefd = calc_sefd(evn_SEFD[antenna1[i]], evn_SEFD[antenna2[i]], tint,
                         chan_width, 0.88)
        amps = np.random.normal(0., sefd, np.shape(data[:, :, i]))
        phase = ((np.pi + np.pi) *
                 np.random.random_sample(np.shape(data[:, :, i]))) - np.pi
        data[:, :, i] = P2R(amps, phase)
        weights[:, i] = np.ones(weights[:, i].shape) / (sefd**2)

    tb.open('%s' % msfile, nomodify=False)
    tb.putcol('%s' % datacolumn, data)
    tb.putcol('%s' % weightnames, weights)
    tb.close()
def test_getdminfo(tmp_path, shape):

    filename = str(tmp_path / 'test.image')

    data = np.random.random(shape)

    ia = image()
    ia.fromarray(outfile=filename, pixels=data, log=False)
    ia.close()

    tb = table()
    tb.open(filename)
    reference = tb.getdminfo()
    tb.close()

    actual = getdminfo(filename)

    # We include information about endian-ness in the dminfo but CASA doesn't
    actual['*1'].pop('BIGENDIAN')

    # The easiest way to compare the output is simply to compare the output
    # from pformat (checking for dictionary equality doesn't work because of
    # the Numpy arrays inside).
    assert pformat(actual) == pformat(reference)
Example #23
0
def make_qa_tables(
    ms_name,
    output_folder='scan_plots_txt',
    outtype='txt',
    overwrite=True,
    chanavg=4096,
):
    '''
    Specifically for saving txt tables. Replace the scan loop in
    `make_qa_scan_figures` to make fewer but larger tables.

    '''

    # Will need to updated for CASA 6
    # from taskinit import tb
    from casatools import table

    tb = table()

    from casaplotms import plotms

    casalog.post("Running make_qa_tables to export txt files for QA.")
    print("Running make_qa_tables to export txt files for QA.")

    # Make folder for scan plots
    if not os.path.exists(output_folder):
        os.mkdir(output_folder)
    else:
        if overwrite:
            casalog.post(
                message="Removing plot tables in {}".format(output_folder),
                origin='make_qa_tables')
            print("Removing plot tables in {}".format(output_folder))
            os.system("rm -r {}/*".format(output_folder))
        else:
            casalog.post("{} already exists. Will skip existing files.".format(
                output_folder))
            # raise ValueError("{} already exists. Enable overwrite=True to rerun.".format(output_folder))

    # Read the field names
    tb.open(os.path.join(ms_name, "FIELD"))
    names = tb.getcol('NAME')
    numFields = tb.nrows()
    tb.close()

    # Intent names
    tb.open(os.path.join(ms_name, 'STATE'))
    intentcol = tb.getcol('OBS_MODE')
    tb.close()

    # Determine the fields that are calibrators.
    tb.open(ms_name)
    is_calibrator = np.empty((numFields, ), dtype='bool')

    for ii in range(numFields):
        subtable = tb.query('FIELD_ID==%s' % ii)

        # Is the intent for calibration?
        scan_intents = intentcol[np.unique(subtable.getcol("STATE_ID"))]
        is_calib = False
        for intent in scan_intents:
            if "CALIBRATE" in intent:
                is_calib = True
                break

        is_calibrator[ii] = is_calib

    tb.close()

    casalog.post(message="Fields are: {}".format(names),
                 origin='make_qa_tables')
    casalog.post(message="Calibrator fields are: {}".format(
        names[is_calibrator]),
                 origin='make_qa_tables')

    print("Fields are: {}".format(names))
    print("Calibrator fields are: {}".format(names[is_calibrator]))

    # Loop through fields. Make separate tables only for different targets.

    for ii in range(numFields):
        casalog.post(message="On field {}".format(names[ii]),
                     origin='make_qa_plots')
        print("On field {}".format(names[ii]))

        # Amp vs. time
        amptime_filename = os.path.join(
            output_folder, 'field_{0}_amp_time.{1}'.format(names[ii], outtype))

        if not os.path.exists(amptime_filename):

            plotms(
                vis=ms_name,
                xaxis='time',
                yaxis='amp',
                ydatacolumn='corrected',
                selectdata=True,
                field=names[ii],
                scan="",
                spw="",
                avgchannel=str(chanavg),
                correlation="",
                averagedata=True,
                avgbaseline=True,
                transform=False,
                extendflag=False,
                plotrange=[],
                # title='Amp vs Time: Field {0} Scan {1}'.format(names[ii], jj),
                xlabel='Time',
                ylabel='Amp',
                showmajorgrid=False,
                showminorgrid=False,
                plotfile=amptime_filename,
                overwrite=True,
                showgui=False)
        else:
            casalog.post(message="File {} already exists. Skipping".format(
                amptime_filename),
                         origin='make_qa_tables')

        # Amp vs. channel
        ampchan_filename = os.path.join(
            output_folder, 'field_{0}_amp_chan.{1}'.format(names[ii], outtype))

        if not os.path.exists(ampchan_filename):

            plotms(
                vis=ms_name,
                xaxis='chan',
                yaxis='amp',
                ydatacolumn='corrected',
                selectdata=True,
                field=names[ii],
                scan="",
                spw="",
                avgchannel="1",
                avgtime="1e8",
                correlation="",
                averagedata=True,
                avgbaseline=True,
                transform=False,
                extendflag=False,
                plotrange=[],
                # title='Amp vs Chan: Field {0} Scan {1}'.format(names[ii], jj),
                xlabel='Channel',
                ylabel='Amp',
                showmajorgrid=False,
                showminorgrid=False,
                plotfile=ampchan_filename,
                overwrite=True,
                showgui=False)
        else:
            casalog.post(message="File {0} already exists. Skipping".format(
                ampchan_filename),
                         origin='make_qa_tables')

        # Plot amp vs uvdist
        ampuvdist_filename = os.path.join(
            output_folder,
            'field_{0}_amp_uvdist.{1}'.format(names[ii], outtype))

        if not os.path.exists(ampuvdist_filename):

            plotms(
                vis=ms_name,
                xaxis='uvdist',
                yaxis='amp',
                ydatacolumn='corrected',
                selectdata=True,
                field=names[ii],
                scan="",
                spw="",
                avgchannel=str(chanavg),
                avgtime='1e8',
                correlation="",
                averagedata=True,
                avgbaseline=False,
                transform=False,
                extendflag=False,
                plotrange=[],
                # title='Amp vs UVDist: Field {0} Scan {1}'.format(names[ii], jj),
                xlabel='uv-dist',
                ylabel='Amp',
                showmajorgrid=False,
                showminorgrid=False,
                plotfile=ampuvdist_filename,
                overwrite=True,
                showgui=False)
        else:
            casalog.post(message="File {} already exists. Skipping".format(
                ampuvdist_filename),
                         origin='make_qa_tables')

        # Make phase plots if a calibrator source.

        if is_calibrator[ii]:

            casalog.post("This is a calibrator. Exporting phase info, too.")
            print("This is a calibrator. Exporting phase info, too.")

            # Plot phase vs time

            phasetime_filename = os.path.join(
                output_folder,
                'field_{0}_phase_time.{1}'.format(names[ii], outtype))

            if not os.path.exists(phasetime_filename):

                plotms(
                    vis=ms_name,
                    xaxis='time',
                    yaxis='phase',
                    ydatacolumn='corrected',
                    selectdata=True,
                    field=names[ii],
                    scan="",
                    spw="",
                    correlation="",
                    avgchannel=str(chanavg),
                    averagedata=True,
                    avgbaseline=True,
                    transform=False,
                    extendflag=False,
                    plotrange=[],
                    # title='Phase vs Time: Field {0} Scan {1}'.format(names[ii], jj),
                    xlabel='Time',
                    ylabel='Phase',
                    showmajorgrid=False,
                    showminorgrid=False,
                    plotfile=phasetime_filename,
                    overwrite=True,
                    showgui=False)
            else:
                casalog.post(message="File {} already exists. Skipping".format(
                    phasetime_filename),
                             origin='make_qa_tables')

            # Plot phase vs channel
            phasechan_filename = os.path.join(
                output_folder,
                'field_{0}_phase_chan.{1}'.format(names[ii], outtype))

            if not os.path.exists(phasechan_filename):

                plotms(
                    vis=ms_name,
                    xaxis='chan',
                    yaxis='phase',
                    ydatacolumn='corrected',
                    selectdata=True,
                    field=names[ii],
                    scan="",
                    spw="",
                    avgchannel="1",
                    avgtime="1e8",
                    correlation="",
                    averagedata=True,
                    avgbaseline=True,
                    transform=False,
                    extendflag=False,
                    plotrange=[],
                    # title='Phase vs Chan: Field {0} Scan {1}'.format(names[ii], jj),
                    xlabel='Chan',
                    ylabel='Phase',
                    showmajorgrid=False,
                    showminorgrid=False,
                    plotfile=phasechan_filename,
                    overwrite=True,
                    showgui=False)
            else:
                casalog.post(message="File {} already exists. Skipping".format(
                    phasechan_filename),
                             origin='make_qa_tables')

            # Plot phase vs uvdist
            phaseuvdist_filename = os.path.join(
                output_folder,
                'field_{0}_phase_uvdist.{1}'.format(names[ii], outtype))

            if not os.path.exists(phaseuvdist_filename):

                plotms(
                    vis=ms_name,
                    xaxis='uvdist',
                    yaxis='phase',
                    ydatacolumn='corrected',
                    selectdata=True,
                    field=names[ii],
                    scan="",
                    spw="",
                    correlation="",
                    avgchannel=str(chanavg),
                    avgtime='1e8',
                    averagedata=True,
                    avgbaseline=False,
                    transform=False,
                    extendflag=False,
                    plotrange=[],
                    # title='Phase vs UVDist: Field {0} Scan {1}'.format(names[ii], jj),
                    xlabel='uv-dist',
                    ylabel='Phase',
                    showmajorgrid=False,
                    showminorgrid=False,
                    plotfile=phaseuvdist_filename,
                    overwrite=True,
                    showgui=False)

            else:
                casalog.post(message="File {} already exists. Skipping".format(
                    phaseuvdist_filename),
                             origin='make_qa_tables')

            # Plot amp vs phase
            ampphase_filename = os.path.join(
                output_folder,
                'field_{0}_amp_phase.{1}'.format(names[ii], outtype))

            if not os.path.exists(ampphase_filename):

                plotms(
                    vis=ms_name,
                    xaxis='amp',
                    yaxis='phase',
                    ydatacolumn='corrected',
                    selectdata=True,
                    field=names[ii],
                    scan="",
                    spw="",
                    correlation="",
                    avgchannel=str(chanavg),
                    avgtime='1e8',
                    averagedata=True,
                    avgbaseline=False,
                    transform=False,
                    extendflag=False,
                    plotrange=[],
                    # title='Amp vs Phase: Field {0} Scan {1}'.format(names[ii], jj),
                    xlabel='Phase',
                    ylabel='Amp',
                    showmajorgrid=False,
                    showminorgrid=False,
                    plotfile=ampphase_filename,
                    overwrite=True,
                    showgui=False)
            else:
                casalog.post(message="File {} already exists. Skipping".format(
                    ampphase_filename),
                             origin='make_qa_tables')

            # Plot uv-wave vs, amp - model residual
            # Check how good the point-source calibrator model is.

            ampresid_filename = os.path.join(
                output_folder,
                'field_{0}_ampresid_uvwave.{1}'.format(names[ii], outtype))

            if not os.path.exists(ampresid_filename):

                plotms(vis=ms_name,
                       xaxis='uvwave',
                       yaxis='amp',
                       ydatacolumn='corrected-model_scalar',
                       selectdata=True,
                       field=names[ii],
                       scan="",
                       spw="",
                       correlation="",
                       avgchannel=str(chanavg),
                       avgtime='1e8',
                       averagedata=True,
                       avgbaseline=False,
                       transform=False,
                       extendflag=False,
                       plotrange=[],
                       xlabel='uv-dist',
                       ylabel='Phase',
                       showmajorgrid=False,
                       showminorgrid=False,
                       plotfile=ampresid_filename,
                       overwrite=True,
                       showgui=False)

            else:
                casalog.post(message="File {} already exists. Skipping".format(
                    ampresid_filename),
                             origin='make_qa_tables')

            # Plot amplitude vs antenna 1.
            # Check for ant outliers

            ampant_filename = os.path.join(
                output_folder,
                'field_{0}_amp_ant1.{1}'.format(names[ii], outtype))

            if not os.path.exists(ampant_filename):

                plotms(vis=ms_name,
                       xaxis='antenna1',
                       yaxis='amp',
                       ydatacolumn='corrected',
                       selectdata=True,
                       field=names[ii],
                       scan="",
                       spw="",
                       correlation="",
                       avgchannel=str(chanavg),
                       avgtime='1e8',
                       averagedata=True,
                       avgbaseline=False,
                       transform=False,
                       extendflag=False,
                       plotrange=[],
                       xlabel='antenna 1',
                       ylabel='Amp',
                       showmajorgrid=False,
                       showminorgrid=False,
                       plotfile=ampant_filename,
                       overwrite=True,
                       showgui=False)

            else:
                casalog.post(message="File {} already exists. Skipping".format(
                    ampant_filename),
                             origin='make_qa_tables')

            # Plot phase vs antenna 1.
            # Check for ant outliers

            phaseant_filename = os.path.join(
                output_folder,
                'field_{0}_phase_ant1.{1}'.format(names[ii], outtype))

            if not os.path.exists(phaseant_filename):

                plotms(vis=ms_name,
                       xaxis='antenna1',
                       yaxis='phase',
                       ydatacolumn='corrected',
                       selectdata=True,
                       field=names[ii],
                       scan="",
                       spw="",
                       correlation="",
                       avgchannel=str(chanavg),
                       avgtime='1e8',
                       averagedata=True,
                       avgbaseline=False,
                       transform=False,
                       extendflag=False,
                       plotrange=[],
                       xlabel='antenna 1',
                       ylabel='Phase',
                       showmajorgrid=False,
                       showminorgrid=False,
                       plotfile=phaseant_filename,
                       overwrite=True,
                       showgui=False)

            else:
                casalog.post(message="File {} already exists. Skipping".format(
                    phaseant_filename),
                             origin='make_qa_tables')
Example #24
0
def make_qa_scan_figures(ms_name, output_folder='scan_plots', outtype='png'):
    '''
    Make a series of plots per scan for QA and
    flagging purposes.

    TODO: Add more settings here for different types of plots, etc.

    Parameters
    ----------
    ms_name : str
        MS name
    output_folder : str, optional
        Output plot folder name.

    '''

    # Will need to updated for CASA 6
    from casatools import table

    tb = table()

    # from taskinit import tb
    # from taskinit import casalog

    from casaplotms import plotms

    # SPWs to loop through
    tb.open(os.path.join(ms_name, "SPECTRAL_WINDOW"))
    spws = range(len(tb.getcol("NAME")))
    nchans = tb.getcol('NUM_CHAN')
    tb.close()

    # Read the field names
    tb.open(os.path.join(ms_name, "FIELD"))
    names = tb.getcol('NAME')
    numFields = tb.nrows()
    tb.close()

    # Intent names
    tb.open(os.path.join(ms_name, 'STATE'))
    intentcol = tb.getcol('OBS_MODE')
    tb.close()

    tb.open(ms_name)
    scanNums = np.unique(tb.getcol('SCAN_NUMBER'))
    field_scans = []
    is_calibrator = np.empty_like(scanNums, dtype='bool')
    is_all_flagged = np.empty((len(spws), len(scanNums)), dtype='bool')
    for ii in range(numFields):
        subtable = tb.query('FIELD_ID==%s' % ii)
        field_scan = np.unique(subtable.getcol('SCAN_NUMBER'))
        field_scans.append(field_scan)

        # Is the intent for calibration?
        scan_intents = intentcol[np.unique(subtable.getcol("STATE_ID"))]
        is_calib = False
        for intent in scan_intents:
            if "CALIBRATE" in intent:
                is_calib = True
                break

        is_calibrator[field_scan - 1] = is_calib

        # Are any of the scans completely flagged?
        for spw in spws:
            for scan in field_scan:
                scantable = \
                    tb.query("SCAN_NUMBER=={0} AND DATA_DESC_ID=={1}".format(scan,
                                                                             spw))
                if scantable.getcol("FLAG").all():
                    is_all_flagged[spw, scan - 1] = True
                else:
                    is_all_flagged[spw, scan - 1] = False

    tb.close()

    # Make folder for scan plots
    if not os.path.exists(output_folder):
        os.mkdir(output_folder)

    # Loop through SPWs and create plots.
    for spw_num in spws:
        casalog.post("On SPW {}".format(spw))

        # Plotting the HI spw (0) takes so so long.
        # Make some simplifications to save time
        # if spw_num == 0:
        #     avg_chan = "4"
        # else:

        # TODO: change appropriately for line SPWs with many channels
        avg_chan = "1"

        spw_folder = os.path.join(output_folder, "spw_{}".format(spw_num))
        if not os.path.exists(spw_folder):
            os.mkdir(spw_folder)
        else:
            # Make sure any old plots are removed first.
            os.system("rm {}/*.png".format(spw_folder))

        for ii in range(len(field_scans)):
            casalog.post("On field {}".format(names[ii]))
            for jj in field_scans[ii]:

                # Check if all of the data is flagged.
                if is_all_flagged[spw_num, jj - 1]:
                    casalog.post("All data flagged in SPW {0} scan {1}".format(
                        spw_num, jj))
                    continue

                casalog.post("On scan {}".format(jj))

                # Amp vs. time
                plotms(vis=ms_name,
                       xaxis='time',
                       yaxis='amp',
                       ydatacolumn='corrected',
                       selectdata=True,
                       field=names[ii],
                       scan=str(jj),
                       spw=str(spw_num),
                       avgchannel=str(avg_chan),
                       correlation="",
                       averagedata=True,
                       avgbaseline=True,
                       transform=False,
                       extendflag=False,
                       plotrange=[],
                       title='Amp vs Time: Field {0} Scan {1}'.format(
                           names[ii], jj),
                       xlabel='Time',
                       ylabel='Amp',
                       showmajorgrid=False,
                       showminorgrid=False,
                       plotfile=os.path.join(
                           spw_folder, 'field_{0}_amp_scan_{1}.{2}'.format(
                               names[ii], jj, outtype)),
                       overwrite=True,
                       showgui=False)

                # Amp vs. channel
                plotms(vis=ms_name,
                       xaxis='chan',
                       yaxis='amp',
                       ydatacolumn='corrected',
                       selectdata=True,
                       field=names[ii],
                       scan=str(jj),
                       spw=str(spw_num),
                       avgchannel=str(avg_chan),
                       avgtime="1e8",
                       correlation="",
                       averagedata=True,
                       avgbaseline=True,
                       transform=False,
                       extendflag=False,
                       plotrange=[],
                       title='Amp vs Chan: Field {0} Scan {1}'.format(
                           names[ii], jj),
                       xlabel='Channel',
                       ylabel='Amp',
                       showmajorgrid=False,
                       showminorgrid=False,
                       plotfile=os.path.join(
                           spw_folder,
                           'field_{0}_amp_chan_scan_{1}.{2}'.format(
                               names[ii], jj, outtype)),
                       overwrite=True,
                       showgui=False)

                # Plot amp vs uvdist
                plotms(vis=ms_name,
                       xaxis='uvdist',
                       yaxis='amp',
                       ydatacolumn='corrected',
                       selectdata=True,
                       field=names[ii],
                       scan=str(jj),
                       spw=str(spw_num),
                       avgchannel=str(4096),
                       avgtime='1e8',
                       correlation="",
                       averagedata=True,
                       avgbaseline=False,
                       transform=False,
                       extendflag=False,
                       plotrange=[],
                       title='Amp vs UVDist: Field {0} Scan {1}'.format(
                           names[ii], jj),
                       xlabel='uv-dist',
                       ylabel='Amp',
                       showmajorgrid=False,
                       showminorgrid=False,
                       plotfile=os.path.join(
                           spw_folder,
                           'field_{0}_amp_uvdist_scan_{1}.{2}'.format(
                               names[ii], jj, outtype)),
                       overwrite=True,
                       showgui=False)

                # Skip the phase plots for the HI SPW (0)
                if is_calibrator[jj - 1]:
                    # Plot phase vs time
                    plotms(vis=ms_name,
                           xaxis='time',
                           yaxis='phase',
                           ydatacolumn='corrected',
                           selectdata=True,
                           field=names[ii],
                           scan=str(jj),
                           spw=str(spw_num),
                           correlation="",
                           averagedata=True,
                           avgbaseline=True,
                           transform=False,
                           extendflag=False,
                           plotrange=[],
                           title='Phase vs Time: Field {0} Scan {1}'.format(
                               names[ii], jj),
                           xlabel='Time',
                           ylabel='Phase',
                           showmajorgrid=False,
                           showminorgrid=False,
                           plotfile=os.path.join(
                               spw_folder,
                               'field_{0}_phase_time_scan_{1}.{2}'.format(
                                   names[ii], jj, outtype)),
                           overwrite=True,
                           showgui=False)

                    # Plot phase vs channel
                    plotms(vis=ms_name,
                           xaxis='chan',
                           yaxis='phase',
                           ydatacolumn='corrected',
                           selectdata=True,
                           field=names[ii],
                           scan=str(jj),
                           spw=str(spw_num),
                           avgchannel=str(avg_chan),
                           avgtime="1e8",
                           correlation="",
                           averagedata=True,
                           avgbaseline=True,
                           transform=False,
                           extendflag=False,
                           plotrange=[],
                           title='Phase vs Chan: Field {0} Scan {1}'.format(
                               names[ii], jj),
                           xlabel='Chan',
                           ylabel='Phase',
                           showmajorgrid=False,
                           showminorgrid=False,
                           plotfile=os.path.join(
                               spw_folder,
                               'field_{0}_phase_chan_scan_{1}.{2}'.format(
                                   names[ii], jj, outtype)),
                           overwrite=True,
                           showgui=False)

                    # Plot phase vs uvdist
                    plotms(vis=ms_name,
                           xaxis='uvdist',
                           yaxis='phase',
                           ydatacolumn='corrected',
                           selectdata=True,
                           field=names[ii],
                           scan=str(jj),
                           spw=str(spw_num),
                           correlation="",
                           avgchannel="4096",
                           avgtime='1e8',
                           averagedata=True,
                           avgbaseline=False,
                           transform=False,
                           extendflag=False,
                           plotrange=[],
                           title='Phase vs UVDist: Field {0} Scan {1}'.format(
                               names[ii], jj),
                           xlabel='uv-dist',
                           ylabel='Phase',
                           showmajorgrid=False,
                           showminorgrid=False,
                           plotfile=os.path.join(
                               spw_folder,
                               'field_{0}_phase_uvdist_scan_{1}.{2}'.format(
                                   names[ii], jj, outtype)),
                           overwrite=True,
                           showgui=False)

                    # Plot amp vs phase
                    plotms(
                        vis=ms_name,
                        xaxis='amp',
                        yaxis='phase',
                        ydatacolumn='corrected',
                        selectdata=True,
                        field=names[ii],
                        scan=str(jj),
                        spw=str(spw_num),
                        correlation="",
                        avgchannel="4096",
                        # avgtime='1e8',
                        averagedata=True,
                        avgbaseline=False,
                        transform=False,
                        extendflag=False,
                        plotrange=[],
                        title='Amp vs Phase: Field {0} Scan {1}'.format(
                            names[ii], jj),
                        xlabel='Phase',
                        ylabel='Amp',
                        showmajorgrid=False,
                        showminorgrid=False,
                        plotfile=os.path.join(
                            spw_folder,
                            'field_{0}_amp_phase_scan_{1}.{2}'.format(
                                names[ii], jj, outtype)),
                        overwrite=True,
                        showgui=False)
Example #25
0
import matplotlib
# Agg doesn't need X - matplotlib doesn't work with xvfb
matplotlib.use('Agg', warn=False)
import matplotlib.pyplot as plt
import numpy as np

from config_parser import validate_args as va
import bookkeeping
import glob
PLOT_DIR = 'plots'
EXTN = 'png'

from casatasks import *
from casatools import table,msmetadata
tb = table()
msmd = msmetadata()
logfile=casalog.logfile()
casalog.setlogfile('logs/{SLURM_JOB_NAME}-{SLURM_JOB_ID}.casa'.format(**os.environ))

import logging
from time import gmtime
logging.Formatter.converter = gmtime
logger = logging.getLogger(__name__)
logging.basicConfig(format="%(asctime)-15s %(levelname)s: %(message)s", level=logging.INFO)

def avg_ants(arrlist):
    return [np.mean(arr, axis=-1) for arr in arrlist]


def lengthen(edat, inpdat):
Example #26
0
from casatools import table
import numpy as np

calK = table()
#calK.open("cal.b.K")
calK.open("test_cal")
delays = np.squeeze(calK.getcol("FPARAM"))
antnames = calK.getcol("ANTENNA1")

tb = np.column_stack((antnames, *delays))

np.savetxt("./delays_residuals.txt",
           tb,
           fmt="%i %.6f %.6f",
           header="antnum X Y")
Example #27
0
def weight_multichan(base_ms, npix, cell_size, robust=np.array([0.]), chans=np.array([2]), method='briggs', perchanweight=False, mod_pcwd=False, npixels=0):
    tb = casatools.table()
    ms = casatools.ms()

    # Use CASA table tools to get frequencies
    tb.open(base_ms+"/SPECTRAL_WINDOW")
    chan_freqs = tb.getcol("CHAN_FREQ")
    rfreq = tb.getcol("REF_FREQUENCY")
    tb.close()

    # Use CASA table tools to get columns of UVW, DATA, WEIGHT, etc.
    tb.open(base_ms, nomodify=False)
    flag   = tb.getcol("FLAG")
    sigma   = tb.getcol("SIGMA")
    uvw     = tb.getcol("UVW")
    weight  = tb.getcol("WEIGHT")
    ant1    = tb.getcol("ANTENNA1")
    ant2    = tb.getcol("ANTENNA2")
    tb.close()

    flag = np.logical_not(np.prod(flag, axis=(0,2)).T)

    # break out the u, v spatial frequencies, convert from m to lambda
    uu = uvw[0,:][:,np.newaxis]*chan_freqs[:,0]/(cc/100)
    vv = uvw[1,:][:,np.newaxis]*chan_freqs[:,0]/(cc/100)

    # toss out the autocorrelation placeholders
    xc = np.where(ant1 != ant2)[0]

    wgts = weight[0,:] + weight[1,:]

    uu_xc = uu[xc][:,flag]
    vv_xc = vv[xc][:,flag]
    wgts_xc = wgts[xc]

    dl = cell_size*arcsec
    dm = cell_size*arcsec

    du = 1./((npix)*dl)
    dv = 1./((npix)*dm)

    # create arrays to dump values
    rms = np.zeros((chans.shape[0], robust.shape[0]))
    beam_params = np.zeros((chans.shape[0],3, robust.shape[0]))

    # grid the weights outside of loop if not perchanweight, only need to do this once... 
    if perchanweight == False:
        gwgts_init = np.zeros((npix, npix))
        gwgts_init = grid_wgts(gwgts_init, np.ravel(uu_xc), np.ravel(vv_xc), du, dv, npix, np.ravel(np.broadcast_to(wgts_xc, (uu_xc.shape[1], uu_xc.shape[0])).T))

    if mod_pcwd == True:
        # TODO CHECK THIS FOR HALF PIXEL OFFSET
        uvdist_grid = np.sqrt(np.add.outer(np.arange(-(npix/2.)*du, (npix/2.)*du, du)**2, np.arange(-(npix/2.)*dv, (npix/2.)*dv, dv)**2))
        frac_bw = (np.max(chan_freqs) - np.min(chan_freqs)) / rfreq
        corr_fac = frac_bw*uvdist_grid/du
        corr_fac[corr_fac<1] = 1.

    for i, chan in enumerate(chans):
        print(chan)
        # grid the weights (with complex conjugates)
        if perchanweight == True:
            gwgts_init = np.zeros((npix, npix))
            gwgts_init = grid_wgts(gwgts_init, uu_xc[:,chan], vv_xc[:,chan], du, dv, npix, wgts_xc)  

        gwgts_init_sq = gwgts_init**2

        for j, r in enumerate(robust):
            # do the weighting, in each case for method/perchanweight selection
            if method == 'briggs':
                # calculate robust parameters
                # normalize differently if only using single channel; note that we assume the weights are not channelized and are uniform across channel
                if perchanweight == True:
                    if mod_pcwd == True:
                        f_sq = ((5*10**(-r))**2)/(np.sum(gwgts_init_sq)/(np.sum(wgts_xc)*2))
                    else:
                        f_sq = ((5*10**(-r))**2)/(np.sum(gwgts_init_sq)/(np.sum(wgts_xc)))
                else:
                    f_sq = ((5*10**(-r))**2)/(np.sum(gwgts_init_sq)/(np.sum(wgts_xc*uu_xc.shape[1])*2))

                if mod_pcwd==True:
                    gr_wgts = 1/(1+gwgts_init/corr_fac*f_sq)
                else:
                    gr_wgts = 1/(1+gwgts_init*f_sq)

                # multiply to get robust weights
                indexed_gr_wgts = ungrid_wgts(gr_wgts, uu_xc[:,chan], vv_xc[:,chan], du, dv, npix)
                wgts_robust = wgts_xc*indexed_gr_wgts
                wgts_robust_sq = wgts_xc*(indexed_gr_wgts)**2

            if method == 'briggsabs':
                # multiply to get robust weights
                S_sq = (gwgts_init[index_arr[chan,:,0], index_arr[chan,:,1]]*r**2).T
                indexed_gr_wgts = (1/(S_sq + 2*wgts_xc))
                wgts_robust = wgts_xc*indexed_gr_wgts
                wgts_robust_sq = wgts_xc*(indexed_gr_wgts)**2


            #get the total gridded weights (to make dirty beam)
            gwgts_final = np.zeros((npix, npix))
            gwgts_final = grid_wgts(gwgts_final, uu_xc[:,chan], vv_xc[:,chan], du, dv, npix, wgts_robust)           

            # create the dirty beam and calculate the beam parameters
            robust_beam = np.real(fftshift(fft2(fftshift(gwgts_final))))
            robust_beam /= np.max(robust_beam)
            #beam_params[i,:,j] = fit_beam(robust_beam, cell_size)
            beam_params[i,:,j] = fit_beam_CASA(robust_beam, cell_size)

            # calculate rms (formula from Briggs et al. 1995)
            C = 1/(2*np.sum(wgts_robust))
            rms[i,j] = 2*C*np.sqrt(np.sum(wgts_robust_sq))
            print(r, beam_params[i,:,j], rms[i,j]*1000.)
        
    return rms, beam_params
Example #28
0
def import_data_ms(filename):
    """Imports data from a CASA measurement set and returns visibility object"""

    tb = table()
    ms_ = ms()

    # Antenna information
    tb.open(filename)
    data = tb.getcol("DATA")
    uvw = tb.getcol("UVW")
    weight = tb.getcol("WEIGHT")
    ant1 = tb.getcol("ANTENNA1")
    ant2 = tb.getcol("ANTENNA2")
    flags = tb.getcol("FLAG")
    time = tb.getcol("TIME")
    tb.close()

    # Spectral window information
    ms_.open(filename)
    spw_info = ms_.getspectralwindowinfo()
    nchan = spw_info["0"]["NumChan"]
    npol = spw_info["0"]["NumCorr"]
    ms_.close()

    # Frequency information
    tb.open(filename+"/SPECTRAL_WINDOW")
    freqs = tb.getcol("CHAN_FREQ")
    rfreq = tb.getcol("REF_FREQUENCY")
    resolution = tb.getcol("CHAN_WIDTH")
    tb.close()

    uu = uvw[0, :]
    vv = uvw[1, :]

    # Check if pols are already averaged
    data = np.squeeze(data)
    weight = np.squeeze(weight)
    flags = np.squeeze(flags)

    if npol == 1:
        Re = data.real
        Im = data.imag
        wgts = weight

    else:
        # Polarization averaging
        Re_xx = data[0, :].real
        Re_yy = data[1, :].real
        Im_xx = data[0, :].imag
        Im_yy = data[1, :].imag
        weight_xx = weight[0, :]
        weight_yy = weight[1, :]
        flags = flags[0, :]*flags[1, :]

        # Weighted averages
        with np.errstate(divide='ignore', invalid='ignore'):
            Re = np.where((weight_xx + weight_yy) != 0, (Re_xx * weight_xx + \
                 Re_yy*weight_yy) / (weight_xx + weight_yy), 0.)
            Im = np.where((weight_xx + weight_yy) != 0, (Im_xx * weight_xx + \
                 Im_yy*weight_yy) / (weight_xx + weight_yy), 0.)
        wgts = (weight_xx + weight_yy)

    # Toss out the autocorrelations
    xc = np.where(ant1 != ant2)[0]

    # Check if there's only a single channel
    if nchan == 1:
        data_real = Re[np.newaxis, xc]
        data_imag = Im[np.newaxis, xc]
        flags = flags[xc]
    else:
        data_real = Re[:, xc]
        data_imag = Im[:, xc]
        flags = flags[:, xc]
        time = time[xc]

        # If the majority of points in any channel are flagged, it probably
        # means an entire channel is flagged - spit warning
        if np.mean(flags.all(axis=0)) > 0.5:
            print('WARNING: Over half of the (u,v) points in at least one '\
                  'channel are marked as flagged. If you did not expect this, it is '\
                  'likely due to having an entire channel flagged in the ms. Please '\
                  'double check this and be careful if model fitting or using diff mode.')

        # Collapse flags to single channel, because weights are not currently channelized
        flags = flags.any(axis=0)

    data_wgts = wgts[xc]
    data_uu = uu[xc]
    data_vv = vv[xc]

    ant1 = ant1[xc]
    ant2 = ant2[xc]

    data_VV = data_real + 1j*data_imag

    # Warning that flagged data was imported
    if np.any(flags):
        print('WARNING: Flagged data was imported. Visibility interpolation can '\
              'proceed normally, but be careful with chi^2 calculations.')

    return Visibility(data_VV.T, data_uu, data_vv, data_wgts, freqs, time, \
                      resolution, ant1, ant2, flags)
if len(sys.argv) > 1:
    fitsimage = sys.argv[1]
else:
    fitsimage = 'SKAMid_B2_8h_v3.fits'

assert os.path.exists(fitsimage)
assert 'fits' in fitsimage
image = fitsimage.replace('.fits', '.image')
if not os.path.exists(image):
    print('creating ms image to be used as sky model')
    im = casatools.image()
    im.fromfits(infile=fitsimage, outfile=image)

# get antenna positions
tabname = 'antenna_positions_' + conf_file.split('.cfg')[0] + '.tab'
tb = casatools.table()
tb.fromascii(tabname,
             conf_file,
             firstline=3,
             sep=' ',
             columnnames=['X', 'Y', 'Z', 'DIAM', 'NAME'],
             datatypes=['D', 'D', 'D', 'D', 'A'])
xx = tb.getcol('X')
yy = tb.getcol('Y')
zz = tb.getcol('Z')
diam = tb.getcol('DIAM')
anames = tb.getcol('NAME')
tb.close()

# simulate setup
sm = casatools.simulator()
Example #30
0
def interpolate_bandpass(tablename,
                         spw_ids=None,
                         window_size_factor=2.5,
                         poly_order=2,
                         add_residuals=True,
                         backup_table=True,
                         test_output_nowrite=False,
                         test_print=False):
    '''
    Use Savitzky-Golay smoothing across flagged channels in the bandpass.

    The smoothing window size is set by `window_size_factor`. This will create
    a smoothing length (by default) 2.5 times larger than the gap that will be interpolated
    across. Smaller window sizes will produce artifacts over the interpolated region.
    '''

    from casatools import table

    # from taskinit import tbtool, casalog

    if backup_table:
        original_table_backup = tablename + '.bak_from_interpbandpass'
        if not os.path.isdir(original_table_backup):
            shutil.copytree(tablename, original_table_backup)

    # tb = tbtool()
    tb = table()

    tb.open(tablename)
    all_spw_ids = np.unique(tb.getcol("SPECTRAL_WINDOW_ID"))
    tb.close()

    if spw_ids is None:
        spw_ids = all_spw_ids
    else:
        for spw in spw_ids:
            if spw not in all_spw_ids:
                raise ValueError(
                    "SPW {} specified does not exist in the table.".format(
                        spw))

    # We'll just output the corrected data as numpy arrays instead of writing
    # back to the table.
    if test_output_nowrite:
        bp_pass_dict = dict()

    for spw in spw_ids:
        casalog.post(message='processing SPW {0}'.format(spw),
                     origin='interpolate_bandpass')

        tb.open(tablename)
        stb = tb.query('SPECTRAL_WINDOW_ID == {0}'.format(spw))
        dat = np.ma.array(stb.getcol('CPARAM'))
        dat.mask = stb.getcol('FLAG')
        stb.close()
        tb.close()

        # Identify if there are gaps to interpolate across
        # We ignore the edge masking in all cases.
        # The sum is over corr and ants. If there are only 2 gaps, no interpolation is needed.
        blank_slices = nd.find_objects(*nd.label(dat.sum(2).sum(0).mask))

        # If there's only 2 slices, it's the SPW edge flagging.
        # We can skip those.
        if len(blank_slices) == 2:
            casalog.post(message="no interpolation needed for {0}".format(spw),
                         origin='interpolate_bandpass')
            continue

        dat_shape = dat.shape

        smooth_dat = deepcopy(dat)

        for ant in range(dat_shape[2]):
            casalog.post(message='processing antenna {0}'.format(ant),
                         origin='interpolate_bandpass')
            for pol in range(dat_shape[0]):

                # Skip if all flagged.
                if np.all(dat.mask[pol, :, ant]):
                    continue

                # Determine ranges to interpolate over
                blank_slices = nd.find_objects(*nd.label(dat[pol, :,
                                                             ant].mask))

                # If there's only 2 slices, it's the SPW edge flagging.
                # No interpolation needed.
                if len(blank_slices) == 2:
                    continue

                # Otherwise we'll mask out the middle gaps
                # Remove the edges.
                if len(blank_slices) > 1:
                    blank_slices.pop(0)
                    blank_slices.pop(-1)

                nchans_in_gap = max([(thisslc[0].stop - thisslc[0].start)
                                     for thisslc in blank_slices])

                # Define the window size based on the given fraction of the num of SPW channels
                window_size = int(np.floor(window_size_factor * nchans_in_gap))

                # Force odd window size
                if window_size % 2 == 0:
                    window_size += 1

                casalog.post(
                    message="Using window size of {0} for SPW {1}".format(
                        window_size, spw),
                    origin='interpolate_bandpass')

                # Print a warning if the gap is >50% of the whole SPW
                if window_size / dat.shape[1] > 0.5:
                    casalog.post(
                        message="Warning: the window size is >50% of the SPW",
                        origin='interpolate_bandpass')

                x_polyfit = np.arange(window_size) - np.floor(window_size / 2.)

                # Mask out the gap.
                for slicer in blank_slices:
                    smooth_dat.mask[(slice(pol, pol + 1), slicer[0],
                                     slice(ant, ant + 1))] = True

                rolled_array = rolling_window(smooth_dat[pol, :, ant],
                                              window_size)

                for i in range(dat_shape[1]):
                    # try:
                    smooth_dat[pol, i,
                               ant] = np.ma.polyfit(x_polyfit, rolled_array[i],
                                                    poly_order)[-1]
                    # What's the catch here? ValueError is all flagged?
                    # except:
                    #     pass

                casalog.post(message="replacing values with smoothed",
                             origin='interpolate_bandpass')

                if add_residuals:
                    resids = dat[pol, :, ant] - smooth_dat[pol, :, ant]

                # Add the interpolated values back to the original array
                for slicer in blank_slices:

                    dat[(slice(pol, pol + 1), slicer[0], slice(ant, ant + 1))] = \
                        smooth_dat[(slice(pol, pol + 1), slicer[0], slice(ant, ant + 1))]

                    # Optionally sample residuals from the difference and add to the interpolated
                    # region to keep a consistent noise level.
                    if add_residuals:

                        gap_size = slicer[0].stop - slicer[0].start
                        resid_samps = np.random.choice(resids[~resids.mask],
                                                       size=gap_size)

                        # Pad some axes on.
                        resid_samps = resid_samps[np.newaxis, :, np.newaxis]

                        dat[(slice(pol, pol + 1), slicer[0],
                             slice(ant, ant + 1))] += resid_samps

                    # Reset the mask across the interp region
                    dat.mask[(slice(pol, pol + 1), slicer[0],
                              slice(ant, ant + 1))] = False

                # Keeping just for diagnosing issues
                if test_print:
                    print((slice(pol, pol + 1), slicer[0], slice(ant,
                                                                 ant + 1)))
                    print(dat[(slice(pol, pol + 1), slicer[0],
                               slice(ant, ant + 1))].shape)
                    print(dat[(slice(pol, pol + 1), slicer[0],
                               slice(ant, ant + 1))][:10])
                    print(smooth_dat[(slice(pol, pol + 1), slicer[0],
                                      slice(ant, ant + 1))][:10])

        if test_output_nowrite:
            bp_pass_dict[spw] = dat
            continue

        casalog.post(
            message="writing out smoothed gaps to table for spw {}".format(
                spw),
            origin='interpolate_bandpass')

        tb.open(tablename, nomodify=False)
        stb = tb.query('SPECTRAL_WINDOW_ID == {0}'.format(spw))
        # Set new data
        stb.putcol('CPARAM', dat.data)
        # Set new flags
        stb.putcol('FLAG', dat.mask)
        stb.close()

        tb.clearlocks()
        tb.close()

    if test_output_nowrite:
        return bp_pass_dict