Ejemplo n.º 1
0
def describe_ms(infile):
    """
    Summarize the contents of an MS directory in casacore table format

    Parameters
    ----------
    infile : str
        input filename of MS

    Returns
    -------
    pandas.core.frame.DataFrame
        Summary information
    """
    import os
    import pandas as pd
    import numpy as np
    import cngi._utils._table_conversion as tblconv
    from casatools import table as tb
    
    infile = os.path.expanduser(infile)  # does nothing if $HOME is unknown
    if not infile.endswith('/'): infile = infile + '/'

    # as part of MSv3 conversion, these columns in the main table are no longer needed
    ignorecols = ['FLAG_CATEGORY', 'FLAG_ROW', 'SIGMA', 'WEIGHT_SPECTRUM', 'DATA_DESC_ID']

    # figure out characteristics of main table from select subtables (must all be present)
    spw_xds = tblconv.convert_simple_table(infile, outfile='', subtable='SPECTRAL_WINDOW', ignore=ignorecols, nofile=True)
    pol_xds = tblconv.convert_simple_table(infile, outfile='', subtable='POLARIZATION', ignore=ignorecols, nofile=True)
    ddi_xds = tblconv.convert_simple_table(infile, outfile='', subtable='DATA_DESCRIPTION', ignore=ignorecols, nofile=True)
    ddis = list(ddi_xds['d0'].values)

    summary = pd.DataFrame([])
    spw_ids = ddi_xds.spectral_window_id.values
    pol_ids = ddi_xds.polarization_id.values
    chans = spw_xds.NUM_CHAN.values
    pols = pol_xds.NUM_CORR.values
    tb_tool = tb()
    tb_tool.open(infile, nomodify=True, lockoptions={'option': 'usernoread'})  # allow concurrent reads
    for ddi in ddis:
        print('processing ddi %i of %i' % (ddi+1, len(ddis)), end='\r')
        sorted_table = tb_tool.taql('select * from %s where DATA_DESC_ID = %i' % (infile, ddi))
        sdf = {'ddi': ddi, 'spw_id': spw_ids[ddi], 'pol_id': pol_ids[ddi], 'rows': sorted_table.nrows(),
               'times': len(np.unique(sorted_table.getcol('TIME'))),
               'baselines': len(np.unique(np.hstack([sorted_table.getcol(rr)[:,None] for rr in ['ANTENNA1', 'ANTENNA2']]), axis=0)),
               'chans': chans[spw_ids[ddi]],
               'pols': pols[pol_ids[ddi]]}
        sdf['size_MB'] = np.ceil((sdf['times']*sdf['baselines']*sdf['chans']*sdf['pols']*17) / 1024**2).astype(int)
        summary = pd.concat([summary, pd.DataFrame(sdf, index=[str(ddi)])], axis=0, sort=False)
        sorted_table.close()
    print(' '*50, end='\r')
    tb_tool.close()
    
    return summary.set_index('ddi').sort_index()
