def read_measurement_set(file_name): tb = table() tb.open(file_name) columns = tb.colnames() columns.remove('DATA') datadict = {} fulldata = MeasurementSetComponent(file_name, 'DATA') realdata = copy.copy(fulldata) realdata._data = fulldata._data.real imagdata = copy.copy(fulldata) imagdata._data = fulldata._data.imag datadict['REAL_DATA'] = realdata datadict['IMAG_DATA'] = imagdata datashape = datadict['REAL_DATA'].data.shape for cn in columns: try: datadict[cn] = MeasurementSetComponent(file_name, cn, shape=datashape) except RuntimeError: # data do not exist or can't be read for some reason pass except ValueError: print(f"FAILED to load {cn}") result = CASAData(**datadict) tb.close() return result
def write_flag(msfile, elevation_limit, elevation, baseline_dict): """ flag data if below user-specified elevation limit """ tb = casatools.table() qa = casatools.quanta() me = casatools.measures() tb.open(msfile, nomodify=True) flag = tb.getcol('FLAG').T tb.close() tb.open(msfile + '/ANTENNA', nomodify=True) station_names = tb.getcol('NAME') pos = tb.getcol('POSITION') mount = tb.getcol('MOUNT') Nant = pos.shape[0] tb.close() for a0 in range(Nant): for a1 in range(Nant): if a1 > a0: flag_mask = np.invert( ((elevation[a1] > elevation_limit) & (elevation[a0] > elevation_limit)) > 0) #print(flag_mask.reshape((flag[:,:,baseline_dict[(a0, a1)]].shape, 1, 1))) print(flag_mask.reshape((flag_mask.shape[0], 1, 1)).shape) flag[baseline_dict[(a0, a1)]] = flag_mask.reshape( (flag_mask.shape[0], 1, 1)) if ("JB" in station_names) & ("M2" in station_names): flagdata(vis=msfile, mode='manual', antenna="JB&M2") tb.open(msfile, nomodify=False) tb.putcol("FLAG", flag.T) tb.close()
def load_antenna_delays(ant_delay_table, nant, npol=2): """Load antenna delays from a CASA calibration table. Parameters ---------- ant_delay_table : str The full path to the calibration table. nant : int The number of antennas. npol : int The number of polarizations. Returns ------- ndarray The relative delay per baseline in nanoseconds. Baselines are in anti-casa order. Dimensions (nbaselines, npol). """ error = 0 tb = cc.table() error += not tb.open(ant_delay_table) antenna_delays = tb.getcol('FPARAM') npol = antenna_delays.shape[0] antenna_delays = antenna_delays.reshape(npol, -1, nant) error += not tb.close() bl_delays = np.zeros(((nant * (nant + 1)) // 2, npol)) idx = 0 for i in np.arange(nant): for j in np.arange(i + 1): #j-i or i-j ? bl_delays[idx, :] = antenna_delays[:, 0, j] - antenna_delays[:, 0, i] return bl_delays
def data(self): if not hasattr(self, '_data'): tb = table() tb.open(self.filename) self._data = tb.getcol(self.colname) tb.close() return self._data
def get_mosaic_centre(ms_name, return_string=True, field_name='M33'): ''' Assuming a fully sampled mosaic, take the median as the phase centre. ''' try: # CASA 6 import casatools # iatool = casatools.image() tb = casatools.table() except ImportError: try: from taskinit import tbtool # iatool = iatool() tb = tbtool() except ImportError: raise ImportError("Could not import CASA (casac).") tb.open(ms_name + "/FIELD") ptgs = tb.getcol("PHASE_DIR").squeeze() ras, decs = (ptgs * u.rad).to(u.deg) if field_name is not None: field_names = tb.getcol('NAME') valids = np.array([True if field_name in name else False for name in field_names]) ras = ras[valids] decs = decs[valids] if ras.size == 0: raise ValueError("No fields with given sourceid.") tb.close() med_ra = np.median(ras) med_dec = np.median(decs) if return_string: med_ptg = SkyCoord(med_ra, med_dec, frame='icrs') ptg_str = "ICRS " ptg_str += med_ptg.to_string('hmsdms') # tclean was rejecting this b/c of a string type # change? Anyways this seems to fix it. ptg_str = str(ptg_str) return ptg_str return med_ra, med_dec
def __init__(self, filename, ia_kwargs={}): try: import casatools self.iatool = casatools.image tb = casatools.table() except ImportError: try: from taskinit import iatool, tbtool self.iatool = iatool tb = tbtool() except ImportError: raise ImportError( "Could not import CASA (casac) and therefore cannot read CASA .image files" ) self.ia_kwargs = ia_kwargs self.filename = filename self._cache = {} log.debug("Creating ArrayLikeCasa object") # try to trick CASA into destroying the ia object def getshape(): ia = self.iatool() # use the ia tool to get the file contents try: ia.open(self.filename, cache=False) except AssertionError as ex: if 'must be of cReqPath type' in str(ex): raise IOError("File {0} not found. Error was: {1}".format( self.filename, str(ex))) else: raise ex self.shape = tuple(ia.shape()[::-1]) self.dtype = np.dtype(ia.pixeltype()) ia.done() ia.close() getshape() self.ndim = len(self.shape) tb.open(self.filename) dminfo = tb.getdminfo() tb.done() # unclear if this is always the correct callspec!!! # (transpose requires this be backwards) self.chunksize = dminfo['*1']['SPEC']['DEFAULTTILESHAPE'][::-1] log.debug("Finished with initialization of ArrayLikeCasa object")
def make_baseline_dictionary(msfile): tb = casatools.table() qa = casatools.quanta() me = casatools.measures() tb.open(msfile, nomodify=True) A0 = tb.getcol('ANTENNA1') A1 = tb.getcol("ANTENNA2") ant_unique = np.unique(np.hstack((A0, A1))) tb.close() return dict([((x, y), np.where((A0 == x) & (A1 == y))[0]) for x in ant_unique for y in ant_unique if y > x])
def vel_to_chan(msfile, field, obsid, spw, restfreq, vel): """ Identifies the channel(s) corresponding to input LSRK velocities. Useful for choosing which channels to split out or flag if a line is expected to be present. Args: msfile (string): name of measurement set field (string): field name spw (int): Spectral window number obsid (int): Observation ID corresponding to the selected spectral window restfreq (float): Rest frequency [Hz] vel (float or array of floats): input velocity in LSRK frame km/s] Returns: (array) or (int) channel number most closely corresponding to input LSRK velocity """ tb = table() mstool = ms() # open the file tb.open(msfile + "/SPECTRAL_WINDOW") chanfreqs = tb.getcol("CHAN_FREQ", startrow=spw, nrow=1) tb.close() tb.open(msfile + "/FIELD") fieldnames = tb.getcol("NAME") tb.close() tb.open(msfile + "/OBSERVATION") obstime = np.squeeze(tb.getcol("TIME_RANGE", startrow=obsid, nrow=1))[0] tb.close() nchan = len(chanfreqs) mstool.open(msfile) lsrkfreqs = mstool.cvelfreqs( spwids=[spw], mode="channel", nchan=nchan, obstime=str(obstime) + "s", start=0, outframe="LSRK", ) # convert to LSRK velocities [km/s] chanvelocities = (restfreq - lsrkfreqs) / restfreq * cc_kms mstool.close() if type(vel) == np.ndarray: outchans = np.zeros_like(vel) for i in range(len(vel)): outchans[i] = np.argmin(np.abs(chanvelocities - vel[i])) return outchans else: return np.argmin(np.abs(chanvelocities - vel))
def __init__(self, filename, colname, shape=None): self.filename = filename self.colname = colname # gotta be a better way, right? if not hasattr(self, '_data'): tb = table() tb.open(self.filename) self._data = tb.getcol(self.colname) if shape is not None: self._data = np.broadcast_to(self._data, shape) tb.close()
def parallacticAngle(msfile, times): #measure = pm.measures() #tab = pt.table(msfile, readonly=True,ack=False) #field_tab = pt.table(tab.getkeyword('FIELD'),ack=False) #direction = np.squeeze(field_tab.getcol('PHASE_DIR')) tb = casatools.table() qa = casatools.quanta() me = casatools.measures() time_unique = times tb.open(msfile + '/FIELD', nomodify=True) direction = np.squeeze(tb.getcol('PHASE_DIR')) tb.close() tb.open(msfile + '/ANTENNA', nomodify=True) station_names = tb.getcol('NAME') pos = tb.getcol('POSITION').T mount = tb.getcol('MOUNT') Nant = pos.shape[0] N = range(Nant) nbl = (Nant * (Nant - 1)) / 2 tb.close() ra = qa.quantity(direction[0], 'rad') dec = qa.quantity(direction[1], 'rad') pointing = me.direction('j2000', ra, dec) start_time = me.epoch('utc', qa.quantity(time_unique[0], 's')) me.doframe(start_time) parallactic_ant_matrix = np.zeros((Nant, time_unique.shape[0])) def antenna_para(antenna): x = qa.quantity(pos[antenna, 0], 'm') y = qa.quantity(pos[antenna, 1], 'm') z = qa.quantity(pos[antenna, 2], 'm') position = me.position('wgs84', x, y, z) me.doframe(position) sec2rad = 2 * np.pi / (24 * 3600.) hour_angle = me.measure(pointing, 'HADEC')['m0']['value'] +\ (time_unique-time_unique.min()) * sec2rad earth_radius = 6371000.0 latitude = np.arcsin(pos[antenna, 2] / earth_radius) return np.arctan2( np.sin(hour_angle) * np.cos(latitude), (np.cos(direction[1]) * np.sin(latitude) - np.cos(hour_angle) * np.cos(latitude) * np.sin(direction[1]))) for i in range(Nant): if mount[i] == 'EQUATORIAL': parallactic_ant_matrix[i] = np.zeros(time_unique.shape) else: parallactic_ant_matrix[i] = antenna_para(i) * (180. / np.pi) return parallactic_ant_matrix
def get_ms_info(self): """Recovers all the metadata from the MS """ tb = casatools.table() # ms = casatools.ms() tb.open(self.msfile) # getting the required keywords, otherwise I need to open the MS again tb_ant = tb.getkeyword('ANTENNA') tb_spw = tb.getkeyword('SPECTRAL_WINDOW') # tb_pol = tb.getkeyword('POLARIZATION') tb_src = tb.getkeyword('FIELD') tb.open(tb_ant.replace('Table: ', '')) self._antennas = list(tb.getcol('NAME')) tb.open(tb_spw.replace('Table: ', '')) self._nsubbands = len(tb.getcol('TOTAL_BANDWIDTH')) self._bandwidth = np.sum(tb.getcol('TOTAL_BANDWIDTH')) * u.Hz self._nchannels = int(tb.getcol('NUM_CHAN')[0]) self._frequency = np.mean( tb.getcol('CHAN_FREQ').reshape( self.nchannels * self.nsubbands)) * u.Hz tb.open(tb_src.replace('Table: ', '')) tmp_src = {} src_coords = tb.getcol('PHASE_DIR') for i, src_name in enumerate(tb.getcol('NAME')): if src_name in self.targets: src_type = SourceType.target elif src_name in self.phase_calibrators: src_type = SourceType.calibrator elif src_name in self.bandpass_calibrators: src_type = SourceType.bandpass else: src_type = SourceType.other tmp_src[src_name] = Source(name=src_name, coordinates=coord.SkyCoord( *src_coords[:, 0, i], unit=u.rad), source_type=src_type) self._sources = tmp_src # Save checks... for source_group in (self.targets, self.phase_calibrators, self.bandpass_calibrators): for a_source in source_group: assert a_source in self.sources, f"The source {a_source} defined in the input file was not observed. " \ f"The available sources are {', '.join(list(self.sources.keys()))}." for an_antenna in self.refants: assert an_antenna in self.antennas, f"The ref. antenna {an_antenna} defined in the input file is not " \ f"in the array for {self.project_name}. " \ f"The available antennas are {', '.join(self.antennas)}."
def get_spw_global_fringe(caltable: str) -> int: """Returns the subband (spw) where the solutions from a global fringe run have been stored. Note that when combining spw in fringe, it will write the solutions in the first subband with data (so in general spw = 0; but not always). This function reads the generated calibration table and returns the number of the spw where the solutions where actually stored. """ tb = casatools.table() tb.open(caltable) tb.open(tb.getkeyword('SPECTRAL_WINDOW').replace('Table: ', '')) the_spw = int(tb.getcol('MEAS_FREQ_REF')[0]) tb.close() tb.close() return the_spw
def match_to_antenna_nos(evn_SEFD, msfile): tb = casatools.table() qa = casatools.quanta() me = casatools.measures() evn_SEFD_2 = {} evn_diams = {} tb.open('%s/ANTENNA' % msfile) x = tb.getcol('NAME') tb.close() print(evn_SEFD) for i, j in enumerate(x): evn_SEFD_2[i] = evn_SEFD[j][0] evn_diams[i] = evn_SEFD[j][1] return evn_SEFD_2, evn_diams
def test_getdesc(tmp_path, filename): casa_filename = str(tmp_path / 'casa.image') make_casa_testimage(filename, casa_filename) tb = table() tb.open(casa_filename) desc_reference = tb.getdesc() tb.close() desc_actual = getdesc(casa_filename) assert pformat(desc_actual) == pformat(desc_reference)
def add_hera_obs_pos(): """Adding HERA observatory position (at some point as PAPER_SA) Only needed for the older versions of CASA """ obstablename = os.path.dirname(casadata.__file__) + \ '/__data__/geodetic/Observatories/' tbl = table() tbl.open(obstablename, nomodify=False) if not (tbl.getcol('Name') == 'HERA').any(): paperi = (tbl.getcol('Name') == 'PAPER_SA').nonzero()[0] tbl.copyrows(obstablename, startrowin=paperi, startrowout=-1, nrow=1) tbl.putcell('Name', tbl.nrows()-1, 'HERA') tbl.close()
def test_generic_table_read(tmp_path): # NOTE: for now, this doesn't check that we can read the data - just # the metadata about the table. filename_fits = str(tmp_path / 'generic.fits') filename_casa = str(tmp_path / 'generic.image') t = Table() t['short'] = np.arange(3, dtype=np.int16) t['ushort'] = np.arange(3, dtype=np.uint16) t['int'] = np.arange(3, dtype=np.int32) t['uint'] = np.arange(3, dtype=np.uint32) t['float'] = np.arange(3, dtype=np.float32) t['double'] = np.arange(3, dtype=np.float64) t['complex'] = np.array([1 + 2j, 3.3 + 8.2j, -1.2 - 4.2j], dtype=np.complex64) t['dcomplex'] = np.array([3.33 + 4.22j, 3.3 + 8.2j, -1.2 - 4.2j], dtype=np.complex128) t['str'] = np.array(['reading', 'casa', 'images']) # Repeat this at the end to make sure we correctly finished reading # the complex column metadata t['int2'] = np.arange(3, dtype=np.int32) t.write(filename_fits) tb = table() tb.fromfits(filename_casa, filename_fits) tb.close() # Use the arrays in the table to also generate keywords of various types keywords = {'scalars': {}, 'arrays': {}} for name in t.colnames: keywords['scalars']['s_' + name] = t[name][0] keywords['arrays']['a_' + name] = t[name] tb.open(filename_casa) tb.putkeywords(keywords) tb.flush() tb.close() desc_actual = getdesc(filename_casa) tb.open(filename_casa) desc_reference = tb.getdesc() tb.close() assert pformat(desc_actual) == pformat(desc_reference)
def flag_quack_integrations(myvis, num_ints=2.5): tb = table() tb.open(myvis) int_time = tb.getcol('INTERVAL')[0] tb.close() this_quackinterval = num_ints * int_time flagdata(vis=myvis, flagbackup=False, mode='quack', quackmode='beg', quackincrement=False, quackinterval=this_quackinterval)
def add_pt_src(msfile, pt_flux): tb = casatools.table() qa = casatools.quanta() me = casatools.measures() cl = casatools.componentlist() tb.open(msfile + '/SOURCE') direc = tb.getcol('DIRECTION') direc = direc.T[0] tb.close() print('J2000 %srad %srad' % (direc[0], direc[1])) cl.addcomponent(flux=pt_flux, fluxunit='Jy', shape='point', dir='J2000 %srad %srad' % (direc[0], direc[1])) os.system('rm -r %s.cl' % msfile) cl.rename('%s.cl' % msfile) cl.close() ft(vis=msfile, complist='%s.cl' % msfile, usescratch=True) #uvsub(vis=msfile,reverse=True) os.system('rm -r %s.cl' % msfile)
def calc_pb_corr(msfile, diam_ants, single_freq): tb = casatools.table() qa = casatools.quanta() me = casatools.measures() arcmin_off = float(msfile.split('_')[1]) + ( float(msfile.split('_')[2].split('.ms')[0]) / 60.) tb.open('%s/SPECTRAL_WINDOW' % msfile) nspw = len(tb.getcol("MEAS_FREQ_REF")) chan_freqs = tb.getcol('CHAN_FREQ').T if single_freq != False: freq = np.mean(chan_freqs) chan_freqs = [freq] * nspw diams_ants2 = {} for i in diams_ants.keys(): pb_freq = {} for j in range(nspw): pb_freq[str(j)] = np.sqrt( calc_hpbw(x=arcmin_off, diam=diam_ants[i], freq=chan_freqs[j])) print(pb_freq) diams_ants2[i] = pb_freq datacolumn = 'MODEL_DATA' tb.open('%s' % msfile, nomodify=True) data = tb.getcol('%s' % datacolumn) antenna1 = tb.getcol('ANTENNA1') antenna2 = tb.getcol('ANTENNA2') spw_id = tb.getcol('DATA_DESC_ID') tint = np.average(tb.getcol('EXPOSURE')) tb.close() for i in range(len(antenna1)): amps, phase = R2P(data[:, :, i]) amps = amps * (diams_ants2[antenna1[i]][str(spw_id[i])] * diams_ants2[antenna2[i]][str(spw_id[i])]) data[:, :, i] = P2R(amps, phase) tb.open('%s' % msfile, nomodify=False) tb.putcol('%s' % datacolumn, data) tb.close() return
def copy_pols(msfile, antenna, pol, newpol): tb = casatools.table() tb.open(msfile + '/POLARIZATION') pol_code = tb.getcol('CORR_TYPE') pol = np.where(pol_code == pol)[0][0] newpol = np.where(pol_code == newpol)[0][0] tb.open(msfile, nomodify=False) ram_restrict = 100000 ranger = list(range(0, tb.nrows(), ram_restrict)) for j in progressbar(ranger, '', 50): if j == ranger[-1]: ram_restrict = tb.nrows() % ram_restrict gain = tb.getcol('DATA', startrow=j, nrow=ram_restrict, rowincr=1) ant1 = tb.getcol('ANTENNA1', startrow=j, nrow=ram_restrict, rowincr=1) ant2 = tb.getcol('ANTENNA2', startrow=j, nrow=ram_restrict, rowincr=1) gain[newpol, :, ((ant1 == ant) | (ant2 == ant))] = gain[pol, :, ((ant1 == ant) | (ant2 == ant))] tb.putcol('DATA', gain, startrow=j, nrow=ram_restrict, rowincr=1) tb.close()
def add_noise(msfile, datacolumn, evn_SEFD, adjust_time=1.0): tb = casatools.table() qa = casatools.quanta() me = casatools.measures() tb.open('%s' % msfile, nomodify=True) data = tb.getcol('%s' % datacolumn) if datacolumn == 'CORRECTED_DATA': weightnames = 'WEIGHT' elif datacolumn == 'DATA': weightnames = 'SIGMA' else: raise TypeError weights = tb.getcol('%s' % weightnames) antenna1 = tb.getcol('ANTENNA1') antenna2 = tb.getcol('ANTENNA2') tint = np.average(tb.getcol('EXPOSURE')) tb.close() if adjust_time != 1.0: tint = tint * adjust_time tb.open('%s/SPECTRAL_WINDOW' % msfile, nomodify=True) chan_width = np.average(tb.getcol('CHAN_WIDTH')) print(chan_width, tint) tb.close() for i in range(len(antenna1)): sefd = calc_sefd(evn_SEFD[antenna1[i]], evn_SEFD[antenna2[i]], tint, chan_width, 0.88) amps = np.random.normal(0., sefd, np.shape(data[:, :, i])) phase = ((np.pi + np.pi) * np.random.random_sample(np.shape(data[:, :, i]))) - np.pi data[:, :, i] = P2R(amps, phase) weights[:, i] = np.ones(weights[:, i].shape) / (sefd**2) tb.open('%s' % msfile, nomodify=False) tb.putcol('%s' % datacolumn, data) tb.putcol('%s' % weightnames, weights) tb.close()
def test_getdminfo(tmp_path, shape): filename = str(tmp_path / 'test.image') data = np.random.random(shape) ia = image() ia.fromarray(outfile=filename, pixels=data, log=False) ia.close() tb = table() tb.open(filename) reference = tb.getdminfo() tb.close() actual = getdminfo(filename) # We include information about endian-ness in the dminfo but CASA doesn't actual['*1'].pop('BIGENDIAN') # The easiest way to compare the output is simply to compare the output # from pformat (checking for dictionary equality doesn't work because of # the Numpy arrays inside). assert pformat(actual) == pformat(reference)
def make_qa_tables( ms_name, output_folder='scan_plots_txt', outtype='txt', overwrite=True, chanavg=4096, ): ''' Specifically for saving txt tables. Replace the scan loop in `make_qa_scan_figures` to make fewer but larger tables. ''' # Will need to updated for CASA 6 # from taskinit import tb from casatools import table tb = table() from casaplotms import plotms casalog.post("Running make_qa_tables to export txt files for QA.") print("Running make_qa_tables to export txt files for QA.") # Make folder for scan plots if not os.path.exists(output_folder): os.mkdir(output_folder) else: if overwrite: casalog.post( message="Removing plot tables in {}".format(output_folder), origin='make_qa_tables') print("Removing plot tables in {}".format(output_folder)) os.system("rm -r {}/*".format(output_folder)) else: casalog.post("{} already exists. Will skip existing files.".format( output_folder)) # raise ValueError("{} already exists. Enable overwrite=True to rerun.".format(output_folder)) # Read the field names tb.open(os.path.join(ms_name, "FIELD")) names = tb.getcol('NAME') numFields = tb.nrows() tb.close() # Intent names tb.open(os.path.join(ms_name, 'STATE')) intentcol = tb.getcol('OBS_MODE') tb.close() # Determine the fields that are calibrators. tb.open(ms_name) is_calibrator = np.empty((numFields, ), dtype='bool') for ii in range(numFields): subtable = tb.query('FIELD_ID==%s' % ii) # Is the intent for calibration? scan_intents = intentcol[np.unique(subtable.getcol("STATE_ID"))] is_calib = False for intent in scan_intents: if "CALIBRATE" in intent: is_calib = True break is_calibrator[ii] = is_calib tb.close() casalog.post(message="Fields are: {}".format(names), origin='make_qa_tables') casalog.post(message="Calibrator fields are: {}".format( names[is_calibrator]), origin='make_qa_tables') print("Fields are: {}".format(names)) print("Calibrator fields are: {}".format(names[is_calibrator])) # Loop through fields. Make separate tables only for different targets. for ii in range(numFields): casalog.post(message="On field {}".format(names[ii]), origin='make_qa_plots') print("On field {}".format(names[ii])) # Amp vs. time amptime_filename = os.path.join( output_folder, 'field_{0}_amp_time.{1}'.format(names[ii], outtype)) if not os.path.exists(amptime_filename): plotms( vis=ms_name, xaxis='time', yaxis='amp', ydatacolumn='corrected', selectdata=True, field=names[ii], scan="", spw="", avgchannel=str(chanavg), correlation="", averagedata=True, avgbaseline=True, transform=False, extendflag=False, plotrange=[], # title='Amp vs Time: Field {0} Scan {1}'.format(names[ii], jj), xlabel='Time', ylabel='Amp', showmajorgrid=False, showminorgrid=False, plotfile=amptime_filename, overwrite=True, showgui=False) else: casalog.post(message="File {} already exists. Skipping".format( amptime_filename), origin='make_qa_tables') # Amp vs. channel ampchan_filename = os.path.join( output_folder, 'field_{0}_amp_chan.{1}'.format(names[ii], outtype)) if not os.path.exists(ampchan_filename): plotms( vis=ms_name, xaxis='chan', yaxis='amp', ydatacolumn='corrected', selectdata=True, field=names[ii], scan="", spw="", avgchannel="1", avgtime="1e8", correlation="", averagedata=True, avgbaseline=True, transform=False, extendflag=False, plotrange=[], # title='Amp vs Chan: Field {0} Scan {1}'.format(names[ii], jj), xlabel='Channel', ylabel='Amp', showmajorgrid=False, showminorgrid=False, plotfile=ampchan_filename, overwrite=True, showgui=False) else: casalog.post(message="File {0} already exists. Skipping".format( ampchan_filename), origin='make_qa_tables') # Plot amp vs uvdist ampuvdist_filename = os.path.join( output_folder, 'field_{0}_amp_uvdist.{1}'.format(names[ii], outtype)) if not os.path.exists(ampuvdist_filename): plotms( vis=ms_name, xaxis='uvdist', yaxis='amp', ydatacolumn='corrected', selectdata=True, field=names[ii], scan="", spw="", avgchannel=str(chanavg), avgtime='1e8', correlation="", averagedata=True, avgbaseline=False, transform=False, extendflag=False, plotrange=[], # title='Amp vs UVDist: Field {0} Scan {1}'.format(names[ii], jj), xlabel='uv-dist', ylabel='Amp', showmajorgrid=False, showminorgrid=False, plotfile=ampuvdist_filename, overwrite=True, showgui=False) else: casalog.post(message="File {} already exists. Skipping".format( ampuvdist_filename), origin='make_qa_tables') # Make phase plots if a calibrator source. if is_calibrator[ii]: casalog.post("This is a calibrator. Exporting phase info, too.") print("This is a calibrator. Exporting phase info, too.") # Plot phase vs time phasetime_filename = os.path.join( output_folder, 'field_{0}_phase_time.{1}'.format(names[ii], outtype)) if not os.path.exists(phasetime_filename): plotms( vis=ms_name, xaxis='time', yaxis='phase', ydatacolumn='corrected', selectdata=True, field=names[ii], scan="", spw="", correlation="", avgchannel=str(chanavg), averagedata=True, avgbaseline=True, transform=False, extendflag=False, plotrange=[], # title='Phase vs Time: Field {0} Scan {1}'.format(names[ii], jj), xlabel='Time', ylabel='Phase', showmajorgrid=False, showminorgrid=False, plotfile=phasetime_filename, overwrite=True, showgui=False) else: casalog.post(message="File {} already exists. Skipping".format( phasetime_filename), origin='make_qa_tables') # Plot phase vs channel phasechan_filename = os.path.join( output_folder, 'field_{0}_phase_chan.{1}'.format(names[ii], outtype)) if not os.path.exists(phasechan_filename): plotms( vis=ms_name, xaxis='chan', yaxis='phase', ydatacolumn='corrected', selectdata=True, field=names[ii], scan="", spw="", avgchannel="1", avgtime="1e8", correlation="", averagedata=True, avgbaseline=True, transform=False, extendflag=False, plotrange=[], # title='Phase vs Chan: Field {0} Scan {1}'.format(names[ii], jj), xlabel='Chan', ylabel='Phase', showmajorgrid=False, showminorgrid=False, plotfile=phasechan_filename, overwrite=True, showgui=False) else: casalog.post(message="File {} already exists. Skipping".format( phasechan_filename), origin='make_qa_tables') # Plot phase vs uvdist phaseuvdist_filename = os.path.join( output_folder, 'field_{0}_phase_uvdist.{1}'.format(names[ii], outtype)) if not os.path.exists(phaseuvdist_filename): plotms( vis=ms_name, xaxis='uvdist', yaxis='phase', ydatacolumn='corrected', selectdata=True, field=names[ii], scan="", spw="", correlation="", avgchannel=str(chanavg), avgtime='1e8', averagedata=True, avgbaseline=False, transform=False, extendflag=False, plotrange=[], # title='Phase vs UVDist: Field {0} Scan {1}'.format(names[ii], jj), xlabel='uv-dist', ylabel='Phase', showmajorgrid=False, showminorgrid=False, plotfile=phaseuvdist_filename, overwrite=True, showgui=False) else: casalog.post(message="File {} already exists. Skipping".format( phaseuvdist_filename), origin='make_qa_tables') # Plot amp vs phase ampphase_filename = os.path.join( output_folder, 'field_{0}_amp_phase.{1}'.format(names[ii], outtype)) if not os.path.exists(ampphase_filename): plotms( vis=ms_name, xaxis='amp', yaxis='phase', ydatacolumn='corrected', selectdata=True, field=names[ii], scan="", spw="", correlation="", avgchannel=str(chanavg), avgtime='1e8', averagedata=True, avgbaseline=False, transform=False, extendflag=False, plotrange=[], # title='Amp vs Phase: Field {0} Scan {1}'.format(names[ii], jj), xlabel='Phase', ylabel='Amp', showmajorgrid=False, showminorgrid=False, plotfile=ampphase_filename, overwrite=True, showgui=False) else: casalog.post(message="File {} already exists. Skipping".format( ampphase_filename), origin='make_qa_tables') # Plot uv-wave vs, amp - model residual # Check how good the point-source calibrator model is. ampresid_filename = os.path.join( output_folder, 'field_{0}_ampresid_uvwave.{1}'.format(names[ii], outtype)) if not os.path.exists(ampresid_filename): plotms(vis=ms_name, xaxis='uvwave', yaxis='amp', ydatacolumn='corrected-model_scalar', selectdata=True, field=names[ii], scan="", spw="", correlation="", avgchannel=str(chanavg), avgtime='1e8', averagedata=True, avgbaseline=False, transform=False, extendflag=False, plotrange=[], xlabel='uv-dist', ylabel='Phase', showmajorgrid=False, showminorgrid=False, plotfile=ampresid_filename, overwrite=True, showgui=False) else: casalog.post(message="File {} already exists. Skipping".format( ampresid_filename), origin='make_qa_tables') # Plot amplitude vs antenna 1. # Check for ant outliers ampant_filename = os.path.join( output_folder, 'field_{0}_amp_ant1.{1}'.format(names[ii], outtype)) if not os.path.exists(ampant_filename): plotms(vis=ms_name, xaxis='antenna1', yaxis='amp', ydatacolumn='corrected', selectdata=True, field=names[ii], scan="", spw="", correlation="", avgchannel=str(chanavg), avgtime='1e8', averagedata=True, avgbaseline=False, transform=False, extendflag=False, plotrange=[], xlabel='antenna 1', ylabel='Amp', showmajorgrid=False, showminorgrid=False, plotfile=ampant_filename, overwrite=True, showgui=False) else: casalog.post(message="File {} already exists. Skipping".format( ampant_filename), origin='make_qa_tables') # Plot phase vs antenna 1. # Check for ant outliers phaseant_filename = os.path.join( output_folder, 'field_{0}_phase_ant1.{1}'.format(names[ii], outtype)) if not os.path.exists(phaseant_filename): plotms(vis=ms_name, xaxis='antenna1', yaxis='phase', ydatacolumn='corrected', selectdata=True, field=names[ii], scan="", spw="", correlation="", avgchannel=str(chanavg), avgtime='1e8', averagedata=True, avgbaseline=False, transform=False, extendflag=False, plotrange=[], xlabel='antenna 1', ylabel='Phase', showmajorgrid=False, showminorgrid=False, plotfile=phaseant_filename, overwrite=True, showgui=False) else: casalog.post(message="File {} already exists. Skipping".format( phaseant_filename), origin='make_qa_tables')
def make_qa_scan_figures(ms_name, output_folder='scan_plots', outtype='png'): ''' Make a series of plots per scan for QA and flagging purposes. TODO: Add more settings here for different types of plots, etc. Parameters ---------- ms_name : str MS name output_folder : str, optional Output plot folder name. ''' # Will need to updated for CASA 6 from casatools import table tb = table() # from taskinit import tb # from taskinit import casalog from casaplotms import plotms # SPWs to loop through tb.open(os.path.join(ms_name, "SPECTRAL_WINDOW")) spws = range(len(tb.getcol("NAME"))) nchans = tb.getcol('NUM_CHAN') tb.close() # Read the field names tb.open(os.path.join(ms_name, "FIELD")) names = tb.getcol('NAME') numFields = tb.nrows() tb.close() # Intent names tb.open(os.path.join(ms_name, 'STATE')) intentcol = tb.getcol('OBS_MODE') tb.close() tb.open(ms_name) scanNums = np.unique(tb.getcol('SCAN_NUMBER')) field_scans = [] is_calibrator = np.empty_like(scanNums, dtype='bool') is_all_flagged = np.empty((len(spws), len(scanNums)), dtype='bool') for ii in range(numFields): subtable = tb.query('FIELD_ID==%s' % ii) field_scan = np.unique(subtable.getcol('SCAN_NUMBER')) field_scans.append(field_scan) # Is the intent for calibration? scan_intents = intentcol[np.unique(subtable.getcol("STATE_ID"))] is_calib = False for intent in scan_intents: if "CALIBRATE" in intent: is_calib = True break is_calibrator[field_scan - 1] = is_calib # Are any of the scans completely flagged? for spw in spws: for scan in field_scan: scantable = \ tb.query("SCAN_NUMBER=={0} AND DATA_DESC_ID=={1}".format(scan, spw)) if scantable.getcol("FLAG").all(): is_all_flagged[spw, scan - 1] = True else: is_all_flagged[spw, scan - 1] = False tb.close() # Make folder for scan plots if not os.path.exists(output_folder): os.mkdir(output_folder) # Loop through SPWs and create plots. for spw_num in spws: casalog.post("On SPW {}".format(spw)) # Plotting the HI spw (0) takes so so long. # Make some simplifications to save time # if spw_num == 0: # avg_chan = "4" # else: # TODO: change appropriately for line SPWs with many channels avg_chan = "1" spw_folder = os.path.join(output_folder, "spw_{}".format(spw_num)) if not os.path.exists(spw_folder): os.mkdir(spw_folder) else: # Make sure any old plots are removed first. os.system("rm {}/*.png".format(spw_folder)) for ii in range(len(field_scans)): casalog.post("On field {}".format(names[ii])) for jj in field_scans[ii]: # Check if all of the data is flagged. if is_all_flagged[spw_num, jj - 1]: casalog.post("All data flagged in SPW {0} scan {1}".format( spw_num, jj)) continue casalog.post("On scan {}".format(jj)) # Amp vs. time plotms(vis=ms_name, xaxis='time', yaxis='amp', ydatacolumn='corrected', selectdata=True, field=names[ii], scan=str(jj), spw=str(spw_num), avgchannel=str(avg_chan), correlation="", averagedata=True, avgbaseline=True, transform=False, extendflag=False, plotrange=[], title='Amp vs Time: Field {0} Scan {1}'.format( names[ii], jj), xlabel='Time', ylabel='Amp', showmajorgrid=False, showminorgrid=False, plotfile=os.path.join( spw_folder, 'field_{0}_amp_scan_{1}.{2}'.format( names[ii], jj, outtype)), overwrite=True, showgui=False) # Amp vs. channel plotms(vis=ms_name, xaxis='chan', yaxis='amp', ydatacolumn='corrected', selectdata=True, field=names[ii], scan=str(jj), spw=str(spw_num), avgchannel=str(avg_chan), avgtime="1e8", correlation="", averagedata=True, avgbaseline=True, transform=False, extendflag=False, plotrange=[], title='Amp vs Chan: Field {0} Scan {1}'.format( names[ii], jj), xlabel='Channel', ylabel='Amp', showmajorgrid=False, showminorgrid=False, plotfile=os.path.join( spw_folder, 'field_{0}_amp_chan_scan_{1}.{2}'.format( names[ii], jj, outtype)), overwrite=True, showgui=False) # Plot amp vs uvdist plotms(vis=ms_name, xaxis='uvdist', yaxis='amp', ydatacolumn='corrected', selectdata=True, field=names[ii], scan=str(jj), spw=str(spw_num), avgchannel=str(4096), avgtime='1e8', correlation="", averagedata=True, avgbaseline=False, transform=False, extendflag=False, plotrange=[], title='Amp vs UVDist: Field {0} Scan {1}'.format( names[ii], jj), xlabel='uv-dist', ylabel='Amp', showmajorgrid=False, showminorgrid=False, plotfile=os.path.join( spw_folder, 'field_{0}_amp_uvdist_scan_{1}.{2}'.format( names[ii], jj, outtype)), overwrite=True, showgui=False) # Skip the phase plots for the HI SPW (0) if is_calibrator[jj - 1]: # Plot phase vs time plotms(vis=ms_name, xaxis='time', yaxis='phase', ydatacolumn='corrected', selectdata=True, field=names[ii], scan=str(jj), spw=str(spw_num), correlation="", averagedata=True, avgbaseline=True, transform=False, extendflag=False, plotrange=[], title='Phase vs Time: Field {0} Scan {1}'.format( names[ii], jj), xlabel='Time', ylabel='Phase', showmajorgrid=False, showminorgrid=False, plotfile=os.path.join( spw_folder, 'field_{0}_phase_time_scan_{1}.{2}'.format( names[ii], jj, outtype)), overwrite=True, showgui=False) # Plot phase vs channel plotms(vis=ms_name, xaxis='chan', yaxis='phase', ydatacolumn='corrected', selectdata=True, field=names[ii], scan=str(jj), spw=str(spw_num), avgchannel=str(avg_chan), avgtime="1e8", correlation="", averagedata=True, avgbaseline=True, transform=False, extendflag=False, plotrange=[], title='Phase vs Chan: Field {0} Scan {1}'.format( names[ii], jj), xlabel='Chan', ylabel='Phase', showmajorgrid=False, showminorgrid=False, plotfile=os.path.join( spw_folder, 'field_{0}_phase_chan_scan_{1}.{2}'.format( names[ii], jj, outtype)), overwrite=True, showgui=False) # Plot phase vs uvdist plotms(vis=ms_name, xaxis='uvdist', yaxis='phase', ydatacolumn='corrected', selectdata=True, field=names[ii], scan=str(jj), spw=str(spw_num), correlation="", avgchannel="4096", avgtime='1e8', averagedata=True, avgbaseline=False, transform=False, extendflag=False, plotrange=[], title='Phase vs UVDist: Field {0} Scan {1}'.format( names[ii], jj), xlabel='uv-dist', ylabel='Phase', showmajorgrid=False, showminorgrid=False, plotfile=os.path.join( spw_folder, 'field_{0}_phase_uvdist_scan_{1}.{2}'.format( names[ii], jj, outtype)), overwrite=True, showgui=False) # Plot amp vs phase plotms( vis=ms_name, xaxis='amp', yaxis='phase', ydatacolumn='corrected', selectdata=True, field=names[ii], scan=str(jj), spw=str(spw_num), correlation="", avgchannel="4096", # avgtime='1e8', averagedata=True, avgbaseline=False, transform=False, extendflag=False, plotrange=[], title='Amp vs Phase: Field {0} Scan {1}'.format( names[ii], jj), xlabel='Phase', ylabel='Amp', showmajorgrid=False, showminorgrid=False, plotfile=os.path.join( spw_folder, 'field_{0}_amp_phase_scan_{1}.{2}'.format( names[ii], jj, outtype)), overwrite=True, showgui=False)
import matplotlib # Agg doesn't need X - matplotlib doesn't work with xvfb matplotlib.use('Agg', warn=False) import matplotlib.pyplot as plt import numpy as np from config_parser import validate_args as va import bookkeeping import glob PLOT_DIR = 'plots' EXTN = 'png' from casatasks import * from casatools import table,msmetadata tb = table() msmd = msmetadata() logfile=casalog.logfile() casalog.setlogfile('logs/{SLURM_JOB_NAME}-{SLURM_JOB_ID}.casa'.format(**os.environ)) import logging from time import gmtime logging.Formatter.converter = gmtime logger = logging.getLogger(__name__) logging.basicConfig(format="%(asctime)-15s %(levelname)s: %(message)s", level=logging.INFO) def avg_ants(arrlist): return [np.mean(arr, axis=-1) for arr in arrlist] def lengthen(edat, inpdat):
from casatools import table import numpy as np calK = table() #calK.open("cal.b.K") calK.open("test_cal") delays = np.squeeze(calK.getcol("FPARAM")) antnames = calK.getcol("ANTENNA1") tb = np.column_stack((antnames, *delays)) np.savetxt("./delays_residuals.txt", tb, fmt="%i %.6f %.6f", header="antnum X Y")
def weight_multichan(base_ms, npix, cell_size, robust=np.array([0.]), chans=np.array([2]), method='briggs', perchanweight=False, mod_pcwd=False, npixels=0): tb = casatools.table() ms = casatools.ms() # Use CASA table tools to get frequencies tb.open(base_ms+"/SPECTRAL_WINDOW") chan_freqs = tb.getcol("CHAN_FREQ") rfreq = tb.getcol("REF_FREQUENCY") tb.close() # Use CASA table tools to get columns of UVW, DATA, WEIGHT, etc. tb.open(base_ms, nomodify=False) flag = tb.getcol("FLAG") sigma = tb.getcol("SIGMA") uvw = tb.getcol("UVW") weight = tb.getcol("WEIGHT") ant1 = tb.getcol("ANTENNA1") ant2 = tb.getcol("ANTENNA2") tb.close() flag = np.logical_not(np.prod(flag, axis=(0,2)).T) # break out the u, v spatial frequencies, convert from m to lambda uu = uvw[0,:][:,np.newaxis]*chan_freqs[:,0]/(cc/100) vv = uvw[1,:][:,np.newaxis]*chan_freqs[:,0]/(cc/100) # toss out the autocorrelation placeholders xc = np.where(ant1 != ant2)[0] wgts = weight[0,:] + weight[1,:] uu_xc = uu[xc][:,flag] vv_xc = vv[xc][:,flag] wgts_xc = wgts[xc] dl = cell_size*arcsec dm = cell_size*arcsec du = 1./((npix)*dl) dv = 1./((npix)*dm) # create arrays to dump values rms = np.zeros((chans.shape[0], robust.shape[0])) beam_params = np.zeros((chans.shape[0],3, robust.shape[0])) # grid the weights outside of loop if not perchanweight, only need to do this once... if perchanweight == False: gwgts_init = np.zeros((npix, npix)) gwgts_init = grid_wgts(gwgts_init, np.ravel(uu_xc), np.ravel(vv_xc), du, dv, npix, np.ravel(np.broadcast_to(wgts_xc, (uu_xc.shape[1], uu_xc.shape[0])).T)) if mod_pcwd == True: # TODO CHECK THIS FOR HALF PIXEL OFFSET uvdist_grid = np.sqrt(np.add.outer(np.arange(-(npix/2.)*du, (npix/2.)*du, du)**2, np.arange(-(npix/2.)*dv, (npix/2.)*dv, dv)**2)) frac_bw = (np.max(chan_freqs) - np.min(chan_freqs)) / rfreq corr_fac = frac_bw*uvdist_grid/du corr_fac[corr_fac<1] = 1. for i, chan in enumerate(chans): print(chan) # grid the weights (with complex conjugates) if perchanweight == True: gwgts_init = np.zeros((npix, npix)) gwgts_init = grid_wgts(gwgts_init, uu_xc[:,chan], vv_xc[:,chan], du, dv, npix, wgts_xc) gwgts_init_sq = gwgts_init**2 for j, r in enumerate(robust): # do the weighting, in each case for method/perchanweight selection if method == 'briggs': # calculate robust parameters # normalize differently if only using single channel; note that we assume the weights are not channelized and are uniform across channel if perchanweight == True: if mod_pcwd == True: f_sq = ((5*10**(-r))**2)/(np.sum(gwgts_init_sq)/(np.sum(wgts_xc)*2)) else: f_sq = ((5*10**(-r))**2)/(np.sum(gwgts_init_sq)/(np.sum(wgts_xc))) else: f_sq = ((5*10**(-r))**2)/(np.sum(gwgts_init_sq)/(np.sum(wgts_xc*uu_xc.shape[1])*2)) if mod_pcwd==True: gr_wgts = 1/(1+gwgts_init/corr_fac*f_sq) else: gr_wgts = 1/(1+gwgts_init*f_sq) # multiply to get robust weights indexed_gr_wgts = ungrid_wgts(gr_wgts, uu_xc[:,chan], vv_xc[:,chan], du, dv, npix) wgts_robust = wgts_xc*indexed_gr_wgts wgts_robust_sq = wgts_xc*(indexed_gr_wgts)**2 if method == 'briggsabs': # multiply to get robust weights S_sq = (gwgts_init[index_arr[chan,:,0], index_arr[chan,:,1]]*r**2).T indexed_gr_wgts = (1/(S_sq + 2*wgts_xc)) wgts_robust = wgts_xc*indexed_gr_wgts wgts_robust_sq = wgts_xc*(indexed_gr_wgts)**2 #get the total gridded weights (to make dirty beam) gwgts_final = np.zeros((npix, npix)) gwgts_final = grid_wgts(gwgts_final, uu_xc[:,chan], vv_xc[:,chan], du, dv, npix, wgts_robust) # create the dirty beam and calculate the beam parameters robust_beam = np.real(fftshift(fft2(fftshift(gwgts_final)))) robust_beam /= np.max(robust_beam) #beam_params[i,:,j] = fit_beam(robust_beam, cell_size) beam_params[i,:,j] = fit_beam_CASA(robust_beam, cell_size) # calculate rms (formula from Briggs et al. 1995) C = 1/(2*np.sum(wgts_robust)) rms[i,j] = 2*C*np.sqrt(np.sum(wgts_robust_sq)) print(r, beam_params[i,:,j], rms[i,j]*1000.) return rms, beam_params
def import_data_ms(filename): """Imports data from a CASA measurement set and returns visibility object""" tb = table() ms_ = ms() # Antenna information tb.open(filename) data = tb.getcol("DATA") uvw = tb.getcol("UVW") weight = tb.getcol("WEIGHT") ant1 = tb.getcol("ANTENNA1") ant2 = tb.getcol("ANTENNA2") flags = tb.getcol("FLAG") time = tb.getcol("TIME") tb.close() # Spectral window information ms_.open(filename) spw_info = ms_.getspectralwindowinfo() nchan = spw_info["0"]["NumChan"] npol = spw_info["0"]["NumCorr"] ms_.close() # Frequency information tb.open(filename+"/SPECTRAL_WINDOW") freqs = tb.getcol("CHAN_FREQ") rfreq = tb.getcol("REF_FREQUENCY") resolution = tb.getcol("CHAN_WIDTH") tb.close() uu = uvw[0, :] vv = uvw[1, :] # Check if pols are already averaged data = np.squeeze(data) weight = np.squeeze(weight) flags = np.squeeze(flags) if npol == 1: Re = data.real Im = data.imag wgts = weight else: # Polarization averaging Re_xx = data[0, :].real Re_yy = data[1, :].real Im_xx = data[0, :].imag Im_yy = data[1, :].imag weight_xx = weight[0, :] weight_yy = weight[1, :] flags = flags[0, :]*flags[1, :] # Weighted averages with np.errstate(divide='ignore', invalid='ignore'): Re = np.where((weight_xx + weight_yy) != 0, (Re_xx * weight_xx + \ Re_yy*weight_yy) / (weight_xx + weight_yy), 0.) Im = np.where((weight_xx + weight_yy) != 0, (Im_xx * weight_xx + \ Im_yy*weight_yy) / (weight_xx + weight_yy), 0.) wgts = (weight_xx + weight_yy) # Toss out the autocorrelations xc = np.where(ant1 != ant2)[0] # Check if there's only a single channel if nchan == 1: data_real = Re[np.newaxis, xc] data_imag = Im[np.newaxis, xc] flags = flags[xc] else: data_real = Re[:, xc] data_imag = Im[:, xc] flags = flags[:, xc] time = time[xc] # If the majority of points in any channel are flagged, it probably # means an entire channel is flagged - spit warning if np.mean(flags.all(axis=0)) > 0.5: print('WARNING: Over half of the (u,v) points in at least one '\ 'channel are marked as flagged. If you did not expect this, it is '\ 'likely due to having an entire channel flagged in the ms. Please '\ 'double check this and be careful if model fitting or using diff mode.') # Collapse flags to single channel, because weights are not currently channelized flags = flags.any(axis=0) data_wgts = wgts[xc] data_uu = uu[xc] data_vv = vv[xc] ant1 = ant1[xc] ant2 = ant2[xc] data_VV = data_real + 1j*data_imag # Warning that flagged data was imported if np.any(flags): print('WARNING: Flagged data was imported. Visibility interpolation can '\ 'proceed normally, but be careful with chi^2 calculations.') return Visibility(data_VV.T, data_uu, data_vv, data_wgts, freqs, time, \ resolution, ant1, ant2, flags)
if len(sys.argv) > 1: fitsimage = sys.argv[1] else: fitsimage = 'SKAMid_B2_8h_v3.fits' assert os.path.exists(fitsimage) assert 'fits' in fitsimage image = fitsimage.replace('.fits', '.image') if not os.path.exists(image): print('creating ms image to be used as sky model') im = casatools.image() im.fromfits(infile=fitsimage, outfile=image) # get antenna positions tabname = 'antenna_positions_' + conf_file.split('.cfg')[0] + '.tab' tb = casatools.table() tb.fromascii(tabname, conf_file, firstline=3, sep=' ', columnnames=['X', 'Y', 'Z', 'DIAM', 'NAME'], datatypes=['D', 'D', 'D', 'D', 'A']) xx = tb.getcol('X') yy = tb.getcol('Y') zz = tb.getcol('Z') diam = tb.getcol('DIAM') anames = tb.getcol('NAME') tb.close() # simulate setup sm = casatools.simulator()
def interpolate_bandpass(tablename, spw_ids=None, window_size_factor=2.5, poly_order=2, add_residuals=True, backup_table=True, test_output_nowrite=False, test_print=False): ''' Use Savitzky-Golay smoothing across flagged channels in the bandpass. The smoothing window size is set by `window_size_factor`. This will create a smoothing length (by default) 2.5 times larger than the gap that will be interpolated across. Smaller window sizes will produce artifacts over the interpolated region. ''' from casatools import table # from taskinit import tbtool, casalog if backup_table: original_table_backup = tablename + '.bak_from_interpbandpass' if not os.path.isdir(original_table_backup): shutil.copytree(tablename, original_table_backup) # tb = tbtool() tb = table() tb.open(tablename) all_spw_ids = np.unique(tb.getcol("SPECTRAL_WINDOW_ID")) tb.close() if spw_ids is None: spw_ids = all_spw_ids else: for spw in spw_ids: if spw not in all_spw_ids: raise ValueError( "SPW {} specified does not exist in the table.".format( spw)) # We'll just output the corrected data as numpy arrays instead of writing # back to the table. if test_output_nowrite: bp_pass_dict = dict() for spw in spw_ids: casalog.post(message='processing SPW {0}'.format(spw), origin='interpolate_bandpass') tb.open(tablename) stb = tb.query('SPECTRAL_WINDOW_ID == {0}'.format(spw)) dat = np.ma.array(stb.getcol('CPARAM')) dat.mask = stb.getcol('FLAG') stb.close() tb.close() # Identify if there are gaps to interpolate across # We ignore the edge masking in all cases. # The sum is over corr and ants. If there are only 2 gaps, no interpolation is needed. blank_slices = nd.find_objects(*nd.label(dat.sum(2).sum(0).mask)) # If there's only 2 slices, it's the SPW edge flagging. # We can skip those. if len(blank_slices) == 2: casalog.post(message="no interpolation needed for {0}".format(spw), origin='interpolate_bandpass') continue dat_shape = dat.shape smooth_dat = deepcopy(dat) for ant in range(dat_shape[2]): casalog.post(message='processing antenna {0}'.format(ant), origin='interpolate_bandpass') for pol in range(dat_shape[0]): # Skip if all flagged. if np.all(dat.mask[pol, :, ant]): continue # Determine ranges to interpolate over blank_slices = nd.find_objects(*nd.label(dat[pol, :, ant].mask)) # If there's only 2 slices, it's the SPW edge flagging. # No interpolation needed. if len(blank_slices) == 2: continue # Otherwise we'll mask out the middle gaps # Remove the edges. if len(blank_slices) > 1: blank_slices.pop(0) blank_slices.pop(-1) nchans_in_gap = max([(thisslc[0].stop - thisslc[0].start) for thisslc in blank_slices]) # Define the window size based on the given fraction of the num of SPW channels window_size = int(np.floor(window_size_factor * nchans_in_gap)) # Force odd window size if window_size % 2 == 0: window_size += 1 casalog.post( message="Using window size of {0} for SPW {1}".format( window_size, spw), origin='interpolate_bandpass') # Print a warning if the gap is >50% of the whole SPW if window_size / dat.shape[1] > 0.5: casalog.post( message="Warning: the window size is >50% of the SPW", origin='interpolate_bandpass') x_polyfit = np.arange(window_size) - np.floor(window_size / 2.) # Mask out the gap. for slicer in blank_slices: smooth_dat.mask[(slice(pol, pol + 1), slicer[0], slice(ant, ant + 1))] = True rolled_array = rolling_window(smooth_dat[pol, :, ant], window_size) for i in range(dat_shape[1]): # try: smooth_dat[pol, i, ant] = np.ma.polyfit(x_polyfit, rolled_array[i], poly_order)[-1] # What's the catch here? ValueError is all flagged? # except: # pass casalog.post(message="replacing values with smoothed", origin='interpolate_bandpass') if add_residuals: resids = dat[pol, :, ant] - smooth_dat[pol, :, ant] # Add the interpolated values back to the original array for slicer in blank_slices: dat[(slice(pol, pol + 1), slicer[0], slice(ant, ant + 1))] = \ smooth_dat[(slice(pol, pol + 1), slicer[0], slice(ant, ant + 1))] # Optionally sample residuals from the difference and add to the interpolated # region to keep a consistent noise level. if add_residuals: gap_size = slicer[0].stop - slicer[0].start resid_samps = np.random.choice(resids[~resids.mask], size=gap_size) # Pad some axes on. resid_samps = resid_samps[np.newaxis, :, np.newaxis] dat[(slice(pol, pol + 1), slicer[0], slice(ant, ant + 1))] += resid_samps # Reset the mask across the interp region dat.mask[(slice(pol, pol + 1), slicer[0], slice(ant, ant + 1))] = False # Keeping just for diagnosing issues if test_print: print((slice(pol, pol + 1), slicer[0], slice(ant, ant + 1))) print(dat[(slice(pol, pol + 1), slicer[0], slice(ant, ant + 1))].shape) print(dat[(slice(pol, pol + 1), slicer[0], slice(ant, ant + 1))][:10]) print(smooth_dat[(slice(pol, pol + 1), slicer[0], slice(ant, ant + 1))][:10]) if test_output_nowrite: bp_pass_dict[spw] = dat continue casalog.post( message="writing out smoothed gaps to table for spw {}".format( spw), origin='interpolate_bandpass') tb.open(tablename, nomodify=False) stb = tb.query('SPECTRAL_WINDOW_ID == {0}'.format(spw)) # Set new data stb.putcol('CPARAM', dat.data) # Set new flags stb.putcol('FLAG', dat.mask) stb.close() tb.clearlocks() tb.close() if test_output_nowrite: return bp_pass_dict