def convert_image(infile, outfile=None, artifacts=[], compressor=None, chunk_shape=(-1, -1, 1, 1)):
    """
    Convert legacy CASA or FITS format Image to xarray Image Dataset and zarr storage format

    This function requires CASA6 casatools module.

    Parameters
    ----------
    infile : str
        Input image filename (.image or .fits format). If taylor terms are present, they should be in the form of filename.image.tt0 and
        this infile string should be filename.image
    outfile : str
        Output zarr filename. If None, will use infile name with .img.zarr extension
    artifacts : list of str
        List of other image artifacts to include if present with infile. Use None for just the specified infile.
        Default [] uses ``['mask','model','pb','psf','residual','sumwt','weight']``
    compressor : numcodecs.blosc.Blosc
        The blosc compressor to use when saving the converted data to disk using zarr.
        If None the zstd compression algorithm used with compression level 2.
    chunk_shape: 4-D tuple of ints
        Shape of desired chunking in the form of (l, m, channels, polarization), use -1 for entire axis in one chunk. Default is (-1, -1, 1, 1)
        Note: chunk size is the product of the four numbers (up to the actual size of the dimension)

    Returns
    -------
    xarray.core.dataset.Dataset
        new xarray Datasets of Image data contents
    """
    from casatools import image as ia
    from casatools import quanta as qa
    from cngi._utils._table_conversion import convert_simple_table, convert_time
    import numpy as np
    from itertools import cycle
    import importlib_metadata
    from pandas.io.json._normalize import nested_to_record, json_normalize
    import xarray
    from numcodecs import Blosc
    import time, os, warnings
    warnings.simplefilter("ignore", category=FutureWarning)  # suppress noisy warnings about bool types

    # TODO - find and save projection type

    infile = os.path.expanduser('./'+infile[:-1]) if infile.endswith('/') else os.path.expanduser('./'+infile)
    prefix = infile[:infile.rindex('.')]
    suffix = infile[infile.rindex('.') + 1:]
    srcdir = infile[:infile.rindex('/')+1]
    if outfile == None: outfile = prefix + '.img.zarr'
    outfile = os.path.expanduser(outfile)

    if compressor is None:
        compressor = Blosc(cname='zstd', clevel=2, shuffle=0)

    tmp = os.system("rm -fr " + outfile)
    tmp = os.system("mkdir " + outfile)

    IA = ia()
    QA = qa()
    begin = time.time()
    
    # all image artifacts will go in same zarr file and share common dimensions if possible
    # check for meta data compatibility
    # store necessary coordinate conversion data
    if artifacts is None: artifacts = [suffix]
    elif len(artifacts) == 0: artifacts = ['image', 'pb', 'psf', 'residual', 'mask', 'model', 'sumwt', 'weight', 'image.pbcor']
    if suffix not in artifacts: artifacts = [suffix] + artifacts
    diftypes, mxds, artifact_dims, artifact_masks = [], xarray.Dataset(), {}, {}
    ttcount = 0

    # for each image artifact, determine what image files in the source directory are compatible with each other
    # extract the metadata from each
    # if taylor terms are present for the artifact, process metadata for first one only
    print("converting Image...")
    for imtype in artifacts:
        imagelist = sorted([srcdir+ff for ff in os.listdir(srcdir) if (srcdir+ff).startswith('%s.%s'%(prefix,imtype))])
        imagelist = [ff for ff in imagelist if ff.endswith(imtype) or ff[ff.rindex('.') + 1:].startswith('tt')]
        if len(imagelist) == 0: continue

        # find number of taylor terms for this artifact and update count for total set if necessary
        ttcount = len([ff for ff in imagelist if ff[ff.rindex('.') + 1:].startswith('tt')]) if ttcount == 0 else ttcount

        rc = IA.open(imagelist[0])
        csys = IA.coordsys()
        summary = IA.summary(list=False)  # imhead would be better but chokes on big images
        ims = IA.shape()  # image shape
        coord_names = [ss.replace(' ', '_').lower().replace('stokes', 'pol').replace('frequency', 'chan') for ss in summary['axisnames']]
        
        # compute world coordinates for spherical dimensions
        sphr_dims = [dd for dd in range(len(ims)) if QA.isangle(summary['axisunits'][dd])]
        coord_idxs = np.mgrid[[range(ims[dd]) if dd in sphr_dims else range(1) for dd in range(len(ims))]].reshape(len(ims), -1)
        coord_world = csys.toworldmany(coord_idxs.astype(float))['numeric'][sphr_dims].reshape((-1,)+tuple(ims[sphr_dims]))
        coords = dict([(coord_names[dd], (['l','m'], coord_world[di])) for di, dd in enumerate(sphr_dims)])
        if imtype == 'sumwt': coords = {}   # special case, force sumwt to only cartesian coords (chan, pol)
        
        # compute world coordinates for cartesian dimensions
        cart_dims = [dd for dd in range(len(ims)) if dd not in sphr_dims]
        coord_idxs = np.mgrid[[range(ims[dd]) if dd in cart_dims else range(1) for dd in range(len(ims))]].reshape(len(ims), -1)
        coord_world = csys.toworldmany(coord_idxs.astype(float))['numeric'][cart_dims].reshape((-1,)+tuple(ims[cart_dims]))
        for dd, cs in enumerate(list(coord_world)):
            spi = tuple([slice(None) if di == dd else slice(1) for di in range(cs.ndim)])
            coords.update(dict([(coord_names[cart_dims[dd]], cs[spi].reshape(-1))]))

        # compute the time coordinate
        dtime = csys.torecord()['obsdate']['m0']['value']
        if csys.torecord()['obsdate']['m0']['unit'] == 'd': dtime = dtime * 86400
        coords['time'] = convert_time([dtime])
        
        # check to see if this image artifact is of a compatible shape to be part of the image artifact dataset
        try:  # easiest to try to merge and let xarray figure it out
            mxds = mxds.merge(xarray.Dataset(coords=coords), compat='equals')
        except Exception:
            diftypes += [imtype]
            continue

        # store rest of image metadata as attributes (if not already in the xds
        omits = list(mxds.attrs.keys())
        omits += ['hasmask', 'masks', 'defaultmask', 'ndim', 'refpix', 'refval', 'shape', 'tileshape', 'messages']
        nested = [kk for kk in summary.keys() if isinstance(summary[kk], dict)]
        mxds = mxds.assign_attrs(dict([(kk.lower(), summary[kk]) for kk in summary.keys() if kk not in omits + nested]))
        artifact_dims[imtype] = [ss.replace('right_ascension', 'l').replace('declination', 'm') for ss in coord_names]
        artifact_masks[imtype] = summary['masks']
            
        # check for common and restoring beams
        rb = IA.restoringbeam()
        if (len(rb) > 0) and ('restoringbeam' not in mxds.attrs):
            # if there is a restoring beam, this should work
            cb = IA.commonbeam()
            mxds = mxds.assign_attrs({'commonbeam': [cb['major']['value'], cb['minor']['value'], cb['pa']['value']]})
            mxds = mxds.assign_attrs({'commonbeam_units': [cb['major']['unit'], cb['minor']['unit'], cb['pa']['unit']]})
            mxds = mxds.assign_attrs({'restoringbeam': [cb['major']['value'], cb['minor']['value'], cb['pa']['value']]})
            if 'beams' in rb:
                beams = np.array([[rbs['major']['value'], rbs['minor']['value'], rbs['positionangle']['value']]
                                                          for rbc in rb['beams'].values() for rbs in rbc.values()])
                mxds = mxds.assign_attrs({'perplanebeams':beams.reshape(len(rb['beams']),-1,3)})
                
        # parse messages for additional keys, drop duplicate info
        omits = list(mxds.attrs.keys()) + ['image_name', 'image_type', 'image_quantity', 'pixel_mask(s)', 'region(s)', 'image_units']
        for msg in summary['messages']:
            line = [tuple(kk.split(':')) for kk in msg.lower().split('\n') if ': ' in kk]
            line = [(kk[0].strip().replace(' ', '_'), kk[1].strip()) for kk in line]
            line = [ll for ll in line if ll[0] not in omits]
            mxds = mxds.assign_attrs(dict(line))
                
        rc = csys.done()
        rc = IA.close()

    print('incompatible components: ', diftypes)

    # if taylor terms are present, the chan axis must be expanded to the length of the terms
    if ttcount > len(mxds.chan): mxds = mxds.pad({'chan': (0, ttcount-len(mxds.chan))}, mode='edge')
    chunk_dict = dict(zip(['l','m','time','chan','pol'], chunk_shape[:2]+(1,)+chunk_shape[2:]))
    mxds = mxds.chunk(chunk_dict)

    # for each artifact, convert the legacy format and add to the new image set
    # masks may be stored within each image, so they will need to be handled like subtables
    for ac, imtype in enumerate(list(artifact_dims.keys())):
        for ec, ext in enumerate([''] + ['/'+ff for ff in list(artifact_masks[imtype])]):
            imagelist = sorted([srcdir + ff for ff in os.listdir(srcdir) if (srcdir + ff).startswith('%s.%s' % (prefix, imtype))])
            imagelist = [ff for ff in imagelist if ff.endswith(imtype) or ff[ff.rindex('.') + 1:].startswith('tt')]
            if len(imagelist) == 0: continue

            dimorder = ['time'] + list(reversed(artifact_dims[imtype]))
            chunkorder = [chunk_dict[vv] for vv in dimorder]
            ixds = convert_simple_table(imagelist[0]+ext, outfile+'.temp', dimnames=dimorder, compressor=compressor, chunk_shape=chunkorder)

            # if the image set has taylor terms, loop through any for this artifact and concat together
            # pad the chan dim as necessary to fill remaining elements if not enough taylor terms in this artifact
            for ii in range(1, ttcount):
                if ii < len(imagelist):
                    txds = convert_simple_table(imagelist[ii]+ext, outfile + '.temp', dimnames=dimorder, chunk_shape=chunkorder, nofile=True)
                    ixds = xarray.concat([ixds, txds], dim='chan')
                else:
                    ixds = ixds.pad({'chan': (0, 1)}, constant_values=np.nan)

            ixds = ixds.rename({list(ixds.data_vars)[0]:(imtype+ext.replace('/','_')).upper()}).transpose('l','m','time','chan','pol')
            if imtype == 'sumwt': ixds = ixds.squeeze(['l','m'], drop=True)
            if imtype == 'mask': ixds = ixds.rename({'MASK':'AUTOMASK'})  # rename mask

            encoding = dict(zip(list(ixds.data_vars), cycle([{'compressor': compressor}])))
            ixds.to_zarr(outfile, mode='w' if (ac==0) and (ec==0) else 'a', encoding=encoding, compute=True, consolidated=True)

    tmp = os.system("rm -fr " + outfile+'.temp')
    print('processed image in %s seconds' % str(np.float32(time.time() - begin)))

    # add attributes from metadata and version tag file
    mxds.to_zarr(outfile, mode='a', compute=True, consolidated=True)
    with open(outfile + '/.version', 'w') as fid:   # write sw version that did this conversion to zarr directory
        fid.write('cngi-protoype ' + importlib_metadata.version('cngi-prototype') + '\n')

    return xarray.open_zarr(outfile)
