forked from revoltek/LiLF
-
Notifications
You must be signed in to change notification settings - Fork 0
/
lib_ms.py
342 lines (265 loc) · 11.8 KB
/
lib_ms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
#!/usr/bin/python
import os, sys, shutil
from casacore import tables
import numpy as np
import pyregion
from pyregion.parser_helper import Shape
import lib_util
from lib_log import logger
class AllMSs(object):
def __init__(self, pathsMS, scheduler):
"""
pathsMS: list of MS paths
scheduler: scheduler object
"""
self.scheduler = scheduler
# sort them, useful for some concatenating steps
self.mssListStr = sorted(pathsMS)
self.mssListObj = []
for pathMS in self.mssListStr:
self.mssListObj.append(MS(pathMS))
def getListObj(self):
"""
"""
return self.mssListObj
def getListStr(self):
"""
"""
return self.mssListStr
def getNThreads(self):
"""
Return the max number of threads in one machine assuming all MSs run at the same time
"""
if self.scheduler.max_processors < len(self.mssListStr): NThreads = 1
else:
NThreads = int(np.rint( self.scheduler.max_processors/len(self.mssListStr) ))
return NThreads
def getStrWsclean(self):
"""
Return a string with all MS paths, useful for wsclean
"""
return ' '.join(self.mssListStr)
def getFreqs(self):
"""
Return a list of freqs per chan per SB
"""
freqs = [ list(ms.getFreqs()) for ms in self.mssListObj ]
return [item for sublist in freqs for item in sublist] # flatten
def getBandwidth(self):
"""
Return the total span of frequency covered by this MS set
"""
freqs = self.getFreqs()
return max(freqs) - min(freqs)
def run(self, command, log, commandType='', maxThreads=None):
"""
Run command 'command' of type 'commandType', and use 'log' for logger,
for each MS of AllMSs.
The command and log file path can be customised for each MS using keywords (see: 'MS.concretiseString()').
Beware: depending on the value of 'Scheduler.max_threads' (see: lib_util.py), the commands are run in parallel.
"""
# add max num of threads given the total jobs to run
# e.g. in a 64 processors machine running on 16 MSs, would result in numthreads=4
if commandType == 'DPPP': command += ' numthreads='+str(self.getNThreads())
for MSObject in self.mssListObj:
commandCurrent = MSObject.concretiseString(command)
logCurrent = MSObject.concretiseString(log)
self.scheduler.add(cmd = commandCurrent, log = logCurrent, commandType = commandType)
# Provide debug output.
#lib_util.printLineBold("commandCurrent:")
#print (commandCurrent)
#lib_util.printLineBold("logCurrent:")
#print (logCurrent)
self.scheduler.run(check = True, maxThreads = maxThreads)
class MS(object):
def __init__(self, pathMS):
"""
pathMS: path of the MS, without '/' at the end!
pathDirectory: path of the parent directory of the MS
nameMS: name of the MS, without parent directories and extension (which is assumed to be ".MS" always)
"""
self.setPathVariables(pathMS)
# If the field name is not a recognised calibrator name, one of two scenarios is true:
# 1. The field is not a calibrator field;
# 2. The field is a calibrator field, but the name was not properly set.
# The following lines correct the field name if scenario 2 is the case.
calibratorDistanceThreshold = 0.5 # in degrees
if (not self.isCalibrator()):
if (self.getCalibratorDistancesSorted()[0] < calibratorDistanceThreshold):
nameFieldOld = self.getNameField()
nameFieldNew = self.getCalibratorNamesSorted()[0]
#logger.warning("Although the field name '" + nameFieldOld + "' is not recognised as a known calibrator name, " +
# "the phase centre coordinates suggest that this scan is a calibrator scan. Changing field name into '" +
# nameFieldNew + "'...")
self.setNameField(nameFieldNew)
def setPathVariables(self, pathMS):
"""
Set logistical variables.
"""
self.pathMS = pathMS
indexLastSlash = self.pathMS.rfind('/')
self.pathDirectory = self.pathMS[ : indexLastSlash]
self.nameMS = self.pathMS[indexLastSlash + 1 : -3]
def move(self, pathMSNew, overwrite=False, keepOrig=False):
"""
Move (or rename) the MS to another locus in the file system.
"""
logger.debug('Move: '+self.pathMS+' -> '+pathMSNew)
if overwrite == True:
lib_util.check_rm(pathMSNew)
if not os.path.exists(pathMSNew):
if keepOrig:
shutil.copytree(self.pathMS, pathMSNew)
else:
shutil.move(self.pathMS, pathMSNew)
self.setPathVariables(pathMSNew)
def setNameField(self, nameField):
"""
Set field name.
"""
pathFieldTable = self.pathMS + "/FIELD"
tables.taql("update $pathFieldTable set NAME=$nameField")
def getNameField(self):
"""
Retrieve field name.
"""
pathFieldTable = self.pathMS + "/FIELD"
nameField = (tables.taql("select NAME from $pathFieldTable")).getcol("NAME")[0]
return nameField
def getCalibratorDistancesSorted(self):
"""
Returns a list of distances (in degrees) to known calibrators, sorted by distance from small to large.
"""
myRA, myDec = self.getPhaseCentre()
calibratorRAs, calibratorDecs, calibratorNames = lib_util.getCalibratorProperties()
calibratorDistances = lib_util.distanceOnSphere(myRA, myDec, calibratorRAs, calibratorDecs)
calibratorDistancesSorted = np.sort(calibratorDistances)
return calibratorDistancesSorted
def getCalibratorNamesSorted(self):
"""
Returns a list of names of known calibrators, sorted by distance from small to large.
"""
myRA, myDec = self.getPhaseCentre()
calibratorRAs, calibratorDecs, calibratorNames = lib_util.getCalibratorProperties()
calibratorDistances = lib_util.distanceOnSphere(myRA, myDec, calibratorRAs, calibratorDecs)
calibratorNamesSorted = calibratorNames[np.argsort(calibratorDistances)]
return calibratorNamesSorted
def isCalibrator(self):
"""
Returns whether the field is a known calibrator field or not.
"""
calibratorRAs, calibratorDecs, calibratorNames = lib_util.getCalibratorProperties()
return (self.getNameField() in calibratorNames)
def concretiseString(self, stringOriginal):
"""
Returns a concretised version of the string 'stringOriginal', with keywords filled in.
More keywords (which start with '$') and their conversions can be added below.
"""
stringCurrent = stringOriginal.replace("$pathMS", self.pathMS)
stringCurrent = stringCurrent.replace( "$pathDirectory", self.pathDirectory)
stringCurrent = stringCurrent.replace( "$nameMS", self.nameMS)
stringCurrent = stringCurrent.replace( "$nameField", self.getNameField())
return stringCurrent
def getFreqs(self):
"""
Get chan frequency
"""
with tables.table(self.pathMS + "/SPECTRAL_WINDOW", ack = False) as t:
freqs = t.getcol("CHAN_FREQ")
return freqs[0]
def getNchan(self):
"""
Find number of channels
"""
with tables.table(self.pathMS + "/SPECTRAL_WINDOW", ack = False) as t:
nchan = t.getcol("NUM_CHAN")
assert (nchan[0] == nchan).all() # all SpWs have same channels?
logger.debug("%s: channel number: %i", self.pathMS, nchan[0])
return nchan[0]
def getChanband(self):
"""
Find bandwidth of a channel in Hz
"""
with tables.table(self.pathMS + "/SPECTRAL_WINDOW", ack = False) as t:
chan_w = t.getcol("CHAN_WIDTH")[0]
assert all(x == chan_w[0] for x in chan_w) # all chans have same width
logger.debug("%s: channel width (MHz): %f", self.pathMS, chan_w[0] / 1.e6)
return chan_w[0]
def getTimeInt(self):
"""
Get time interval in seconds
"""
with tables.table(self.pathMS, ack = False) as t:
nTimes = len(set(t.getcol("TIME")))
with tables.table(self.pathMS + "/OBSERVATION", ack = False) as t:
deltat = (t.getcol("TIME_RANGE")[0][1] - t.getcol("TIME_RANGE")[0][0]) / nTimes
logger.debug("%s: time interval (seconds): %f", self.pathMS, deltat)
return deltat
def getPhaseCentre(self):
"""
Get the phase centre (in degrees) of the first source (is it a problem?) of an MS.
"""
field_no = 0
ant_no = 0
with tables.table(self.pathMS + "/FIELD", ack = False) as field_table:
direction = field_table.getcol("PHASE_DIR")
RA = direction[ant_no, field_no, 0]
Dec = direction[ant_no, field_no, 1]
if (RA < 0):
RA += 2 * np.pi
#logger.debug("%s: phase centre (degrees): (%f, %f)", self.pathMS, np.degrees(RA), np.degrees(Dec))
return (np.degrees(RA), np.degrees(Dec))
def getObsMode(self):
"""
If LBA observation, return obs mode: INNER, OUTER, SPARSE_EVEN, SPARSE_ODD
"""
with tables.table(self.pathMS+'/OBSERVATION', ack = False) as t:
return t.getcol("LOFAR_ANTENNA_SET")[0]
def makeBeamReg(self, outfile, pb_cut=None, to_null=False):
"""
Create a ds9 region of the beam
outfile : str
output file
pb_cut : float, optional
diameter of the beam
to_null : bool, optional
arrive to the first null, not the FWHM
"""
logger.debug('Making PB region: '+outfile)
ra, dec = self.getPhaseCentre()
if pb_cut is None:
if 'OUTER' in self.getObsMode():
size = 8./2.
elif 'SPARSE' in self.getObsMode():
size = 12./2.
elif 'INNER' in self.getObsMode():
size = 16./2.
else:
logger.error('Cannot find beam size, only LBA_OUTER or LBA_SPARSE_* are implemented. Assuming beam diameter = 8 deg.')
size = 8./2.
else:
size = pb_cut/2.
if to_null: size *= 1.7 # rough estimation
s = Shape('circle', None)
s.coord_format = 'fk5'
s.coord_list = [ ra, dec, size ] # ra, dec, radius
s.coord_format = 'fk5'
s.attr = ([], {'width': '2', 'point': 'cross',
'font': '"helvetica 16 normal roman"'})
s.comment = 'color=red text="beam"'
regions = pyregion.ShapeList([s])
lib_util.check_rm(outfile)
regions.write(outfile)
def getResolution(self):
"""
Return the expected resolution (in arcsec) of the MS
Completely flagged lines are removed
"""
c = 299792458. # in metres per second
with tables.table(self.pathMS, ack = False).query('not all(FLAG)') as t:
col = t.getcol('UVW')
with tables.table(self.pathMS+'/SPECTRAL_WINDOW', ack = False) as t:
wavelength = c / t.getcol('REF_FREQUENCY')[0] # in metres
#print 'Wavelength:', wavelength,'m (Freq: '+str(t.getcol('REF_FREQUENCY')[0]/1.e6)+' MHz)'
maxdist = np.nanmax( np.sqrt(col[:,0] ** 2 + col[:,1] ** 2) )
return int(round(wavelength / maxdist * (180 / np.pi) * 3600)) # in arcseconds