def get_scan_start_stop(scn):
    """
    Given an scan, get the start and stop times (returned as a two-
    element tuple).
    """
    
    # UNIX timestamp for the start
    tStart = utcjd_to_unix(scn.mjd + MJD_OFFSET)
    tStart += scn.mpm / 1000.0
    
    # UNIX timestamp for the stop
    tStop = tStart +  scn.dur / 1000.0
    
    # Conversion to a timezone-aware datetime instance
    tStart = _UTC.localize( datetime.utcfromtimestamp(tStart) )
    tStop  = _UTC.localize( datetime.utcfromtimestamp(tStop ) )
    
    # Return
    return tStart, tStop
def main(args):
    # Filenames in an easier format - input
    inputIDF  = args.filename
    if args.date is not None:
        y, m, d = args.date.split('/', 2)
        args.date = date(int(y,10), int(m,10), int(d,10))
    if args.time is not None:
        h, m, s = args.time.split(':', 2)
        us = int((float(s) - int(float(s)))*1e6)
        s = int(float(s))
        if us >= 1000000:
            us -= 1000000
            s += 1
        args.time = time(int(h,10), int(m,10), s, us)
        
    # Parse the input file and get the dates of the scans
    station = stations.lwa1
    project = idf.parse_idf(inputIDF)
    
    # Load the station and objects to find the Sun and Jupiter
    observer = station.get_observer()
    Sun = ephem.Sun()
    Jupiter = ephem.Jupiter()
    
    nObs = len(project.runs[0].scans)
    tStart = [None,]*nObs
    for i in range(nObs):
        tStart[i]  = utcjd_to_unix(project.runs[0].scans[i].mjd + MJD_OFFSET)
        tStart[i] += project.runs[0].scans[i].mpm / 1000.0
        tStart[i]  = datetime.utcfromtimestamp(tStart[i])
        tStart[i]  = _UTC.localize(tStart[i])
        
    # Get the LST at the start
    observer.date = (min(tStart)).strftime('%Y/%m/%d %H:%M:%S')
    lst = observer.sidereal_time()
    
    # Report on the file
    print("Filename: %s" % inputIDF)
    print(" Project ID: %s" % project.id)
    print(" Run ID: %i" % project.runs[0].id)
    print(" Scans appear to start at %s" % (min(tStart)).strftime(formatString))
    print(" -> LST at %s for this date/time is %s" % (station.name, lst))
    
    # Filenames in an easier format - output
    if not args.query:
        if args.outname is not None:
            outputIDF = args.outname
        else:
            outputIDF  = None
            
    # Query only mode starts here...
    if args.query:
        lastDur = project.runs[0].scans[nObs-1].dur
        lastDur = timedelta(seconds=int(lastDur/1000), microseconds=(lastDur*1000) % 1000000)
        runDur = max(tStart) - min(tStart) + lastDur
        
        print(" ")
        print(" Total Run Duration: %s" % runDur)
        print(" -> First scan starts at %s" % min(tStart).strftime(formatString))
        print(" -> Last scan ends at %s" % (max(tStart) + lastDur).strftime(formatString))
        print(" Correlator Setup:")
        print(" -> %i channels" % project.runs[0].corr_channels)
        print(" -> %.3f s integration time" % project.runs[0].corr_inttime)
        print(" -> %s output polarization basis" % project.runs[0].corr_basis)
        
        print(" ")
        print(" Number of scans: %i" % nObs)
        print(" Scan Detail:")
        for i in range(nObs):
            currDur = project.runs[0].scans[i].dur
            currDur = timedelta(seconds=int(currDur/1000), microseconds=(currDur*1000) % 1000000)
            
            print("  Scan #%i" % (i+1,))
            
            ## Basic setup
            print("   Target: %s" % project.runs[0].scans[i].target)
            print("   Intent: %s" % project.runs[0].scans[i].intent)
            print("   Start:")
            print("    MJD: %i" % project.runs[0].scans[i].mjd)
            print("    MPM: %i" % project.runs[0].scans[i].mpm)
            print("    -> %s" % get_scan_start_stop(project.runs[0].scans[i])[0].strftime(formatString))
            print("   Duration: %s" % currDur)
            
            ## DP setup
            print("   Tuning 1: %.3f MHz" % (project.runs[0].scans[i].frequency1/1e6,))
            print("   Tuning 2: %.3f MHz" % (project.runs[0].scans[i].frequency2/1e6,))
            print("   Filter code: %i" % project.runs[0].scans[i].filter)
            
            ## Comments/notes
            print("   Observer Comments: %s" % project.runs[0].scans[i].comments)
            
        # Valid?
        print(" ")
        try:
            if project.validate():
                print(" Valid?  Yes")
            else:
                print(" Valid?  No")
        except:
            print(" Valid?  No")
            
        # And then exits
        sys.exit()
        
    #
    # Query the time and compute the time shifts
    #
    if (not args.no_update):
        # Get the new start date/time in UTC and report on the difference
        if args.lst:
            if args.date is None:
                print(" ")
                print("Enter the new UTC start date:")
                tNewStart = input('YYYY/MM/DD-> ')
                try:
                    fields = tNewStart.split('/', 2)
                    fields = [int(f) for f in fields]
                    tNewStart = date(fields[0], fields[1], fields[2])
                    tNewStart = datetime.combine(tNewStart, min(tStart).time())
                except Exception as e:
                    print("Error: %s" % str(e))
                    sys.exit(1)
                    
            else:
                tNewStart = datetime.combine(args.date, min(tStart).time())
                
            tNewStart = _UTC.localize(tNewStart)
            
            # Figure out a new start time on the correct day
            diff = ((tNewStart - min(tStart)).days) * siderealRegression
            ## Try to make sure that the timedelta object is less than 1 day
            while diff.days > 0:
                diff -= siderealDay
            while diff.days < -1:
                diff += siderealDay
            ## Come up with the new start time
            siderealShift = tNewStart - diff
            ## Another check to make sure we are are the right day
            if siderealShift.date() < tNewStart.date():
                siderealShift += siderealDay
            if siderealShift.date() > tNewStart.date():
                siderealShift -= siderealDay
            ## And yet another one to deal with the corner case that scan starts at ~UT 00:00
            if min(tStart) == siderealShift:
                newSiderealShift1 = siderealShift + siderealDay
                newSiderealShift2 = siderealShift - siderealDay
                if newSiderealShift1.date() == tNewStart.date():
                    siderealShift = newSiderealShift1
                elif newSiderealShift2.date() == tNewStart.date():
                    siderealShift = newSiderealShift2
            tNewStart = siderealShift
            
        else:
            if args.date is None or args.time is None:
                print(" ")
                print("Enter the new UTC start date/time:")
                tNewStart = input('YYYY/MM/DD HH:MM:SS.SSS -> ')
                try:
                    tNewStart = datetime.strptime(tNewStart, '%Y/%m/%d %H:%M:%S.%f')
                except ValueError:
                    try:
                        tNewStart = datetime.strptime(tNewStart, '%Y/%m/%d %H:%M:%S')
                    except Exception as e:
                        print("Error: %s" % str(e))
                        sys.exit(1)
                        
            else:
                tNewStart = datetime.combine(args.date, args.time)
            
            tNewStart = _UTC.localize(tNewStart)
            
        # Get the new shift needed to translate the old times to the new times
        tShift = tNewStart - min(tStart)
        
        # Get the LST at the new start
        observer.date = (tNewStart).strftime('%Y/%m/%d %H:%M:%S')
        lst = observer.sidereal_time()
        
        print(" ")
        print("Shifting scans to start at %s" % tNewStart.strftime(formatString))
        print("-> Difference of %i days, %.3f seconds" % (tShift.days, (tShift.seconds + tShift.microseconds/1000000.0),))
        print("-> LST at %s for this date/time is %s" % (station.name, lst))
        if tShift.days == 0 and tShift.seconds == 0 and tShift.microseconds == 0:
            print(" ")
            print("The current shift is zero.  Do you want to continue anyways?")
            yesNo = input("-> [y/N] ")
            if yesNo not in ('y', 'Y'):
                sys.exit()
                
    else:
        tShift = timedelta(seconds=0)
        
    # Shift the start times and recompute the MJD and MPM values
    for i in range(nObs):
        tStart[i] += tShift
        
    #
    # Query and set the new run ID
    #
    print(" ")
    if args.rid is None:
        print("Enter the new run ID or return to keep current:")
        sid = input('-> ')
        if len(sid) > 0:
            sid = int(sid)
        else:
            sid = project.runs[0].id
    else:
        sid = args.rid
    print("Shifting run ID from %i to %i" % (project.runs[0].id, sid))
    project.runs[0].id = sid
    
    #
    # Go! (apply the changes to the scans)
    #
    print(" ")
    newPOOC = []
    for i in range(nObs):
        print("Working on Scan #%i" % (i+1,))
        newPOOC.append("")
        
        #
        # Start MJD,MPM Shifting
        #
        if (not args.no_update) and tShift != timedelta(seconds=0):
            if len(newPOOC[-1]) != 0:
                newPOOC[-1] += ';;'
            newPOOC[-1] += 'Original MJD:%i,MPM:%i' % (project.runs[0].scans[i].mjd, project.runs[0].scans[i].mpm)
            
            start = tStart[i].strftime("%Z %Y %m %d %H:%M:%S.%f")
            start = start[:-3]

            utc = Time(tStart[i], format=Time.FORMAT_PY_DATE)
            mjd = int(utc.utc_mjd)
            
            utcMidnight = datetime(tStart[i].year, tStart[i].month, tStart[i].day, 0, 0, 0, tzinfo=_UTC)
            diff = tStart[i] - utcMidnight
            mpm = int(round((diff.seconds + diff.microseconds/1000000.0)*1000.0))
            
            print(" Time shifting")
            print("  MJD: %8i -> %8i" % (project.runs[0].scans[i].mjd, mjd))
            print("  MPM: %8i -> %8i" % (project.runs[0].scans[i].mpm, mpm))
            
            project.runs[0].scans[i].mjd = mjd
            project.runs[0].scans[i].mpm = mpm
            project.runs[0].scans[i].start = start
            
    #
    # Project office comments
    #
    # Update the project office comments with this change
    newPOSC = "Shifted IDF with shiftIDF.py (v%s);;Time Shift? %s" % (__version__, 'Yes' if (not args.no_update) else 'No')
    
    if project.project_office.runs[0] is None:
        project.project_office.runs[0] = newPOSC
    else:
        project.project_office.runs[0] += ';;%s' % newPOSC
        
    for i in range(nObs):
        try:
            project.project_office.scans[0][i] += ';;%s' % newPOOC[i]
        except Exception as e:
            print(e)
            project.project_office.scans[0][i] = '%s' % newPOOC[i]
            
    #
    # Save
    #
    if outputIDF is None:
        pID = project.id
        rID = project.runs[0].id
        foStart = min(tStart)
        outputIDF = '%s_%s_%s_%04i.idf' % (pID, foStart.strftime('%y%m%d'), foStart.strftime('%H%M'), rID)
        
    print(" ")
    print("Saving to: %s" % outputIDF)
    fh = open(outputIDF, 'w')
    if not project.validate():
        # Make sure we are about to be valid
        project.validate(verbose=True)
        raise RuntimeError("Cannot validate IDF file")
        
    fh.write( project.render() )
    fh.close()
