def convert_image(infile, outfile=None, artifacts=[], compressor=None, chunk_shape=(-1, -1, 1, 1)):
    Convert legacy CASA or FITS format Image to xarray Image Dataset and zarr storage format

    This function requires CASA6 casatools module.

    infile : str
        Input image filename (.image or .fits format). If taylor terms are present, they should be in the form of filename.image.tt0 and
        this infile string should be filename.image
    outfile : str
        Output zarr filename. If None, will use infile name with .img.zarr extension
    artifacts : list of str
        List of other image artifacts to include if present with infile. Use None for just the specified infile.
        Default [] uses ``['mask','model','pb','psf','residual','sumwt','weight']``
    compressor : numcodecs.blosc.Blosc
        The blosc compressor to use when saving the converted data to disk using zarr.
        If None the zstd compression algorithm used with compression level 2.
    chunk_shape: 4-D tuple of ints
        Shape of desired chunking in the form of (l, m, channels, polarization), use -1 for entire axis in one chunk. Default is (-1, -1, 1, 1)
        Note: chunk size is the product of the four numbers (up to the actual size of the dimension)

        new xarray Datasets of Image data contents
    from casatools import image as ia
    from casatools import quanta as qa
    from cngi._utils._table_conversion import convert_simple_table, convert_time
    import numpy as np
    from itertools import cycle
    import importlib_metadata
    from import nested_to_record, json_normalize
    import xarray
    from numcodecs import Blosc
    import time, os, warnings
    warnings.simplefilter("ignore", category=FutureWarning)  # suppress noisy warnings about bool types

    # TODO - find and save projection type

    infile = os.path.expanduser('./'+infile[:-1]) if infile.endswith('/') else os.path.expanduser('./'+infile)
    prefix = infile[:infile.rindex('.')]
    suffix = infile[infile.rindex('.') + 1:]
    srcdir = infile[:infile.rindex('/')+1]
    if outfile == None: outfile = prefix + '.img.zarr'
    outfile = os.path.expanduser(outfile)

    if compressor is None:
        compressor = Blosc(cname='zstd', clevel=2, shuffle=0)

    tmp = os.system("rm -fr " + outfile)
    tmp = os.system("mkdir " + outfile)

    IA = ia()
    QA = qa()
    begin = time.time()
    # all image artifacts will go in same zarr file and share common dimensions if possible
    # check for meta data compatibility
    # store necessary coordinate conversion data
    if artifacts is None: artifacts = [suffix]
    elif len(artifacts) == 0: artifacts = ['image', 'pb', 'psf', 'residual', 'mask', 'model', 'sumwt', 'weight', 'image.pbcor']
    if suffix not in artifacts: artifacts = [suffix] + artifacts
    diftypes, mxds, artifact_dims, artifact_masks = [], xarray.Dataset(), {}, {}
    ttcount = 0

    # for each image artifact, determine what image files in the source directory are compatible with each other
    # extract the metadata from each
    # if taylor terms are present for the artifact, process metadata for first one only
    print("converting Image...")
    for imtype in artifacts:
        imagelist = sorted([srcdir+ff for ff in os.listdir(srcdir) if (srcdir+ff).startswith('%s.%s'%(prefix,imtype))])
        imagelist = [ff for ff in imagelist if ff.endswith(imtype) or ff[ff.rindex('.') + 1:].startswith('tt')]
        if len(imagelist) == 0: continue

        # find number of taylor terms for this artifact and update count for total set if necessary
        ttcount = len([ff for ff in imagelist if ff[ff.rindex('.') + 1:].startswith('tt')]) if ttcount == 0 else ttcount

        rc =[0])
        csys = IA.coordsys()
        summary = IA.summary(list=False)  # imhead would be better but chokes on big images
        ims = IA.shape()  # image shape
        coord_names = [ss.replace(' ', '_').lower().replace('stokes', 'pol').replace('frequency', 'chan') for ss in summary['axisnames']]
        # compute world coordinates for spherical dimensions
        sphr_dims = [dd for dd in range(len(ims)) if QA.isangle(summary['axisunits'][dd])]
        coord_idxs = np.mgrid[[range(ims[dd]) if dd in sphr_dims else range(1) for dd in range(len(ims))]].reshape(len(ims), -1)
        coord_world = csys.toworldmany(coord_idxs.astype(float))['numeric'][sphr_dims].reshape((-1,)+tuple(ims[sphr_dims]))
        coords = dict([(coord_names[dd], (['l','m'], coord_world[di])) for di, dd in enumerate(sphr_dims)])
        if imtype == 'sumwt': coords = {}   # special case, force sumwt to only cartesian coords (chan, pol)
        # compute world coordinates for cartesian dimensions
        cart_dims = [dd for dd in range(len(ims)) if dd not in sphr_dims]
        coord_idxs = np.mgrid[[range(ims[dd]) if dd in cart_dims else range(1) for dd in range(len(ims))]].reshape(len(ims), -1)
        coord_world = csys.toworldmany(coord_idxs.astype(float))['numeric'][cart_dims].reshape((-1,)+tuple(ims[cart_dims]))
        for dd, cs in enumerate(list(coord_world)):
            spi = tuple([slice(None) if di == dd else slice(1) for di in range(cs.ndim)])
            coords.update(dict([(coord_names[cart_dims[dd]], cs[spi].reshape(-1))]))

        # compute the time coordinate
        dtime = csys.torecord()['obsdate']['m0']['value']
        if csys.torecord()['obsdate']['m0']['unit'] == 'd': dtime = dtime * 86400
        coords['time'] = convert_time([dtime])
        # check to see if this image artifact is of a compatible shape to be part of the image artifact dataset
        try:  # easiest to try to merge and let xarray figure it out
            mxds = mxds.merge(xarray.Dataset(coords=coords), compat='equals')
        except Exception:
            diftypes += [imtype]

        # store rest of image metadata as attributes (if not already in the xds
        omits = list(mxds.attrs.keys())
        omits += ['hasmask', 'masks', 'defaultmask', 'ndim', 'refpix', 'refval', 'shape', 'tileshape', 'messages']
        nested = [kk for kk in summary.keys() if isinstance(summary[kk], dict)]
        mxds = mxds.assign_attrs(dict([(kk.lower(), summary[kk]) for kk in summary.keys() if kk not in omits + nested]))
        artifact_dims[imtype] = [ss.replace('right_ascension', 'l').replace('declination', 'm') for ss in coord_names]
        artifact_masks[imtype] = summary['masks']
        # check for common and restoring beams
        rb = IA.restoringbeam()
        if (len(rb) > 0) and ('restoringbeam' not in mxds.attrs):
            # if there is a restoring beam, this should work
            cb = IA.commonbeam()
            mxds = mxds.assign_attrs({'commonbeam': [cb['major']['value'], cb['minor']['value'], cb['pa']['value']]})
            mxds = mxds.assign_attrs({'commonbeam_units': [cb['major']['unit'], cb['minor']['unit'], cb['pa']['unit']]})
            mxds = mxds.assign_attrs({'restoringbeam': [cb['major']['value'], cb['minor']['value'], cb['pa']['value']]})
            if 'beams' in rb:
                beams = np.array([[rbs['major']['value'], rbs['minor']['value'], rbs['positionangle']['value']]
                                                          for rbc in rb['beams'].values() for rbs in rbc.values()])
                mxds = mxds.assign_attrs({'perplanebeams':beams.reshape(len(rb['beams']),-1,3)})
        # parse messages for additional keys, drop duplicate info
        omits = list(mxds.attrs.keys()) + ['image_name', 'image_type', 'image_quantity', 'pixel_mask(s)', 'region(s)', 'image_units']
        for msg in summary['messages']:
            line = [tuple(kk.split(':')) for kk in msg.lower().split('\n') if ': ' in kk]
            line = [(kk[0].strip().replace(' ', '_'), kk[1].strip()) for kk in line]
            line = [ll for ll in line if ll[0] not in omits]
            mxds = mxds.assign_attrs(dict(line))
        rc = csys.done()
        rc = IA.close()

    print('incompatible components: ', diftypes)

    # if taylor terms are present, the chan axis must be expanded to the length of the terms
    if ttcount > len(mxds.chan): mxds = mxds.pad({'chan': (0, ttcount-len(mxds.chan))}, mode='edge')
    chunk_dict = dict(zip(['l','m','time','chan','pol'], chunk_shape[:2]+(1,)+chunk_shape[2:]))
    mxds = mxds.chunk(chunk_dict)

    # for each artifact, convert the legacy format and add to the new image set
    # masks may be stored within each image, so they will need to be handled like subtables
    for ac, imtype in enumerate(list(artifact_dims.keys())):
        for ec, ext in enumerate([''] + ['/'+ff for ff in list(artifact_masks[imtype])]):
            imagelist = sorted([srcdir + ff for ff in os.listdir(srcdir) if (srcdir + ff).startswith('%s.%s' % (prefix, imtype))])
            imagelist = [ff for ff in imagelist if ff.endswith(imtype) or ff[ff.rindex('.') + 1:].startswith('tt')]
            if len(imagelist) == 0: continue

            dimorder = ['time'] + list(reversed(artifact_dims[imtype]))
            chunkorder = [chunk_dict[vv] for vv in dimorder]
            ixds = convert_simple_table(imagelist[0]+ext, outfile+'.temp', dimnames=dimorder, compressor=compressor, chunk_shape=chunkorder)

            # if the image set has taylor terms, loop through any for this artifact and concat together
            # pad the chan dim as necessary to fill remaining elements if not enough taylor terms in this artifact
            for ii in range(1, ttcount):
                if ii < len(imagelist):
                    txds = convert_simple_table(imagelist[ii]+ext, outfile + '.temp', dimnames=dimorder, chunk_shape=chunkorder, nofile=True)
                    ixds = xarray.concat([ixds, txds], dim='chan')
                    ixds = ixds.pad({'chan': (0, 1)}, constant_values=np.nan)

            ixds = ixds.rename({list(ixds.data_vars)[0]:(imtype+ext.replace('/','_')).upper()}).transpose('l','m','time','chan','pol')
            if imtype == 'sumwt': ixds = ixds.squeeze(['l','m'], drop=True)
            if imtype == 'mask': ixds = ixds.rename({'MASK':'AUTOMASK'})  # rename mask

            encoding = dict(zip(list(ixds.data_vars), cycle([{'compressor': compressor}])))
            ixds.to_zarr(outfile, mode='w' if (ac==0) and (ec==0) else 'a', encoding=encoding, compute=True, consolidated=True)

    tmp = os.system("rm -fr " + outfile+'.temp')
    print('processed image in %s seconds' % str(np.float32(time.time() - begin)))

    # add attributes from metadata and version tag file
    mxds.to_zarr(outfile, mode='a', compute=True, consolidated=True)
    with open(outfile + '/.version', 'w') as fid:   # write sw version that did this conversion to zarr directory
        fid.write('cngi-protoype ' + importlib_metadata.version('cngi-prototype') + '\n')

    return xarray.open_zarr(outfile)
