Ejemplo n.º 1
0
def _comply_vispol(pol):
    '''Maps an input visibility polarization string onto a string compliant with pyuvdata
    and hera_cal.'''
    if _is_cardinal(pol):
        return polnum2str(polstr2num(pol, x_orientation='north'), x_orientation='north')
    else:
        return polnum2str(polstr2num(pol))
Ejemplo n.º 2
0
    def get_Omegas(self, polpairs):
        """
        Get OmegaP and OmegaPP across beam_freqs for requested polarization
        pairs.

        Parameters
        ----------
        polpairs : list
            List of polarization-pair tuples or integers.

        Returns
        -------
        OmegaP : array_like
            Array containing power_beam_int, shape: (Nbeam_freqs, Npols).

        OmegaPP : array_like
            Array containing power_sq_beam_int, shape: (Nbeam_freqs, Npols).
        """
        # Unpack polpairs into tuples
        if not isinstance(polpairs, (list, np.ndarray)):
            if isinstance(polpairs, (tuple, int, np.integer)):
                polpairs = [
                    polpairs,
                ]
            else:
                raise TypeError("polpairs is not a list of integers or tuples")

        # Convert integers to tuples
        polpairs = [
            uvputils.polpair_int2tuple(p) if isinstance(
                p, (int, np.integer, np.int32)) else p for p in polpairs
        ]

        # Calculate Omegas for each pol pair
        OmegaP, OmegaPP = [], []
        for pol1, pol2 in polpairs:
            if isinstance(pol1, (int, np.integer)):
                pol1 = uvutils.polnum2str(pol1)
            if isinstance(pol2, (int, np.integer)):
                pol2 = uvutils.polnum2str(pol2)

            # Check for cross-pol; only same-pol calculation currently supported
            if pol1 != pol2:
                raise NotImplementedError(
                    "get_Omegas does not support cross-correlation between "
                    "two different visibility polarizations yet. "
                    "Could not calculate Omegas for (%s, %s)" % (pol1, pol2))

            # Calculate Omegas
            OmegaP.append(self.power_beam_int(pol=pol1))
            OmegaPP.append(self.power_beam_sq_int(pol=pol1))

        OmegaP = np.array(OmegaP).T
        OmegaPP = np.array(OmegaPP).T
        return OmegaP, OmegaPP
Ejemplo n.º 3
0
def test_pol_funcs_x_orientation():
    """ Test utility functions to convert between polarization strings and numbers with x_orientation """

    pol_nums = [-8, -7, -6, -5, -4, -3, -2, -1, 1, 2, 3, 4]

    x_orient1 = 'e'
    pol_str = [
        'ne', 'en', 'nn', 'ee', 'lr', 'rl', 'll', 'rr', 'pI', 'pQ', 'pU', 'pV'
    ]
    assert pol_nums == uvutils.polstr2num(pol_str, x_orientation=x_orient1)
    assert pol_str == uvutils.polnum2str(pol_nums, x_orientation=x_orient1)
    # Check individuals
    assert -6 == uvutils.polstr2num('NN', x_orientation=x_orient1)
    assert 'pV' == uvutils.polnum2str(4)
    # Check errors
    pytest.raises(KeyError, uvutils.polstr2num, 'foo', x_orientation=x_orient1)
    pytest.raises(ValueError, uvutils.polstr2num, 1, x_orientation=x_orient1)
    pytest.raises(ValueError, uvutils.polnum2str, 7.3, x_orientation=x_orient1)
    # Check parse
    assert uvutils.parse_polstr("eE", x_orientation=x_orient1) == 'ee'
    assert uvutils.parse_polstr("xx", x_orientation=x_orient1) == 'ee'
    assert uvutils.parse_polstr("NN", x_orientation=x_orient1) == 'nn'
    assert uvutils.parse_polstr("yy", x_orientation=x_orient1) == 'nn'
    assert uvutils.parse_polstr('i', x_orientation=x_orient1) == 'pI'

    x_orient2 = 'n'
    pol_str = [
        'en', 'ne', 'ee', 'nn', 'lr', 'rl', 'll', 'rr', 'pI', 'pQ', 'pU', 'pV'
    ]
    assert pol_nums == uvutils.polstr2num(pol_str, x_orientation=x_orient2)
    assert pol_str == uvutils.polnum2str(pol_nums, x_orientation=x_orient2)
    # Check individuals
    assert -6 == uvutils.polstr2num('EE', x_orientation=x_orient2)
    assert 'pV' == uvutils.polnum2str(4)
    # Check errors
    pytest.raises(KeyError, uvutils.polstr2num, 'foo', x_orientation=x_orient2)
    pytest.raises(ValueError, uvutils.polstr2num, 1, x_orientation=x_orient2)
    pytest.raises(ValueError, uvutils.polnum2str, 7.3, x_orientation=x_orient2)
    # Check parse
    assert uvutils.parse_polstr("nN", x_orientation=x_orient2) == 'nn'
    assert uvutils.parse_polstr("xx", x_orientation=x_orient2) == 'nn'
    assert uvutils.parse_polstr("EE", x_orientation=x_orient2) == 'ee'
    assert uvutils.parse_polstr("yy", x_orientation=x_orient2) == 'ee'
    assert uvutils.parse_polstr('i', x_orientation=x_orient2) == 'pI'

    # check warnings for non-recognized x_orientation
    assert uvtest.checkWarnings(uvutils.polstr2num, ['xx'],
                                {'x_orientation': 'foo'},
                                message='x_orientation not recognized') == -5
    assert uvtest.checkWarnings(uvutils.polnum2str, [-6],
                                {'x_orientation': 'foo'},
                                message='x_orientation not recognized') == 'yy'
Ejemplo n.º 4
0
def test_pol_funcs():
    """ Test utility functions to convert between polarization strings and numbers """

    pol_nums = [-8, -7, -6, -5, -4, -3, -2, -1, 1, 2, 3, 4]
    pol_str = [
        'YX', 'XY', 'YY', 'XX', 'LR', 'RL', 'LL', 'RR', 'pI', 'pQ', 'pU', 'pV'
    ]
    nt.assert_equal(pol_nums, uvutils.polstr2num(pol_str))
    nt.assert_equal(pol_str, uvutils.polnum2str(pol_nums))
    # Check individuals
    nt.assert_equal(-6, uvutils.polstr2num('YY'))
    nt.assert_equal('pV', uvutils.polnum2str(4))
    # Check errors
    nt.assert_raises(KeyError, uvutils.polstr2num, 'foo')
    nt.assert_raises(ValueError, uvutils.polstr2num, 1)
    nt.assert_raises(ValueError, uvutils.polnum2str, 7.3)