Example #3
0
def main(args):
    # Parse the command line
    ## Baseline list
    if args.baseline is not None:
        ## Fill the baseline list with the conjugates, if needed
        newBaselines = []
        for pair in args.baseline:
            newBaselines.append((pair[1], pair[0]))
        args.baseline.extend(newBaselines)
    ## Polarization
    plot_pols = []
    if args.xx:
        plot_pols.append('XX')
    if args.xy:
        plot_pols.append('XY')
    if args.yx:
        plot_pols.append('YX')
    if args.yy:
        plot_pols.append('YY')
    filename = args.filename

    figs = {}
    first = True
    for filename in args.filename:
        print("Working on '%s'" % os.path.basename(filename))
        # Open the FITS IDI file and access the UV_DATA extension
        hdulist = astrofits.open(filename, mode='readonly')
        andata = hdulist['ANTENNA']
        fqdata = hdulist['FREQUENCY']
        fgdata = None
        for hdu in hdulist[1:]:
            if hdu.header['EXTNAME'] == 'FLAG':
                fgdata = hdu
        uvdata = hdulist['UV_DATA']

        # Pull out various bits of information we need to flag the file
        ## Antenna look-up table
        antLookup = {}
        for an, ai in zip(andata.data['ANNAME'], andata.data['ANTENNA_NO']):
            antLookup[an] = ai
        ## Frequency and polarization setup
        nBand, nFreq, nStk = uvdata.header['NO_BAND'], uvdata.header[
            'NO_CHAN'], uvdata.header['NO_STKD']
        stk0 = uvdata.header['STK_1']
        ## Baseline list
        bls = uvdata.data['BASELINE']
        ## Time of each integration
        obsdates = uvdata.data['DATE']
        obstimes = uvdata.data['TIME']
        inttimes = uvdata.data['INTTIM']
        ## Source list
        srcs = uvdata.data['SOURCE']
        ## Band information
        fqoffsets = fqdata.data['BANDFREQ'].ravel()
        ## Frequency channels
        freq = (numpy.arange(nFreq) -
                (uvdata.header['CRPIX3'] - 1)) * uvdata.header['CDELT3']
        freq += uvdata.header['CRVAL3']
        ## UVW coordinates
        try:
            u, v, w = uvdata.data['UU'], uvdata.data['VV'], uvdata.data['WW']
        except KeyError:
            u, v, w = uvdata.data['UU---SIN'], uvdata.data[
                'VV---SIN'], uvdata.data['WW---SIN']
        uvw = numpy.array([u, v, w]).T
        ## The actual visibility data
        flux = uvdata.data['FLUX'].astype(numpy.float32)

        # Convert the visibilities to something that we can easily work with
        nComp = flux.shape[1] // nBand // nFreq // nStk
        if nComp == 2:
            ## Case 1) - Just real and imaginary data
            flux = flux.view(numpy.complex64)
        else:
            ## Case 2) - Real, imaginary data + weights (drop the weights)
            flux = flux[:, 0::nComp] + 1j * flux[:, 1::nComp]
        flux.shape = (flux.shape[0], nBand, nFreq, nStk)

        # Find unique baselines, times, and sources to work with
        ubls = numpy.unique(bls)
        utimes = numpy.unique(obstimes)
        usrc = numpy.unique(srcs)

        # Convert times to real times
        times = utcjd_to_unix(obsdates + obstimes)
        times = numpy.unique(times)

        # Build a mask
        mask = numpy.zeros(flux.shape, dtype=numpy.bool)
        if fgdata is not None and not args.drop:
            reltimes = obsdates - obsdates[0] + obstimes
            maxtimes = reltimes + inttimes / 2.0 / 86400.0
            mintimes = reltimes - inttimes / 2.0 / 86400.0

            bls_ant1 = bls // 256
            bls_ant2 = bls % 256

            for row in fgdata.data:
                ant1, ant2 = row['ANTS']

                ## Only deal with flags that we need for the plots
                process_flag = False
                if args.include_auto or ant1 != ant2 or ant1 == 0 or ant2 == 0:
                    if ant1 == 0 and ant2 == 0:
                        process_flag = True
                    elif args.baseline is not None:
                        if ant2 == 0 and ant1 in [
                                a0 for a0, a1 in args.baseline
                        ]:
                            process_flag = True
                        elif (ant1, ant2) in args.baseline:
                            process_flag = True
                    elif args.ref_ant is not None:
                        if ant1 == args.ref_ant or ant2 == args.ref_ant:
                            process_flag = True
                    else:
                        process_flag = True
                if not process_flag:
                    continue

                tStart, tStop = row['TIMERANG']
                band = row['BANDS']
                try:
                    len(band)
                except TypeError:
                    band = [
                        band,
                    ]
                cStart, cStop = row['CHANS']
                if cStop == 0:
                    cStop = -1
                pol = row['PFLAGS'].astype(numpy.bool)

                if ant1 == 0 and ant2 == 0:
                    btmask = numpy.where(
                        ((maxtimes >= tStart) & (mintimes <= tStop)))[0]
                elif ant1 == 0 or ant2 == 0:
                    ant1 = max([ant1, ant2])
                    btmask = numpy.where( ( (bls_ant1 == ant1) | (bls_ant2 == ant1) ) \
                                          & ( (maxtimes >= tStart) & (mintimes <= tStop) ) )[0]
                else:
                    btmask = numpy.where( ( (bls_ant1 == ant1) & (bls_ant2 == ant2) ) \
                                          & ( (maxtimes >= tStart) & (mintimes <= tStop) ) )[0]
                for b, v in enumerate(band):
                    if not v:
                        continue
                    mask[btmask, b, cStart - 1:cStop, :] |= pol

        plot_bls = []
        cross = []
        for i in xrange(len(ubls)):
            bl = ubls[i]
            ant1, ant2 = (bl >> 8) & 0xFF, bl & 0xFF
            if args.include_auto or ant1 != ant2:
                if args.baseline is not None:
                    if (ant1, ant2) in args.baseline:
                        plot_bls.append(bl)
                        cross.append(i)
                elif args.ref_ant is not None:
                    if ant1 == args.ref_ant or ant2 == args.ref_ant:
                        plot_bls.append(bl)
                        cross.append(i)
                else:
                    plot_bls.append(bl)
                    cross.append(i)
        nBL = len(cross)

        # Decimation, if needed
        if args.decimate > 1:
            if nFreq % args.decimate != 0:
                raise RuntimeError(
                    "Invalid freqeunce decimation factor:  %i %% %i = %i" %
                    (nFreq, args.decimate, nFreq % args.decimate))

            nFreq //= args.decimate
            freq.shape = (freq.size // args.decimate, args.decimate)
            freq = freq.mean(axis=1)

            flux.shape = (flux.shape[0], flux.shape[1],
                          flux.shape[2] // args.decimate, args.decimate,
                          flux.shape[3])
            flux = flux.mean(axis=3)

            mask.shape = (mask.shape[0], mask.shape[1],
                          mask.shape[2] // args.decimate, args.decimate,
                          mask.shape[3])
            mask = mask.mean(axis=3)

        good = numpy.arange(freq.size // 8,
                            freq.size * 7 // 8)  # Inner 75% of the band

        if first:
            ref_time = obsdates[0] + obstimes[0]

        # NOTE: Assumes that the Stokes parameters increment by -1
        namMapper = {}
        for i in xrange(nStk):
            stk = stk0 - i
            namMapper[i] = NUMERIC_STOKES[stk]
        polMapper = {'XX': 0, 'YY': 1, 'XY': 2, 'YX': 3}

        for b in xrange(len(plot_bls)):
            bl = plot_bls[b]
            valid = numpy.where(bls == bl)[0]
            i, j = (bl >> 8) & 0xFF, bl & 0xFF
            dTimes = obsdates[valid] + obstimes[valid]
            dTimes -= ref_time
            dTimes *= 86400.0

            for p in plot_pols:
                blName = (i, j)
                blName = '%s-%s - %s' % (
                    'EA%02i' % blName[0] if blName[0] < 51 else 'LWA%i' %
                    (blName[0] - 50),
                    'EA%02i' % blName[1] if blName[1] < 51 else 'LWA%i' %
                    (blName[1] - 50), namMapper[polMapper[p]])

                if first or blName not in figs:
                    fig = plt.figure()
                    fig.suptitle('%s' % blName)
                    fig.subplots_adjust(hspace=0.001)
                    axA = fig.add_subplot(1, 2, 1)
                    axP = fig.add_subplot(1, 2, 2)
                    figs[blName] = (fig, axA, axP)
                fig, axA, axP = figs[blName]

                for band, offset in enumerate(fqoffsets):
                    frq = freq + offset
                    vis = numpy.ma.array(flux[valid, band, :, polMapper[p]],
                                         mask=mask[valid, band, :,
                                                   polMapper[p]])

                    amp = numpy.ma.abs(vis)
                    vmin, vmax = percentile(amp, 1), percentile(amp, 99)
                    axA.imshow(amp,
                               extent=(frq[0] / 1e6, frq[-1] / 1e6, dTimes[0],
                                       dTimes[-1]),
                               origin='lower',
                               interpolation='nearest',
                               vmin=vmin,
                               vmax=vmax)

                    axP.imshow(numpy.ma.angle(vis),
                               extent=(frq[0] / 1e6, frq[-1] / 1e6, dTimes[0],
                                       dTimes[-1]),
                               origin='lower',
                               vmin=-numpy.pi,
                               vmax=numpy.pi,
                               interpolation='nearest')

        first = False

    for blName in figs:
        fig, axA, axP = figs[blName]

        fig.suptitle("%s UTC\n%s" % (datetime.utcfromtimestamp(
            times[0]).strftime("%Y/%m/%d %H:%M"), blName))

        axA.axis('auto')
        axA.set_title('Amp.')
        axA.set_xlabel('Frequency [MHz]')
        axA.set_ylabel('Amp. - Elapsed Time [s]')

        axP.axis('auto')
        axP.set_title('Phase')
        axP.set_xlabel('Frequency [MHz]')

        if args.save_images:
            fig.savefig('fringes-%s.png' % (blName.replace(' ', ''), ))

    if not args.save_images:
        plt.show()
Example #4
0
def main(args):
    # Parse the command line
    filenames = args.filename

    for filename in filenames:
        t0 = time.time()
        print("Working on '%s'" % os.path.basename(filename))
        # Open the FITS IDI file and access the UV_DATA extension
        hdulist = astrofits.open(filename, mode='readonly')
        andata = hdulist['ANTENNA']
        fqdata = hdulist['FREQUENCY']
        fgdata = None
        for hdu in hdulist[1:]:
            if hdu.header['EXTNAME'] == 'FLAG':
                fgdata = hdu
        uvdata = hdulist['UV_DATA']

        # Verify we can flag this data
        if uvdata.header['STK_1'] > 0:
            raise RuntimeError("Cannot flag data with STK_1 = %i" %
                               uvdata.header['STK_1'])
        if uvdata.header['NO_STKD'] < 4:
            raise RuntimeError("Cannot flag data with NO_STKD = %i" %
                               uvdata.header['NO_STKD'])

        # NOTE: Assumes that the Stokes parameters increment by -1
        polMapper = {}
        for i in xrange(uvdata.header['NO_STKD']):
            stk = uvdata.header['STK_1'] - i
            polMapper[i] = NUMERIC_STOKES[stk]

        # Pull out various bits of information we need to flag the file
        ## Antenna look-up table
        antLookup = {}
        for an, ai in zip(andata.data['ANNAME'], andata.data['ANTENNA_NO']):
            antLookup[an] = ai
        ## Frequency and polarization setup
        nBand, nFreq, nStk = uvdata.header['NO_BAND'], uvdata.header[
            'NO_CHAN'], uvdata.header['NO_STKD']
        ## Baseline list
        bls = uvdata.data['BASELINE']
        ## Time of each integration
        obsdates = uvdata.data['DATE']
        obstimes = uvdata.data['TIME']
        inttimes = uvdata.data['INTTIM']
        ## Source list
        srcs = uvdata.data['SOURCE']
        ## Band information
        fqoffsets = fqdata.data['BANDFREQ'].ravel()
        ## Frequency channels
        freq = (numpy.arange(nFreq) -
                (uvdata.header['CRPIX3'] - 1)) * uvdata.header['CDELT3']
        freq += uvdata.header['CRVAL3']
        ## UVW coordinates
        try:
            u, v, w = uvdata.data['UU'], uvdata.data['VV'], uvdata.data['WW']
        except KeyError:
            u, v, w = uvdata.data['UU---SIN'], uvdata.data[
                'VV---SIN'], uvdata.data['WW---SIN']
        uvw = numpy.array([u, v, w]).T
        ## The actual visibility data
        flux = uvdata.data['FLUX'].astype(numpy.float32)

        # Convert the visibilities to something that we can easily work with
        nComp = flux.shape[1] // nBand // nFreq // nStk
        if nComp == 2:
            ## Case 1) - Just real and imaginary data
            flux = flux.view(numpy.complex64)
        else:
            ## Case 2) - Real, imaginary data + weights (drop the weights)
            flux = flux[:, 0::nComp] + 1j * flux[:, 1::nComp]
        flux.shape = (flux.shape[0], nBand, nFreq, nStk)

        # Find unique baselines, times, and sources to work with
        ubls = numpy.unique(bls)
        utimes = numpy.unique(obstimes)
        usrc = numpy.unique(srcs)

        # Find unique scans to work on, making sure that there are no large gaps
        blocks = []
        for src in usrc:
            valid = numpy.where(src == srcs)[0]

            blocks.append([valid[0], valid[0]])
            for v in valid[1:]:
                if v == blocks[-1][1] + 1 \
                   and (obsdates[v] - obsdates[blocks[-1][1]] + obstimes[v] - obstimes[blocks[-1][1]])*86400 < 10*inttimes[v]:
                    blocks[-1][1] = v
                else:
                    blocks.append([v, v])
        blocks.sort()

        # Build up the mask
        mask = numpy.zeros(flux.shape, dtype=numpy.bool)
        for i, block in enumerate(blocks):
            tS = time.time()
            print('  Working on scan %i of %i' % (i + 1, len(blocks)))
            match = range(block[0], block[1] + 1)

            bbls = numpy.unique(bls[match])
            times = obstimes[match] * 86400.0
            scanStart = datetime.utcfromtimestamp(
                utcjd_to_unix(obsdates[match[0]] + obstimes[match[0]]))
            scanStop = datetime.utcfromtimestamp(
                utcjd_to_unix(obsdates[match[-1]] + obstimes[match[-1]]))
            print('    Scan spans %s to %s UTC' %
                  (scanStart.strftime('%Y/%m/%d %H:%M:%S'),
                   scanStop.strftime('%Y/%m/%d %H:%M:%S')))

            for b, offset in enumerate(fqoffsets):
                print('    IF #%i' % (b + 1, ))
                crd = uvw[match, :]
                visXX = flux[match, b, :, 0]
                visYY = flux[match, b, :, 1]

                nBL = len(bbls)
                if b == 0:
                    times = times[0::nBL]
                crd.shape = (crd.shape[0] // nBL, nBL, 1, 3)
                visXX.shape = (visXX.shape[0] // nBL, nBL, visXX.shape[1])
                visYY.shape = (visYY.shape[0] // nBL, nBL, visYY.shape[1])
                print(
                    '      Scan/IF contains %i times, %i baselines, %i channels'
                    % visXX.shape)

                if visXX.shape[0] < 5:
                    print('        Too few integrations, skipping')
                    continue

                antennas = []
                for j in xrange(nBL):
                    ant1, ant2 = (bbls[j] >> 8) & 0xFF, bbls[j] & 0xFF
                    if ant1 not in antennas:
                        antennas.append(ant1)
                    if ant2 not in antennas:
                        antennas.append(ant2)

                print('      Flagging baselines')
                maskXX = mask_bandpass(antennas,
                                       times,
                                       freq + offset,
                                       visXX,
                                       freq_range=args.freq_range)
                maskYY = mask_bandpass(antennas,
                                       times,
                                       freq + offset,
                                       visYY,
                                       freq_range=args.freq_range)

                visXX = numpy.ma.array(visXX, mask=maskXX)
                visYY = numpy.ma.array(visYY, mask=maskYY)

                if args.scf_passes > 0:
                    print('      Flagging spurious correlations')
                    for p in xrange(args.scf_passes):
                        print('        Pass #%i' % (p + 1, ))
                        visXX.mask = mask_spurious(antennas, times, crd,
                                                   freq + offset, visXX)
                        visYY.mask = mask_spurious(antennas, times, crd,
                                                   freq + offset, visYY)

                print('      Cleaning masks')
                visXX.mask = cleanup_mask(visXX.mask)
                visYY.mask = cleanup_mask(visYY.mask)

                print('      Saving polarization masks')
                submask = visXX.mask
                submask.shape = (len(match), flux.shape[2])
                mask[match, b, :, 0] = submask
                submask = visYY.mask
                submask.shape = (len(match), flux.shape[2])
                mask[match, b, :, 1] = submask

                print('      Statistics for this scan/IF')
                print('      -> %s      - %.1f%% flagged' % (
                    polMapper[0],
                    100.0 * mask[match, b, :, 0].sum() /
                    mask[match, b, :, 0].size,
                ))
                print('      -> %s      - %.1f%% flagged' % (
                    polMapper[1],
                    100.0 * mask[match, b, :, 1].sum() /
                    mask[match, b, :, 0].size,
                ))
                print('      -> Elapsed - %.3f s' % (time.time() - tS, ))

        # Convert the masks into a format suitable for writing to a FLAG table
        print("  Building FLAG table")
        ants, times, bands, chans, pols, reas, sevs = [], [], [], [], [], [], []
        if not args.drop:
            ## Old flags
            if fgdata is not None:
                for row in fgdata.data:
                    ants.append(row['ANTS'])
                    times.append(row['TIMERANG'])
                    try:
                        len(row['BANDS'])
                        bands.append(row['BANDS'])
                    except TypeError:
                        bands.append([
                            row['BANDS'],
                        ])
                    chans.append(row['CHANS'])
                    pols.append(row['PFLAGS'])
                    reas.append(row['REASON'])
                    sevs.append(row['SEVERITY'])
        ## New Flags
        nBL = len(ubls)
        for i in xrange(nBL):
            blset = numpy.where(bls == ubls[i])[0]
            ant1, ant2 = (ubls[i] >> 8) & 0xFF, ubls[i] & 0xFF
            if i % 100 == 0 or i + 1 == nBL:
                print("    Baseline %i of %i" % (i + 1, nBL))

            if len(blset) == 0:
                continue

            for b, offset in enumerate(fqoffsets):
                maskXX = mask[blset, b, :, 0]
                maskYY = mask[blset, b, :, 1]

                flagsXX, _ = create_flag_groups(obstimes[blset], freq + offset,
                                                maskXX)
                flagsYY, _ = create_flag_groups(obstimes[blset], freq + offset,
                                                maskYY)

                for flag in flagsXX:
                    ants.append((ant1, ant2))
                    times.append(
                        (obsdates[blset[flag[0]]] + obstimes[blset[flag[0]]] -
                         obsdates[0], obsdates[blset[flag[1]]] +
                         obstimes[blset[flag[1]]] - obsdates[0]))
                    bands.append([1 if j == b else 0 for j in xrange(nBand)])
                    chans.append((flag[2] + 1, flag[3] + 1))
                    pols.append((1, 0, 1, 1))
                    reas.append('FLAGIDI.PY')
                    sevs.append(-1)
                for flag in flagsYY:
                    ants.append((ant1, ant2))
                    times.append(
                        (obsdates[blset[flag[0]]] + obstimes[blset[flag[0]]] -
                         obsdates[0], obsdates[blset[flag[1]]] +
                         obstimes[blset[flag[1]]] - obsdates[0]))
                    bands.append([1 if j == b else 0 for j in xrange(nBand)])
                    chans.append((flag[2] + 1, flag[3] + 1))
                    pols.append((0, 1, 1, 1))
                    reas.append('FLAGIDI.PY')
                    sevs.append(-1)

        ## Figure out our revision
        try:
            repo = git.Repo(os.path.dirname(os.path.abspath(__file__)))
            try:
                branch = repo.active_branch.name
                hexsha = repo.active_branch.commit.hexsha
            except TypeError:
                branch = '<detached>'
                hexsha = repo.head.commit.hexsha
            shortsha = hexsha[-7:]
            dirty = ' (dirty)' if repo.is_dirty() else ''
        except git.exc.GitError:
            branch = 'unknown'
            hexsha = 'unknown'
            shortsha = 'unknown'
            dirty = ''

        ## Build the FLAG table
        print('    FITS HDU')
        ### Columns
        nFlags = len(ants)
        c1 = astrofits.Column(name='SOURCE_ID',
                              format='1J',
                              array=numpy.zeros((nFlags, ), dtype=numpy.int32))
        c2 = astrofits.Column(name='ARRAY',
                              format='1J',
                              array=numpy.zeros((nFlags, ), dtype=numpy.int32))
        c3 = astrofits.Column(name='ANTS',
                              format='2J',
                              array=numpy.array(ants, dtype=numpy.int32))
        c4 = astrofits.Column(name='FREQID',
                              format='1J',
                              array=numpy.zeros((nFlags, ), dtype=numpy.int32))
        c5 = astrofits.Column(name='TIMERANG',
                              format='2E',
                              array=numpy.array(times, dtype=numpy.float32))
        c6 = astrofits.Column(name='BANDS',
                              format='%iJ' % nBand,
                              array=numpy.array(bands,
                                                dtype=numpy.int32).squeeze())
        c7 = astrofits.Column(name='CHANS',
                              format='2J',
                              array=numpy.array(chans, dtype=numpy.int32))
        c8 = astrofits.Column(name='PFLAGS',
                              format='4J',
                              array=numpy.array(pols, dtype=numpy.int32))
        c9 = astrofits.Column(name='REASON',
                              format='A40',
                              array=numpy.array(reas))
        c10 = astrofits.Column(name='SEVERITY',
                               format='1J',
                               array=numpy.array(sevs, dtype=numpy.int32))
        colDefs = astrofits.ColDefs([c1, c2, c3, c4, c5, c6, c7, c8, c9, c10])
        ### The table itself
        flags = astrofits.BinTableHDU.from_columns(colDefs)
        ### The header
        flags.header['EXTNAME'] = ('FLAG', 'FITS-IDI table name')
        flags.header['EXTVER'] = (1 if fgdata is None else
                                  fgdata.header['EXTVER'] + 1,
                                  'table instance number')
        flags.header['TABREV'] = (2, 'table format revision number')
        for key in ('NO_STKD', 'STK_1', 'NO_BAND', 'NO_CHAN', 'REF_FREQ',
                    'CHAN_BW', 'REF_PIXL', 'OBSCODE', 'ARRNAM', 'RDATE'):
            try:
                flags.header[key] = (uvdata.header[key],
                                     uvdata.header.comments[key])
            except KeyError:
                pass
        flags.header['HISTORY'] = 'Flagged with %s, revision %s.%s%s' % (
            os.path.basename(__file__), branch, shortsha, dirty)

        # Clean up the old FLAG tables, if any, and then insert the new table where it needs to be
        if args.drop:
            ## Reset the EXTVER on the new FLAG table
            flags.header['EXTVER'] = (1, 'table instance number')
            ## Find old tables
            toRemove = []
            for hdu in hdulist:
                try:
                    if hdu.header['EXTNAME'] == 'FLAG':
                        toRemove.append(hdu)
                except KeyError:
                    pass
            ## Remove old tables
            for hdu in toRemove:
                ver = hdu.header['EXTVER']
                del hdulist[hdulist.index(hdu)]
                print("  WARNING: removing old FLAG table - version %i" % ver)
        ## Insert the new table right before UV_DATA
        hdulist.insert(-1, flags)

        # Save
        print("  Saving to disk")
        ## What to call it
        outname = os.path.basename(filename)
        outname, outext = os.path.splitext(outname)
        outname = '%s_flagged%s' % (outname, outext)
        ## Does it already exist or not
        if os.path.exists(outname):
            if not args.force:
                yn = raw_input("WARNING: '%s' exists, overwrite? [Y/n] " %
                               outname)
            else:
                yn = 'y'

            if yn not in ('n', 'N'):
                os.unlink(outname)
            else:
                raise RuntimeError("Output file '%s' already exists" % outname)
        ## Open and create a new primary HDU
        hdulist2 = astrofits.open(outname, mode='append')
        primary = astrofits.PrimaryHDU()
        processed = []
        for key in hdulist[0].header:
            if key in ('COMMENT', 'HISTORY'):
                if key not in processed:
                    parts = str(hdulist[0].header[key]).split('\n')
                    for part in parts:
                        primary.header[key] = part
                    processed.append(key)
            else:
                primary.header[key] = (hdulist[0].header[key],
                                       hdulist[0].header.comments[key])
        hdulist2.append(primary)
        hdulist2.flush()
        ## Copy the extensions over to the new file
        for hdu in hdulist[1:]:
            hdulist2.append(hdu)
            hdulist2.flush()
        hdulist2.close()
        hdulist.close()
        print("  -> Flagged FITS IDI file is '%s'" % outname)
        print("  Finished in %.3f s" % (time.time() - t0, ))
Example #5
0
def main(args):
    # Filenames in an easier format - input
    inputSDF  = args.filename
    if args.date is not None:
        y, m, d = args.date.split('/', 2)
        args.date = date(int(y,10), int(m,10), int(d,10))
    if args.time is not None:
        h, m, s = args.time.split(':', 2)
        us = int((float(s) - int(float(s)))*1e6)
        s = int(float(s))
        if us >= 1000000:
            us -= 1000000
            s += 1
        args.time = time(int(h,10), int(m,10), s, us)
        
    # Parse the input file and get the dates of the observations
    try:
        ## LWA-1
        station = stations.lwa1
        project = sdf.parse_sdf(inputSDF)
        adp = False
    except Exception as e:
        ## LWA-SV
        ### Try again
        station = stations.lwasv
        project = sdfADP.parse_sdf(inputSDF)
        adp = True
        
    # Load the station and objects to find the Sun and Jupiter
    observer = station.get_observer()
    Sun = ephem.Sun()
    Jupiter = ephem.Jupiter()
    
    nObs = len(project.sessions[0].observations)
    tStart = [None,]*nObs
    for i in range(nObs):
        tStart[i]  = utcjd_to_unix(project.sessions[0].observations[i].mjd + MJD_OFFSET)
        tStart[i] += project.sessions[0].observations[i].mpm / 1000.0
        tStart[i]  = datetime.utcfromtimestamp(tStart[i])
        tStart[i]  = _UTC.localize(tStart[i])
        
    # Get the LST at the start
    observer.date = (min(tStart)).strftime('%Y/%m/%d %H:%M:%S')
    lst = observer.sidereal_time()
    
    # Report on the file
    print("Filename: %s" % inputSDF)
    print(" Project ID: %s" % project.id)
    print(" Session ID: %i" % project.sessions[0].id)
    print(" Observations appear to start at %s" % (min(tStart)).strftime(formatString))
    print(" -> LST at %s for this date/time is %s" % (station.name, lst))
    
    # Filenames in an easier format - output
    if not args.query:
        if args.outname is not None:
            outputSDF = args.outname
        else:
            outputSDF  = None
            
    # Query only mode starts here...
    if args.query:
        lastDur = project.sessions[0].observations[nObs-1].dur
        lastDur = timedelta(seconds=int(lastDur/1000), microseconds=(lastDur*1000) % 1000000)
        sessionDur = max(tStart) - min(tStart) + lastDur
        
        print(" ")
        print(" Total Session Duration: %s" % sessionDur)
        print(" -> First observation starts at %s" % min(tStart).strftime(formatString))
        print(" -> Last observation ends at %s" % (max(tStart) + lastDur).strftime(formatString))
        if project.sessions[0].observations[0].mode not in ('TBW', 'TBN'):
            drspec = 'No'
            if project.sessions[0].spcSetup[0] != 0 and project.sessions[0].spcSetup[1] != 0:
                drspec = 'Yes'
            drxBeam = project.sessions[0].drx_beam
            if drxBeam < 1:
                drxBeam = "MCS decides"
            else:
                drxBeam = "%i" % drxBeam
            print(" DRX Beam: %s" % drxBeam)
            print(" DR Spectrometer used? %s" % drspec)
            if drspec == 'Yes':
                mt = project.sessions[0].spcMetatag
                if mt is None:
                    mt = '{Stokes=XXYY}'
                junk, mt = mt.split('=', 1)
                mt = mt.replace('}', '')
                
                if mt in ('XX', 'YY', 'XY', 'YX', 'XXYY', 'XXXYYXYY'):
                    products = len(mt)//2
                    mt = [mt[2*i:2*i+2] for i in range(products)]
                else:
                    products = len(mt)
                    mt = [mt[1*i:1*i+1] for i in range(products)]
                    
                print(" -> %i channels, %i windows/integration" % tuple(project.sessions[0].spcSetup))
                print(" -> %i data products (%s)" % (products, ','.join(mt)))
        else:
            print(" Transient Buffer: %s\n" % ('Wide band' if project.sessions[0].observations[0].mode == 'TBW' else 'Narrow band',))
            
        print(" ")
        print(" Number of observations: %i" % nObs)
        print(" Observation Detail:")
        for i in range(nObs):
            currDur = project.sessions[0].observations[i].dur
            currDur = timedelta(seconds=int(currDur/1000), microseconds=(currDur*1000) % 1000000)
            
            print("  Observation #%i" % (i+1,))
            
            ## Basic setup
            print("   Target: %s" % project.sessions[0].observations[i].target)
            print("   Mode: %s" % project.sessions[0].observations[i].mode)
            print("   Start:")
            print("    MJD: %i" % project.sessions[0].observations[i].mjd)
            print("    MPM: %i" % project.sessions[0].observations[i].mpm)
            print("    -> %s" % getObsStartStop(project.sessions[0].observations[i])[0].strftime(formatString))
            print("   Duration: %s" % currDur)
            
            ## DP setup
            if project.sessions[0].observations[i].mode not in ('TBW',):
                print("   Tuning 1: %.3f MHz" % (project.sessions[0].observations[i].frequency1/1e6,))
            if project.sessions[0].observations[i].mode not in ('TBW', 'TBN'):
                print("   Tuning 2: %.3f MHz" % (project.sessions[0].observations[i].frequency2/1e6,))
            if project.sessions[0].observations[i].mode not in ('TBW',):
                print("   Filter code: %i" % project.sessions[0].observations[i].filter)
                
            ## Comments/notes
            print("   Observer Comments: %s" % project.sessions[0].observations[i].comments)
            
        # Valid?
        print(" ")
        try:
            if project.validate():
                print(" Valid?  Yes")
            else:
                print(" Valid?  No")
        except:
            print(" Valid?  No")
            
        # And then exits
        sys.exit()
        
    #
    # Query the time and compute the time shifts
    #
    if (not args.no_update):
        # Get the new start date/time in UTC and report on the difference
        if args.lst:
            if args.date is None:
                print(" ")
                print("Enter the new UTC start date:")
                tNewStart = input('YYYY/MM/DD-> ')
                try:
                    fields = tNewStart.split('/', 2)
                    fields = [int(f) for f in fields]
                    tNewStart = date(fields[0], fields[1], fields[2])
                    tNewStart = datetime.combine(tNewStart, min(tStart).time())
                except Exception as e:
                    print("Error: %s" % str(e))
                    sys.exit(1)
                    
            else:
                tNewStart = datetime.combine(args.date, min(tStart).time())
                
            tNewStart = _UTC.localize(tNewStart)
            
            # Figure out a new start time on the correct day
            diff = ((tNewStart - min(tStart)).days) * siderealRegression
            ## Try to make sure that the timedelta object is less than 1 day
            while diff.days > 0:
                diff -= siderealDay
            while diff.days < -1:
                diff += siderealDay
            ## Come up with the new start time
            siderealShift = tNewStart - diff
            ## Another check to make sure we are are the right day
            if siderealShift.date() < tNewStart.date():
                siderealShift += siderealDay
            if siderealShift.date() > tNewStart.date():
                siderealShift -= siderealDay
            ## And yet another one to deal with the corner case that observation starts at ~UT 00:00
            if min(tStart) == siderealShift:
                newSiderealShift1 = siderealShift + siderealDay
                newSiderealShift2 = siderealShift - siderealDay
                if newSiderealShift1.date() == tNewStart.date():
                    siderealShift = newSiderealShift1
                elif newSiderealShift2.date() == tNewStart.date():
                    siderealShift = newSiderealShift2
            tNewStart = siderealShift
            
        else:
            if args.date is None or args.time is None:
                print(" ")
                print("Enter the new UTC start date/time:")
                tNewStart = input('YYYY/MM/DD HH:MM:SS.SSS -> ')
                try:
                    tNewStart = datetime.strptime(tNewStart, '%Y/%m/%d %H:%M:%S.%f')
                except ValueError:
                    try:
                        tNewStart = datetime.strptime(tNewStart, '%Y/%m/%d %H:%M:%S')
                    except Exception as e:
                        print("Error: %s" % str(e))
                        sys.exit(1)
                        
            else:
                tNewStart = datetime.combine(args.date, args.time)
            
            tNewStart = _UTC.localize(tNewStart)
            
        # Get the new shift needed to translate the old times to the new times
        tShift = tNewStart - min(tStart)
        
        # Get the LST at the new start
        observer.date = (tNewStart).strftime('%Y/%m/%d %H:%M:%S')
        lst = observer.sidereal_time()
        
        print(" ")
        print("Shifting observations to start at %s" % tNewStart.strftime(formatString))
        print("-> Difference of %i days, %.3f seconds" % (tShift.days, (tShift.seconds + tShift.microseconds/1000000.0),))
        print("-> LST at %s for this date/time is %s" % (station.name, lst))
        if tShift.days == 0 and tShift.seconds == 0 and tShift.microseconds == 0:
            print(" ")
            print("The current shift is zero.  Do you want to continue anyways?")
            yesNo = input("-> [y/N] ")
            if yesNo not in ('y', 'Y'):
                sys.exit()
                
    else:
        tShift = timedelta(seconds=0)
        
    # Shift the start times and recompute the MJD and MPM values
    for i in range(nObs):
        tStart[i] += tShift
        
    #
    # Query and set the new session ID
    #
    print(" ")
    if args.sid is None:
        print("Enter the new session ID or return to keep current:")
        sid = input('-> ')
        if len(sid) > 0:
            sid = int(sid)
        else:
            sid = project.sessions[0].id
    else:
        sid = args.sid
    print("Shifting session ID from %i to %i" % (project.sessions[0].id, sid))
    project.sessions[0].id = sid
    
    #
    # Go! (apply the changes to the observations)
    #
    print(" ")
    newPOOC = []
    for i in range(nObs):
        print("Working on Observation #%i" % (i+1,))
        newPOOC.append("")
        
        #
        # Start MJD,MPM Shifting
        #
        if (not args.no_update) and tShift != timedelta(seconds=0):
            if len(newPOOC[-1]) != 0:
                newPOOC[-1] += ';;'
            newPOOC[-1] += 'Original MJD:%i,MPM:%i' % (project.sessions[0].observations[i].mjd, project.sessions[0].observations[i].mpm)
            
            start = tStart[i].strftime("%Z %Y %m %d %H:%M:%S.%f")
            start = start[:-3]

            utc = Time(tStart[i], format=Time.FORMAT_PY_DATE)
            mjd = int(utc.utc_mjd)
            
            utcMidnight = datetime(tStart[i].year, tStart[i].month, tStart[i].day, 0, 0, 0, tzinfo=_UTC)
            diff = tStart[i] - utcMidnight
            mpm = int(round((diff.seconds + diff.microseconds/1000000.0)*1000.0))
            
            print(" Time shifting")
            print("  MJD: %8i -> %8i" % (project.sessions[0].observations[i].mjd, mjd))
            print("  MPM: %8i -> %8i" % (project.sessions[0].observations[i].mpm, mpm))
            
            project.sessions[0].observations[i].mjd = mjd
            project.sessions[0].observations[i].mpm = mpm
            project.sessions[0].observations[i].start = start
            
    #
    # Project office comments
    #
    # Update the project office comments with this change
    newPOSC = "Shifted SDF with shiftSDF.py (v%s);;Time Shift? %s" % (__version__, 'Yes' if (not args.no_update) else 'No')
    
    if project.project_office.sessions[0] is None:
        project.project_office.sessions[0] = newPOSC
    else:
        project.project_office.sessions[0] += ';;%s' % newPOSC
        
    for i in range(nObs):
        try:
            project.project_office.observations[0][i] += ';;%s' % newPOOC[i]
        except Exception as e:
            print(e)
            project.project_office.observations[0][i] = '%s' % newPOOC[i]
            
    #
    # Save
    #
    if outputSDF is None:
        pID = project.id
        sID = project.sessions[0].id
        beam = project.sessions[0].drx_beam
        foStart = min(tStart)
        
        if project.sessions[0].observations[0].mode not in ('TBW', 'TBN'):
            if beam == -1:
                print(" ")
                print("Enter the DRX beam to use:")
                newBeam = input('[1 through 4]-> ')
                try:
                    newBeam = int(newBeam)
                except Exception as e:
                    print("Error: %s" % str(e))
                    sys.exit(1)
                if adp:
                    if newBeam not in (1,):
                        print("Error: beam '%i' is out of range" % newBeam)
                        sys.exit(1)
                        
                else:
                    if newBeam not in (1, 2, 3, 4):
                        print("Error: beam '%i' is out of range" % newBeam)
                        sys.exit(1)
                        
                print("Shifting DRX beam from %i to %i" % (beam, newBeam))
                beam = newBeam
                project.sessions[0].drx_beam = beam
                
            outputSDF = '%s_%s_%s_%04i_B%i.sdf' % (pID, foStart.strftime('%y%m%d'), foStart.strftime('%H%M'), sID, beam)
        else:
            outputSDF = '%s_%s_%s_%04i_%s.sdf' % (pID, foStart.strftime('%y%m%d'), foStart.strftime('%H%M'), sID, project.sessions[0].observations[0].mode)
            
    print(" ")
    print("Saving to: %s" % outputSDF)
    fh = open(outputSDF, 'w')
    if not project.validate():
        # Make sure we are about to be valid
        project.validate(verbose=True)
        raise RuntimeError("Cannot validate SDF file")
        
    fh.write( project.render() )
    fh.close()
Example #6
0
def main(args):
    # Parse the command line
    ## Baseline list
    if args.baseline is not None:
        ## Fill the baseline list with the conjugates, if needed
        newBaselines = []
        for pair in args.baseline:
            newBaselines.append( (pair[1],pair[0]) )
        args.baseline.extend(newBaselines)
    ## Search limits
    args.delay_window = [float(v) for v in args.delay_window.split(',', 1)]
    args.rate_window = [float(v) for v in args.rate_window.split(',', 1)]
    
    figs = {}
    first = True
    for filename in args.filename:
        print("Working on '%s'" % os.path.basename(filename))
        # Open the FITS IDI file and access the UV_DATA extension
        hdulist = astrofits.open(filename, mode='readonly')
        andata = hdulist['ANTENNA']
        fqdata = hdulist['FREQUENCY']
        fgdata = None
        for hdu in hdulist[1:]:
                if hdu.header['EXTNAME'] == 'FLAG':
                    fgdata = hdu
        uvdata = hdulist['UV_DATA']
        
        # Pull out various bits of information we need to flag the file
        ## Antenna look-up table
        antLookup = {}
        for an, ai in zip(andata.data['ANNAME'], andata.data['ANTENNA_NO']):
            antLookup[an] = ai
        ## Frequency and polarization setup
        nBand, nFreq, nStk = uvdata.header['NO_BAND'], uvdata.header['NO_CHAN'], uvdata.header['NO_STKD']
        stk0 = uvdata.header['STK_1']
        ## Baseline list
        bls = uvdata.data['BASELINE']
        ## Time of each integration
        obsdates = uvdata.data['DATE']
        obstimes = uvdata.data['TIME']
        inttimes = uvdata.data['INTTIM']
        ## Source list
        srcs = uvdata.data['SOURCE']
        ## Band information
        fqoffsets = fqdata.data['BANDFREQ'].ravel()
        ## Frequency channels
        freq = (numpy.arange(nFreq)-(uvdata.header['CRPIX3']-1))*uvdata.header['CDELT3']
        freq += uvdata.header['CRVAL3']
        ## UVW coordinates
        try:
            u, v, w = uvdata.data['UU'], uvdata.data['VV'], uvdata.data['WW']
        except KeyError:
            u, v, w = uvdata.data['UU---SIN'], uvdata.data['VV---SIN'], uvdata.data['WW---SIN']
        uvw = numpy.array([u, v, w]).T
        ## The actual visibility data
        flux = uvdata.data['FLUX'].astype(numpy.float32)
        
        # Convert the visibilities to something that we can easily work with
        nComp = flux.shape[1] // nBand // nFreq // nStk
        if nComp == 2:
            ## Case 1) - Just real and imaginary data
            flux = flux.view(numpy.complex64)
        else:
            ## Case 2) - Real, imaginary data + weights (drop the weights)
            flux = flux[:,0::nComp] + 1j*flux[:,1::nComp]
        flux.shape = (flux.shape[0], nBand, nFreq, nStk)
        
        # Find unique baselines, times, and sources to work with
        ubls = numpy.unique(bls)
        utimes = numpy.unique(obstimes)
        usrc = numpy.unique(srcs)
        
        # Convert times to real times
        times = utcjd_to_unix(obsdates + obstimes)
        times = numpy.unique(times)
        
        # Build a mask
        mask = numpy.zeros(flux.shape, dtype=numpy.bool)
        if fgdata is not None and not args.drop:
            reltimes = obsdates - obsdates[0] + obstimes
            maxtimes = reltimes + inttimes / 2.0 / 86400.0
            mintimes = reltimes - inttimes / 2.0 / 86400.0
            
            bls_ant1 = bls//256
            bls_ant2 = bls%256
            
            for row in fgdata.data:
                ant1, ant2 = row['ANTS']
                
                ## Only deal with flags that we need for the plots
                process_flag = False
                if ant1 != ant2 or ant1 == 0 or ant2 == 0:
                    if ant1 == 0 and ant2 == 0:
                        process_flag = True
                    elif args.baseline is not None:
                        if ant2 == 0 and ant1 in [a0 for a0,a1 in args.baseline]:
                            process_flag = True
                        elif (ant1,ant2) in args.baseline:
                            process_flag = True
                    elif args.ref_ant is not None:
                        if ant1 == args.ref_ant or ant2 == args.ref_ant:
                            process_flag = True
                    else:
                        process_flag = True
                if not process_flag:
                    continue
                    
                tStart, tStop = row['TIMERANG']
                band = row['BANDS']
                try:
                    len(band)
                except TypeError:
                    band = [band,]
                cStart, cStop = row['CHANS']
                if cStop == 0:
                    cStop = -1
                pol = row['PFLAGS'].astype(numpy.bool)
                
                if ant1 == 0 and ant2 == 0:
                    btmask = numpy.where( ( (maxtimes >= tStart) & (mintimes <= tStop) ) )[0]
                elif ant1 == 0 or ant2 == 0:
                    ant1 = max([ant1, ant2])
                    btmask = numpy.where( ( (bls_ant1 == ant1) | (bls_ant2 == ant1) ) \
                                          & ( (maxtimes >= tStart) & (mintimes <= tStop) ) )[0]
                else:
                    btmask = numpy.where( ( (bls_ant1 == ant1) & (bls_ant2 == ant2) ) \
                                          & ( (maxtimes >= tStart) & (mintimes <= tStop) ) )[0]
                for b,v in enumerate(band):
                    if not v:
                        continue
                    mask[btmask,b,cStart-1:cStop,:] |= pol
                    
        # Make sure the reference antenna is in there
        if first:
            if args.ref_ant is None:
                bl = bls[0]
                i,j = (bl>>8)&0xFF, bl&0xFF
                args.ref_ant = i
            else:
                found = False
                for bl in bls:
                    i,j = (bl>>8)&0xFF, bl&0xFF
                    if i == args.ref_ant:
                        found = True
                        break
                if not found:
                    raise RuntimeError("Cannot file reference antenna %i in the data" % args.ref_ant)
                    
        plot_bls = []
        cross = []
        for i in xrange(len(ubls)):
            bl = ubls[i]
            ant1, ant2 = (bl>>8)&0xFF, bl&0xFF 
            if ant1 != ant2:
                if args.baseline is not None:
                    if (ant1,ant2) in args.baseline:
                        plot_bls.append( bl )
                        cross.append( i )
                elif args.ref_ant is not None:
                    if ant1 == args.ref_ant or ant2 == args.ref_ant:
                        plot_bls.append( bl )
                        cross.append( i )
                else:
                    plot_bls.append( bl )
                    cross.append( i )
        nBL = len(cross)
        
        # Decimation, if needed
        if args.decimate > 1:
            if nFreq % args.decimate != 0:
                raise RuntimeError("Invalid freqeunce decimation factor:  %i %% %i = %i" % (nFreq, args.decimate, nFreq%args.decimate))

            nFreq //= args.decimate
            freq.shape = (freq.size//args.decimate, args.decimate)
            freq = freq.mean(axis=1)
            
            flux.shape = (flux.shape[0], flux.shape[1], flux.shape[2]//args.decimate, args.decimate, flux.shape[3])
            flux = flux.mean(axis=3)
            
            mask.shape = (mask.shape[0], mask.shape[1], mask.shape[2]//args.decimate, args.decimate, mask.shape[3])
            mask = mask.mean(axis=3)
            
        good = numpy.arange(freq.size//8, freq.size*7//8)		# Inner 75% of the band
        
        iSize = int(round(args.interval/robust.mean(inttimes)))
        print(" -> Chunk size is %i intervals (%.3f seconds)" % (iSize, iSize*robust.mean(inttimes)))
        iCount = times.size//iSize
        print(" -> Working with %i chunks of data" % iCount)
        
        print("Number of frequency channels: %i (~%.1f Hz/channel)" % (len(freq), freq[1]-freq[0]))

        dTimes = times - times[0]
        if first:
            ref_time = (int(times[0]) / 60) * 60
            
        dMax = 1.0/(freq[1]-freq[0])/4
        dMax = int(dMax*1e6)*1e-6
        if -dMax*1e6 > args.delay_window[0]:
            args.delay_window[0] = -dMax*1e6
        if dMax*1e6 < args.delay_window[1]:
            args.delay_window[1] = dMax*1e6
        rMax = 1.0/robust.mean(inttimes)/4
        rMax = int(rMax*1e2)*1e-2
        if -rMax*1e3 > args.rate_window[0]:
            args.rate_window[0] = -rMax*1e3
        if rMax*1e3 < args.rate_window[1]:
            args.rate_window[1] = rMax*1e3
            
        dres = 1.0
        nDelays = int((args.delay_window[1]-args.delay_window[0])/dres)
        while nDelays < 50:
            dres /= 10
            nDelays = int((args.delay_window[1]-args.delay_window[0])/dres)
        while nDelays > 5000:
            dres *= 10
            nDelays = int((args.delay_window[1]-args.delay_window[0])/dres)
        nDelays += (nDelays + 1) % 2
        
        rres = 10.0
        nRates = int((args.rate_window[1]-args.rate_window[0])/rres)
        while nRates < 50:
            rres /= 10
            nRates = int((args.rate_window[1]-args.rate_window[0])/rres)
        while nRates > 5000:
            rres *= 10
            nRates = int((args.rate_window[1]-args.rate_window[0])/rres)
        nRates += (nRates + 1) % 2
        
        print("Searching delays %.1f to %.1f us in steps of %.2f us" % (args.delay_window[0], args.delay_window[1], dres))
        print("           rates %.1f to %.1f mHz in steps of %.2f mHz" % (args.rate_window[0], args.rate_window[1], rres))
        print(" ")
        
        delay = numpy.linspace(args.delay_window[0]*1e-6, args.delay_window[1]*1e-6, nDelays)		# s
        drate = numpy.linspace(args.rate_window[0]*1e-3,  args.rate_window[1]*1e-3,  nRates )		# Hz
        
        # Find RFI and trim it out.  This is done by computing average visibility 
        # amplitudes (a "spectrum") and running a median filter in frequency to extract
        # the bandpass.  After the spectrum has been bandpassed, 3sigma features are 
        # trimmed.  Additionally, area where the bandpass fall below 10% of its mean
        # value are also masked.
        spec  = numpy.median(numpy.abs(flux[:,0,:,0]), axis=0)
        spec += numpy.median(numpy.abs(flux[:,0,:,1]), axis=0)
        smth = spec*0.0
        winSize = int(250e3/(freq[1]-freq[0]))
        winSize += ((winSize+1)%2)
        for i in xrange(smth.size):
            mn = max([0, i-winSize//2])
            mx = min([i+winSize, smth.size])
            smth[i] = numpy.median(spec[mn:mx])
        smth /= robust.mean(smth)
        bp = spec / smth
        good = numpy.where( (smth > 0.1) & (numpy.abs(bp-robust.mean(bp)) < 3*robust.std(bp)) )[0]
        nBad = nFreq - len(good)
        print("Masking %i of %i channels (%.1f%%)" % (nBad, nFreq, 100.0*nBad/nFreq))
        
        freq2 = freq*1.0
        freq2.shape += (1,)
        
        dirName = os.path.basename( filename )
        for b in xrange(len(plot_bls)):
            bl = plot_bls[b]
            valid = numpy.where( bls == bl )[0]
            i,j = (bl>>8)&0xFF, bl&0xFF
            dTimes = obsdates[valid] + obstimes[valid]
            dTimes = numpy.array([utcjd_to_unix(v) for v in dTimes])
            
            ## Skip over baselines that are not in the baseline list (if provided)
            if args.baseline is not None:
                if (i,j) not in args.baseline:
                    continue
            ## Skip over baselines that don't include the reference antenna
            elif i != args.ref_ant and j != args.ref_ant:
                continue
                
            ## Check and see if we need to conjugate the visibility, i.e., switch from
            ## baseline (*,ref) to baseline (ref,*)
            doConj = False
            if j == args.ref_ant:
                doConj = True
                
            ## Figure out which polarizations to process
            if args.cross_hands:
                polToUse = ('XX', 'XY', 'YY')
                visToUse = (0, 2, 1)
            else:
                polToUse = ('XX', 'YY')
                visToUse = (0, 1)
                
            blName = (i, j)
            if doConj:
                blName = (j, i)
            blName = '%s-%s' % ('EA%02i' % blName[0] if blName[0] < 51 else 'LWA%i' % (blName[0]-50), 
                                'EA%02i' % blName[1] if blName[1] < 51 else 'LWA%i' % (blName[1]-50))
                            
            if first or blName not in figs:
                fig = plt.figure()
                fig.suptitle('%s' % blName)
                fig.subplots_adjust(hspace=0.001)
                axR = fig.add_subplot(2, 1, 1)
                axD = fig.add_subplot(2, 1, 2, sharex=axR)
                figs[blName] = (fig, axR, axD)
            fig, axR, axD = figs[blName]
            
            markers = {'XX':'s', 'YY':'o', 'XY':'v', 'YX':'^'}
            
            for pol,vis in zip(polToUse, visToUse):
                for i in xrange(iCount):
                    subStart, subStop = dTimes[iSize*i], dTimes[iSize*(i+1)-1]
                    if (subStop - subStart) > 1.1*args.interval:
                        continue
                        
                    subTime = dTimes[iSize*i:iSize*(i+1)].mean()
                    dTimes2 = dTimes[iSize*i:iSize*(i+1)]*1.0
                    dTimes2 -= dTimes2[0]
                    dTimes2.shape += (1,)
                    subData = flux[valid,...][iSize*i:iSize*(i+1),0,good,vis]*1.0
                    subPhase = flux[valid,...][iSize*i:iSize*(i+1),0,good,vis]*1.0
                    if doConj:
                        subData = subData.conj()
                        subPhase = subPhase.conj()
                    subData = numpy.dot(subData, numpy.exp(-2j*numpy.pi*freq2[good,:]*delay))
                    subData /= freq2[good,:].size
                    amp = numpy.dot(subData.T, numpy.exp(-2j*numpy.pi*dTimes2*drate))
                    amp = numpy.abs(amp / dTimes2.size)
                    
                    subPhase = numpy.angle(subPhase.mean()) * 180/numpy.pi
                    subPhase %= 360
                    if subPhase > 180:
                        subPhase -= 360
                        
                    best = numpy.where( amp == amp.max() )
                    if amp.max() > 0:
                        bsnr = (amp[best]-amp.mean())[0]/amp.std()
                        bdly = delay[best[0][0]]*1e6
                        brat = drate[best[1][0]]*1e3
                        
                        c = axR.scatter(subTime-ref_time, brat, c=bsnr, marker=markers[pol],
                                        cmap='gist_yarg', vmin=3, vmax=40)
                        c = axD.scatter(subTime-ref_time, bdly, c=bsnr, marker=markers[pol],
                                        cmap='gist_yarg', vmin=3, vmax=40)
                        
        first = False
        
    for blName in figs:
        fig, axR, axD = figs[blName]
        
        # Colorbar
        cb = fig.colorbar(c, ax=axR, orientation='horizontal')
        cb.set_label('SNR')
        # Legend and reference marks
        handles = []
        for pol in polToUse:
            handles.append(Line2D([0,], [0,], linestyle='', marker=markers[pol], color='k', label=pol))
        axR.legend(handles=handles, loc=0)
        oldLim = axR.get_xlim()
        for ax in (axR, axD):
            ax.hlines(0, oldLim[0], oldLim[1], linestyle=':', alpha=0.5)
        axR.set_xlim(oldLim)
        # Set the labels
        axR.set_ylabel('Rate [mHz]')
        axD.set_ylabel('Delay [$\\mu$s]')
        for ax in (axR, axD):
            ax.set_xlabel('Elapsed Time [s since %s]' % datetime.utcfromtimestamp(ref_time).strftime('%Y%b%d %H:%M'))
        # Set the y ranges
        axR.set_ylim((-max([100, max([abs(v) for v in axR.get_ylim()])]), max([100, max([abs(v) for v in axR.get_ylim()])])))
        axD.set_ylim((-max([0.5, max([abs(v) for v in axD.get_ylim()])]), max([0.5, max([abs(v) for v in axD.get_ylim()])])))
        # No-go regions for the delays
        xlim, ylim = axD.get_xlim(), axD.get_ylim()
        axD.add_patch(Box(xy=(xlim[0],ylim[0]), width=xlim[1]-xlim[0], height=-0.5001-ylim[0],
                          fill=True, color='red', alpha=0.2))
        axD.add_patch(Box(xy=(xlim[0],0.5001), width=xlim[1]-xlim[0], height=ylim[1]-0.5001,
                          fill=True, color='red', alpha=0.2))
        axD.set_xlim(xlim)
        axD.set_ylim(ylim)
        
        fig.tight_layout()
        plt.draw()
        
        if args.save_images:
            fig.savefig('sniffer-%s.png' % blName)
            
    if not args.save_images:
        plt.show()
Example #7
0
def vla_to_unix(timetag):
    """
    Convert a VLA timetag in MJD nanoseconds to a UNIX timestamp.
    """

    return utcjd_to_unix(vla_to_utcmjd(timetag) + MJD_OFFSET)
def main(args):
    reference = args.ref_source
    filenames = args.filename
    
    #
    # Gather the station meta-data from its various sources
    #
    dataDict = numpy.load(filenames[0])
    ssmifContents = dataDict['ssmifContents']
    if ssmifContents.shape == ():
        site = lwa1
    else:
        fh, tempSSMIF = tempfile.mkstemp(suffix='.txt', prefix='ssmif-')
        fh = open(tempSSMIF, 'w')
        for line in ssmifContents:
            fh.write('%s\n' % line)
        fh.close()
        
        site = parse_ssmif(tempSSMIF)
        os.unlink(tempSSMIF)
    print(site.name)
    observer = site.get_observer()
    antennas = site.antennas
    nAnts = len(antennas)
    
    #
    # Find the reference source
    #
    srcs = [ephem.Sun(),]
    for line in _srcs:
        srcs.append( ephem.readdb(line) )
        
    refSrc = None
    for i in xrange(len(srcs)):
        if srcs[i].name == reference:
            refSrc = srcs[i]
            
    if refSrc is None:
        print("Cannot find reference source '%s' in source list, aborting." % reference)
        sys.exit(1)
        
    #
    # Parse the input files
    #
    data = []
    time = []
    freq = []
    oldRef = None
    oldMD5 = None
    maxTime = -1
    for filename in filenames:
        dataDict = numpy.load(filename)
        
        ref_ant = dataDict['ref'].item()
        refX   = dataDict['refX'].item()
        refY   = dataDict['refY'].item()
        tInt = dataDict['tInt'].item()
        
        times = dataDict['times']
        phase = dataDict['simpleVis']
        
        central_freq = dataDict['central_freq'].item()
        
        ssmifContents = dataDict['ssmifContents']
        
        beginDate = datetime.utcfromtimestamp(times[0])
        observer.date = beginDate.strftime("%Y/%m/%d %H:%M:%S")
        
        # Make sure we aren't mixing reference antennas
        if oldRef is None:
            oldRef = ref_ant
        if ref_ant != oldRef:
            raise RuntimeError("Dataset has different reference antennas than previous (%i != %i)" % (ref_ant, oldRef))
            
        # Make sure we aren't mixing SSMIFs
        ssmifMD5 = md5sum(ssmifContents)
        if oldMD5 is None:
            oldMD5 = ssmifMD5
        if ssmifMD5 != oldMD5:
            raise RuntimeError("Dataset has different SSMIF than previous (%s != %s)" % (ssmifMD5, oldMD5))
            
        print("Central Frequency: %.3f Hz" % central_freq)
        print("Start date/time: %s" % beginDate.strftime("%Y/%m/%d %H:%M:%S"))
        print("Integration Time: %.3f s" % tInt)
        print("Number of time samples: %i (%.3f s)" % (phase.shape[0], phase.shape[0]*tInt))
        
        allRates = {}
        for src in srcs:
            src.compute(observer)
            if src.alt > 0:
                fRate = getFringeRate(antennas[0], antennas[refX], observer, src, central_freq)
                allRates[src.name] = fRate
        # Calculate the fringe rates of all sources - for display purposes only
        print("Starting Fringe Rates:")
        for name in allRates.keys():
            fRate = allRates[name]
            print(" %-4s: %+6.3f mHz" % (name, fRate*1e3))
            
        freq.append( central_freq )
        time.append( numpy.array([unix_to_utcjd(t) for t in times]) )
        data.append( phase )
        
        ## Save the length of the `time` entry so that we can trim them
        ## all down to the same size later
        if time[-1].size > maxTime:
            maxTime = time[-1].size
            
    # Pad with NaNs to the same length
    for i in xrange(len(filenames)):
        nTimes = time[i].size
        
        if nTimes < maxTime:
            ## Pad 'time'
            newTime = numpy.zeros(maxTime, dtype=time[i].dtype)
            newTime += numpy.nan
            newTime[0:nTimes] = time[i][:]
            time[i] = newTime
            
            ## Pad 'data'
            newData = numpy.zeros((maxTime, data[i].shape[1]), dtype=data[i].dtype)
            newData += numpy.nan
            newData[0:nTimes,:] = data[i][:,:]
            data[i] = newData
            
    # Convert to 2-D and 3-D numpy arrays
    freq = numpy.array(freq)
    time = numpy.array(time)
    data = numpy.array(data)
    
    #
    # Sort the data by frequency
    #
    order = numpy.argsort(freq)
    freq = numpy.take(freq, order)
    time = numpy.take(time, order, axis=0)
    data = numpy.take(data, order, axis=0)
    
    # 
    # Find the fringe stopping averaging times
    #
    ls = {}
    for fStart in xrange(20, 90, 5):
        fStop = fStart + 5
        l = numpy.where( (freq >= fStart*1e6) & (freq < fStop*1e6) )[0]
        if len(l) > 0:
            ls[fStart] = l
            
    ms = {}
    for fStart in ls.keys():
        m = 1e6
        for l in ls[fStart]:
            good = numpy.where( numpy.isfinite(time[l,:]) == 1 )[0]
            if len(good) < m:
                m = len(good)
        ms[fStart] = m
        
    print("Minimum fringe stopping times:")
    for fStart in sorted(ls.keys()):
        fStop = fStart + 5
        m = ms[fStart]
        print("  >=%i Mhz and <%i MHz: %.3f s" % (fStart, fStop, m*tInt,))
        
    #
    # Report on progress and data coverage
    #
    nFreq = len(freq)
    
    print("Reference stand #%i (X: %i, Y: %i)" % (ref_ant, refX, refY))
    print("-> X: %s" % str(antennas[refX]))
    print("-> Y: %s" % str(antennas[refY]))
    
    print("Using a set of %i frequencies" % nFreq)
    print("->", freq/1e6)
    
    #
    # Compute source positions/fringe stop and remove the source
    #
    print("Fringe stopping on '%s':" % refSrc.name)
    pbar = ProgressBar(max=freq.size*520)
    
    for i in xrange(freq.size):
        fq = freq[i]
        
        for j in xrange(data.shape[2]):
            # Compute the times in seconds relative to the beginning
            times  = time[i,:] - time[i,0]
            times *= 24.0
            times *= 3600.0
            
            # Compute the fringe rates across all time
            fRate = [None,]*data.shape[1]
            for k in xrange(data.shape[1]):
                jd = time[i,k]
                
                try:
                    currDate = datetime.utcfromtimestamp(utcjd_to_unix(jd))
                except ValueError:
                    pass
                observer.date = currDate.strftime("%Y/%m/%d %H:%M:%S")
                refSrc.compute(observer)
        
                if j % 2 == 0:
                    fRate[k] = getFringeRate(antennas[j], antennas[refX], observer, refSrc, fq)
                else:
                    fRate[k] = getFringeRate(antennas[j], antennas[refY], observer, refSrc, fq)
                    
            # Create the basis rate and the residual rates
            baseRate = fRate[0]
            residRate = numpy.array(fRate) - baseRate
        
            # Fringe stop to more the source of interest to the DC component
            data[i,:,j] *= numpy.exp(-2j*numpy.pi* baseRate*(times - times[0]))
            data[i,:,j] *= numpy.exp(-2j*numpy.pi*residRate*(times - times[0]))
            
            # Calculate the geometric delay term across all time
            gDelay = [None,]*data.shape[1]
            for k in xrange(data.shape[1]):
                jd = time[i,k]
                
                try:
                    currDate = datetime.utcfromtimestamp(utcjd_to_unix(jd))
                except ValueError:
                    pass
                observer.date = currDate.strftime("%Y/%m/%d %H:%M:%S")
                refSrc.compute(observer)
                
                az = refSrc.az
                el = refSrc.alt
                if j % 2 == 0:
                    gDelay[k] = getGeoDelay(antennas[j], antennas[refX], az, el, Degrees=False)
                else:
                    gDelay[k] = getGeoDelay(antennas[j], antennas[refY], az, el, Degrees=False)
                    
            # Create the basis delay and the residual delays
            baseDelay = gDelay[0]
            residDelay = numpy.array(gDelay) - baseDelay
            
            # Remove the array geometry
            data[i,:,j] *= numpy.exp(-2j*numpy.pi*fq* baseDelay)
            data[i,:,j] *= numpy.exp(-2j*numpy.pi*fq*residDelay)
            
            pbar.inc()
            sys.stdout.write("%s\r" % pbar.show())
            sys.stdout.flush()
    sys.stdout.write('\n')
    
    # Average down to remove other sources/the correlated sky
    print("Input (pre-averaging) data shapes:")
    print("  time:", time.shape)
    print("  data:", data.shape)
    time = time[:,0]
    
    data2 = numpy.zeros((data.shape[0], data.shape[2]), dtype=data.dtype)
    for j in xrange(data2.shape[1]):
        for fStart in ls.keys():
            l = ls[fStart]
            m = ms[fStart]
            data2[l,j] = data[l,:m,j].mean(axis=1)
    data = data2
    print("Output (post-averaging) data shapes:")
    print("  time:", time.shape)
    print("  data:", data.shape)

    #
    # Save
    #
    outname = args.output
    outname, ext = os.path.splitext(outname)
    outname = "%s-ref%03i%s" % (outname, ref_ant, ext)
    numpy.savez(outname, ref_ant=ref_ant, refX=refX, refY=refY, freq=freq, time=time, data=data, ssmifContents=ssmifContents)
Example #9
0
def main(args):
    # Get the site and observer
    site = stations.lwa1
    observer = site.get_observer()

    # Filenames in an easier format
    inputTGZ = args.filename

    # Parse the input file and get the dates of the observations.  Be default
    # this is for LWA1 but we switch over to LWA-SV if an error occurs.
    try:
        # LWA1
        project = metabundle.get_sdf(inputTGZ)
        obsImpl = metabundle.get_observation_spec(inputTGZ)
        fileInfo = metabundle.get_session_metadata(inputTGZ)
        aspConfigB = metabundle.get_asp_configuration_summary(
            inputTGZ, which='Beginning')
        aspConfigE = metabundle.get_asp_configuration_summary(inputTGZ,
                                                              which='End')
    except:
        # LWA-SV
        ## Site changes
        site = stations.lwasv
        observer = site.get_observer()
        ## Try again
        project = metabundleADP.get_sdf(inputTGZ)
        obsImpl = metabundleADP.get_observation_spec(inputTGZ)
        fileInfo = metabundleADP.get_session_metadata(inputTGZ)
        aspConfigB = metabundleADP.get_asp_configuration_summary(
            inputTGZ, which='Beginning')
        aspConfigE = metabundleADP.get_asp_configuration_summary(inputTGZ,
                                                                 which='End')

    nObs = len(project.sessions[0].observations)
    tStart = [
        None,
    ] * nObs
    for i in range(nObs):
        tStart[i] = utcjd_to_unix(project.sessions[0].observations[i].mjd +
                                  MJD_OFFSET)
        tStart[i] += project.sessions[0].observations[i].mpm / 1000.0
        tStart[i] = datetime.utcfromtimestamp(tStart[i])
        tStart[i] = _UTC.localize(tStart[i])

    # Get the LST at the start
    observer.date = (min(tStart)).strftime('%Y/%m/%d %H:%M:%S')
    lst = observer.sidereal_time()

    # Report on the file
    print("Filename: %s" % inputTGZ)
    print(" Project ID: %s" % project.id)
    print(" Session ID: %i" % project.sessions[0].id)
    print(" Observations appear to start at %s" %
          (min(tStart)).strftime(_FORMAT_STRING))
    print(" -> LST at %s for this date/time is %s" % (site.name, lst))

    lastDur = project.sessions[0].observations[nObs - 1].dur
    lastDur = timedelta(seconds=int(lastDur / 1000),
                        microseconds=(lastDur * 1000) % 1000000)
    sessionDur = max(tStart) - min(tStart) + lastDur

    print(" ")
    print(" Total Session Duration: %s" % sessionDur)
    print(" -> First observation starts at %s" %
          min(tStart).strftime(_FORMAT_STRING))
    print(" -> Last observation ends at %s" %
          (max(tStart) + lastDur).strftime(_FORMAT_STRING))
    if project.sessions[0].observations[0].mode not in ('TBW', 'TBN'):
        drspec = 'No'
        if project.sessions[0].spcSetup[0] != 0 and project.sessions[
                0].spcSetup[1] != 0:
            drspec = 'Yes'
        drxBeam = project.sessions[0].drx_beam
        if drxBeam < 1:
            drxBeam = "MCS decides"
        else:
            drxBeam = "%i" % drxBeam
        print(" DRX Beam: %s" % drxBeam)
        print(" DR Spectrometer used? %s" % drspec)
        if drspec == 'Yes':
            print(" -> %i channels, %i windows/integration" %
                  tuple(project.sessions[0].spcSetup))
    else:
        tbnCount = 0
        tbwCount = 0
        for obs in project.sessions[0].observations:
            if obs.mode == 'TBW':
                tbwCount += 1
            else:
                tbnCount += 1
        if tbwCount > 0 and tbnCount == 0:
            print(" Transient Buffer Mode: TBW")
        elif tbwCount == 0 and tbnCount > 0:
            print(" Transient Buffer Mode: TBN")
        else:
            print(" Transient Buffer Mode: both TBW and TBN")
    print(" ")
    print("File Information:")
    for obsID in fileInfo.keys():
        print(" Obs. #%i: %s" % (obsID, fileInfo[obsID]['tag']))

    print(" ")
    print("ASP Configuration:")
    print('  Beginning')
    for k, v in aspConfigB.items():
        print('    %s: %i' % (k, v))
    print('  End')
    for k, v in aspConfigE.items():
        print('    %s: %i' % (k, v))

    print(" ")
    print(" Number of observations: %i" % nObs)
    print(" Observation Detail:")
    for i in range(nObs):
        currDur = project.sessions[0].observations[i].dur
        currDur = timedelta(seconds=int(currDur / 1000),
                            microseconds=(currDur * 1000) % 1000000)

        print("  Observation #%i" % (i + 1, ))
        currObs = None
        for j in range(len(obsImpl)):
            if obsImpl[j]['obsID'] == i + 1:
                currObs = obsImpl[j]
                break

        ## Basic setup
        print("   Target: %s" % project.sessions[0].observations[i].target)
        print("   Mode: %s" % project.sessions[0].observations[i].mode)
        print("   Start:")
        print("    MJD: %i" % project.sessions[0].observations[i].mjd)
        print("    MPM: %i" % project.sessions[0].observations[i].mpm)
        print("    -> %s" % get_observation_start_stop(
            project.sessions[0].observations[i])[0].strftime(_FORMAT_STRING))
        print("   Duration: %s" % currDur)

        ## DP setup
        if project.sessions[0].observations[i].mode not in ('TBW', ):
            print("   Tuning 1: %.3f MHz" %
                  (project.sessions[0].observations[i].frequency1 / 1e6, ))
        if project.sessions[0].observations[i].mode not in ('TBW', 'TBN'):
            print("   Tuning 2: %.3f MHz" %
                  (project.sessions[0].observations[i].frequency2 / 1e6, ))
        if project.sessions[0].observations[i].mode not in ('TBW', ):
            print("   Filter code: %i" %
                  project.sessions[0].observations[i].filter)
        if currObs is not None:
            if project.sessions[0].observations[i].mode not in ('TBW', ):
                if project.sessions[0].observations[i].mode == 'TBN':
                    print("   Gain setting: %i" % currObs['tbnGain'])
                else:
                    print("   Gain setting: %i" % currObs['drxGain'])
        else:
            print(
                "   WARNING: observation specification not found for this observation"
            )

        ## Comments/notes
        print("   Observer Comments: %s" %
              project.sessions[0].observations[i].comments)
Example #10
0
def main(args):
    # Parse the command line
    filenames = args.filename

    for filename in filenames:
        t0 = time.time()
        print("Working on '%s'" % os.path.basename(filename))
        # Open the FITS IDI file and access the UV_DATA extension
        hdulist = astrofits.open(filename, mode='readonly')
        andata = hdulist['ANTENNA']
        fqdata = hdulist['FREQUENCY']
        srdata = hdulist['SOURCE']
        fgdata = None
        for hdu in hdulist[1:]:
            if hdu.header['EXTNAME'] == 'FLAG':
                fgdata = hdu
        uvdata = hdulist['UV_DATA']

        # Verify we can flag this data
        if uvdata.header['STK_1'] > 0:
            raise RuntimeError("Cannot flag data with STK_1 = %i" %
                               uvdata.header['STK_1'])
        if uvdata.header['NO_STKD'] < 4:
            raise RuntimeError("Cannot flag data with NO_STKD = %i" %
                               uvdata.header['NO_STKD'])

        # NOTE: Assumes that the Stokes parameters increment by -1
        polMapper = {}
        for i in xrange(uvdata.header['NO_STKD']):
            stk = uvdata.header['STK_1'] - i
            polMapper[i] = NUMERIC_STOKES[stk]

        # Pull out various bits of information we need to flag the file
        ## Antenna look-up table
        antLookup = {}
        for an, ai in zip(andata.data['ANNAME'], andata.data['ANTENNA_NO']):
            antLookup[an] = ai
        ## Frequency and polarization setup
        nBand, nFreq, nStk = uvdata.header['NO_BAND'], uvdata.header[
            'NO_CHAN'], uvdata.header['NO_STKD']
        ## Baseline list
        bls = uvdata.data['BASELINE']
        ## Time of each integration
        obsdates = uvdata.data['DATE']
        obstimes = uvdata.data['TIME']
        inttimes = uvdata.data['INTTIM']
        ## Source list
        srcs = uvdata.data['SOURCE']
        ## Band information
        fqoffsets = fqdata.data['BANDFREQ'].ravel()
        ## Frequency channels
        freq = (numpy.arange(nFreq) -
                (uvdata.header['CRPIX3'] - 1)) * uvdata.header['CDELT3']
        freq += uvdata.header['CRVAL3']
        ## The actual visibility data
        flux = uvdata.data['FLUX'].astype(numpy.float32)

        # Convert the visibilities to something that we can easily work with
        nComp = flux.shape[1] // nBand // nFreq // nStk
        if nComp == 2:
            ## Case 1) - Just real and imaginary data
            flux = flux.view(numpy.complex64)
        else:
            ## Case 2) - Real, imaginary data + weights (drop the weights)
            flux = flux[:, 0::nComp] + 1j * flux[:, 1::nComp]
        flux.shape = (flux.shape[0], nBand, nFreq, nStk)

        # Find unique baselines and scans to work on
        ubls = numpy.unique(bls)
        blocks = get_source_blocks(hdulist)

        # Create a mask of the old flags, if needed
        old_flag_mask = get_flags_as_mask(hdulist, version=0 - args.drop)

        # Dedisperse
        mask = numpy.zeros(flux.shape, dtype=numpy.bool)
        for i, block in enumerate(blocks):
            tS = time.time()
            print('  Working on scan %i of %i' % (i + 1, len(blocks)))
            match = range(block[0], block[1] + 1)

            bbls = numpy.unique(bls[match])
            times = obstimes[match] * 86400.0
            ints = inttimes[match]
            scanStart = datetime.utcfromtimestamp(
                utcjd_to_unix(obsdates[match[0]] + obstimes[match[0]]))
            scanStop = datetime.utcfromtimestamp(
                utcjd_to_unix(obsdates[match[-1]] + obstimes[match[-1]]))
            print('    Scan spans %s to %s UTC' %
                  (scanStart.strftime('%Y/%m/%d %H:%M:%S'),
                   scanStop.strftime('%Y/%m/%d %H:%M:%S')))

            freq_comb = []
            for b, offset in enumerate(fqoffsets):
                freq_comb.append(freq + offset)
            freq_comb = numpy.concatenate(freq_comb)

            nBL = len(bbls)
            vis = flux[match, :, :, :]
            ofm = old_flag_mask[match, :, :, :]

            ## If this is the last block, check and see if there is anything that
            ## we can pull out next file in the sequence so that we don't have a
            ## dedispersion gap
            to_trim = -1
            if i == len(blocks) - 1:
                src_name = srdata.data['SOURCE'][srcs[match[0]] - 1]
                nextmask, nextflux = get_trailing_scan(
                    filename, src_name, max(delay(freq_comb, args.DM)))
                if nextmask is not None and nextflux is not None:
                    to_trim = ofm.shape[0]
                    vis = numpy.concatenate([vis, nextflux])
                    ofm = numpy.concatenate([ofm, nextmask])
                    print(
                        '      Appended %i times from the next file in the sequence'
                        % (nextflux.shape[0] // nBL))

            vis.shape = (vis.shape[0] // nBL, nBL, vis.shape[1] * vis.shape[2],
                         vis.shape[3])
            ofm.shape = (ofm.shape[0] // nBL, nBL, ofm.shape[1] * ofm.shape[2],
                         ofm.shape[3])
            print(
                '      Scan contains %i times, %i baselines, %i bands/channels, %i polarizations'
                % vis.shape)

            if vis.shape[0] < 5:
                print('        Too few integrations, skipping')
                vis[:, :, :, :] = numpy.nan
                ofm[:, :, :, :] = True
            else:
                for j in xrange(nBL):
                    for k in xrange(nStk):
                        vis[:, j, :, k] = incoherent(freq_comb,
                                                     vis[:, j, :, k],
                                                     ints[0],
                                                     args.DM,
                                                     boundary='fill',
                                                     fill_value=numpy.nan)
                        ofm[:, j, :, k] = incoherent(freq_comb,
                                                     ofm[:, j, :, k],
                                                     ints[0],
                                                     args.DM,
                                                     boundary='fill',
                                                     fill_value=True)
            vis.shape = (vis.shape[0] * vis.shape[1], len(fqoffsets),
                         vis.shape[2] // len(fqoffsets), vis.shape[3])
            ofm.shape = (ofm.shape[0] * ofm.shape[1], len(fqoffsets),
                         ofm.shape[2] // len(fqoffsets), ofm.shape[3])

            if to_trim != -1:
                print('      Removing the appended times')
                vis = vis[:to_trim, ...]
                ofm = ofm[:to_trim, ...]
            flux[match, :, :, :] = vis

            print('      Saving polarization masks')
            submask = numpy.where(numpy.isfinite(vis), False, True)
            submask.shape = (len(match), flux.shape[1], flux.shape[2],
                             flux.shape[3])
            mask[match, :, :, :] = submask

            print('      Statistics for this scan')
            print('      -> %s      - %.1f%% flagged' % (
                polMapper[0],
                100.0 * mask[match, :, :, 0].sum() / mask[match, :, :, 0].size,
            ))
            print('      -> %s      - %.1f%% flagged' % (
                polMapper[1],
                100.0 * mask[match, :, :, 1].sum() / mask[match, :, :, 0].size,
            ))
            print('      -> Elapsed - %.3f s' % (time.time() - tS, ))

            # Add in the original flag mask
            mask[match, :, :, :] |= ofm

        # Convert the masks into a format suitable for writing to a FLAG table
        print("  Building FLAG table")
        ants, times, bands, chans, pols, reas, sevs = [], [], [], [], [], [], []
        ## New Flags
        nBL = len(ubls)
        for i in xrange(nBL):
            blset = numpy.where(bls == ubls[i])[0]
            ant1, ant2 = (ubls[i] >> 8) & 0xFF, ubls[i] & 0xFF
            if i % 100 == 0 or i + 1 == nBL:
                print("    Baseline %i of %i" % (i + 1, nBL))

            if len(blset) == 0:
                continue

            for b, offset in enumerate(fqoffsets):
                maskXX = mask[blset, b, :, 0]
                maskYY = mask[blset, b, :, 1]

                flagsXX, _ = create_flag_groups(obstimes[blset], freq + offset,
                                                maskXX)
                flagsYY, _ = create_flag_groups(obstimes[blset], freq + offset,
                                                maskYY)

                for flag in flagsXX:
                    ants.append((ant1, ant2))
                    times.append(
                        (obsdates[blset[flag[0]]] + obstimes[blset[flag[0]]] -
                         obsdates[0], obsdates[blset[flag[1]]] +
                         obstimes[blset[flag[1]]] - obsdates[0]))
                    bands.append([1 if j == b else 0 for j in xrange(nBand)])
                    chans.append((flag[2] + 1, flag[3] + 1))
                    pols.append((1, 0, 1, 1))
                    reas.append('DEDISPERSEIDI.PY')
                    sevs.append(-1)
                for flag in flagsYY:
                    ants.append((ant1, ant2))
                    times.append(
                        (obsdates[blset[flag[0]]] + obstimes[blset[flag[0]]] -
                         obsdates[0], obsdates[blset[flag[1]]] +
                         obstimes[blset[flag[1]]] - obsdates[0]))
                    bands.append([1 if j == b else 0 for j in xrange(nBand)])
                    chans.append((flag[2] + 1, flag[3] + 1))
                    pols.append((0, 1, 1, 1))
                    reas.append('DEDISPERSEIDI.PY')
                    sevs.append(-1)

        ## Figure out our revision
        try:
            repo = git.Repo(os.path.dirname(os.path.abspath(__file__)))
            try:
                branch = repo.active_branch.name
                hexsha = repo.active_branch.commit.hexsha
            except TypeError:
                branch = '<detached>'
                hexsha = repo.head.commit.hexsha
            shortsha = hexsha[-7:]
            dirty = ' (dirty)' if repo.is_dirty() else ''
        except git.exc.GitError:
            branch = 'unknown'
            hexsha = 'unknown'
            shortsha = 'unknown'
            dirty = ''

        ## Build the FLAG table
        print('    FITS HDU')
        ### Columns
        nFlags = len(ants)
        c1 = astrofits.Column(name='SOURCE_ID',
                              format='1J',
                              array=numpy.zeros((nFlags, ), dtype=numpy.int32))
        c2 = astrofits.Column(name='ARRAY',
                              format='1J',
                              array=numpy.zeros((nFlags, ), dtype=numpy.int32))
        c3 = astrofits.Column(name='ANTS',
                              format='2J',
                              array=numpy.array(ants, dtype=numpy.int32))
        c4 = astrofits.Column(name='FREQID',
                              format='1J',
                              array=numpy.zeros((nFlags, ), dtype=numpy.int32))
        c5 = astrofits.Column(name='TIMERANG',
                              format='2E',
                              array=numpy.array(times, dtype=numpy.float32))
        c6 = astrofits.Column(name='BANDS',
                              format='%iJ' % nBand,
                              array=numpy.array(bands,
                                                dtype=numpy.int32).squeeze())
        c7 = astrofits.Column(name='CHANS',
                              format='2J',
                              array=numpy.array(chans, dtype=numpy.int32))
        c8 = astrofits.Column(name='PFLAGS',
                              format='4J',
                              array=numpy.array(pols, dtype=numpy.int32))
        c9 = astrofits.Column(name='REASON',
                              format='A40',
                              array=numpy.array(reas))
        c10 = astrofits.Column(name='SEVERITY',
                               format='1J',
                               array=numpy.array(sevs, dtype=numpy.int32))
        colDefs = astrofits.ColDefs([c1, c2, c3, c4, c5, c6, c7, c8, c9, c10])
        ### The table itself
        flags = astrofits.BinTableHDU.from_columns(colDefs)
        ### The header
        flags.header['EXTNAME'] = ('FLAG', 'FITS-IDI table name')
        flags.header['EXTVER'] = (1 if fgdata is None else
                                  fgdata.header['EXTVER'] + 1,
                                  'table instance number')
        flags.header['TABREV'] = (2, 'table format revision number')
        for key in ('NO_STKD', 'STK_1', 'NO_BAND', 'NO_CHAN', 'REF_FREQ',
                    'CHAN_BW', 'REF_PIXL', 'OBSCODE', 'ARRNAM', 'RDATE'):
            try:
                flags.header[key] = (uvdata.header[key],
                                     uvdata.header.comments[key])
            except KeyError:
                pass
        flags.header['HISTORY'] = 'Flagged with %s, revision %s.%s%s' % (
            os.path.basename(__file__), branch, shortsha, dirty)
        flags.header['HISTORY'] = 'Dedispersed at %.6f pc / cm^3' % args.DM

        # Clean up the old FLAG tables, if any, and then insert the new table where it needs to be
        if args.drop:
            ## Reset the EXTVER on the new FLAG table
            flags.header['EXTVER'] = (1, 'table instance number')
            ## Find old tables
            toRemove = []
            for hdu in hdulist:
                try:
                    if hdu.header['EXTNAME'] == 'FLAG':
                        toRemove.append(hdu)
                except KeyError:
                    pass
            ## Remove old tables
            for hdu in toRemove:
                ver = hdu.header['EXTVER']
                del hdulist[hdulist.index(hdu)]
                print("  WARNING: removing old FLAG table - version %i" % ver)
        ## Insert the new table right before UV_DATA
        hdulist.insert(-1, flags)

        # Save
        print("  Saving to disk")
        ## What to call it
        outname = os.path.basename(filename)
        outname, outext = os.path.splitext(outname)
        outname = '%s_DM%.4f%s' % (outname, args.DM, outext)
        ## Does it already exist or not
        if os.path.exists(outname):
            if not args.force:
                yn = raw_input("WARNING: '%s' exists, overwrite? [Y/n] " %
                               outname)
            else:
                yn = 'y'

            if yn not in ('n', 'N'):
                os.unlink(outname)
            else:
                raise RuntimeError("Output file '%s' already exists" % outname)
        ## Open and create a new primary HDU
        hdulist2 = astrofits.open(outname, mode='append')
        primary = astrofits.PrimaryHDU()
        processed = []
        for key in hdulist[0].header:
            if key in ('COMMENT', 'HISTORY'):
                if key not in processed:
                    parts = str(hdulist[0].header[key]).split('\n')
                    for part in parts:
                        primary.header[key] = part
                    processed.append(key)
            else:
                primary.header[key] = (hdulist[0].header[key],
                                       hdulist[0].header.comments[key])
        primary.header['HISTORY'] = 'Dedispersed with %s, revision %s.%s%s' % (
            os.path.basename(__file__), branch, shortsha, dirty)
        primary.header['HISTORY'] = 'Dedispersed at %.6f pc / cm^3' % args.DM
        hdulist2.append(primary)
        hdulist2.flush()
        ## Copy the extensions over to the new file
        for hdu in hdulist[1:]:
            if hdu.header['EXTNAME'] == 'UV_DATA':
                ### Updated the UV_DATA table with the dedispersed data
                flux = numpy.where(numpy.isfinite(flux), flux, 0.0)
                flux = flux.view(numpy.float32)  # pylint: disable=no-member
                flux = flux.astype(hdu.data['FLUX'].dtype)
                flux.shape = hdu.data['FLUX'].shape
                hdu.data['FLUX'][...] = flux

            hdulist2.append(hdu)
            hdulist2.flush()
        hdulist2.close()
        hdulist.close()
        print("  -> Dedispersed FITS IDI file is '%s'" % outname)
        print("  Finished in %.3f s" % (time.time() - t0, ))
Example #11
0
def main(args):
    # Parse the command line
    ## Baseline list
    if args.baseline is not None:
        ## Fill the baseline list with the conjugates, if needed
        newBaselines = []
        for pair in args.baseline:
            newBaselines.append((pair[1], pair[0]))
        args.baseline.extend(newBaselines)
    ## Polarization
    args.polToPlot = 'XX'
    if args.xy:
        args.polToPlot = 'XY'
    elif args.yx:
        args.polToPlot = 'YX'
    elif args.yy:
        args.polToPlot = 'YY'
    filename = args.filename

    print("Working on '%s'" % os.path.basename(filename))
    # Open the FITS IDI file and access the UV_DATA extension
    hdulist = astrofits.open(filename, mode='readonly')
    andata = hdulist['ANTENNA']
    fqdata = hdulist['FREQUENCY']
    fgdata = None
    for hdu in hdulist[1:]:
        if hdu.header['EXTNAME'] == 'FLAG':
            fgdata = hdu
    uvdata = hdulist['UV_DATA']

    # Pull out various bits of information we need to flag the file
    ## Antenna look-up table
    antLookup = {}
    for an, ai in zip(andata.data['ANNAME'], andata.data['ANTENNA_NO']):
        antLookup[an] = ai
    ## Frequency and polarization setup
    nBand, nFreq, nStk = uvdata.header['NO_BAND'], uvdata.header[
        'NO_CHAN'], uvdata.header['NO_STKD']
    stk0 = uvdata.header['STK_1']
    ## Baseline list
    bls = uvdata.data['BASELINE']
    ## Time of each integration
    obsdates = uvdata.data['DATE']
    obstimes = uvdata.data['TIME']
    inttimes = uvdata.data['INTTIM']
    ## Source list
    srcs = uvdata.data['SOURCE']
    ## Band information
    fqoffsets = fqdata.data['BANDFREQ'].ravel()
    ## Frequency channels
    freq = (numpy.arange(nFreq) -
            (uvdata.header['CRPIX3'] - 1)) * uvdata.header['CDELT3']
    freq += uvdata.header['CRVAL3']
    ## UVW coordinates
    try:
        u, v, w = uvdata.data['UU'], uvdata.data['VV'], uvdata.data['WW']
    except KeyError:
        u, v, w = uvdata.data['UU---SIN'], uvdata.data[
            'VV---SIN'], uvdata.data['WW---SIN']
    uvw = numpy.array([u, v, w]).T
    ## The actual visibility data
    flux = uvdata.data['FLUX'].astype(numpy.float32)

    # Convert the visibilities to something that we can easily work with
    nComp = flux.shape[1] // nBand // nFreq // nStk
    if nComp == 2:
        ## Case 1) - Just real and imaginary data
        flux = flux.view(numpy.complex64)
    else:
        ## Case 2) - Real, imaginary data + weights (drop the weights)
        flux = flux[:, 0::nComp] + 1j * flux[:, 1::nComp]
    flux.shape = (flux.shape[0], nBand, nFreq, nStk)

    # Find unique baselines, times, and sources to work with
    ubls = numpy.unique(bls)
    utimes = numpy.unique(obstimes)
    usrc = numpy.unique(srcs)

    # Convert times to real times
    times = utcjd_to_unix(obsdates + obstimes)
    times = numpy.unique(times)

    # Build a mask
    mask = numpy.zeros(flux.shape, dtype=numpy.bool)
    if fgdata is not None and not args.drop:
        reltimes = obsdates - obsdates[0] + obstimes
        maxtimes = reltimes + inttimes / 2.0 / 86400.0
        mintimes = reltimes - inttimes / 2.0 / 86400.0

        bls_ant1 = bls // 256
        bls_ant2 = bls % 256

        for row in fgdata.data:
            ant1, ant2 = row['ANTS']

            ## Only deal with flags that we need for the plots
            process_flag = False
            if args.include_auto or ant1 != ant2 or ant1 == 0 or ant2 == 0:
                if ant1 == 0 and ant2 == 0:
                    process_flag = True
                elif args.baseline is not None:
                    if ant2 == 0 and ant1 in [a0 for a0, a1 in args.baseline]:
                        process_flag = True
                    elif (ant1, ant2) in args.baseline:
                        process_flag = True
                elif args.ref_ant is not None:
                    if ant1 == args.ref_ant or ant2 == args.ref_ant:
                        process_flag = True
                else:
                    process_flag = True
            if not process_flag:
                continue

            tStart, tStop = row['TIMERANG']
            band = row['BANDS']
            try:
                len(band)
            except TypeError:
                band = [
                    band,
                ]
            cStart, cStop = row['CHANS']
            if cStop == 0:
                cStop = -1
            pol = row['PFLAGS'].astype(numpy.bool)

            if ant1 == 0 and ant2 == 0:
                btmask = numpy.where(
                    ((maxtimes >= tStart) & (mintimes <= tStop)))[0]
            elif ant1 == 0 or ant2 == 0:
                ant1 = max([ant1, ant2])
                btmask = numpy.where( ( (bls_ant1 == ant1) | (bls_ant2 == ant1) ) \
                                      & ( (maxtimes >= tStart) & (mintimes <= tStop) ) )[0]
            else:
                btmask = numpy.where( ( (bls_ant1 == ant1) & (bls_ant2 == ant2) ) \
                                      & ( (maxtimes >= tStart) & (mintimes <= tStop) ) )[0]
            for b, v in enumerate(band):
                if not v:
                    continue
                mask[btmask, b, cStart - 1:cStop, :] |= pol

    plot_bls = []
    cross = []
    for i in xrange(len(ubls)):
        bl = ubls[i]
        ant1, ant2 = (bl >> 8) & 0xFF, bl & 0xFF
        if args.include_auto or ant1 != ant2:
            if args.baseline is not None:
                if (ant1, ant2) in args.baseline:
                    plot_bls.append(bl)
                    cross.append(i)
            elif args.ref_ant is not None:
                if ant1 == args.ref_ant or ant2 == args.ref_ant:
                    plot_bls.append(bl)
                    cross.append(i)
            else:
                plot_bls.append(bl)
                cross.append(i)
    nBL = len(cross)

    # Decimation, if needed
    if args.decimate > 1:
        if nFreq % args.decimate != 0:
            raise RuntimeError(
                "Invalid freqeunce decimation factor:  %i %% %i = %i" %
                (nFreq, args.decimate, nFreq % args.decimate))

        nFreq //= args.decimate
        freq.shape = (freq.size // args.decimate, args.decimate)
        freq = freq.mean(axis=1)

        flux.shape = (flux.shape[0], flux.shape[1],
                      flux.shape[2] // args.decimate, args.decimate,
                      flux.shape[3])
        flux = flux.mean(axis=3)

        mask.shape = (mask.shape[0], mask.shape[1],
                      mask.shape[2] // args.decimate, args.decimate,
                      mask.shape[3])
        mask = mask.mean(axis=3)

    good = numpy.arange(freq.size // 8,
                        freq.size * 7 // 8)  # Inner 75% of the band

    # NOTE: Assumes that the Stokes parameters increment by -1
    namMapper = {}
    for i in xrange(nStk):
        stk = stk0 - i
        namMapper[i] = NUMERIC_STOKES[stk]
    polMapper = {'XX': 0, 'YY': 1, 'XY': 2, 'YX': 3}

    fig1 = plt.figure()
    fig2 = plt.figure()
    fig3 = plt.figure()
    fig4 = plt.figure()
    fig5 = plt.figure()

    k = 0
    nRow = int(numpy.sqrt(len(plot_bls)))
    nCol = int(numpy.ceil(len(plot_bls) * 1.0 / nRow))
    for b in xrange(len(plot_bls)):
        bl = plot_bls[b]
        valid = numpy.where(bls == bl)[0]
        i, j = (bl >> 8) & 0xFF, bl & 0xFF
        dTimes = obsdates[valid] + obstimes[valid]
        dTimes -= dTimes[0]
        dTimes *= 86400.0

        ax1, ax2, ax3, ax4, ax5 = None, None, None, None, None
        for band, offset in enumerate(fqoffsets):
            frq = freq + offset
            vis = numpy.ma.array(flux[valid, band, :,
                                      polMapper[args.polToPlot]],
                                 mask=mask[valid, band, :,
                                           polMapper[args.polToPlot]])

            ax1 = fig1.add_subplot(nRow,
                                   nCol * nBand,
                                   nBand * k + 1 + band,
                                   sharey=ax1)
            ax1.imshow(numpy.ma.angle(vis),
                       extent=(frq[0] / 1e6, frq[-1] / 1e6, dTimes[0],
                               dTimes[-1]),
                       origin='lower',
                       vmin=-numpy.pi,
                       vmax=numpy.pi,
                       interpolation='nearest')
            ax1.axis('auto')
            ax1.set_xlabel('Frequency [MHz]')
            if band == 0:
                ax1.set_ylabel('Elapsed Time [s]')
            ax1.set_title("%i,%i - %s" %
                          (i, j, namMapper[polMapper[args.polToPlot]]))
            ax1.set_xlim((frq[0] / 1e6, frq[-1] / 1e6))
            ax1.set_ylim((dTimes[0], dTimes[-1]))

            ax2 = fig2.add_subplot(nRow,
                                   nCol * nBand,
                                   nBand * k + 1 + band,
                                   sharey=ax2)
            amp = numpy.ma.abs(vis)
            vmin, vmax = percentile(amp, 1), percentile(amp, 99)
            ax2.imshow(amp,
                       extent=(frq[0] / 1e6, frq[-1] / 1e6, dTimes[0],
                               dTimes[-1]),
                       origin='lower',
                       interpolation='nearest',
                       vmin=vmin,
                       vmax=vmax)
            ax2.axis('auto')
            ax2.set_xlabel('Frequency [MHz]')
            if band == 0:
                ax2.set_ylabel('Elapsed Time [s]')
            ax2.set_title("%i,%i - %s" %
                          (i, j, namMapper[polMapper[args.polToPlot]]))
            ax2.set_xlim((frq[0] / 1e6, frq[-1] / 1e6))
            ax2.set_ylim((dTimes[0], dTimes[-1]))

            ax3 = fig3.add_subplot(nRow,
                                   nCol * nBand,
                                   nBand * k + 1 + band,
                                   sharey=ax3)
            ax3.plot(frq / 1e6, numpy.ma.abs(vis.mean(axis=0)))
            ax3.set_xlabel('Frequency [MHz]')
            if band == 0:
                ax3.set_ylabel('Mean Vis. Amp. [lin.]')
            ax3.set_title("%i,%i - %s" %
                          (i, j, namMapper[polMapper[args.polToPlot]]))
            ax3.set_xlim((frq[0] / 1e6, frq[-1] / 1e6))

            ax4 = fig4.add_subplot(nRow,
                                   nCol * nBand,
                                   nBand * k + 1 + band,
                                   sharey=ax4)
            ax4.plot(numpy.ma.angle(vis[:, good].mean(axis=1)) * 180 /
                     numpy.pi,
                     dTimes,
                     linestyle='',
                     marker='+')
            ax4.set_xlim((-180, 180))
            ax4.set_xlabel('Mean Vis. Phase [deg]')
            if band == 0:
                ax4.set_ylabel('Elapsed Time [s]')
            ax4.set_title("%i,%i - %s" %
                          (i, j, namMapper[polMapper[args.polToPlot]]))
            ax4.set_ylim((dTimes[0], dTimes[-1]))

            ax5 = fig5.add_subplot(nRow,
                                   nCol * nBand,
                                   nBand * k + 1 + band,
                                   sharey=ax5)
            ax5.plot(numpy.ma.abs(vis[:, good].mean(axis=1)) * 180 / numpy.pi,
                     dTimes,
                     linestyle='',
                     marker='+')
            ax5.set_xlabel('Mean Vis. Amp. [lin.]')
            if band == 0:
                ax5.set_ylabel('Elapsed Time [s]')
            ax5.set_title("%i,%i - %s" %
                          (i, j, namMapper[polMapper[args.polToPlot]]))
            ax5.set_ylim((dTimes[0], dTimes[-1]))

            if band > 0:
                for ax in (ax1, ax2, ax3, ax4, ax5):
                    plt.setp(ax.get_yticklabels(), visible=False)
            if band < nBand - 1:
                for ax in (ax1, ax2, ax3, ax4, ax5):
                    xticks = ax.xaxis.get_major_ticks()
                    xticks[-1].label1.set_visible(False)

        k += 1

    for f in (fig1, fig2, fig3, fig4, fig5):
        f.suptitle(
            "%s to %s UTC" %
            (datetime.utcfromtimestamp(times[0]).strftime("%Y/%m/%d %H:%M"),
             datetime.utcfromtimestamp(times[-1]).strftime("%Y/%m/%d %H:%M")))
        if nBand > 1:
            f.subplots_adjust(wspace=0.0)

    plt.show()
Example #12
0
def main(args):
    # Parse the command line
    ## Baseline list
    if args.baseline is not None:
        ## Fill the baseline list with the conjugates, if needed
        newBaselines = []
        for pair in args.baseline:
            newBaselines.append( (pair[1],pair[0]) )
        args.baseline.extend(newBaselines)
    ## Search limits
    args.delay_window = [float(v) for v in args.delay_window.split(',', 1)]
    args.rate_window = [float(v) for v in args.rate_window.split(',', 1)]
    
    print("Working on '%s'" % os.path.basename(args.filename))
    # Open the FITS IDI file and access the UV_DATA extension
    hdulist = astrofits.open(args.filename, mode='readonly')
    andata = hdulist['ANTENNA']
    fqdata = hdulist['FREQUENCY']
    uvdata = hdulist['UV_DATA']
    
    # Verify we can flag this data
    if uvdata.header['STK_1'] > 0:
        raise RuntimeError("Cannot flag data with STK_1 = %i" % uvdata.header['STK_1'])
    if uvdata.header['NO_STKD'] < 4:
        raise RuntimeError("Cannot flag data with NO_STKD = %i" % uvdata.header['NO_STKD'])
        
    # Pull out various bits of information we need to flag the file
    ## Antenna look-up table
    antLookup = {}
    for an, ai in zip(andata.data['ANNAME'], andata.data['ANTENNA_NO']):
        antLookup[an] = ai
    ## Frequency and polarization setup
    nBand, nFreq, nStk = uvdata.header['NO_BAND'], uvdata.header['NO_CHAN'], uvdata.header['NO_STKD']
    ## Baseline list
    bls = uvdata.data['BASELINE']
    ## Time of each integration
    obsdates = uvdata.data['DATE']
    obstimes = uvdata.data['TIME']
    inttimes = uvdata.data['INTTIM']
    ## Source list
    srcs = uvdata.data['SOURCE']
    ## Band information
    fqoffsets = fqdata.data['BANDFREQ'].ravel()
    ## Frequency channels
    freq = (numpy.arange(nFreq)-(uvdata.header['CRPIX3']-1))*uvdata.header['CDELT3']
    freq += uvdata.header['CRVAL3']
    ## UVW coordinates
    try:
        u, v, w = uvdata.data['UU'], uvdata.data['VV'], uvdata.data['WW']
    except KeyError:
        u, v, w = uvdata.data['UU---SIN'], uvdata.data['VV---SIN'], uvdata.data['WW---SIN']
    uvw = numpy.array([u, v, w]).T
    ## The actual visibility data
    flux = uvdata.data['FLUX'].astype(numpy.float32)
    
    # Convert the visibilities to something that we can easily work with
    nComp = flux.shape[1] // nBand // nFreq // nStk
    if nComp == 2:
        ## Case 1) - Just real and imaginary data
        flux = flux.view(numpy.complex64)
    else:
        ## Case 2) - Real, imaginary data + weights (drop the weights)
        flux = flux[:,0::nComp] + 1j*flux[:,1::nComp]
    flux.shape = (flux.shape[0], nBand, nFreq, nStk)
    
    # Find unique baselines, times, and sources to work with
    ubls = numpy.unique(bls)
    utimes = numpy.unique(obstimes)
    usrc = numpy.unique(srcs)
    
    # Convert times to real times
    times = utcjd_to_unix(obsdates + obstimes)
    times = numpy.unique(times)
    
    # Find unique scans to work on, making sure that there are no large gaps
    blocks = []
    for src in usrc:
        valid = numpy.where( src == srcs )[0]
        
        blocks.append( [valid[0],valid[0]] )
        for v in valid[1:]:
            if v == blocks[-1][1] + 1 \
                and (obsdates[v] - obsdates[blocks[-1][1]] + obstimes[v] - obstimes[blocks[-1][1]])*86400 < 10*inttimes[v]:
                blocks[-1][1] = v
            else:
                blocks.append( [v,v] )
    blocks.sort()
    
    # Make sure the reference antenna is in there
    if args.ref_ant is None:
        bl = ubls[0]
        ant1, ant2 = (bl>>8)&0xFF, bl&0xFF 
        args.ref_ant = ant1
    else:
        found = False
        for bl in ubls:
            ant1, ant2 = (bl>>8)&0xFF, bl&0xFF
            if ant1 == args.ref_ant:
                found = True
                break
        if not found:
            raise RuntimeError("Cannot file reference antenna %i in the data" % args.ref_ant)
            
    search_bls = []
    cross = []
    for i in xrange(len(ubls)):
        bl = ubls[i]
        ant1, ant2 = (bl>>8)&0xFF, bl&0xFF 
        if ant1 != ant2:
            search_bls.append( bl )
            cross.append( i )
    nBL = len(cross)
    
    iTimes = numpy.zeros(times.size-1, dtype=times.dtype)
    for i in xrange(1, len(times)):
        iTimes[i-1] = times[i] - times[i-1]
    print(" -> Interval: %.3f +/- %.3f seconds (%.3f to %.3f seconds)" % (iTimes.mean(), iTimes.std(), iTimes.min(), iTimes.max()))
    
    print("Number of frequency channels: %i (~%.1f Hz/channel)" % (len(freq), freq[1]-freq[0]))

    dTimes = times - times[0]
    
    dMax = 1.0/(freq[1]-freq[0])/4
    dMax = int(dMax*1e6)*1e-6
    if -dMax*1e6 > args.delay_window[0]:
        args.delay_window[0] = -dMax*1e6
    if dMax*1e6 < args.delay_window[1]:
        args.delay_window[1] = dMax*1e6
    rMax = 1.0/iTimes.mean()/4
    rMax = int(rMax*1e2)*1e-2
    if -rMax*1e3 > args.rate_window[0]:
        args.rate_window[0] = -rMax*1e3
    if rMax*1e3 < args.rate_window[1]:
        args.rate_window[1] = rMax*1e3
        
    dres = 1.0
    nDelays = int((args.delay_window[1]-args.delay_window[0])/dres)
    while nDelays < 50:
        dres /= 10
        nDelays = int((args.delay_window[1]-args.delay_window[0])/dres)
    while nDelays > 5000:
        dres *= 10
        nDelays = int((args.delay_window[1]-args.delay_window[0])/dres)
    nDelays += (nDelays + 1) % 2
    
    rres = 10.0
    nRates = int((args.rate_window[1]-args.rate_window[0])/rres)
    while nRates < 50:
        rres /= 10
        nRates = int((args.rate_window[1]-args.rate_window[0])/rres)
    while nRates > 5000:
        rres *= 10
        nRates = int((args.rate_window[1]-args.rate_window[0])/rres)
    nRates += (nRates + 1) % 2
    
    print("Searching delays %.1f to %.1f us in steps of %.2f us" % (args.delay_window[0], args.delay_window[1], dres))
    print("           rates %.1f to %.1f mHz in steps of %.2f mHz" % (args.rate_window[0], args.rate_window[1], rres))
    print(" ")
    
    delay = numpy.linspace(args.delay_window[0]*1e-6, args.delay_window[1]*1e-6, nDelays)		# s
    drate = numpy.linspace(args.rate_window[0]*1e-3,  args.rate_window[1]*1e-3,  nRates )		# Hz
    
    # Find RFI and trim it out.  This is done by computing average visibility 
    # amplitudes (a "spectrum") and running a median filter in frequency to extract
    # the bandpass.  After the spectrum has been bandpassed, 3sigma features are 
    # trimmed.  Additionally, area where the bandpass fall below 10% of its mean
    # value are also masked.
    spec  = numpy.median(numpy.abs(flux[:,0,:,0]), axis=0)
    spec += numpy.median(numpy.abs(flux[:,0,:,1]), axis=0)
    smth = spec*0.0
    winSize = int(250e3/(freq[1]-freq[0]))
    winSize += ((winSize+1)%2)
    for i in xrange(smth.size):
        mn = max([0, i-winSize//2])
        mx = min([i+winSize//2+1, smth.size])
        smth[i] = numpy.median(spec[mn:mx])
    smth /= robust.mean(smth)
    bp = spec / smth
    good = numpy.where( (smth > 0.1) & (numpy.abs(bp-robust.mean(bp)) < 3*robust.std(bp)) )[0]
    nBad = nFreq - len(good)
    print("Masking %i of %i channels (%.1f%%)" % (nBad, nFreq, 100.0*nBad/nFreq))
    if args.plot:
        fig = plt.figure()
        ax = fig.gca()
        ax.plot(freq/1e6, numpy.log10(spec)*10)
        ax.plot(freq[good]/1e6, numpy.log10(spec[good])*10)
        ax.set_title('Mean Visibility Amplitude')
        ax.set_xlabel('Frequency [MHz]')
        ax.set_ylabel('PSD [arb. dB]')
        plt.draw()
    
    freq2 = freq*1.0
    freq2.shape += (1,)
    dTimes2 = dTimes*1.0
    dTimes2.shape += (1,)
    
    # NOTE: Assumed linear data
    polMapper = {'XX':0, 'YY':1, 'XY':2, 'YX':3}
    
    print("%3s  %9s  %2s  %6s  %9s  %11s" % ('#', 'BL', 'Pl', 'S/N', 'Delay', 'Rate'))
    for b in xrange(len(search_bls)):
        bl = search_bls[b]
        ant1, ant2 = (bl>>8)&0xFF, bl&0xFF
        
        ## Skip over baselines that are not in the baseline list (if provided)
        if args.baseline is not None:
            if (ant1, ant2) not in args.baseline:
                continue
        ## Skip over baselines that don't include the reference antenna
        elif ant1 != args.ref_ant and ant2 != args.ref_ant:
            continue
            
        ## Check and see if we need to conjugate the visibility, i.e., switch from
        ## baseline (*,ref) to baseline (ref,*)
        doConj = False
        if ant2 == args.ref_ant:
            doConj = True
            
        ## Figure out which polarizations to process
        if ant1 not in (51, 52) and ant2 not in (51, 52):
            ### Standard VLA-VLA baseline
            polToUse = ('XX', 'YY')
        else:
            ### LWA-LWA or LWA-VLA baseline
            if args.y_only:
                polToUse = ('YX', 'YY')
            else:
                polToUse = ('XX', 'XY', 'YX', 'YY')
                
        if args.plot:
            fig = plt.figure()
            axs = {}
            axs['XX'] = fig.add_subplot(2, 2, 1)
            axs['YY'] = fig.add_subplot(2, 2, 2)
            axs['XY'] = fig.add_subplot(2, 2, 3)
            axs['YX'] = fig.add_subplot(2, 2, 4)
            
        valid = numpy.where( bls == bl )[0]
        for pol in polToUse:
            subData = flux[valid,0,:,polMapper[pol]]*1.0
            subData = subData[:,good]
            if doConj:
                subData = subData.conj()
            subData = numpy.dot(subData, numpy.exp(-2j*numpy.pi*freq2[good,:]*delay))
            subData /= freq2[good,:].size
            amp = numpy.dot(subData.T, numpy.exp(-2j*numpy.pi*dTimes2*drate))
            amp = numpy.abs(amp / dTimes2.size)
            
            blName = (ant1, ant2)
            if doConj:
                blName = (ant2, ant1)
            blName = '%s-%s' % ('EA%02i' % blName[0] if blName[0] < 51 else 'LWA%i' % (blName[0]-50), 
                        'EA%02i' % blName[1] if blName[1] < 51 else 'LWA%i' % (blName[1]-50))
                        
            best = numpy.where( amp == amp.max() )
            if amp.max() > 0:
                bsnr = (amp[best]-amp.mean())[0]/amp.std()
                bdly = delay[best[0][0]]*1e6
                brat = drate[best[1][0]]*1e3
                print("%3i  %9s  %2s  %6.2f  %6.2f us  %7.2f mHz" % (b, blName, pol, bsnr, bdly, brat))
            else:
                print("%3i  %9s  %2s  %6s  %9s  %11s" % (b, blName, pol, '----', '----', '----'))
                
            if args.plot:
                axs[pol].imshow(amp, origin='lower', interpolation='nearest', 
                            extent=(drate[0]*1e3, drate[-1]*1e3, delay[0]*1e6, delay[-1]*1e6), 
                            cmap='gray_r')
                axs[pol].plot(drate[best[1][0]]*1e3, delay[best[0][0]]*1e6, linestyle='', marker='x', color='r', ms=15, alpha=0.75)
                
        if args.plot:
            fig.suptitle(os.path.basename(args.filename))
            for pol in axs.keys():
                ax = axs[pol]
                ax.axis('auto')
                ax.set_title(pol)
                ax.set_xlabel('Rate [mHz]')
                ax.set_ylabel('Delay [$\\mu$s]')
            fig.suptitle("%s" % blName)
            plt.draw()
            
    if args.plot:
        plt.show()
Example #13
0
 def utc_dp(self):
     return int(astro.utcjd_to_unix(self.utc_jd) * fS)