from casatools import image as ia
from cngi.dio import read_image

IA = ia()
rc ='~/dev/data/ALMA_smallcube.image.fits')
xds = read_image('~/dev/data/ALMA_smallcube.image.zarr')

points = [(1, 1), (24, 112), (11, 500), (340, 223), (503, 101), (511, 511)]

# position
for pt in points:
    casa_coords = IA.toworld(np.array(pt))['numeric'][:2]
    cngi_coords = [xds.right_ascension.values[pt], xds.declination.values[pt]]
    percent_dev = (casa_coords - cngi_coords) / casa_coords * 100
    print('ra/dec deviation % : ', percent_dev)

# stokes
for pt in points:
    casa_coords = IA.toworld(np.array(pt))['numeric'][2]
    cngi_coords = xds.image[pt].stokes.values[0]
    percent_dev = (casa_coords - cngi_coords) / casa_coords * 100
    print('stokes deviation % : ', percent_dev)

# frequency
for pt in points:
    casa_coords = []
    for ch in range(xds.frequency.shape[0]):
        casa_coords += [IA.toworld(np.array(pt + (0, ch)))['numeric'][3]]
    cngi_coords = xds.image[pt].frequency.values
    percent_dev = (np.array(casa_coords) -
                   cngi_coords) / np.array(casa_coords) * 100