Ejemplo n.º 5
0
def test_pol_funcs():
    """ Test utility functions to convert between polarization strings and numbers """

    pol_nums = [-8, -7, -6, -5, -4, -3, -2, -1, 1, 2, 3, 4]
    pol_str = [
        'yx', 'xy', 'yy', 'xx', 'lr', 'rl', 'll', 'rr', 'pI', 'pQ', 'pU', 'pV'
    ]
    assert pol_nums == uvutils.polstr2num(pol_str)
    assert pol_str == uvutils.polnum2str(pol_nums)
    # Check individuals
    assert -6 == uvutils.polstr2num('YY')
    assert 'pV' == uvutils.polnum2str(4)
    # Check errors
    pytest.raises(KeyError, uvutils.polstr2num, 'foo')
    pytest.raises(ValueError, uvutils.polstr2num, 1)
    pytest.raises(ValueError, uvutils.polnum2str, 7.3)
    # Check parse
    assert uvutils.parse_polstr("xX") == 'xx'
    assert uvutils.parse_polstr("XX") == 'xx'
    assert uvutils.parse_polstr('i') == 'pI'
Ejemplo n.º 6
0
def polpair_int2tuple(polpair, pol_strings=False):
    """
    Convert a pol-pair integer into a tuple pair of polarization
    integers. See polpair_tuple2int for more details.

    Parameters
    ----------
    polpair : int or list of int
        Integer representation of polarization pair.

    pol_strings : bool, optional
        If True, return polarization pair tuples with polarization strings.
        Otherwise, use polarization integers. Default: False.

    Returns
    -------
    polpair : tuple, length 2
        A length-2 tuple containing a pair of polarization
        integers, e.g. (-5, -5).
    """
    # Recursive evaluation
    if isinstance(polpair, (list, np.ndarray)):
        return [polpair_int2tuple(p, pol_strings=pol_strings) for p in polpair]

    # Check for integer type
    assert isinstance(polpair, (int, np.integer)), \
        "polpair must be integer: %s" % type(polpair)

    # Split into pol1 and pol2 integers
    pol1 = int(str(polpair)[:-2]) - 20
    pol2 = int(str(polpair)[-2:]) - 20

    # Check that pol1 and pol2 are in the allowed range (-8, 4)
    if (pol1 < -8 or pol1 > 4) or (pol2 < -8 or pol2 > 4):
        raise ValueError("polpair integer evaluates to an invalid "
                         "polarization pair: (%d, %d)" % (pol1, pol2))
    # Convert to strings if requested
    if pol_strings:
        return (polnum2str(pol1), polnum2str(pol2))
    else:
        return (pol1, pol2)
Ejemplo n.º 7
0
def test_pol_funcs():
    """ Test utility functions to convert between polarization strings and numbers """

    pol_nums = [-8, -7, -6, -5, -4, -3, -2, -1, 1, 2, 3, 4]
    pol_str = [
        'yx', 'xy', 'yy', 'xx', 'lr', 'rl', 'll', 'rr', 'pI', 'pQ', 'pU', 'pV'
    ]
    nt.assert_equal(pol_nums, uvutils.polstr2num(pol_str))
    nt.assert_equal(pol_str, uvutils.polnum2str(pol_nums))
    # Check individuals
    nt.assert_equal(-6, uvutils.polstr2num('YY'))
    nt.assert_equal('pV', uvutils.polnum2str(4))
    # Check errors
    nt.assert_raises(KeyError, uvutils.polstr2num, 'foo')
    nt.assert_raises(ValueError, uvutils.polstr2num, 1)
    nt.assert_raises(ValueError, uvutils.polnum2str, 7.3)
    # Check parse
    nt.assert_equal(uvutils.parse_polstr("xX"), 'xx')
    nt.assert_equal(uvutils.parse_polstr("XX"), 'xx')
    nt.assert_equal(uvutils.parse_polstr('i'), 'pI')
    nt.assert_equal(uvutils.parse_jpolstr('x'), 'Jxx')
    nt.assert_equal(uvutils.parse_jpolstr('xy'), 'Jxy')
    nt.assert_equal(uvutils.parse_jpolstr('XY'), 'Jxy')
Ejemplo n.º 8
0
args = parser.parse_args()

uv = pyuvdata.UVData()
uv.read_miriad(args.file)
antpos = uv.antenna_positions + uv.telescope_location
antpos = uvutils.ENU_from_ECEF(antpos.T, *uv.telescope_location_lat_lon_alt).T

amps = np.zeros(uv.Nants_telescope)
for ant in range(uv.Nants_telescope):
    d = uv.get_data((uv.antenna_numbers[ant], uv.antenna_numbers[ant]))
    amps[ant] = np.median(np.abs(d))

at_time = time.Time(uv.extra_keywords['obsid'], format='gps')
h = sys_handling.Handling()
pol = uvutils.polnum2str(uv.polarization_array[0])[0]
if pol == 'X':
    pol = 'e'
else:
    pol = 'n'