Ejemplo n.º 2
0
def convert_ms(infile,
               outfile=None,
               ddis=None,
               ignore=['HISTORY'],
               compressor=None,
               chunks=(100, 400, 32, 1),
               sub_chunks=10000,
               append=False):
    """
    Convert legacy format MS to xarray Visibility Dataset and zarr storage format

    This function requires CASA6 casatools module. The CASA MSv2 format is converted to the MSv3 schema per the
    specified definition here: https://drive.google.com/file/d/10TZ4dsFw9CconBc-GFxSeb2caT6wkmza/view?usp=sharing
    
    The MS is partitioned by DDI, which guarantees a fixed data shape per partition. This results in different subdirectories
    under the main vis.zarr folder.  There is no DDI in MSv3, so this simply serves as a partition id in the zarr directory.

    Parameters
    ----------
    infile : str
        Input MS filename
    outfile : str
        Output zarr filename. If None, will use infile name with .vis.zarr extension
    ddis : list
        List of specific DDIs to convert. DDI's are integer values, or use 'global' string for subtables. Leave as None to convert entire MS
    ignore : list
        List of subtables to ignore (case sensitive and generally all uppercase). This is useful if a particular subtable is causing errors.
        Default is None. Note: default is now temporarily set to ignore the HISTORY table due a CASA6 issue in the table tool affecting a small
        set of test cases (set back to None if HISTORY is needed)
    compressor : numcodecs.blosc.Blosc
        The blosc compressor to use when saving the converted data to disk using zarr.
        If None the zstd compression algorithm used with compression level 2.
    chunks: 4-D tuple of ints
        Shape of desired chunking in the form of (time, baseline, channel, polarization), use -1 for entire axis in one chunk. Default is (100, 400, 20, 1)
        Note: chunk size is the product of the four numbers, and data is batch processed by time axis, so that will drive memory needed for conversion.
    sub_chunks: int
        Chunking used for subtable conversion (except for POINTING which will use time/baseline dims from chunks parameter). This is a single integer
        used for the row-axis (d0) chunking only, no other dims in the subtables will be chunked.
    append : bool
        Keep destination zarr store intact and add new DDI's to it. Note that duplicate DDI's will still be overwritten. Default False deletes and replaces
        entire directory.
    Returns
    -------
    xarray.core.dataset.Dataset
      Master xarray dataset of datasets for this visibility set
    """
    import itertools
    import os
    import xarray
    import dask.array as da
    import numpy as np
    import time
    import cngi._utils._table_conversion as tblconv
    import cngi._utils._io as xdsio
    import warnings
    import importlib_metadata
    warnings.filterwarnings('ignore', category=FutureWarning)

    # parse filename to use
    infile = os.path.expanduser(infile)
    prefix = infile[:infile.rindex('.')]
    if outfile is None: outfile = prefix + '.vis.zarr'
    outfile = os.path.expanduser(outfile)

    # need to manually remove existing zarr file (if any)
    if not append:
        os.system("rm -fr " + outfile)
        os.system("mkdir " + outfile)

    # as part of MSv3 conversion, these columns in the main table are no longer needed
    ignorecols = ['FLAG_CATEGORY', 'FLAG_ROW', 'DATA_DESC_ID']
    if ignore is None: ignore = []

    # we need to assume an explicit ordering of dims
    dimorder = ['time', 'baseline', 'chan', 'pol']

    # we need the spectral window, polarization, and data description tables for processing the main table
    spw_xds = tblconv.convert_simple_table(infile,
                                           outfile='',
                                           subtable='SPECTRAL_WINDOW',
                                           ignore=ignorecols,
                                           nofile=True,
                                           add_row_id=True)
    pol_xds = tblconv.convert_simple_table(infile,
                                           outfile='',
                                           subtable='POLARIZATION',
                                           ignore=ignorecols,
                                           nofile=True)
    ddi_xds = tblconv.convert_simple_table(infile,
                                           outfile='',
                                           subtable='DATA_DESCRIPTION',
                                           ignore=ignorecols,
                                           nofile=True)

    # let's assume that each DATA_DESC_ID (ddi) is a fixed shape that may differ from others
    # form a list of ddis to process, each will be placed it in its own xarray dataset and partition
    if ddis is None: ddis = list(ddi_xds['d0'].values) + ['global']
    else: ddis = np.atleast_1d(ddis)
    xds_list = []

    # extra data selection to split autocorr and crosscorr into separate xds
    # extrasels[0] is for autocorrelation
    # extrasels[1] is for others (corsscorrelations, correlations between feeds)
    extrasels = [
        'ANTENNA1 == ANTENNA2 && FEED1 == FEED2',
        'ANTENNA1 != ANTENNA2 || FEED1 != FEED2'
    ]

    ####################################################################
    # process each selected DDI from the input MS, assume a fixed shape within the ddi (should always be true)
    # each DDI is written to its own subdirectory under the parent folder
    for extrasel, ddi in itertools.product(extrasels, ddis):
        if ddi == 'global': continue  # handled afterwards

        extra_sel_index = extrasels.index(extrasel)
        if extra_sel_index == 0:
            xds_prefix = 'xdsa'
        else:
            xds_prefix = 'xds'
        xds_name = f'{xds_prefix}{ddi}'

        ddi = int(ddi)
        print('Processing ddi', ddi, f'xds name is {xds_name}', end='\r')
        start_ddi = time.time()

        # these columns are different / absent in MSv3 or need to be handled as special cases
        msv2 = [
            'WEIGHT', 'WEIGHT_SPECTRUM', 'SIGMA', 'SIGMA_SPECTRUM', 'ANTENNA1',
            'ANTENNA2', 'UVW'
        ]

        # convert columns that are common to MSv2 and MSv3
        xds = tblconv.convert_expanded_table(infile,
                                             os.path.join(outfile, xds_name),
                                             keys={
                                                 'TIME':
                                                 'time',
                                                 ('ANTENNA1', 'ANTENNA2'):
                                                 'baseline'
                                             },
                                             subsel={'DATA_DESC_ID': ddi},
                                             timecols=['time'],
                                             dimnames={
                                                 'd2': 'chan',
                                                 'd3': 'pol'
                                             },
                                             ignore=ignorecols + msv2,
                                             compressor=compressor,
                                             chunks=chunks,
                                             nofile=False,
                                             extraselstr=extrasel)
        if len(xds.dims) == 0: continue

        # convert and append UVW separately so we can handle its special dimension
        uvw_chunks = (chunks[0], chunks[1], 3)  #No chunking over uvw_index
        uvw_xds = tblconv.convert_expanded_table(
            infile,
            os.path.join(outfile, 'tmp'),
            keys={
                'TIME': 'time',
                ('ANTENNA1', 'ANTENNA2'): 'baseline'
            },
            subsel={'DATA_DESC_ID': ddi},
            timecols=['time'],
            dimnames={'d2': 'uvw_index'},
            ignore=ignorecols + list(xds.data_vars) + msv2[:-1],
            compressor=compressor,
            chunks=uvw_chunks,
            nofile=False,
            extraselstr=extrasel)
        uvw_xds.to_zarr(os.path.join(outfile, xds_name),
                        mode='a',
                        compute=True,
                        consolidated=True)

        # convert and append the ANTENNA1 and ANTENNA2 columns separately so we can squash the unnecessary time dimension
        ant_xds = tblconv.convert_expanded_table(
            infile,
            os.path.join(outfile, 'tmp'),
            keys={
                'TIME': 'time',
                ('ANTENNA1', 'ANTENNA2'): 'baseline'
            },
            subsel={'DATA_DESC_ID': ddi},
            timecols=['time'],
            ignore=ignorecols + list(xds.data_vars) + msv2[:4] + ['UVW'],
            compressor=compressor,
            chunks=chunks[:2],
            nofile=False,
            extraselstr=extrasel)
        ant_xds = ant_xds.assign({
            'ANTENNA1': ant_xds.ANTENNA1.max(axis=0),
            'ANTENNA2': ant_xds.ANTENNA2.max(axis=0)
        }).drop_dims('time')
        ant_xds.to_zarr(os.path.join(outfile, xds_name),
                        mode='a',
                        compute=True,
                        consolidated=True)

        # now convert just the WEIGHT and WEIGHT_SPECTRUM (if preset)
        # WEIGHT needs to be expanded to full dimensionality (time, baseline, chan, pol)
        wt_xds = tblconv.convert_expanded_table(
            infile,
            os.path.join(outfile, 'tmp'),
            keys={
                'TIME': 'time',
                ('ANTENNA1', 'ANTENNA2'): 'baseline'
            },
            subsel={'DATA_DESC_ID': ddi},
            timecols=['time'],
            dimnames={},
            ignore=ignorecols + list(xds.data_vars) + msv2[-3:],
            compressor=compressor,
            chunks=chunks,
            nofile=False,
            extraselstr=extrasel)

        # MSv3 changes to weight/sigma column handling
        # 1. DATA_WEIGHT = 1/sqrt(SIGMA)
        # 2. CORRECTED_DATA_WEIGHT = WEIGHT
        # 3. if SIGMA_SPECTRUM or WEIGHT_SPECTRUM present, use them instead of SIGMA and WEIGHT
        # 4. discard SIGMA, WEIGHT, SIGMA_SPECTRUM and WEIGHT_SPECTRUM from converted ms
        # 5. set shape of DATA_WEIGHT / CORRECTED_DATA_WEIGHT to (time, baseline, chan, pol) padding as necessary
        if 'DATA' in xds.data_vars:
            if 'SIGMA_SPECTRUM' in wt_xds.data_vars:
                wt_xds = wt_xds.rename(
                    dict(zip(wt_xds.SIGMA_SPECTRUM.dims, dimorder))).assign(
                        {'DATA_WEIGHT': 1 / wt_xds.SIGMA_SPECTRUM**2})
            elif 'SIGMA' in wt_xds.data_vars:
                wts = wt_xds.SIGMA.shape[:2] + (1, ) + (
                    wt_xds.SIGMA.shape[-1], )
                wt_da = da.tile(da.reshape(wt_xds.SIGMA.data, wts),
                                (1, 1, len(xds.chan), 1)).rechunk(chunks)
                wt_xds = wt_xds.assign({
                    'DATA_WEIGHT':
                    xarray.DataArray(1 / wt_da**2, dims=dimorder)
                })
        if 'CORRECTED_DATA' in xds.data_vars:
            if 'WEIGHT_SPECTRUM' in wt_xds.data_vars:
                wt_xds = wt_xds.rename(
                    dict(zip(wt_xds.WEIGHT_SPECTRUM.dims, dimorder))).assign(
                        {'CORRECTED_DATA_WEIGHT': wt_xds.WEIGHT_SPECTRUM})
            elif 'WEIGHT' in wt_xds.data_vars:
                wts = wt_xds.WEIGHT.shape[:2] + (1, ) + (
                    wt_xds.WEIGHT.shape[-1], )
                wt_da = da.tile(da.reshape(wt_xds.WEIGHT.data, wts),
                                (1, 1, len(xds.chan), 1)).rechunk(chunks)
                wt_xds = wt_xds.assign({
                    'CORRECTED_DATA_WEIGHT':
                    xarray.DataArray(wt_da, dims=dimorder)
                })

        wt_xds = wt_xds.drop([cc for cc in msv2 if cc in wt_xds.data_vars])
        wt_xds.to_zarr(os.path.join(outfile, xds_name),
                       mode='a',
                       compute=True,
                       consolidated=True)

        # add in relevant data grouping, spw and polarization attributes
        attrs = {'data_groups': [{}]}
        if ('DATA' in xds.data_vars) and ('DATA_WEIGHT' in wt_xds.data_vars):
            attrs['data_groups'][0][str(len(attrs['data_groups'][0]))] = {
                'id': str(len(attrs['data_groups'][0])),
                'data': 'DATA',
                'uvw': 'UVW',
                'flag': 'FLAG',
                'weight': 'DATA_WEIGHT'
            }
        if ('CORRECTED_DATA' in xds.data_vars) and ('CORRECTED_DATA_WEIGHT'
                                                    in wt_xds.data_vars):
            attrs['data_groups'][0][str(len(attrs['data_groups'][0]))] = {
                'id': str(len(attrs['data_groups'][0])),
                'data': 'CORRECTED_DATA',
                'uvw': 'UVW',
                'flag': 'FLAG',
                'weight': 'CORRECTED_DATA_WEIGHT'
            }

        for dv in spw_xds.data_vars:
            attrs[dv.lower()] = spw_xds[dv].values[
                ddi_xds['spectral_window_id'].values[ddi]]
            attrs[dv.lower()] = int(attrs[dv.lower()]) if type(attrs[dv.lower(
            )]) is np.bool_ else attrs[dv.lower()]  # convert bools
        for dv in pol_xds.data_vars:
            attrs[dv.lower()] = pol_xds[dv].values[
                ddi_xds['polarization_id'].values[ddi]]
            attrs[dv.lower()] = int(attrs[dv.lower()]) if type(attrs[dv.lower(
            )]) is np.bool_ else attrs[dv.lower()]  # convert bools

        # grab the channel frequency values from the spw table data and pol idxs from the polarization table, add spw and pol ids
        chan = attrs.pop('chan_freq')[:len(xds.chan)]
        pol = attrs.pop('corr_type')[:len(xds.pol)]

        # truncate per-chan values to the actual number of channels and move to coordinates
        chan_width = xarray.DataArray(da.from_array(
            attrs.pop('chan_width')[:len(xds.chan)], chunks=chunks[2]),
                                      dims=['chan'])
        effective_bw = xarray.DataArray(da.from_array(
            attrs.pop('effective_bw')[:len(xds.chan)], chunks=chunks[2]),
                                        dims=['chan'])
        resolution = xarray.DataArray(da.from_array(
            attrs.pop('resolution')[:len(xds.chan)], chunks=chunks[2]),
                                      dims=['chan'])

        coords = {
            'chan': chan,
            'pol': pol,
            'spw_id': [ddi_xds['spectral_window_id'].values[ddi]],
            'pol_id': [ddi_xds['polarization_id'].values[ddi]],
            'chan_width': chan_width,
            'effective_bw': effective_bw,
            'resolution': resolution
        }
        aux_xds = xarray.Dataset(coords=coords, attrs=attrs)

        aux_xds.to_zarr(os.path.join(outfile, xds_name),
                        mode='a',
                        compute=True,
                        consolidated=True)
        xds = xarray.open_zarr(os.path.join(outfile, xds_name))

        xds_list += [(xds_name, xds)]
        print('Completed ddi %i  process time {:0.2f} s'.format(time.time() -
                                                                start_ddi) %
              ddi)

    # clean up the tmp directory created by the weight conversion to MSv3
    os.system("rm -fr " + os.path.join(outfile, 'tmp'))

    # convert other subtables to their own partitions, denoted by 'global_' prefix
    skip_tables = ['DATA_DESCRIPTION', 'SORTED_TABLE'] + ignore
    subtables = sorted([
        tt for tt in os.listdir(infile)
        if os.path.isdir(os.path.join(infile, tt)) and tt not in skip_tables
    ])
    if 'global' in ddis:
        start_ddi = time.time()
        for ii, subtable in enumerate(subtables):
            print('processing subtable %i of %i : %s' %
                  (ii, len(subtables), subtable),
                  end='\r')
            if subtable == 'POINTING':  # expand the dimensions of the pointing table
                xds_sub_list = [(subtable,
                                 tblconv.convert_expanded_table(
                                     infile,
                                     os.path.join(outfile, 'global'),
                                     subtable=subtable,
                                     keys={
                                         'TIME': 'time',
                                         'ANTENNA_ID': 'antenna_id'
                                     },
                                     timecols=['time'],
                                     chunks=chunks))]
            else:
                add_row_id = (subtable in [
                    'ANTENNA', 'FIELD', 'OBSERVATION', 'SCAN',
                    'SPECTRAL_WINDOW', 'STATE'
                ])
                xds_sub_list = [(subtable,
                                 tblconv.convert_simple_table(
                                     infile,
                                     os.path.join(outfile, 'global'),
                                     subtable,
                                     timecols=['TIME'],
                                     ignore=ignorecols,
                                     compressor=compressor,
                                     nofile=False,
                                     chunks=(sub_chunks, -1),
                                     add_row_id=add_row_id))]

            if len(xds_sub_list[-1][1].dims) != 0:
                xds_list += xds_sub_list
            #else:
            #    print('Empty Subtable:',subtable)

        print(
            'Completed subtables  process time {:0.2f} s'.format(time.time() -
                                                                 start_ddi))

    # write sw version that did this conversion to zarr directory
    try:
        version = importlib_metadata.version('cngi-prototype')
    except:
        version = '0.0.0'

    with open(outfile + '/.version', 'w') as fid:
        fid.write('cngi-protoype ' + version + '\n')

    # build the master xds to return
    mxds = xdsio.vis_xds_packager(xds_list)
    print(' ' * 50)

    return mxds