Esempio n. 3
def convert_image(infile,
                  chunk_shape=(-1, -1, 1, 1),
    Convert legacy CASA or FITS format Image to xarray Image Dataset and zarr storage format

    This function requires CASA6 casatools module.

    infile : str
        Input image filename (.image or .fits format)
    outfile : str
        Output zarr filename. If None, will use infile name with .img.zarr extension
    artifacts : list of str
        List of other image artifacts to include if present with infile. Default None uses ``['mask','model','pb','psf','residual','sumwt','weight']``
    compressor : numcodecs.blosc.Blosc
        The blosc compressor to use when saving the converted data to disk using zarr.
        If None the zstd compression algorithm used with compression level 2.
    chunk_shape: 4-D tuple of ints
        Shape of desired chunking in the form of (x, y, channels, polarization), use -1 for entire axis in one chunk. Default is (-1, -1, 1, 1)
        Note: chunk size is the product of the four numbers (up to the actual size of the dimension)
    nofile : bool
        Allows legacy Image to be directly read without file conversion. If set to true, no output file will be written and entire Image will be held in memory.
        Requires ~4x the memory of the Image size.  Default is False

        new xarray Datasets of Image data contents
    from casatools import image as ia
    import numpy as np
    from itertools import cycle
    from import nested_to_record
    import xarray
    from xarray import Dataset as xd
    from xarray import DataArray as xa
    from numcodecs import Blosc
    import time, os, warnings
        category=FutureWarning)  # suppress noisy warnings about bool types

    print("converting Image...")

    infile = os.path.expanduser(infile)
    prefix = infile[:infile.rindex('.')]
    suffix = infile[infile.rindex('.') + 1:]

    # sanitize to avoid KeyError when calling imtypes later
    while suffix.endswith('/'):
        suffix = suffix[:-1]

    if outfile == None:
        outfile = prefix + '.img.zarr'
        outfile = os.path.expanduser(outfile)

    if not nofile:
        tmp = os.system("rm -fr " + outfile)

    begin = time.time()

    if compressor is None:
        compressor = Blosc(cname='zstd', clevel=2, shuffle=0)

    IA = ia()

    # all image artifacts will go in same zarr file and share common dimensions if possible
    # check for meta data compatibility
    # store necessary coordinate conversion data
    if artifacts == None:
        imtypes = [
            'image.pbcor', 'mask', 'model', 'pb', 'psf', 'residual', 'sumwt',
        if suffix not in imtypes: imtypes = [suffix] + imtypes
        imtypes = [suffix] + artifacts
    meta, tm, diftypes, difmeta, xds = {}, {}, [], [], []
    for imtype in imtypes:
        if os.path.exists(prefix + '.' + imtype):
            rc = + '.' + imtype)
            summary = IA.summary(
                list=False)  # imhead would be better but chokes on big images
            ims = tuple(IA.shape())  # image shape
            coord_names = [
                ss.replace(' ', '_').lower().replace('stokes', 'pol').replace(
                    'frequency', 'chan') for ss in summary['axisnames']

            # compute world coordinates for spherical dimensions
            # the only way to know is to check the units for angular types (i.e. radians)
            sphr_dims = [
                dd for dd in range(len(ims))
                if summary['axisunits'][dd] == 'rad'
            coord_idxs = np.mgrid[[
                range(ims[dd]) if dd in sphr_dims else range(1)
                for dd in range(len(ims))
            coord_idxs = coord_idxs.reshape(len(ims), -1)
            coord_world = IA.coordsys().toworldmany(
            coord_world = coord_world[sphr_dims].reshape(
                (len(sphr_dims), ) + tuple(np.array(ims)[sphr_dims]))
            spi = ['d' + str(dd) for dd in sphr_dims]
            coords = dict([(coord_names[dd], (spi, coord_world[di]))
                           for di, dd in enumerate(sphr_dims)])

            # compute world coordinates for cartesian dimensions
            cart_dims = [dd for dd in range(len(ims)) if dd not in sphr_dims]
            coord_idxs = np.mgrid[[
                range(ims[dd]) if dd in cart_dims else range(1)
                for dd in range(len(ims))
            coord_idxs = coord_idxs.reshape(len(ims), -1)
            coord_world = IA.coordsys().toworldmany(
            coord_world = coord_world[cart_dims].reshape(
                (len(cart_dims), ) + tuple(np.array(ims)[cart_dims]))
            for dd, cs in enumerate(list(coord_world)):
                spi = tuple([
                    slice(None) if di == dd else slice(1)
                    for di in range(cs.ndim)
                coords.update(dict([(coord_names[cart_dims[dd]], cs[spi][0])]))

            # store metadata for later
            tm['coords'] = coords
            tm['dsize'] = np.array(summary['shape'])
            tm['dims'] = [
                coord_names[di] if di in cart_dims else 'd' + str(di)
                for di in range(len(ims))

            # store rest of image meta data as attributes
            omits = [
                'axisnames', 'hasmask', 'masks', 'defaultmask', 'ndim',
                'refpix', 'refval', 'shape', 'tileshape', 'messages'
            nested = [
                kk for kk in summary.keys() if isinstance(summary[kk], dict)
            tm['attrs'] = dict([(kk.lower(), summary[kk])
                                for kk in summary.keys()
                                if kk not in omits + nested])
                       list(nested_to_record(summary[kk], sep='.').items()))
                      for kk in nested]))

            # parse messages for additional keys, drop duplicate info
            omits = [
                'image_name', 'image_type', 'image_quantity', 'pixel_mask(s)',
                'region(s)', 'image_units'
            for msg in summary['messages']:
                line = [
                    tuple(kk.split(':')) for kk in msg.lower().split('\n')
                    if ': ' in kk
                line = [(kk[0].strip().replace(' ', '_'), kk[1].strip())
                        for kk in line]
                line = [ll for ll in line if ll[0] not in omits]

            # save metadata from first image product (the image itself)
            # compare later image products to see if dimensions match up
            # Note: only checking image dimensions, NOT COORDINATE VALUES!!
            if meta == {}:
                meta = dict(tm)
            elif (np.any(meta['dsize'] != np.array(summary['shape'])) &
                  (imtype != 'sumwt')):
                diftypes += [imtype]
                difmeta += [tm]
                imtypes = [_ for _ in imtypes if _ != imtype]

            rc = IA.close()
            imtypes = [_ for _ in imtypes if _ != imtype]

    print('compatible components: ', imtypes)
    print('separate components: ', diftypes)

    # process all image artifacts with compatible metadata to same zarr file
    # partition by channel, read each image artifact for each channel
    dsize, chan_dim = meta['dsize'], meta['dims'].index('chan')
    pt1, pt2 = [-1 for _ in range(len(dsize))], [-1 for _ in range(len(dsize))]
    if chunk_shape[2] <= 0: chunk_shape[2] = dsize[chan_dim]
    chan_batch = dsize[chan_dim] if nofile else chunk_shape[2]
    for chan in range(0, dsize[chan_dim], chan_batch):
        print('processing channel ' + str(chan + 1) + ' of ' +
        pt1[chan_dim], pt2[chan_dim] = chan, chan + chan_batch - 1
        chunk_coords = dict(meta['coords'])  # only want one freq channel coord
        chunk_coords['chan'] = coords['chan'][np.arange(
            chan, min(chan + chan_batch, dsize[chan_dim]))]
        xdas = {}
        for imtype in imtypes:
            rc = + '.' + imtype)

            # extract pixel data
            imchunk = IA.getchunk(pt1, pt2)
            if imtype == 'fits': imtype = 'image'
            if imtype == 'mask':
                xdas['DECONVOLVE'] = xa(imchunk.astype(bool),
            elif imtype == 'sumwt':
                xdas[imtype.upper()] = xa(imchunk.reshape(imchunk.shape[2], 1),
                                          dims=['pol', 'chan'])
                xdas[imtype.upper()] = xa(imchunk, dims=meta['dims'])

            # extract mask
            summary = IA.summary(list=False)
            if len(summary['masks']) > 0:
                imchunk = IA.getchunk(pt1, pt2, getmask=True)
                xdas['MASK'] = xa(imchunk.astype(bool), dims=meta['dims'])

            rc = IA.close()

        chunking = dict([(dd, chunk_shape[ii])
                         for ii, dd in enumerate(['d0', 'd1', 'chan', 'pol'])
                         if chunk_shape[ii] > 0])
        xds = xd(xdas, coords=chunk_coords,

        # for everyone's sanity, lets make sure the dimensions are ordered the same way as visibility data
        if ('pol' in xds.dims):
            xds = xds.transpose(xds.IMAGE.dims[0], xds.IMAGE.dims[1], 'chan',

        if (chan == 0) and (not nofile):
            # xds = xd(xdas, coords=chunk_coords, attrs=nested_to_record(meta['attrs'], sep='_'))
            encoding = dict(
                zip(list(xds.data_vars), cycle([{
                    'compressor': compressor
            xds.to_zarr(outfile, mode='w', encoding=encoding)
        elif not nofile:
            xds.to_zarr(outfile, mode='a', append_dim='chan')

    print("processed image size " + str(dsize) + " in " +
          str(np.float32(time.time() - begin)) + " seconds")

    if not nofile:
        xds = xarray.open_zarr(outfile)

    return xds