f = plt.figure(figsize=(10, 8))
plt.scatter(antpos[:, 0], antpos[:, 1], c=amps)
plt.clim([0, amps.max()])
plt.colorbar()
receiverators = []
pams = []
texts = []
for ant in range(uv.Nants_telescope):
    pam = h.get_pam_info(uv.antenna_names[ant], at_time)
    text = (str(uv.antenna_numbers[ant]) + pol + '\n' + pam[pol][0] + '\n' +
Ejemplo n.º 9
0
def coherent_average_vis(uvd_in, wgt_by_nsample=True, bl_error_tol=1., 
                         inplace=False):
    """
    Coherently average together visibilities in redundant groups.
    
    Parameters
    ----------
    uvd_in : UVData
        Visibility data (should already be calibrated).
    
    wgt_by_nsample : bool, optional
        Whether to weight the average by the number of samples (nsamples) array. 
        If False, uses an unweighted average. Default: True.
    
    bl_error_tol : float, optional
        Tolerance in baseline length (in meters) to use when grouping baselines 
        into redundant groups. Default: 1.
    
    inplace : bool, optional
        Whether to do the averaging in-place, or on a new copy of the UVData 
        object.
    
    Returns
    -------
    uvd_avg : UVData
        UVData object containing averaged visibilities. The averages are 
        assigned to the first baseline in each redundant group (the other 
        baselines in the group are removed).
    """
    # Whether to work in-place or not
    if inplace:
        uvd = uvd_in
    else:
        uvd = copy.deepcopy(uvd_in)
    
    # Get antenna positions and polarizations
    antpos, ants = uvd.get_ENU_antpos()
    antposd = dict(zip(ants, antpos))
    pols = [uvutils.polnum2str(pol) for pol in uvd.polarization_array]

    # Get redundant groups
    reds = hc.redcal.get_pos_reds(antposd, bl_error_tol=bl_error_tol)

    # Eliminate baselines not in data
    antpairs = uvd.get_antpairs()
    reds = [[bl for bl in blg if bl in antpairs] for blg in reds]
    reds = [blg for blg in reds if len(blg) > 0]
    
    # Iterate over redundant groups and polarizations and perform average
    for pol in pols:
        for blg in reds:
            # Get data and weight arrays for this pol-blgroup
            d = np.asarray([uvd.get_data(bl + (pol,)) for bl in blg])
            f = np.asarray([(~uvd.get_flags(bl + (pol,))).astype(np.float) 
                            for bl in blg])
            n = np.asarray([uvd.get_nsamples(bl + (pol,)) for bl in blg])
            if wgt_by_nsample:
                w = f * n
            else:
                w = f
            
            # Take the weighted average
            wsum = np.sum(w, axis=0).clip(1e-10, np.inf)
            davg = np.sum(d * w, axis=0) / wsum
            navg = np.sum(n, axis=0)
            favg = np.isclose(wsum, 0.0)
            
            # Replace in UVData with first bl of blg
            bl_inds = uvd.antpair2ind(blg[0])
            polind = pols.index(pol)
            uvd.data_array[bl_inds, 0, :, polind] = davg
            uvd.flag_array[bl_inds, 0, :, polind] = favg
            uvd.nsample_array[bl_inds, 0, :, polind] = navg

    # Select out averaged bls
    bls = hp.utils.flatten([[blg[0] + (pol,) for pol in pols] for blg in reds])
    uvd.select(bls=bls)
    return uvd
Ejemplo n.º 10
0
def red_average(data, reds=None, bl_tol=1.0, inplace=False,
                wgts=None, flags=None, nsamples=None):
    """
    Redundantly average visibilities in a DataContainer, HERAData or UVData object.
    Average is weighted by integration_time * nsamples * ~flags unless wgts are fed.

    Args:
        data : DataContainer, HERAData or UVData object
            Object to redundantly average
        reds : list, optional
            Nested lists of antpair tuples to redundantly average.
            E.g. [ [(1, 2), (2, 3)], [(1, 3), (2, 4)], ...]
            If None, will calculate these from the metadata
        bl_tol : float
            Baseline redundancy tolerance in meters. Only used if reds is None.
        inplace : bool
            Perform average and downselect inplace, otherwise returns a deepcopy.
            The first baseline in each reds sublist is kept.
        wgts : DataContainer
            Manual weights to use in redundant average. This supercedes flags and nsamples
            If provided, and will also be used if input data is a UVData or a subclass of it.
        flags : DataContainer
            If data is a DataContainer, these are its flags. Default (None) is no flags.
        nsamples : DataContainer
            If data is a DataContainer, these are its nsamples. Default (None) is 1.0 for all pixels.
            Furthermore, if data is a DataContainer, integration_time is 1.0 for all pixels.

    Returns:
        if fed a DataContainer:
            DataContainer, averaged data
            DataContainer, averaged flags
            DataContainer, summed nsamples
        elif fed a HERAData or UVData:
            HERAData or UVData object, averaged data

    Notes:
        1. Different polarizations are assumed to be non-redundant.
        2. Default weighting is nsamples * integration_time * ~flags.
        3. If wgts Container is fed then they supercede flag and nsample weighting.
    """
    from hera_cal import redcal, datacontainer

    # type checks
    if not (isinstance(data, datacontainer.DataContainer) or isinstance(data, UVData)):
        raise ValueError("data must be a DataContainer or a UVData or its subclass")
    fed_container = isinstance(data, datacontainer.DataContainer)

    # fill DataContainers if necessary
    if fed_container:
        if not inplace:
            flags = copy.deepcopy(flags)
            nsamples = copy.deepcopy(nsamples)
        if flags is None:
            flags = datacontainer.DataContainer({k: np.zeros_like(data[k], np.bool) for k in data})
        if nsamples is None:
            nsamples = datacontainer.DataContainer({k: np.ones_like(data[k], np.float) for k in data})

    # get weights: if wgts are not fed, then use flags and nsamples
    if wgts is None:
        if fed_container:
            wgts = datacontainer.DataContainer({k: nsamples[k] * ~flags[k] for k in data})
        else:
            wgts = datacontainer.DataContainer({k: data.get_nsamples(k) * ~data.get_flags(k) for k in data.get_antpairpols()})

    # deepcopy
    if not inplace:
        data = copy.deepcopy(data)

    # get metadata
    if fed_container:
        pols = sorted(data.pols())
    else:
        pols = [polnum2str(pol, x_orientation=data.x_orientation) for pol in data.polarization_array]

    # get redundant groups
    if reds is None:
        # if DataContainer, check for antpos
        if fed_container:
            if not hasattr(data, 'antpos') or data.antpos is None:
                raise ValueError("DataContainer must have antpos dictionary to calculate reds")
            antposd = data.antpos
        else:
            antpos, ants = data.get_ENU_antpos()
            antposd = dict(zip(ants, antpos))
        reds = redcal.get_pos_reds(antposd, bl_error_tol=bl_tol)

    # eliminate baselines not in data
    if fed_container:
        antpairs = sorted(data.antpairs())
    else:
        antpairs = data.get_antpairs()
    reds = [[bl for bl in blg if bl in antpairs] for blg in reds]
    reds = [blg for blg in reds if len(blg) > 0]

    # iterate over redundant groups and polarizations
    for pol in pols:
        for blg in reds:
            # get data and weighting for this pol-blgroup
            if fed_container:
                d = np.asarray([data[bl + (pol,)] for bl in blg])
                f = np.asarray([(~flags[bl + (pol,)]).astype(np.float) for bl in blg])
                n = np.asarray([nsamples[bl + (pol,)] for bl in blg])
                # DataContainer can't track integration time, so no tint here
                tint = np.array([1.0])
                w = np.asarray([wgts[bl + (pol,)] for bl in blg])

            else:
                d = np.asarray([data.get_data(bl + (pol,)) for bl in blg])
                f = np.asarray([(~data.get_flags(bl + (pol,))).astype(np.float) for bl in blg])
                n = np.asarray([data.get_nsamples(bl + (pol,)) for bl in blg])
                tint = np.asarray([data.integration_time[data.antpair2ind(bl + (pol,))] for bl in blg])[:, :, None]
                w = np.asarray([wgts[bl + (pol,)] for bl in blg]) * tint

            # take the weighted average
            wsum = np.sum(w, axis=0).clip(1e-10, np.inf)  # this is the normalization
            davg = np.sum(d * w, axis=0) / wsum  # weighted average
            navg = np.sum(n * f, axis=0)         # this is the new total nsample (without flagged elements)
            fmax = np.max(f, axis=2)             # collapse along freq: marks any fully flagged integrations
            iavg = np.sum(tint.squeeze() * fmax, axis=0) / np.sum(fmax, axis=0).clip(1e-10, np.inf)
            favg = np.isclose(wsum, 0.0)         # this is getting any fully flagged pixels

            # replace with new data
            if fed_container:
                blkey = blg[0] + (pol,)
                data[blkey] = davg
                flags[blkey] = favg
                nsamples[blkey] = navg

            else:
                blinds = data.antpair2ind(blg[0])
                polind = pols.index(pol)
                data.data_array[blinds, 0, :, polind] = davg
                data.flag_array[blinds, 0, :, polind] = favg
                data.nsample_array[blinds, 0, :, polind] = navg
                data.integration_time[blinds] = iavg

    # select out averaged bls
    bls = [blg[0] + (pol,) for pol in pols for blg in reds]
    if fed_container:
        for bl in list(data.keys()):
            if bl not in bls:
                del data[bl]
    else:
        data.select(bls=bls)

    if not inplace:
        if fed_container:
            return data, flags, nsamples
        else:
            return data
Ejemplo n.º 11
0
def pbcorr(modelname):

    import numpy as np
    import astropy.io.fits as fits
    from astropy import wcs
    from pyuvdata import UVBeam, utils as uvutils
    import os
    import sys
    import glob
    import argparse
    import shutil
    import copy
    import healpy
    import scipy.stats as stats
    from casa_imaging import casa_utils
    from scipy import interpolate
    from astropy.time import Time
    from astropy import coordinates as crd
    from astropy import units as u

    _fitsfiles = ["{}.fits".format(modelname)]

    # PB args
    _multiply = True
    _lon = p.longitude
    _lat = p.latitude
    _time = p.time

    # beam args
    _beamfile = p.beamfile
    _pols = -5, -6
    _freq_interp_kind = 'cubic'

    # IO args
    _ext = ''
    _outdir = p.out_dir
    _overwrite = True
    _silence = False
    _spec_cube = False

    def echo(message, type=0):
        if verbose:
            if type == 0:
                print(message)
            elif type == 1:
                print('\n{}\n{}'.format(message, '-' * 40))

    verbose = _silence == False

    # load pb
    echo("...loading beamfile {}".format(_beamfile))
    # load beam
    uvb = UVBeam()
    uvb.read_beamfits(_beamfile)
    if uvb.pixel_coordinate_system == 'healpix':
        uvb.interpolation_function = 'healpix_simple'
    else:
        uvb.interpolation_function = 'az_za_simple'
    uvb.freq_interp_kind = _freq_interp_kind

    # get beam models and beam parameters
    beam_freqs = uvb.freq_array.squeeze() / 1e6
    Nbeam_freqs = len(beam_freqs)

    # iterate over FITS files
    for i, ffile in enumerate(_fitsfiles):

        # create output filename
        if _outdir is None:
            output_dir = os.path.dirname(ffile)
        else:
            output_dir = _outdir

        output_fname = os.path.basename(ffile)
        output_fname = os.path.splitext(output_fname)
        if _ext is not None:
            output_fname = output_fname[0] + '.pbcorr{}'.format(
                _ext) + output_fname[1]
        else:
            output_fname = output_fname[0] + '.pbcorr' + output_fname[1]
        output_fname = os.path.join(output_dir, output_fname)

        # check for overwrite
        if os.path.exists(output_fname) and _overwrite is False:
            raise IOError("{} exists, not overwriting".format(output_fname))

        # load hdu
        echo("...loading {}".format(ffile))
        hdu = fits.open(ffile)

        # get header and data
        head = hdu[0].header
        data = hdu[0].data

        # get polarization info
        ra, dec, pol_arr, data_freqs, stok_ax, freq_ax = casa_utils.get_hdu_info(
            hdu)
        Ndata_freqs = len(data_freqs)

        # get axes info
        npix1 = head["NAXIS1"]
        npix2 = head["NAXIS2"]
        nstok = head["NAXIS{}".format(stok_ax)]
        nfreq = head["NAXIS{}".format(freq_ax)]

        # replace with forced polarization if provided
        if _pols is not None:
            pol_arr = np.asarray(_pols, dtype=np.int)

        pols = [uvutils.polnum2str(pol) for pol in pol_arr]

        # make sure required pols exist in maps
        if not np.all([p in uvb.polarization_array for p in pol_arr]):
            raise ValueError(
                "Required polarizationns {} not found in Beam polarization array"
                .format(pol_arr))

        # convert from equatorial to spherical coordinates
        loc = crd.EarthLocation(lat=_lat * u.degree, lon=_lon * u.degree)
        time = Time(_time, format='jd', scale='utc')
        equatorial = crd.SkyCoord(ra=ra * u.degree,
                                  dec=dec * u.degree,
                                  frame='fk5',
                                  location=loc,
                                  obstime=time)
        altaz = equatorial.transform_to('altaz')
        theta = np.abs(altaz.alt.value - 90.0)
        phi = altaz.az.value

        # convert to radians
        theta *= np.pi / 180
        phi *= np.pi / 180

        if i == 0 or _spec_cube is False:
            # evaluate primary beam
            echo("...evaluating PB")
            pb, _ = uvb.interp(phi.ravel(),
                               theta.ravel(),
                               polarizations=pols,
                               reuse_spline=True)
            pb = np.abs(pb.reshape((len(pols), Nbeam_freqs) + phi.shape))

        # interpolate primary beam onto data frequencies
        echo("...interpolating PB")
        pb_shape = (pb.shape[1], pb.shape[2])
        pb_interp = interpolate.interp1d(beam_freqs,
                                         pb,
                                         axis=1,
                                         kind=_freq_interp_kind,
                                         fill_value='extrapolate')(data_freqs /
                                                                   1e6)

        # data shape is [naxis4, naxis3, naxis2, naxis1]
        if freq_ax == 4:
            pb_interp = np.moveaxis(pb_interp, 0, 1)

        # divide or multiply by primary beam
        if _multiply is True:
            echo("...multiplying PB into image")
            data_pbcorr = data * pb_interp
        else:
            echo("...dividing PB into image")
            data_pbcorr = data / pb_interp

        # change polarization to interpolated beam pols
        head["CRVAL{}".format(stok_ax)] = pol_arr[0]
        if len(pol_arr) == 1:
            step = 1
        else:
            step = np.diff(pol_arr)[0]
        head["CDELT{}".format(stok_ax)] = step
        head["NAXIS{}".format(stok_ax)] = len(pol_arr)

        echo("...saving {}".format(output_fname))
        fits.writeto(output_fname, data_pbcorr, head, overwrite=True)

        output_pb = output_fname.replace(".pbcorr.", ".pb.")
        echo("...saving {}".format(output_pb))
        fits.writeto(output_pb, pb_interp, head, overwrite=True)

        return
Ejemplo n.º 12
0
def run_simulation_partial_freq(
    freq_chans,
    uvh5_file,
    skymod_file,
    fov=180,
    beam=None,
    beam_kwargs={},
    beam_freq_interp="linear",
    smooth_beam=True,
    smooth_scale=2.0,
    Nprocs=1,
    add_to_history=None,
):
    """
    Run a healvis simulation on a selected range of frequency channels.

    Requires a pyuvdata.UVH5 file and SkyModel file (HDF5 format) to exist
    on disk with matching frequencies.

    Args:
        freq_chans : integer 1D array
            Frequency channel indices of uvh5_file to simulate
        uvh5_file : str or UVData
            Filepath to a UVH5 file
        skymod_file : str or SkyModel
            Filepath to a SkyModel file
        beam : str, UVbeam, PowerBeam or AnalyticBeam
            Filepath to beamfits, a UVBeam object, or PowerBeam or AnalyticBeam object
        beam_kwargs : dictionary
            If beam is a viable input to AnalyticBeam, these are its keyword arguments
        beam_freq_interp : str
            Interpolation method if beam is PowerBeam. See scipy.interpolate.interp1d fro details.
        smooth_beam : bool
            If True, and beam is PowerBeam, smooth it across frequency with a Gaussian Process
        smooth_scale : float
            If smoothing the beam, smooth it at this frequency scale [MHz]
        Nprocs : int
            Number of processes for this task
        add_to_history : str
            History string to append to file history. Default is no append to history.

    Result:
        Writes simulation result into uvh5_file
    """
    # load UVH5 metadata
    if isinstance(uvh5_file, str):
        uvd = UVData()
        uvd.read_uvh5(uvh5_file, read_data=False)
    pols = [uvutils.polnum2str(pol) for pol in uvd.polarization_array]

    # load SkyModel
    if isinstance(skymod_file, str):
        sky = sky_model.SkyModel()
        sky.read_hdf5(skymod_file, freq_chans=freq_chans, shared_memory=False)

    # Check that chosen freqs are a subset of the skymodel frequencies.
    assert np.isclose(sky.freqs, uvd.freq_array[0, freq_chans]).all(
    ), "Frequency arrays in UHV5 file {} and SkyModel file {} don't agree".format(
        uvh5_file, skymod_file)

    # setup observatory
    obs = setup_observatory_from_uvdata(
        uvd,
        fov=fov,
        set_pointings=True,
        beam=beam,
        beam_kwargs=beam_kwargs,
        freq_chans=freq_chans,
        beam_freq_interp=beam_freq_interp,
        smooth_beam=smooth_beam,
        smooth_scale=smooth_scale,
    )

    # run simulation
    visibility = []
    for pol in pols:
        # calculate visibility
        visibs, time_array, baseline_inds = obs.make_visibilities(
            sky, Nprocs=Nprocs, beam_pol=pol)
        visibility.append(visibs)

    visibility = np.moveaxis(visibility, 0, -1)
    flags = np.zeros_like(visibility, bool)
    nsamples = np.ones_like(visibility, float)

    # write to disk
    print("...writing to {}".format(uvh5_file))
    uvd.write_uvh5_part(
        uvh5_file,
        visibility,
        flags,
        nsamples,
        freq_chans=freq_chans,
        add_to_history=add_to_history,
    )
Ejemplo n.º 13
0
    for key in keys:
        ant = int(re.findall(r'visdata://(\d+)/', key)[0])
        pol = key[-2:]
        autos[(ant, pol)] = np.fromstring(redis.hgetall(key).get('data'),
                                          dtype=np.float32)
        amps[(ant, pol)] = np.median(autos[(ant, pol)])
        if args.log:
            autos_raw[(ant, pol)] = autos[(ant, pol)]
            autos[(ant, pol)] = 10.0 * np.log10(autos[(ant, pol)])
            amps[(ant, pol)] = 10.0 * np.log10(amps[(ant, pol)])
        times[(ant, pol)] = float(redis.hgetall(key).get('time', 0))
else:
    counts = {}
    for fname in args.files:
        uvd = apm.UV(fname)
        pol = uvutils.polnum2str(uvd['pol']).lower()
        uvd.select('auto', 1, 1)
        for (uvw, this_t, (ant1, ant2)), auto, fname in uvd.all(raw=True):
            try:
                counts[(ant1, pol)] += 1
                autos[(ant1, pol)] += auto
                times[(ant1, pol)] += this_t
            except KeyError:
                counts[(ant1, pol)] = 1
                autos[(ant1, pol)] = auto
                times[(ant1, pol)] = this_t
    for key in autos.keys():
        autos[key] /= counts[key]
        times[key] /= counts[key]
        amps[key] = np.median(autos[key])
        if args.log:
Ejemplo n.º 14
0
    def __init__(self, calfits_files, use_gp=True):
        """Initilize the object.

        Parameters
        ----------
        calfits_files : str or list
            Filename for a *.first.calfits file or a list of (time-ordered)
            .first.calfits files of the same polarization.
        use_gp : bool, optional
            If True, use a Gaussian process model to subtract underlying smooth
            delay solution behavior over time from fluctuations. Default is True.

        Attributes
        ----------
        UVC : pyuvdata.UVCal() object
            The resulting UVCal object from reading in the calfits files.
        Nants : int
            The number of antennas in the UVCal object.
        Ntimes : int
            The number of times in the UVCal object
        delays : ndarray, shape=(Nants, Ntimes)
            The firstcal delay solutions in nanoseconds.
        delay_avgs : ndarray, shape=(Nants,)
            The median delay solutions across time in nanoseconds.
        delay_fluctuations : ndarray, shape=(Nants, Ntimes)
            The firstcal delay solution fluctuations from the time average in
            nanoseconds.
        start_JD : float
            The integer JD of the start of the UVCal object.
        frac_JD : ndarray, shape=(Ntimes,)
            The time-stamps of each integration in units of the fraction of start_JD.
            E.g., 2457966.53433 becomes 0.53433.
        ants : ndarray, shape=(Nants,)
            The antenna numbers contained in the calfits files.
        pol : {"x", "y"}
            Polarization of the files examined.
        fc_basename : str
            Basename of the calfits file, or first calfits file in the list.
        fc_filename : str
            Filename of the calfits file, or first calfits file in the list.
        fc_filestem : str
            Filename minus extension of the calfits file, or first calfits file in the list.
        times : ndarray, shape=(Ntimes,)
            The times contained in the UVCal object.
        minutes : ndarray, shape=(Ntimes,)
            The number of minutes of the fractional JD.
        ants
            The list of antennas in the UVCal object.
        version_str : str
            The version of the hera_qm module used to generate these metrics.
        history : str
            History to append to the metrics files when writing out files.

        """
        # Instantiate UVCal and read calfits
        self.UVC = UVCal()
        self.UVC.read_calfits(calfits_files)

        self.pols = np.array([
            uvutils.polnum2str(jones, x_orientation=self.UVC.x_orientation)
            for jones in self.UVC.jones_array
        ])
        self.Npols = self.pols.size

        # get file prefix
        if isinstance(calfits_files, list):
            calfits_file = calfits_files[0]
        else:
            calfits_file = calfits_files
        self.fc_basename = os.path.basename(calfits_file)
        self.fc_filename = calfits_file
        self.fc_filestem = utils.strip_extension(self.fc_filename)

        # get other relevant arrays
        self.times = self.UVC.time_array
        self.Ntimes = len(list(set(self.times)))
        self.start_JD = np.floor(self.times).min()
        self.frac_JD = self.times - self.start_JD
        self.minutes = 24 * 60 * (self.frac_JD - self.frac_JD.min())
        self.Nants = self.UVC.Nants_data
        self.ants = self.UVC.ant_array
        self.version_str = hera_qm_version_str
        self.history = ''

        # Get the firstcal delays and/or gains and/or rotated antennas
        if self.UVC.cal_type == 'gain':
            # get delays
            freqs = self.UVC.freq_array.squeeze()
            # the unwrap is dove over the frequency axis
            fc_gains = self.UVC.gain_array[:, 0, :, :, :]
            fc_phi = np.unwrap(np.angle(fc_gains), axis=1)
            d_nu = np.median(np.diff(freqs))
            d_phi = np.median(fc_phi[:, 1:, :, :] - fc_phi[:, :-1, :, :],
                              axis=1)
            gain_slope = (d_phi / d_nu)
            self.delays = gain_slope / (-2 * np.pi)
            self.gains = fc_gains

            # get delay offsets at nu = 0 Hz, and then get rotated antennas
            self.offsets = fc_phi[:, 0, :, :] - gain_slope * freqs[0]
            # find where the offest have a difference of pi from 0
            rot_offset_bool = np.isclose(np.pi,
                                         np.mod(np.abs(self.offsets),
                                                2 * np.pi),
                                         atol=0.1).T
            rot_offset_bool = np.any(rot_offset_bool, axis=(0, 1))
            self.rot_ants = np.unique(self.ants[rot_offset_bool])

        elif self.UVC.cal_type == 'delay':
            self.delays = self.UVC.delay_array.squeeze()
            self.gains = None
            self.offsets = None
            self.rot_ants = []

        # Calculate avg delay solution and subtract to get delay_fluctuations
        delay_flags = np.all(self.UVC.flag_array, axis=(1, 2))
        self.delays = self.delays * 1e9
        self.delays[delay_flags] = np.nan
        self.delay_avgs = np.nanmedian(self.delays, axis=1, keepdims=True)
        self.delay_avgs[~np.isfinite(self.delay_avgs)] = 0
        self.delays[delay_flags] = 0
        self.delay_fluctuations = (self.delays - self.delay_avgs)

        # use gaussian process model to subtract underlying mean function
        if use_gp is True and self.sklearn_import is True:
            # initialize GP kernel.
            # RBF is a squared exponential kernel with a minimum length_scale_bound of 0.01 JD, meaning
            # the GP solution won't have time fluctuations quicker than ~0.01 JD, which will preserve
            # short time fluctuations. WhiteKernel is a Gaussian white noise component with a fiducial
            # noise level of 0.01 nanoseconds. Both of these are hyperparameters that are fit for via
            # a gradient descent algorithm in the GP.fit() routine, so length_scale=0.2 and
            # noise_level=0.01 are just initial conditions and are not the final hyperparameter solution
            kernel = (gp.kernels.RBF(length_scale=0.2,
                                     length_scale_bounds=(0.01, 1.0)) +
                      gp.kernels.WhiteKernel(noise_level=0.01))
            xdata = self.frac_JD.reshape(-1, 1)
            self.delay_smooths = copy.copy(self.delay_fluctuations)
            # iterate over each antenna
            for anti in range(self.Nants):
                # get ydata
                ydata = copy.copy(self.delay_fluctuations[anti, :, :])
                # scale by std
                ystd = np.sqrt([
                    astats.biweight_midvariance(
                        ydata[~delay_flags[anti, :, ip], ip])
                    for ip in range(self.Npols)
                ])
                ydata /= ystd
                GP = gp.GaussianProcessRegressor(kernel=kernel,
                                                 n_restarts_optimizer=0)
                for pol_cnt in range(self.Npols):
                    if np.all(np.isfinite(ydata[..., pol_cnt])):
                        # fit GP and remove from delay fluctuations but only one polarization at a time
                        GP.fit(xdata, ydata[..., pol_cnt])
                        ymodel = (GP.predict(xdata) * ystd[pol_cnt])
                        self.delay_fluctuations[anti, :, pol_cnt] -= ymodel
                        self.delay_smooths[anti, :, pol_cnt] = ymodel
Ejemplo n.º 15
0
def source_extract(imfile,
                   source,
                   source_ra,
                   source_dec,
                   source_ext='',
                   radius=1,
                   gaussfit_mult=1.5,
                   rms_max_r=None,
                   rms_min_r=None,
                   pols=1,
                   plot_fit=False):

    # open fits file
    hdu = fits.open(imfile)

    # get header
    head = hdu[0].header

    # get info
    RA, DEC, pol_arr, freqs, stok_ax, freq_ax = casa_utils.get_hdu_info(hdu)
    dra, ddec = head['CDELT1'], head['CDELT2']

    # get axes info
    npix1 = head["NAXIS1"]
    npix2 = head["NAXIS2"]
    nstok = head["NAXIS{}".format(stok_ax)]
    nfreq = head["NAXIS{}".format(freq_ax)]

    # get frequency of image
    freq = head["CRVAL{}".format(freq_ax)]

    # get radius coordinates: flat-sky approx
    R = np.sqrt((RA - source_ra)**2 + (DEC - source_dec)**2)

    # select pixels
    select = R < radius

    # polarization check
    if isinstance(pols, (int, np.integer, str, np.str)):
        pols = [pols]

    # iterate over polarizations
    peak, peak_err, rms, peak_gauss_flux, int_gauss_flux = [], [], [], [], []
    for pol in pols:
        # get polstr
        if isinstance(pol, (int, np.integer)):
            polint = pol
            polstr = uvutils.polnum2str(polint)
        elif isinstance(pol, (str, np.str)):
            polstr = pol
            polint = uvutils.polstr2num(polstr)

        if polint not in pol_arr:
            raise ValueError(
                "Requested polarization {} not found in pol_arr {}".format(
                    polint, pol_arr))
        pol_ind = pol_arr.tolist().index(polint)

        # get data
        if stok_ax == 3:
            data = hdu[0].data[0, pol_ind, :, :]
        elif stok_ax == 4:
            data = hdu[0].data[pol_ind, 0, :, :]

        # get beam info for this polarization
        bmaj, bmin, bpa = casa_utils.get_beam_info(hdu, pol_ind=pol_ind)

        # check for tclean failed PSF
        if np.isclose(bmaj, bmin, 1e-6):
            raise ValueError(
                "The PSF is not defined for pol {}.".format(polstr))

        # relate FWHM of major and minor axes to standard deviation
        maj_std = bmaj / 2.35
        min_std = bmin / 2.35

        # calculate beam area in degrees^2
        # https://casa.nrao.edu/docs/CasaRef/image.fitcomponents.html
        beam_area = (bmaj * bmin * np.pi / 4 / np.log(2))

        # calculate pixel area in degrees^2
        pixel_area = np.abs(dra * ddec)
        Npix_beam = beam_area / pixel_area

        # get peak brightness within pixel radius
        _peak = np.nanmax(data[select])

        # get rms outside of source radius
        if rms_max_r is not None and rms_max_r is not None:
            rms_select = (R < rms_max_r) & (R > rms_min_r)
            _rms = np.sqrt(np.mean(data[rms_select]**2))
        else:
            _rms = np.sqrt(np.mean(data[~select]**2))

        ## fit a 2D gaussian and get integrated and peak flux statistics ##
        # recenter R array by peak flux point and get thata T array
        peak_ind = np.argmax(data[select])
        peak_ra = RA[select][peak_ind]
        peak_dec = DEC[select][peak_ind]
        X = (RA - peak_ra)
        Y = (DEC - peak_dec)
        R = np.sqrt(X**2 + Y**2)
        X[np.where(np.isclose(X, 0.0))] = 1e-5
        T = np.arctan(Y / X)

        # use synthesized beam as data mask
        ecc = maj_std / min_std
        beam_theta = bpa * np.pi / 180 + np.pi / 2
        EMAJ = R * np.sqrt(
            np.cos(T + beam_theta)**2 + ecc**2 * np.sin(T + beam_theta)**2)
        fit_mask = EMAJ < (maj_std * gaussfit_mult)
        masked_data = data.copy()
        masked_data[~fit_mask] = 0.0

        # fit 2d gaussian
        gauss_init = mod.functional_models.Gaussian2D(_peak,
                                                      peak_ra,
                                                      peak_dec,
                                                      x_stddev=maj_std,
                                                      y_stddev=min_std)
        fitter = mod.fitting.LevMarLSQFitter()
        gauss_fit = fitter(gauss_init, RA[fit_mask], DEC[fit_mask],
                           data[fit_mask])

        # get gaussian fit properties
        _peak_gauss_flux = gauss_fit.amplitude.value
        P = np.array([X, Y]).T
        beam_theta -= np.pi / 2  # correct for previous + np.pi/2
        Prot = P.dot(
            np.array([[np.cos(beam_theta), -np.sin(beam_theta)],
                      [np.sin(beam_theta),
                       np.cos(beam_theta)]]))
        gauss_cov = np.array([[gauss_fit.x_stddev.value**2, 0],
                              [0, gauss_fit.y_stddev.value**2]])
        # try to get integrated flux
        try:
            model_gauss = stats.multivariate_normal.pdf(Prot,
                                                        mean=np.array([0, 0]),
                                                        cov=gauss_cov)
            model_gauss *= gauss_fit.amplitude.value / model_gauss.max()
            nanmask = ~np.isnan(model_gauss)
            _int_gauss_flux = np.nansum(model_gauss) / Npix_beam
        except:
            model_gauss = np.zeros_like(data)
            _int_gauss_flux = 0

        # get peak error
        # http://www.gb.nrao.edu/~bmason/pubs/m2mapspeed.pdf
        beam = np.exp(-((X / maj_std)**2 + (Y / min_std)**2))
        _peak_err = _rms / np.sqrt(np.sum(beam**2))

        # append
        peak.append(_peak)
        peak_err.append(_peak_err)
        rms.append(_rms)
        peak_gauss_flux.append(_peak_gauss_flux)
        int_gauss_flux.append(_int_gauss_flux)

        # plot
        if plot_fit:
            # get postage cutout
            ra_axis = RA[npix1 // 2]
            dec_axis = DEC[:, npix2 // 2]
            ra_select = np.where(np.abs(ra_axis - source_ra) < radius)[0]
            dec_select = np.where(np.abs(dec_axis - source_dec) < radius)[0]
            d = data[ra_select[0]:ra_select[-1] + 1,
                     dec_select[0]:dec_select[-1] + 1]
            m = model_gauss[ra_select[0]:ra_select[-1] + 1,
                            dec_select[0]:dec_select[-1] + 1]

            # setup wcs and figure
            wcs = WCS(head, naxis=2)
            fig = plt.figure(figsize=(14, 5))
            fig.subplots_adjust(wspace=0.2)
            fig.suptitle("Source {} from {}\n{:.2f} MHz".format(
                source, imfile, freq / 1e6),
                         fontsize=10)

            # make 3D plot
            if mplot:
                ax = fig.add_subplot(131, projection='3d')
                ax.axis('off')
                x, y = np.meshgrid(ra_select, dec_select)
                ax.plot_wireframe(x,
                                  y,
                                  m,
                                  color='steelblue',
                                  lw=2,
                                  rcount=20,
                                  ccount=20,
                                  alpha=0.75)
                ax.plot_surface(x,
                                y,
                                d,
                                rcount=40,
                                ccount=40,
                                cmap='magma',
                                alpha=0.5)

            # plot cut-out
            ax = fig.add_subplot(132, projection=wcs)
            cax = ax.imshow(data, origin='lower', cmap='magma')
            ax.contour(fit_mask, origin='lower', colors='lime', levels=[0.5])
            ax.contour(model_gauss,
                       origin='lower',
                       colors='snow',
                       levels=np.array([0.5, 0.9]) * np.nanmax(m))
            ax.grid(color='w')
            cbar = fig.colorbar(cax, ax=ax)
            [tl.set_size(8) for tl in cbar.ax.yaxis.get_ticklabels()]
            [tl.set_size(10) for tl in ax.get_xticklabels()]
            [tl.set_size(10) for tl in ax.get_yticklabels()]
            ax.set_xlim(ra_select[0], ra_select[-1] + 1)
            ax.set_ylim(dec_select[0], dec_select[-1] + 1)
            ax.set_xlabel('Right Ascension', fontsize=12)
            ax.set_ylabel('Declination', fontsize=12)
            ax.set_title("Source Flux and Gaussian Fit", fontsize=10)

            # plot residual
            ax = fig.add_subplot(133, projection=wcs)
            resid = data - model_gauss
            vlim = np.abs(resid[fit_mask]).max()
            cax = ax.imshow(resid,
                            origin='lower',
                            cmap='magma',
                            vmin=-vlim,
                            vmax=vlim)
            ax.contour(fit_mask, origin='lower', colors='lime', levels=[0.5])
            ax.grid(color='w')
            ax.set_xlabel('Right Ascension', fontsize=12)
            cbar = fig.colorbar(cax, ax=ax)
            cbar.set_label(head['BUNIT'], fontsize=10)
            [tl.set_size(8) for tl in cbar.ax.yaxis.get_ticklabels()]
            [tl.set_size(10) for tl in ax.get_xticklabels()]
            [tl.set_size(10) for tl in ax.get_yticklabels()]
            ax.set_xlim(ra_select[0], ra_select[-1] + 1)
            ax.set_ylim(dec_select[0], dec_select[-1] + 1)
            ax.set_title("Residual", fontsize=10)

            fig.savefig('{}.{}.png'.format(
                os.path.splitext(imfile)[0], source + source_ext))
            plt.close()

    peak = np.asarray(peak)
    peak_err = np.asarray(peak_err)
    rms = np.asarray(rms)
    peak_gauss_flux = np.asarray(peak_gauss_flux)
    int_gauss_flux = np.asarray(int_gauss_flux)

    return peak, peak_err, rms, peak_gauss_flux, int_gauss_flux, freq
Ejemplo n.º 16
0

def _comply_vispol(pol):
    '''Maps an input visibility polarization string onto a string compliant with pyuvdata
    and hera_cal.'''
    if _is_cardinal(pol):
        return polnum2str(polstr2num(pol, x_orientation='north'), x_orientation='north')
    else:
        return polnum2str(polstr2num(pol))


_VISPOLS = set([pol for pol in list(POL_STR2NUM_DICT.keys()) if polstr2num(pol) < 0])
# Add east/north polarizations to _VISPOLS while relying only on pyuvdata definitions
for pol in copy.deepcopy(_VISPOLS):
    try:
        _VISPOLS.add(polnum2str(polstr2num(pol), x_orientation='north'))
    except KeyError:
        pass
SPLIT_POL = {pol: (_comply_antpol(pol[0]), _comply_antpol(pol[1])) for pol in _VISPOLS}
JOIN_POL = {v: k for k, v in SPLIT_POL.items()}


def split_pol(pol):
    '''Splits visibility polarization string (pyuvdata's polstr) into
    antenna polarization strings (pyuvdata's jstr).'''
    return SPLIT_POL[_comply_vispol(pol)]


def join_pol(p1, p2):
    '''Joins antenna polarization strings (pyuvdata's jstr) into
    visibility polarization string (pyuvdata's polstr).'''
Ejemplo n.º 17
0
def generate_fullpol_file_list(files, pol_list):
    """Generate a list of unique JDs that have all four polarizations available.

    This function, when given a list of files, will look for the specified polarizations,
    and add the JD to the returned list if all polarizations were found. The return is a
    list of lists, where the outer list is a single JD and the inner list is a "full set"
    of polarizations, based on the polarization list provided.

    Parameters
    ----------
    files : list of str
        The list of files to look for.
    pol_list : list of str
        The list of polarizations to look for, as strings (e.g., ['xx', 'xy', 'yx',
        'yy']).

    Returns
    -------
    jd_list : list
        The list of lists of JDs where all supplied polarizations could be found.

    """
    # initialize
    file_list = []

    # Check if all input files are full-pol files
    # if so return the input files as the full list
    uvd = UVData()

    for filename in files:
        if filename.split('.')[-1] == 'uvh5':
            uvd.read_uvh5(filename, read_data=False)
        else:
            uvd.read(filename)
        # convert the polarization array to strings and compare with the
        # expected input.
        # If anyone file is not a full-pol file then this will be false.
        input_pols = uvutils.polnum2str(uvd.polarization_array,
                                        x_orientation=uvd.x_orientation)
        full_pol_check = np.array_equal(np.sort(input_pols), np.sort(pol_list))

        if not full_pol_check:
            # if a file has more than one polarization but not all expected pols
            # raise an error that mixed pols are not allowed.
            if len(input_pols) > 1:
                base_fname = os.path.basename(filename)
                raise ValueError("The file: {fname} contains {npol} "
                                 "polarizations: {pol}. "
                                 "Currently only full lists of all expected "
                                 "polarization files or lists of "
                                 "files with single polarizations in the "
                                 "name of the file (e.g. zen.JD.pol.HH.uv) "
                                 "are allowed.".format(fname=base_fname,
                                                       npol=len(input_pols),
                                                       pol=input_pols))

            else:
                # if only one polarization then try the old regex method
                # assumes all files have the same number of polarizations
                break
    del uvd

    if full_pol_check:
        # Output of this function is a list of lists of files
        # We expect all full pol files to be unique JDs so
        # turn the list of files into a list of lists of each file.
        return [[f] for f in files]

    for filename in files:
        abspath = os.path.abspath(filename)
        # need to loop through groups of JDs already present
        in_list = False
        for jd_list in file_list:
            if abspath in jd_list:
                in_list = True
                break
        if not in_list:
            # try to find the other polarizations
            pols_exist = True
            file_pol = get_pol(filename)
            dirname = os.path.dirname(abspath)
            for pol in pol_list:
                # guard against strange directory names that might contain something that
                # looks like a pol string
                fn = re.sub(file_pol, pol, filename)
                full_filename = os.path.join(dirname, fn)
                if not os.path.exists(full_filename):
                    warnings.warn("Could not find " + full_filename +
                                  "; skipping that JD")
                    pols_exist = False
                    break
            if pols_exist:
                # add all pols to file_list
                jd_list = []
                for pol in pol_list:
                    fn = re.sub(file_pol, pol, filename)
                    full_filename = os.path.join(dirname, fn)
                    jd_list.append(full_filename)
                file_list.append(jd_list)

    return file_list