Ejemplo n.º 3
0
def convert_image(infile, outfile=None, artifacts=[], compressor=None, chunk_shape=(-1, -1, 1, 1)):
    """
    Convert legacy CASA or FITS format Image to xarray Image Dataset and zarr storage format

    This function requires CASA6 casatools module.

    Parameters
    ----------
    infile : str
        Input image filename (.image or .fits format). If taylor terms are present, they should be in the form of filename.image.tt0 and
        this infile string should be filename.image
    outfile : str
        Output zarr filename. If None, will use infile name with .img.zarr extension
    artifacts : list of str
        List of other image artifacts to include if present with infile. Use None for just the specified infile.
        Default [] uses ``['mask','model','pb','psf','residual','sumwt','weight']``
    compressor : numcodecs.blosc.Blosc
        The blosc compressor to use when saving the converted data to disk using zarr.
        If None the zstd compression algorithm used with compression level 2.
    chunk_shape: 4-D tuple of ints
        Shape of desired chunking in the form of (l, m, channels, polarization), use -1 for entire axis in one chunk. Default is (-1, -1, 1, 1)
        Note: chunk size is the product of the four numbers (up to the actual size of the dimension)

    Returns
    -------
    xarray.core.dataset.Dataset
        new xarray Datasets of Image data contents
    """
    from casatools import image as ia
    from casatools import quanta as qa
    from cngi._utils._table_conversion import convert_simple_table, convert_time
    import numpy as np
    from itertools import cycle
    import importlib_metadata
    from pandas.io.json._normalize import nested_to_record, json_normalize
    import xarray
    from numcodecs import Blosc
    import time, os, warnings
    warnings.simplefilter("ignore", category=FutureWarning)  # suppress noisy warnings about bool types

    # TODO - find and save projection type

    infile = os.path.expanduser('./'+infile[:-1]) if infile.endswith('/') else os.path.expanduser('./'+infile)
    prefix = infile[:infile.rindex('.')]
    suffix = infile[infile.rindex('.') + 1:]
    srcdir = infile[:infile.rindex('/')+1]
    if outfile == None: outfile = prefix + '.img.zarr'
    outfile = os.path.expanduser(outfile)

    if compressor is None:
        compressor = Blosc(cname='zstd', clevel=2, shuffle=0)

    tmp = os.system("rm -fr " + outfile)
    tmp = os.system("mkdir " + outfile)

    IA = ia()
    QA = qa()
    begin = time.time()
    
    # all image artifacts will go in same zarr file and share common dimensions if possible
    # check for meta data compatibility
    # store necessary coordinate conversion data
    if artifacts is None: artifacts = [suffix]
    elif len(artifacts) == 0: artifacts = ['image', 'pb', 'psf', 'residual', 'mask', 'model', 'sumwt', 'weight', 'image.pbcor']
    if suffix not in artifacts: artifacts = [suffix] + artifacts
    diftypes, mxds, artifact_dims, artifact_masks = [], xarray.Dataset(), {}, {}
    ttcount = 0

    # for each image artifact, determine what image files in the source directory are compatible with each other
    # extract the metadata from each
    # if taylor terms are present for the artifact, process metadata for first one only
    print("converting Image...")
    for imtype in artifacts:
        imagelist = sorted([srcdir+ff for ff in os.listdir(srcdir) if (srcdir+ff).startswith('%s.%s'%(prefix,imtype))])
        imagelist = [ff for ff in imagelist if ff.endswith(imtype) or ff[ff.rindex('.') + 1:].startswith('tt')]
        if len(imagelist) == 0: continue

        # find number of taylor terms for this artifact and update count for total set if necessary
        ttcount = len([ff for ff in imagelist if ff[ff.rindex('.') + 1:].startswith('tt')]) if ttcount == 0 else ttcount

        rc = IA.open(imagelist[0])
        csys = IA.coordsys()
        summary = IA.summary(list=False)  # imhead would be better but chokes on big images
        ims = IA.shape()  # image shape
        coord_names = [ss.replace(' ', '_').lower().replace('stokes', 'pol').replace('frequency', 'chan') for ss in summary['axisnames']]
        
        # compute world coordinates for spherical dimensions
        sphr_dims = [dd for dd in range(len(ims)) if QA.isangle(summary['axisunits'][dd])]
        coord_idxs = np.mgrid[[range(ims[dd]) if dd in sphr_dims else range(1) for dd in range(len(ims))]].reshape(len(ims), -1)
        coord_world = csys.toworldmany(coord_idxs.astype(float))['numeric'][sphr_dims].reshape((-1,)+tuple(ims[sphr_dims]))
        coords = dict([(coord_names[dd], (['l','m'], coord_world[di])) for di, dd in enumerate(sphr_dims)])
        if imtype == 'sumwt': coords = {}   # special case, force sumwt to only cartesian coords (chan, pol)
        
        # compute world coordinates for cartesian dimensions
        cart_dims = [dd for dd in range(len(ims)) if dd not in sphr_dims]
        coord_idxs = np.mgrid[[range(ims[dd]) if dd in cart_dims else range(1) for dd in range(len(ims))]].reshape(len(ims), -1)
        coord_world = csys.toworldmany(coord_idxs.astype(float))['numeric'][cart_dims].reshape((-1,)+tuple(ims[cart_dims]))
        for dd, cs in enumerate(list(coord_world)):
            spi = tuple([slice(None) if di == dd else slice(1) for di in range(cs.ndim)])
            coords.update(dict([(coord_names[cart_dims[dd]], cs[spi].reshape(-1))]))

        # compute the time coordinate
        dtime = csys.torecord()['obsdate']['m0']['value']
        if csys.torecord()['obsdate']['m0']['unit'] == 'd': dtime = dtime * 86400
        coords['time'] = convert_time([dtime])
        
        # check to see if this image artifact is of a compatible shape to be part of the image artifact dataset
        try:  # easiest to try to merge and let xarray figure it out
            mxds = mxds.merge(xarray.Dataset(coords=coords), compat='equals')
        except Exception:
            diftypes += [imtype]
            continue

        # store rest of image metadata as attributes (if not already in the xds
        omits = list(mxds.attrs.keys())
        omits += ['hasmask', 'masks', 'defaultmask', 'ndim', 'refpix', 'refval', 'shape', 'tileshape', 'messages']
        nested = [kk for kk in summary.keys() if isinstance(summary[kk], dict)]
        mxds = mxds.assign_attrs(dict([(kk.lower(), summary[kk]) for kk in summary.keys() if kk not in omits + nested]))
        artifact_dims[imtype] = [ss.replace('right_ascension', 'l').replace('declination', 'm') for ss in coord_names]
        artifact_masks[imtype] = summary['masks']
            
        # check for common and restoring beams
        rb = IA.restoringbeam()
        if (len(rb) > 0) and ('restoringbeam' not in mxds.attrs):
            # if there is a restoring beam, this should work
            cb = IA.commonbeam()
            mxds = mxds.assign_attrs({'commonbeam': [cb['major']['value'], cb['minor']['value'], cb['pa']['value']]})
            mxds = mxds.assign_attrs({'commonbeam_units': [cb['major']['unit'], cb['minor']['unit'], cb['pa']['unit']]})
            mxds = mxds.assign_attrs({'restoringbeam': [cb['major']['value'], cb['minor']['value'], cb['pa']['value']]})
            if 'beams' in rb:
                beams = np.array([[rbs['major']['value'], rbs['minor']['value'], rbs['positionangle']['value']]
                                                          for rbc in rb['beams'].values() for rbs in rbc.values()])
                mxds = mxds.assign_attrs({'perplanebeams':beams.reshape(len(rb['beams']),-1,3)})
                
        # parse messages for additional keys, drop duplicate info
        omits = list(mxds.attrs.keys()) + ['image_name', 'image_type', 'image_quantity', 'pixel_mask(s)', 'region(s)', 'image_units']
        for msg in summary['messages']:
            line = [tuple(kk.split(':')) for kk in msg.lower().split('\n') if ': ' in kk]
            line = [(kk[0].strip().replace(' ', '_'), kk[1].strip()) for kk in line]
            line = [ll for ll in line if ll[0] not in omits]
            mxds = mxds.assign_attrs(dict(line))
                
        rc = csys.done()
        rc = IA.close()

    print('incompatible components: ', diftypes)

    # if taylor terms are present, the chan axis must be expanded to the length of the terms
    if ttcount > len(mxds.chan): mxds = mxds.pad({'chan': (0, ttcount-len(mxds.chan))}, mode='edge')
    chunk_dict = dict(zip(['l','m','time','chan','pol'], chunk_shape[:2]+(1,)+chunk_shape[2:]))
    mxds = mxds.chunk(chunk_dict)

    # for each artifact, convert the legacy format and add to the new image set
    # masks may be stored within each image, so they will need to be handled like subtables
    for ac, imtype in enumerate(list(artifact_dims.keys())):
        for ec, ext in enumerate([''] + ['/'+ff for ff in list(artifact_masks[imtype])]):
            imagelist = sorted([srcdir + ff for ff in os.listdir(srcdir) if (srcdir + ff).startswith('%s.%s' % (prefix, imtype))])
            imagelist = [ff for ff in imagelist if ff.endswith(imtype) or ff[ff.rindex('.') + 1:].startswith('tt')]
            if len(imagelist) == 0: continue

            dimorder = ['time'] + list(reversed(artifact_dims[imtype]))
            chunkorder = [chunk_dict[vv] for vv in dimorder]
            ixds = convert_simple_table(imagelist[0]+ext, outfile+'.temp', dimnames=dimorder, compressor=compressor, chunk_shape=chunkorder)

            # if the image set has taylor terms, loop through any for this artifact and concat together
            # pad the chan dim as necessary to fill remaining elements if not enough taylor terms in this artifact
            for ii in range(1, ttcount):
                if ii < len(imagelist):
                    txds = convert_simple_table(imagelist[ii]+ext, outfile + '.temp', dimnames=dimorder, chunk_shape=chunkorder, nofile=True)
                    ixds = xarray.concat([ixds, txds], dim='chan')
                else:
                    ixds = ixds.pad({'chan': (0, 1)}, constant_values=np.nan)

            ixds = ixds.rename({list(ixds.data_vars)[0]:(imtype+ext.replace('/','_')).upper()}).transpose('l','m','time','chan','pol')
            if imtype == 'sumwt': ixds = ixds.squeeze(['l','m'], drop=True)
            if imtype == 'mask': ixds = ixds.rename({'MASK':'AUTOMASK'})  # rename mask

            encoding = dict(zip(list(ixds.data_vars), cycle([{'compressor': compressor}])))
            ixds.to_zarr(outfile, mode='w' if (ac==0) and (ec==0) else 'a', encoding=encoding, compute=True, consolidated=True)

    tmp = os.system("rm -fr " + outfile+'.temp')
    print('processed image in %s seconds' % str(np.float32(time.time() - begin)))

    # add attributes from metadata and version tag file
    mxds.to_zarr(outfile, mode='a', compute=True, consolidated=True)
    with open(outfile + '/.version', 'w') as fid:   # write sw version that did this conversion to zarr directory
        fid.write('cngi-protoype ' + importlib_metadata.version('cngi-prototype') + '\n')

    return xarray.open_zarr(outfile)