from casatools import image as ia
from cngi.dio import read_image

IA = ia()
rc = IA.open('~/dev/data/ALMA_smallcube.image.fits')
xds = read_image('~/dev/data/ALMA_smallcube.image.zarr')

points = [(1, 1), (24, 112), (11, 500), (340, 223), (503, 101), (511, 511)]

# position
for pt in points:
    casa_coords = IA.toworld(np.array(pt))['numeric'][:2]
    cngi_coords = [xds.right_ascension.values[pt], xds.declination.values[pt]]
    percent_dev = (casa_coords - cngi_coords) / casa_coords * 100
    print('ra/dec deviation % : ', percent_dev)

# stokes
for pt in points:
    casa_coords = IA.toworld(np.array(pt))['numeric'][2]
    cngi_coords = xds.image[pt].stokes.values[0]
    percent_dev = (casa_coords - cngi_coords) / casa_coords * 100
    print('stokes deviation % : ', percent_dev)

# frequency
for pt in points:
    casa_coords = []
    for ch in range(xds.frequency.shape[0]):
        casa_coords += [IA.toworld(np.array(pt + (0, ch)))['numeric'][3]]
    cngi_coords = xds.image[pt].frequency.values
    percent_dev = (np.array(casa_coords) -
                   cngi_coords) / np.array(casa_coords) * 100