Ejemplo n.º 4
0
def convert_ms(infile,
               outfile=None,
               ddis=None,
               ignore=['HISTORY'],
               compressor=None,
               chunk_shape=(100, 400, 32, 1),
               append=False):
    """
    Convert legacy format MS to xarray Visibility Dataset and zarr storage format

    This function requires CASA6 casatools module. The CASA MSv2 format is converted to the MSv3 schema per the
    specified definition here: https://drive.google.com/file/d/10TZ4dsFw9CconBc-GFxSeb2caT6wkmza/view?usp=sharing
    
    The MS is partitioned by DDI, which guarentees a fixed data shape per partition. This results in different subdirectories
    under the main vis.zarr folder.  There is no DDI in MSv3, so this simply serves as a partition id in the zarr directory.

    Parameters
    ----------
    infile : str
        Input MS filename
    outfile : str
        Output zarr filename. If None, will use infile name with .vis.zarr extension
    ddis : list
        List of specific DDIs to convert. DDI's are integer values, or use 'global' string for subtables. Leave as None to convert entire MS
    ignore : list
        List of subtables to ignore (case sensitive and generally all uppercase). This is useful if a particular subtable is causing errors.
        Default is None. Note: default is now temporarily set to ignore the HISTORY table due a CASA6 issue in the table tool affecting a small
        set of test cases (set back to None if HISTORY is needed)
    compressor : numcodecs.blosc.Blosc
        The blosc compressor to use when saving the converted data to disk using zarr.
        If None the zstd compression algorithm used with compression level 2.
    chunk_shape: 4-D tuple of ints
        Shape of desired chunking in the form of (time, baseline, channel, polarization), use -1 for entire axis in one chunk. Default is (100, 400, 20, 1)
        Note: chunk size is the product of the four numbers, and data is batch processed by time axis, so that will drive memory needed for conversion.
    append : bool
        Keep destination zarr store intact and add new DDI's to it. Note that duplicate DDI's will still be overwritten. Default False deletes and replaces
        entire directory.
    Returns
    -------
    xarray.core.dataset.Dataset
      Master xarray dataset of datasets for this visibility set
    """
    import os
    import xarray
    import dask.array as da
    import numpy as np
    import time
    import cngi._utils._table_conversion as tblconv
    import cngi._utils._io as xdsio
    import warnings
    import importlib_metadata
    warnings.filterwarnings('ignore', category=FutureWarning)

    # parse filename to use
    infile = os.path.expanduser(infile)
    prefix = infile[:infile.rindex('.')]
    if outfile is None: outfile = prefix + '.vis.zarr'
    outfile = os.path.expanduser(outfile)

    # need to manually remove existing zarr file (if any)
    if not append:
        os.system("rm -fr " + outfile)
        os.system("mkdir " + outfile)

    # as part of MSv3 conversion, these columns in the main table are no longer needed
    ignorecols = ['FLAG_CATEGORY', 'FLAG_ROW', 'DATA_DESC_ID']
    if ignore is None: ignore = []

    # we need the spectral window, polarization, and data description tables for processing the main table
    spw_xds = tblconv.convert_simple_table(infile,
                                           outfile='',
                                           subtable='SPECTRAL_WINDOW',
                                           ignore=ignorecols,
                                           nofile=True)
    pol_xds = tblconv.convert_simple_table(infile,
                                           outfile='',
                                           subtable='POLARIZATION',
                                           ignore=ignorecols,
                                           nofile=True)
    ddi_xds = tblconv.convert_simple_table(infile,
                                           outfile='',
                                           subtable='DATA_DESCRIPTION',
                                           ignore=ignorecols,
                                           nofile=True)

    # let's assume that each DATA_DESC_ID (ddi) is a fixed shape that may differ from others
    # form a list of ddis to process, each will be placed it in its own xarray dataset and partition
    if ddis is None: ddis = list(ddi_xds['d0'].values) + ['global']
    else: ddis = np.atleast_1d(ddis)
    xds_list = []

    ####################################################################
    # process each selected DDI from the input MS, assume a fixed shape within the ddi (should always be true)
    # each DDI is written to its own subdirectory under the parent folder
    for ddi in ddis:
        if ddi == 'global': continue  # handled afterwards
        ddi = int(ddi)
        print('Processing ddi', ddi, end='\r')
        start_ddi = time.time()

        # these columns are different / absent in MSv3 or need to be handled as special cases
        msv2 = ['WEIGHT', 'WEIGHT_SPECTRUM', 'SIGMA', 'SIGMA_SPECTRUM', 'UVW']

        # convert columns that are common to MSv2 and MSv3
        xds = tblconv.convert_expanded_table(infile,
                                             os.path.join(
                                                 outfile, 'xds' + str(ddi)),
                                             keys={
                                                 'TIME':
                                                 'time',
                                                 ('ANTENNA1', 'ANTENNA2'):
                                                 'baseline'
                                             },
                                             subsel={'DATA_DESC_ID': ddi},
                                             timecols=['time'],
                                             dimnames={
                                                 'd2': 'chan',
                                                 'd3': 'pol'
                                             },
                                             ignore=ignorecols + msv2,
                                             compressor=compressor,
                                             chunk_shape=chunk_shape,
                                             nofile=False)

        # convert and append UVW separately so we can handle its special dimension
        uvw_xds = tblconv.convert_expanded_table(
            infile,
            os.path.join(outfile, 'tmp'),
            keys={
                'TIME': 'time',
                ('ANTENNA1', 'ANTENNA2'): 'baseline'
            },
            subsel={'DATA_DESC_ID': ddi},
            timecols=['time'],
            dimnames={'d2': 'uvw_index'},
            ignore=ignorecols + list(xds.data_vars) + msv2[:-1],
            compressor=compressor,
            chunk_shape=chunk_shape,
            nofile=False)
        uvw_xds.to_zarr(os.path.join(outfile, 'xds' + str(ddi)),
                        mode='a',
                        compute=True,
                        consolidated=True)

        # now convert just the WEIGHT and WEIGHT_SPECTRUM (if preset)
        # WEIGHT needs to be expanded to full dimensionality (time, baseline, chan, pol)
        wt_xds = tblconv.convert_expanded_table(infile,
                                                os.path.join(outfile, 'tmp'),
                                                keys={
                                                    'TIME':
                                                    'time',
                                                    ('ANTENNA1', 'ANTENNA2'):
                                                    'baseline'
                                                },
                                                subsel={'DATA_DESC_ID': ddi},
                                                timecols=['time'],
                                                dimnames={},
                                                ignore=ignorecols +
                                                list(xds.data_vars) + msv2[2:],
                                                compressor=compressor,
                                                chunk_shape=chunk_shape,
                                                nofile=False)

        # if WEIGHT_SPECTRUM is present, append it to the main xds as the new WEIGHT column
        # otherwise expand the dimensionality of WEIGHT and add it to the xds
        if 'WEIGHT_SPECTRUM' in wt_xds.data_vars:
            wt_xds = wt_xds.drop_vars('WEIGHT').rename(
                dict(
                    zip(wt_xds.WEIGHT_SPECTRUM.dims,
                        ['time', 'baseline', 'chan', 'pol'])))
            wt_xds.to_zarr(os.path.join(outfile, 'xds' + str(ddi)),
                           mode='a',
                           compute=True,
                           consolidated=True)
        else:
            wts = wt_xds.WEIGHT.shape[:2] + (1, ) + (wt_xds.WEIGHT.shape[-1], )
            wt_da = da.tile(da.reshape(wt_xds.WEIGHT.data, wts),
                            (1, 1, len(xds.chan), 1)).rechunk(chunk_shape)
            wt_xds = wt_xds.drop_vars('WEIGHT').assign({
                'WEIGHT':
                xarray.DataArray(wt_da,
                                 dims=['time', 'baseline', 'chan', 'pol'])
            })
            wt_xds.to_zarr(os.path.join(outfile, 'xds' + str(ddi)),
                           mode='a',
                           compute=True,
                           consolidated=True)

        # add in relevant spw and polarization attributes
        attrs = {}
        for dv in spw_xds.data_vars:
            attrs[dv.lower()] = spw_xds[dv].values[
                ddi_xds['spectral_window_id'].values[ddi]]
            attrs[dv.lower()] = int(attrs[dv.lower()]) if type(attrs[dv.lower(
            )]) is np.bool_ else attrs[dv.lower()]  # convert bools
        for dv in pol_xds.data_vars:
            attrs[dv.lower()] = pol_xds[dv].values[
                ddi_xds['polarization_id'].values[ddi]]
            attrs[dv.lower()] = int(attrs[dv.lower()]) if type(attrs[dv.lower(
            )]) is np.bool_ else attrs[dv.lower()]  # convert bools

        # grab the channel frequency values from the spw table data and pol idxs from the polarization table, add spw and pol ids
        chan = attrs.pop('chan_freq')[:len(xds.chan)]
        pol = attrs.pop('corr_type')[:len(xds.pol)]

        # truncate per-chan values to the actual number of channels and move to coordinates
        chan_width = xarray.DataArray(attrs.pop('chan_width')[:len(xds.chan)],
                                      dims=['chan'])
        effective_bw = xarray.DataArray(
            attrs.pop('effective_bw')[:len(xds.chan)], dims=['chan'])
        resolution = xarray.DataArray(attrs.pop('resolution')[:len(xds.chan)],
                                      dims=['chan'])

        coords = {
            'chan': chan,
            'pol': pol,
            'spw_id': [ddi_xds['spectral_window_id'].values[ddi]],
            'pol_id': [ddi_xds['polarization_id'].values[ddi]],
            'chan_width': chan_width,
            'effective_bw': effective_bw,
            'resolution': resolution
        }
        aux_xds = xarray.Dataset(coords=coords, attrs=attrs)

        aux_xds.to_zarr(os.path.join(outfile, 'xds' + str(ddi)),
                        mode='a',
                        compute=True,
                        consolidated=True)
        xds = xarray.open_zarr(os.path.join(outfile, 'xds' + str(ddi)))

        xds_list += [('xds' + str(ddi), xds)]
        print('Completed ddi %i  process time {:0.2f} s'.format(time.time() -
                                                                start_ddi) %
              ddi)

    # clean up the tmp directory created by the weight conversion to MSv3
    os.system("rm -fr " + os.path.join(outfile, 'tmp'))

    # convert other subtables to their own partitions, denoted by 'global_' prefix
    skip_tables = ['DATA_DESCRIPTION', 'SORTED_TABLE'] + ignore
    subtables = sorted([
        tt for tt in os.listdir(infile)
        if os.path.isdir(os.path.join(infile, tt)) and tt not in skip_tables
    ])
    if 'global' in ddis:
        start_ddi = time.time()
        for ii, subtable in enumerate(subtables):
            print('processing subtable %i of %i : %s' %
                  (ii, len(subtables), subtable),
                  end='\r')
            if subtable == 'POINTING':  # expand the dimensions of the pointing table
                xds_sub_list = [(subtable,
                                 tblconv.convert_expanded_table(
                                     infile,
                                     os.path.join(outfile, 'global'),
                                     subtable=subtable,
                                     keys={
                                         'TIME': 'time',
                                         'ANTENNA_ID': 'antenna_id'
                                     },
                                     timecols=['time'],
                                     chunk_shape=chunk_shape))]
            else:
                xds_sub_list = [(subtable,
                                 tblconv.convert_simple_table(
                                     infile,
                                     os.path.join(outfile, 'global'),
                                     subtable,
                                     timecols=['TIME'],
                                     ignore=ignorecols,
                                     compressor=compressor,
                                     nofile=False))]

            if len(xds_sub_list[-1][1].dims) != 0:
                # to conform to MSv3, we need to add explicit ID fields to certain tables
                if subtable in [
                        'ANTENNA', 'FIELD', 'OBSERVATION', 'SCAN',
                        'SPECTRAL_WINDOW', 'STATE'
                ]:
                    #if 'd0' in xds_sub_list[-1][1].dims:
                    aux_xds = xarray.Dataset(
                        coords={
                            subtable.lower() + '_id':
                            xarray.DataArray(xds_sub_list[-1][1].d0.values,
                                             dims=['d0'])
                        })
                    aux_xds.to_zarr(os.path.join(outfile,
                                                 'global/' + subtable),
                                    mode='a',
                                    compute=True,
                                    consolidated=True)
                    xds_sub_list[-1] = (subtable,
                                        xarray.open_zarr(
                                            os.path.join(
                                                outfile,
                                                'global/' + subtable)))

                xds_list += xds_sub_list
            #else:
            #    print('Empty Subtable:',subtable)

        print(
            'Completed subtables  process time {:0.2f} s'.format(time.time() -
                                                                 start_ddi))

    # write sw version that did this conversion to zarr directory
    with open(outfile + '/.version', 'w') as fid:
        fid.write('cngi-protoype ' +
                  importlib_metadata.version('cngi-prototype') + '\n')

    # build the master xds to return
    mxds = xdsio.vis_xds_packager(xds_list)
    print(' ' * 50)

    return mxds
Ejemplo n.º 5
0
def convert_table(infile,
                  outfile=None,
                  subtable=None,
                  keys=None,
                  timecols=None,
                  ignorecols=None,
                  compressor=None,
                  chunk_shape=(40000, 20, 1),
                  append=False,
                  nofile=False):
    """
    Convert casacore table format to xarray Dataset and zarr storage format.

    This function requires CASA6 casatools module. Table rows may be renamed or expanded to n-dim arrays based on column values specified in keys.

    Parameters
    ----------
    infile : str
        Input table filename
    outfile : str
        Output zarr filename. If None, will use infile name with .tbl.zarr extension
    subtable : str
        Name of the subtable to process. If None, main table will be used
    keys : dict or str
        Source column mappings to dimensions. Can be a dict mapping source columns to target dims, use a tuple when combining cols
        (ie {('ANTENNA1','ANTENNA2'):'baseline'} or a string to rename the row axis dimension to the specified value.  Default of None
    timecols : list
        list of strings specifying column names to convert to datetime format from casacore time.  Default is None
    ignorecols : list
        list of column names to ignore. This is useful if a particular column is causing errors.  Default is None
    compressor : numcodecs.blosc.Blosc
        The blosc compressor to use when saving the converted data to disk using zarr.
        If None the zstd compression algorithm used with compression level 2.
    chunk_shape : int
        Shape of desired chunking in the form of (dim0, dim1, ..., dimN), use -1 for entire axis in one chunk. Default is (80000, 10).
        Chunking is applied per column / data variable.  If too few dimensions are specified, last chunk size is reused as necessary.
        Note: chunk size is the product of the four numbers, and data is batch processed by the first axis, so that will drive memory needed for conversion.
    append : bool
        Append an xarray dataset as a new partition to an existing zarr directory.  False will overwrite zarr directory with a single new partition
    nofile : bool
        Allows legacy table to be directly read without file conversion. If set to true, no output file will be written and entire table will be held in memory.
        Requires ~4x the memory of the table size.  Default is False
    Returns
    -------
    New xarray.core.dataset.Dataset
      New xarray Dataset of table data contents. One element in list per DDI plus the metadata global.
    """
    import os
    from numcodecs import Blosc
    import importlib_metadata
    import cngi._utils._table_conversion as tblconv

    # parse filename to use
    infile = os.path.expanduser(infile)
    prefix = infile[:infile.rindex('.')]
    if outfile is None: outfile = prefix + '.tbl.zarr'
    outfile = os.path.expanduser(outfile)
    if not infile.endswith('/'): infile = infile + '/'
    if not outfile.endswith('/'): outfile = outfile + '/'
    if subtable is None: subtable = ''
    if compressor is None:
        compressor = Blosc(cname='zstd', clevel=2, shuffle=0)

    print('processing %s to %s' % (infile + subtable, outfile + subtable))

    # need to manually remove existing zarr file (if any)
    if (not nofile) and (not append):
        os.system("rm -fr " + outfile)
        os.system("mkdir " + outfile)

    if (keys is None) or (type(keys) is str):
        xds = tblconv.convert_simple_table(
            infile,
            outfile,
            subtable=subtable,
            #rowdim='d0' if keys is None else keys,
            timecols=[] if timecols is None else timecols,
            ignore=[] if ignorecols is None else ignorecols,
            compressor=compressor,
            chunk_shape=chunk_shape,
            nofile=nofile)
    else:
        xds = tblconv.convert_expanded_table(
            infile,
            outfile,
            keys=keys,
            subtable=subtable,
            subsel=None,
            timecols=[] if timecols is None else timecols,
            dimnames={},
            ignore=[] if ignorecols is None else ignorecols,
            compressor=compressor,
            chunk_shape=chunk_shape,
            nofile=nofile)

    # write sw version that did this conversion to zarr directory
    with open(outfile + '/.version', 'w') as fid:
        fid.write('cngi-protoype ' +
                  importlib_metadata.version('cngi-prototype') + '\n')

    return xds