Example #3
0
def convert_image(infile,
                  outfile=None,
                  artifacts=None,
                  compressor=None,
                  chunk_shape=(-1, -1, 1, 1),
                  nofile=False):
    """
    Convert legacy CASA or FITS format Image to xarray Image Dataset and zarr storage format

    This function requires CASA6 casatools module.

    Parameters
    ----------
    infile : str
        Input image filename (.image or .fits format)
    outfile : str
        Output zarr filename. If None, will use infile name with .img.zarr extension
    artifacts : list of str
        List of other image artifacts to include if present with infile. Default None uses ``['mask','model','pb','psf','residual','sumwt','weight']``
    compressor : numcodecs.blosc.Blosc
        The blosc compressor to use when saving the converted data to disk using zarr.
        If None the zstd compression algorithm used with compression level 2.
    chunk_shape: 4-D tuple of ints
        Shape of desired chunking in the form of (x, y, channels, polarization), use -1 for entire axis in one chunk. Default is (-1, -1, 1, 1)
        Note: chunk size is the product of the four numbers (up to the actual size of the dimension)
    nofile : bool
        Allows legacy Image to be directly read without file conversion. If set to true, no output file will be written and entire Image will be held in memory.
        Requires ~4x the memory of the Image size.  Default is False

    Returns
    -------
    xarray.core.dataset.Dataset
        new xarray Datasets of Image data contents
    """
    from casatools import image as ia
    import numpy as np
    from itertools import cycle
    from pandas.io.json._normalize import nested_to_record
    import xarray
    from xarray import Dataset as xd
    from xarray import DataArray as xa
    from numcodecs import Blosc
    import time, os, warnings
    warnings.simplefilter(
        "ignore",
        category=FutureWarning)  # suppress noisy warnings about bool types

    print("converting Image...")

    infile = os.path.expanduser(infile)
    prefix = infile[:infile.rindex('.')]
    suffix = infile[infile.rindex('.') + 1:]

    # sanitize to avoid KeyError when calling imtypes later
    while suffix.endswith('/'):
        suffix = suffix[:-1]

    if outfile == None:
        outfile = prefix + '.img.zarr'
    else:
        outfile = os.path.expanduser(outfile)

    if not nofile:
        tmp = os.system("rm -fr " + outfile)

    begin = time.time()

    if compressor is None:
        compressor = Blosc(cname='zstd', clevel=2, shuffle=0)

    IA = ia()

    # all image artifacts will go in same zarr file and share common dimensions if possible
    # check for meta data compatibility
    # store necessary coordinate conversion data
    if artifacts == None:
        imtypes = [
            'image.pbcor', 'mask', 'model', 'pb', 'psf', 'residual', 'sumwt',
            'weight'
        ]
        if suffix not in imtypes: imtypes = [suffix] + imtypes
    else:
        imtypes = [suffix] + artifacts
    meta, tm, diftypes, difmeta, xds = {}, {}, [], [], []
    for imtype in imtypes:
        if os.path.exists(prefix + '.' + imtype):
            rc = IA.open(prefix + '.' + imtype)
            summary = IA.summary(
                list=False)  # imhead would be better but chokes on big images
            ims = tuple(IA.shape())  # image shape
            coord_names = [
                ss.replace(' ', '_').lower().replace('stokes', 'pol').replace(
                    'frequency', 'chan') for ss in summary['axisnames']
            ]

            # compute world coordinates for spherical dimensions
            # the only way to know is to check the units for angular types (i.e. radians)
            sphr_dims = [
                dd for dd in range(len(ims))
                if summary['axisunits'][dd] == 'rad'
            ]
            coord_idxs = np.mgrid[[
                range(ims[dd]) if dd in sphr_dims else range(1)
                for dd in range(len(ims))
            ]]
            coord_idxs = coord_idxs.reshape(len(ims), -1)
            coord_world = IA.coordsys().toworldmany(
                coord_idxs.astype(float))['numeric']
            coord_world = coord_world[sphr_dims].reshape(
                (len(sphr_dims), ) + tuple(np.array(ims)[sphr_dims]))
            spi = ['d' + str(dd) for dd in sphr_dims]
            coords = dict([(coord_names[dd], (spi, coord_world[di]))
                           for di, dd in enumerate(sphr_dims)])

            # compute world coordinates for cartesian dimensions
            cart_dims = [dd for dd in range(len(ims)) if dd not in sphr_dims]
            coord_idxs = np.mgrid[[
                range(ims[dd]) if dd in cart_dims else range(1)
                for dd in range(len(ims))
            ]]
            coord_idxs = coord_idxs.reshape(len(ims), -1)
            coord_world = IA.coordsys().toworldmany(
                coord_idxs.astype(float))['numeric']
            coord_world = coord_world[cart_dims].reshape(
                (len(cart_dims), ) + tuple(np.array(ims)[cart_dims]))
            for dd, cs in enumerate(list(coord_world)):
                spi = tuple([
                    slice(None) if di == dd else slice(1)
                    for di in range(cs.ndim)
                ])
                coords.update(dict([(coord_names[cart_dims[dd]], cs[spi][0])]))

            # store metadata for later
            tm['coords'] = coords
            tm['dsize'] = np.array(summary['shape'])
            tm['dims'] = [
                coord_names[di] if di in cart_dims else 'd' + str(di)
                for di in range(len(ims))
            ]

            # store rest of image meta data as attributes
            omits = [
                'axisnames', 'hasmask', 'masks', 'defaultmask', 'ndim',
                'refpix', 'refval', 'shape', 'tileshape', 'messages'
            ]
            nested = [
                kk for kk in summary.keys() if isinstance(summary[kk], dict)
            ]
            tm['attrs'] = dict([(kk.lower(), summary[kk])
                                for kk in summary.keys()
                                if kk not in omits + nested])
            tm['attrs'].update(
                dict([(kk,
                       list(nested_to_record(summary[kk], sep='.').items()))
                      for kk in nested]))

            # parse messages for additional keys, drop duplicate info
            omits = [
                'image_name', 'image_type', 'image_quantity', 'pixel_mask(s)',
                'region(s)', 'image_units'
            ]
            for msg in summary['messages']:
                line = [
                    tuple(kk.split(':')) for kk in msg.lower().split('\n')
                    if ': ' in kk
                ]
                line = [(kk[0].strip().replace(' ', '_'), kk[1].strip())
                        for kk in line]
                line = [ll for ll in line if ll[0] not in omits]
                tm['attrs'].update(dict(line))

            # save metadata from first image product (the image itself)
            # compare later image products to see if dimensions match up
            # Note: only checking image dimensions, NOT COORDINATE VALUES!!
            if meta == {}:
                meta = dict(tm)
            elif (np.any(meta['dsize'] != np.array(summary['shape'])) &
                  (imtype != 'sumwt')):
                diftypes += [imtype]
                difmeta += [tm]
                imtypes = [_ for _ in imtypes if _ != imtype]

            rc = IA.close()
        else:
            imtypes = [_ for _ in imtypes if _ != imtype]

    print('compatible components: ', imtypes)
    print('separate components: ', diftypes)

    # process all image artifacts with compatible metadata to same zarr file
    # partition by channel, read each image artifact for each channel
    dsize, chan_dim = meta['dsize'], meta['dims'].index('chan')
    pt1, pt2 = [-1 for _ in range(len(dsize))], [-1 for _ in range(len(dsize))]
    if chunk_shape[2] <= 0: chunk_shape[2] = dsize[chan_dim]
    chan_batch = dsize[chan_dim] if nofile else chunk_shape[2]
    for chan in range(0, dsize[chan_dim], chan_batch):
        print('processing channel ' + str(chan + 1) + ' of ' +
              str(dsize[chan_dim]),
              end='\r')
        pt1[chan_dim], pt2[chan_dim] = chan, chan + chan_batch - 1
        chunk_coords = dict(meta['coords'])  # only want one freq channel coord
        chunk_coords['chan'] = coords['chan'][np.arange(
            chan, min(chan + chan_batch, dsize[chan_dim]))]
        xdas = {}
        for imtype in imtypes:
            rc = IA.open(prefix + '.' + imtype)

            # extract pixel data
            imchunk = IA.getchunk(pt1, pt2)
            if imtype == 'fits': imtype = 'image'
            if imtype == 'mask':
                xdas['DECONVOLVE'] = xa(imchunk.astype(bool),
                                        dims=meta['dims'])
            elif imtype == 'sumwt':
                xdas[imtype.upper()] = xa(imchunk.reshape(imchunk.shape[2], 1),
                                          dims=['pol', 'chan'])
            else:
                xdas[imtype.upper()] = xa(imchunk, dims=meta['dims'])

            # extract mask
            summary = IA.summary(list=False)
            if len(summary['masks']) > 0:
                imchunk = IA.getchunk(pt1, pt2, getmask=True)
                xdas['MASK'] = xa(imchunk.astype(bool), dims=meta['dims'])

            rc = IA.close()

        chunking = dict([(dd, chunk_shape[ii])
                         for ii, dd in enumerate(['d0', 'd1', 'chan', 'pol'])
                         if chunk_shape[ii] > 0])
        xds = xd(xdas, coords=chunk_coords,
                 attrs=meta['attrs']).chunk(chunking)

        # for everyone's sanity, lets make sure the dimensions are ordered the same way as visibility data
        if ('pol' in xds.dims):
            xds = xds.transpose(xds.IMAGE.dims[0], xds.IMAGE.dims[1], 'chan',
                                'pol')

        if (chan == 0) and (not nofile):
            # xds = xd(xdas, coords=chunk_coords, attrs=nested_to_record(meta['attrs'], sep='_'))
            encoding = dict(
                zip(list(xds.data_vars), cycle([{
                    'compressor': compressor
                }])))
            xds.to_zarr(outfile, mode='w', encoding=encoding)
        elif not nofile:
            xds.to_zarr(outfile, mode='a', append_dim='chan')

    print("processed image size " + str(dsize) + " in " +
          str(np.float32(time.time() - begin)) + " seconds")

    if not nofile:
        xds = xarray.open_zarr(outfile)

    return xds