def test_select_field():
    datasets = []

    for field_name in ["DEEP2"]:
        name = np.asarray([field_name], dtype=np.object)
        ds = Dataset({"NAME": (("row", ), da.from_array(name, chunks=1))})
        datasets.append(ds)

    # No field selection, single field id returned
    assert select_field_id(datasets, None) == 0

    datasets = []

    for field_name in ["PKS-1934", "3C286", "DEEP2"]:
        name = np.asarray([field_name], dtype=np.object)
        ds = Dataset({"NAME": (("row", ), da.from_array(name, chunks=1))})
        datasets.append(ds)

    # No field selection, ValueError raised
    with pytest.raises(ValueError):
        assert select_field_id(datasets, None) == [0, 1, 2]

    assert select_field_id(datasets, "PKS-1934") == 0
    assert select_field_id(datasets, "0") == 0
    assert select_field_id(datasets, "3C286") == 1
    assert select_field_id(datasets, "1") == 1
    assert select_field_id(datasets, "DEEP2") == 2
    assert select_field_id(datasets, "2") == 2
Example #2
0
def test_ms_create_and_update(Dataset, tmp_path, chunks):
    """ Test that we can update and append at the same time """
    filename = str(tmp_path / "create-and-update.ms")

    rs = np.random.RandomState(42)

    # Create a dataset of 10 rows with DATA and DATA_DESC_ID
    dims = ("row", "chan", "corr")
    row, chan, corr = tuple(sum(chunks[d]) for d in dims)
    ms_datasets = []
    np_data = (rs.normal(size=(row, chan, corr)) +
               1j * rs.normal(size=(row, chan, corr))).astype(np.complex64)

    data_chunks = tuple((chunks['row'], chan, corr))
    dask_data = da.from_array(np_data, chunks=data_chunks)
    # Create dask ddid column
    dask_ddid = da.full(row, 0, chunks=chunks['row'], dtype=np.int32)
    dataset = Dataset({
        'DATA': (dims, dask_data),
        'DATA_DESC_ID': (("row", ), dask_ddid),
    })
    ms_datasets.append(dataset)

    # Write it
    writes = xds_to_table(ms_datasets, filename, ["DATA", "DATA_DESC_ID"])
    dask.compute(writes)

    ms_datasets = xds_from_ms(filename)

    # Now add another dataset (different DDID), with no ROWID
    np_data = (rs.normal(size=(row, chan, corr)) +
               1j * rs.normal(size=(row, chan, corr))).astype(np.complex64)
    data_chunks = tuple((chunks['row'], chan, corr))
    dask_data = da.from_array(np_data, chunks=data_chunks)
    # Create dask ddid column
    dask_ddid = da.full(row, 1, chunks=chunks['row'], dtype=np.int32)
    dataset = Dataset({
        'DATA': (dims, dask_data),
        'DATA_DESC_ID': (("row", ), dask_ddid),
    })
    ms_datasets.append(dataset)

    # Write it
    writes = xds_to_table(ms_datasets, filename, ["DATA", "DATA_DESC_ID"])
    dask.compute(writes)

    # Rows have been added and additional data is present
    with pt.table(filename, ack=False, readonly=True) as T:
        first_data_desc_id = da.full(row,
                                     ms_datasets[0].DATA_DESC_ID,
                                     chunks=chunks['row'])
        ds_data = da.concatenate(
            [ms_datasets[0].DATA.data, ms_datasets[1].DATA.data])
        ds_ddid = da.concatenate(
            [first_data_desc_id, ms_datasets[1].DATA_DESC_ID.data])
        assert_array_equal(T.getcol("DATA"), ds_data)
        assert_array_equal(T.getcol("DATA_DESC_ID"), ds_ddid)
Example #3
0
def average_spw(spw_ds, chan_bin_size):
    """
    Parameters
    ----------
    spw_ds : list of Datasets
        list of Datasets, each describing a single Spectral Window
    chan_bin_size : int
        Number of channels in an averaging bin

    Returns
    -------
    spw_ds : list of Datasets
        list of Datasets, each describing an averaged Spectral Window
    """

    new_spw_ds = []

    for r, spw in enumerate(spw_ds):
        # Get the dataset variables as a mutable dictionary
        dv = dict(spw.data_vars)

        # Extract arrays we wish to average
        chan_freq = dv['CHAN_FREQ'].data[0]
        chan_width = dv['CHAN_WIDTH'].data[0]
        effective_bw = dv['EFFECTIVE_BW'].data[0]
        resolution = dv['RESOLUTION'].data[0]

        # Construct channel metadata
        chan_arrays = (chan_freq, chan_width, effective_bw, resolution)
        chan_meta = chan_metadata((), chan_arrays, chan_bin_size)
        # Average channel based data
        avg = dask_chan_avg(chan_meta,
                            chan_freq=chan_freq,
                            chan_width=chan_width,
                            effective_bw=effective_bw,
                            resolution=resolution,
                            chan_bin_size=chan_bin_size)

        num_chan = da.full((1, ), avg.chan_freq.shape[0], dtype=np.int32)

        # These columns change, re-create them
        dv['NUM_CHAN'] = (("row", ), num_chan)
        dv['CHAN_FREQ'] = (("row", "chan"), avg.chan_freq[None, :])
        dv['CHAN_WIDTH'] = (("row", "chan"), avg.chan_width[None, :])
        dv['EFFECTIVE_BW'] = (("row", "chan"), avg.effective_bw[None, :])
        dv['RESOLUTION'] = (("row", "chan"), avg.resolution[None, :])

        # But re-use all the others
        new_spw_ds.append(Dataset(dv))

    return new_spw_ds
def test_select_fields():
    datasets = []

    for field_name in ["PKS-1934", "3C286", "DEEP2"]:
        name = np.asarray([field_name], dtype=np.object)
        ds = Dataset({"NAME": (("row",), da.from_array(name, chunks=1))})
        datasets.append(ds)

    # No field selection, all fields returned
    assert valid_field_ids(datasets, None) == [0, 1, 2]

    assert valid_field_ids(datasets, "PKS-1934") == [0]
    assert valid_field_ids(datasets, "3C286, DEEP2") == [1, 2]
    assert valid_field_ids(datasets, "1, DEEP2") == [1, 2]
    assert valid_field_ids(datasets, "0, 1, 2") == [0, 1, 2]
    assert valid_field_ids(datasets, "2, 3") == [2]
Example #5
0
def _psf(**kw):
    args = OmegaConf.create(kw)
    from omegaconf import ListConfig
    if not isinstance(args.ms, list) and not isinstance(args.ms, ListConfig):
        args.ms = [args.ms]
    OmegaConf.set_struct(args, True)

    import numpy as np
    from pfb.utils.misc import chan_to_band_mapping
    import dask
    # from dask.distributed import performance_report
    from dask.graph_manipulation import clone
    from daskms import xds_from_storage_ms as xds_from_ms
    from daskms import xds_from_storage_table as xds_from_table
    from daskms import Dataset
    from daskms.experimental.zarr import xds_to_zarr
    import dask.array as da
    from africanus.constants import c as lightspeed
    from africanus.gridding.wgridder.dask import dirty as vis2im
    from ducc0.fft import good_size
    from pfb.utils.misc import stitch_images, plan_row_chunk
    from pfb.utils.fits import set_wcs, save_fits

    # chan <-> band mapping
    ms = args.ms
    nband = args.nband
    freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping(
        ms, nband=nband)

    # gridder memory budget
    max_chan_chunk = 0
    max_freq = 0
    for ims in args.ms:
        for spw in freqs[ims]:
            counts = freq_bin_counts[ims][spw].compute()
            freq = freqs[ims][spw].compute()
            max_chan_chunk = np.maximum(max_chan_chunk, counts.max())
            max_freq = np.maximum(max_freq, freq.max())

    # assumes measurement sets have the same columns,
    # number of correlations etc.
    xds = xds_from_ms(args.ms[0])
    ncorr = xds[0].dims['corr']
    nrow = xds[0].dims['row']
    # we still have to cater for complex valued data because we cast
    # the weights to complex but we not longer need to factor the
    # weight column into our memory budget
    data_bytes = getattr(xds[0], args.data_column).data.itemsize
    bytes_per_row = max_chan_chunk * ncorr * data_bytes
    memory_per_row = bytes_per_row

    # flags (uint8 or bool)
    memory_per_row += bytes_per_row / 8

    # UVW
    memory_per_row += xds[0].UVW.data.itemsize * 3

    # ANTENNA1/2
    memory_per_row += xds[0].ANTENNA1.data.itemsize * 2

    # TIME
    memory_per_row += xds[0].TIME.data.itemsize

    # data column is not actually read into memory just used to infer
    # dtype and chunking
    columns = (args.data_column, args.weight_column, args.flag_column, 'UVW',
               'ANTENNA1', 'ANTENNA2', 'TIME')

    # flag row
    if 'FLAG_ROW' in xds[0]:
        columns += ('FLAG_ROW', )
        memory_per_row += xds[0].FLAG_ROW.data.itemsize

    # imaging weights
    if args.imaging_weight_column is not None:
        columns += (args.imaging_weight_column, )
        memory_per_row += bytes_per_row / 2

    # Mueller term (complex valued)
    if args.mueller_column is not None:
        columns += (args.mueller_column, )
        memory_per_row += bytes_per_row

    # get max uv coords over all fields
    uvw = []
    u_max = 0.0
    v_max = 0.0
    for ims in args.ms:
        xds = xds_from_ms(ims, columns=('UVW'), chunks={'row': -1})

        for ds in xds:
            uvw = ds.UVW.data
            u_max = da.maximum(u_max, abs(uvw[:, 0]).max())
            v_max = da.maximum(v_max, abs(uvw[:, 1]).max())
            uv_max = da.maximum(u_max, v_max)

    uv_max = uv_max.compute()
    del uvw

    # image size
    cell_N = 1.0 / (2 * uv_max * max_freq / lightspeed)

    if args.cell_size is not None:
        cell_size = args.cell_size
        cell_rad = cell_size * np.pi / 60 / 60 / 180
        if cell_N / cell_rad < 1:
            raise ValueError(
                "Requested cell size too small. "
                "Super resolution factor = ", cell_N / cell_rad)
        print("Super resolution factor = %f" % (cell_N / cell_rad), file=log)
    else:
        cell_rad = cell_N / args.super_resolution_factor
        cell_size = cell_rad * 60 * 60 * 180 / np.pi
        print("Cell size set to %5.5e arcseconds" % cell_size, file=log)

    if args.nx is None:
        fov = args.field_of_view * 3600
        npix = int(args.psf_oversize * fov / cell_size)
        if npix % 2:
            npix += 1
        nx = npix
        ny = npix
    else:
        nx = args.nx
        ny = args.ny if args.ny is not None else nx

    print("PSF size set to (%i, %i, %i)" % (nband, nx, ny), file=log)

    # get approx image size
    # this is not a conservative estimate when multiple SPW's map to a single
    # imaging band
    pixel_bytes = np.dtype(args.output_type).itemsize
    band_size = nx * ny * pixel_bytes

    if args.host_address is None:
        # full image on single node
        row_chunk = plan_row_chunk(args.mem_limit / args.nworkers, band_size,
                                   nrow, memory_per_row,
                                   args.nthreads_per_worker)

    else:
        # single band per node
        row_chunk = plan_row_chunk(args.mem_limit, band_size, nrow,
                                   memory_per_row, args.nthreads_per_worker)

    if args.row_chunks is not None:
        row_chunk = int(args.row_chunks)
        if row_chunk == -1:
            row_chunk = nrow

    print(
        "nrows = %i, row chunks set to %i for a total of %i chunks per node" %
        (nrow, row_chunk, int(np.ceil(nrow / row_chunk))),
        file=log)

    chunks = {}
    for ims in args.ms:
        chunks[ims] = []  # xds_from_ms expects a list per ds
        for spw in freqs[ims]:
            chunks[ims].append({
                'row': row_chunk,
                'chan': chan_chunks[ims][spw]['chan']
            })

    psfs = []
    radec = None  # assumes we are only imaging field 0 of first MS
    out_datasets = []
    for ims in args.ms:
        xds = xds_from_ms(ims, chunks=chunks[ims], columns=columns)

        # subtables
        ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
        fields = xds_from_table(ims + "::FIELD")
        spws = xds_from_table(ims + "::SPECTRAL_WINDOW")
        pols = xds_from_table(ims + "::POLARIZATION")

        # subtable data
        ddids = dask.compute(ddids)[0]
        fields = dask.compute(fields)[0]
        spws = dask.compute(spws)[0]
        pols = dask.compute(pols)[0]

        for ds in xds:
            field = fields[ds.FIELD_ID]

            # check fields match
            if radec is None:
                radec = field.PHASE_DIR.data.squeeze()

            if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()):
                continue

            # this is not correct, need to use spw
            spw = ds.DATA_DESC_ID

            uvw = clone(ds.UVW.data)

            data_type = getattr(ds, args.data_column).data.dtype
            data_shape = getattr(ds, args.data_column).data.shape
            data_chunks = getattr(ds, args.data_column).data.chunks

            weights = getattr(ds, args.weight_column).data
            if len(weights.shape) < 3:
                weights = da.broadcast_to(weights[:, None, :],
                                          data_shape,
                                          chunks=data_chunks)

            if args.imaging_weight_column is not None:
                imaging_weights = getattr(ds, args.imaging_weight_column).data
                if len(imaging_weights.shape) < 3:
                    imaging_weights = da.broadcast_to(imaging_weights[:,
                                                                      None, :],
                                                      data_shape,
                                                      chunks=data_chunks)

                weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0]
                weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1]
            else:
                weightsxx = weights[:, :, 0]
                weightsyy = weights[:, :, -1]

            # apply mueller term
            if args.mueller_column is not None:
                mueller = getattr(ds, args.mueller_column).data
                weightsxx *= da.absolute(mueller[:, :, 0])**2
                weightsyy *= da.absolute(mueller[:, :, -1])**2

            # weighted sum corr to Stokes I
            weights = weightsxx + weightsyy

            # MS may contain auto-correlations
            if 'FLAG_ROW' in xds[0]:
                frow = ds.FLAG_ROW.data | (ds.ANTENNA1.data
                                           == ds.ANTENNA2.data)
            else:
                frow = (ds.ANTENNA1.data == ds.ANTENNA2.data)

            # only keep data where both corrs are unflagged
            flag = getattr(ds, args.flag_column).data
            flagxx = flag[:, :, 0]
            flagyy = flag[:, :, -1]
            # ducc0 uses uint8 mask not flag
            mask = ~da.logical_or((flagxx | flagyy), frow[:, None])

            psf = vis2im(uvw,
                         freqs[ims][spw],
                         weights.astype(data_type),
                         freq_bin_idx[ims][spw],
                         freq_bin_counts[ims][spw],
                         nx,
                         ny,
                         cell_rad,
                         flag=mask.astype(np.uint8),
                         nthreads=args.nvthreads,
                         epsilon=args.epsilon,
                         do_wstacking=args.wstack,
                         double_accum=args.double_accum)

            psfs.append(psf)

            data_vars = {
                'FIELD_ID': (('row', ),
                             da.full_like(ds.TIME.data,
                                          ds.FIELD_ID,
                                          chunks=args.row_out_chunk)),
                'DATA_DESC_ID': (('row', ),
                                 da.full_like(ds.TIME.data,
                                              ds.DATA_DESC_ID,
                                              chunks=args.row_out_chunk)),
                'WEIGHT':
                (('row', 'chan'), weights.rechunk({0: args.row_out_chunk
                                                   })),  # why no 'f4'?
                'UVW': (('row', 'uvw'), uvw.rechunk({0: args.row_out_chunk}))
            }

            coords = {'chan': (('chan', ), freqs[ims][spw])}

            out_ds = Dataset(data_vars, coords)

            out_datasets.append(out_ds)

    writes = xds_to_zarr(out_datasets,
                         args.output_filename + '.zarr',
                         columns='ALL')

    # dask.visualize(writes, filename=args.output_filename + '_psf_writes_graph.pdf', optimize_graph=False)
    # dask.visualize(psfs, filename=args.output_filename + '_psf_graph.pdf', optimize_graph=False)

    if not args.mock:
        # psfs = dask.compute(psfs, writes, optimize_graph=False)[0]
        # with performance_report(filename=args.output_filename + '_psf_per.html'):
        psfs = dask.compute(psfs, writes, optimize_graph=False)[0]

        psf = stitch_images(psfs, nband, band_mapping)

        hdr = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec,
                      freq_out)
        save_fits(args.output_filename + '_psf.fits',
                  psf,
                  hdr,
                  dtype=args.output_type)

        psf_mfs = np.sum(psf, axis=0)
        wsum = psf_mfs.max()
        psf_mfs /= wsum

        hdr_mfs = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec,
                          np.mean(freq_out))
        save_fits(args.output_filename + '_psf_mfs.fits',
                  psf_mfs,
                  hdr_mfs,
                  dtype=args.output_type)

    print("All done here.", file=log)
Example #6
0
def ms_create(ms_table_name, info, ant_pos, vis_array, baselines, timestamps, pol_feeds, sources):
    ''' Create a Measurement Set from some TART observations
    
    Parameters
    ----------
    
    ms_table_name : string
        The name of the MS top level directory. I think this only workds in 
        the local directory.
    
    info : JSON
        "info": {
            "info": {
                "L0_frequency": 1571328000.0,
                "bandwidth": 2500000.0,
                "baseband_frequency": 4092000.0,
                "location": {
                    "alt": 270.0,
                    "lat": -45.85177,
                    "lon": 170.5456
                },
                "name": "Signal Hill - Dunedin",
                "num_antenna": 24,
                "operating_frequency": 1575420000.0,
                "sampling_frequency": 16368000.0
            }
        },

    Returns
    -------
    None
      
    '''

    epoch_s = timestamp_to_ms_epoch(timestamps)
    LOGGER.info("Time {}".format(epoch_s))

    try:
        loc = info['location']
    except:
        loc = info
    # Sort out the coordinate frames using astropy
    # https://casa.nrao.edu/casadocs/casa-5.4.1/reference-material/coordinate-frames
    iers.conf.iers_auto_url = 'https://astroconda.org/aux/astropy_mirror/iers_a_1/finals2000A.all' 
    iers.conf.auto_max_age = None 

    location = EarthLocation.from_geodetic(lon=loc['lon']*u.deg,
                                           lat=loc['lat']*u.deg,
                                           height=loc['alt']*u.m,
                                           ellipsoid='WGS84')
    obstime = Time(timestamps)
    local_frame = AltAz(obstime=obstime, location=location)

    phase_altaz = SkyCoord(alt=90.0*u.deg, az=0.0*u.deg, obstime = obstime, frame = 'altaz', location = location)
    phase_j2000 = phase_altaz.transform_to('fk5')

    # Get the stokes enums for the polarization types
    corr_types = [[MS_STOKES_ENUMS[p_f] for p_f in pol_feeds]]

    LOGGER.info("Pol Feeds {}".format(pol_feeds))
    LOGGER.info("Correlation Types {}".format(corr_types))
    num_freq_channels = [1]

    ant_table = MSTable(ms_table_name, 'ANTENNA')
    feed_table = MSTable(ms_table_name, 'FEED')
    field_table = MSTable(ms_table_name, 'FIELD')
    pol_table = MSTable(ms_table_name, 'POLARIZATION')
    obs_table = MSTable(ms_table_name, 'OBSERVATION')
    # SOURCE is an optional MS sub-table
    src_table = MSTable(ms_table_name, 'SOURCE')
    
    ddid_table_name = "::".join((ms_table_name, "DATA_DESCRIPTION"))
    spw_table_name = "::".join((ms_table_name, "SPECTRAL_WINDOW"))

    ms_datasets = []
    ddid_datasets = []
    spw_datasets = []

    # Create ANTENNA dataset
    # Each column in the ANTENNA has a fixed shape so we
    # can represent all rows with one dataset
    num_ant = len(ant_pos)
    position = da.asarray(ant_pos)
    diameter = da.ones(num_ant) * 0.025
    offset = da.zeros((num_ant, 3))
    names = np.array(['ANTENNA-%d' % i for i in range(num_ant)], dtype=np.object)
    stations = np.array([info['name'] for i in range(num_ant)], dtype=np.object)

    dataset = Dataset({
        'POSITION': (("row", "xyz"), position),
        'OFFSET': (("row", "xyz"), offset),
        'DISH_DIAMETER': (("row",), diameter),
        'NAME': (("row",), da.from_array(names, chunks=num_ant)),
        'STATION': (("row",), da.from_array(stations, chunks=num_ant)),
    })
    ant_table.append(dataset)

    ###################  Create a FEED dataset. ###################################
    # There is one feed per antenna, so this should be quite similar to the ANTENNA
    num_pols = len(pol_feeds)
    pol_types = pol_feeds
    pol_responses = [POL_RESPONSES[ct] for ct in pol_feeds]

    LOGGER.info("Pol Types {}".format(pol_types))
    LOGGER.info("Pol Responses {}".format(pol_responses))

    antenna_ids = da.asarray(range(num_ant))
    feed_ids = da.zeros(num_ant)
    num_receptors = da.zeros(num_ant) + num_pols
    polarization_types = np.array([pol_types for i in range(num_ant)], dtype=np.object)
    receptor_angles = np.array([[0.0] for i in range(num_ant)])
    pol_response = np.array([pol_responses for i in range(num_ant)])

    beam_offset = np.array([[[0.0, 0.0]] for i in range(num_ant)])

    dataset = Dataset({
        'ANTENNA_ID': (("row",), antenna_ids),
        'FEED_ID': (("row",), feed_ids),
        'NUM_RECEPTORS': (("row",), num_receptors),
        'POLARIZATION_TYPE': (("row", "receptors",),
                              da.from_array(polarization_types, chunks=num_ant)),
        'RECEPTOR_ANGLE': (("row", "receptors",),
                           da.from_array(receptor_angles, chunks=num_ant)),
        'POL_RESPONSE': (("row", "receptors", "receptors-2"),
                         da.from_array(pol_response, chunks=num_ant)),
        'BEAM_OFFSET': (("row", "receptors", "radec"),
                        da.from_array(beam_offset, chunks=num_ant)),
    })
    feed_table.append(dataset)


    ####################### FIELD dataset #########################################
    
    direction = [[phase_j2000.ra.radian, phase_j2000.dec.radian]]
    field_direction = da.asarray(direction)[None, :]
    field_name = da.asarray(np.asarray(['up'], dtype=np.object), chunks=1)
    field_num_poly = da.zeros(1) # Zero order polynomial in time for phase center.

    dir_dims = ("row", 'field-poly', 'field-dir',)

    dataset = Dataset({
        'PHASE_DIR': (dir_dims, field_direction),
        'DELAY_DIR': (dir_dims, field_direction),
        'REFERENCE_DIR': (dir_dims, field_direction),
        'NUM_POLY': (("row", ), field_num_poly),
        'NAME': (("row", ), field_name),
    })
    field_table.append(dataset)

   ######################### OBSERVATION dataset #####################################

    dataset = Dataset({
        'TELESCOPE_NAME': (("row",), da.asarray(np.asarray(['TART'], dtype=np.object), chunks=1)),
        'OBSERVER': (("row",), da.asarray(np.asarray(['Tim'], dtype=np.object), chunks=1)),
        "TIME_RANGE": (("row","obs-exts"), da.asarray(np.array([[epoch_s, epoch_s+1]]), chunks=1)),
    })
    obs_table.append(dataset)

    ######################## SOURCE datasets ########################################
    for src in sources:
        name = src['name']
        # Convert to J2000 
        dir_altaz = SkyCoord(alt=src['el']*u.deg, az=src['az']*u.deg, obstime = obstime,
                             frame = 'altaz', location = location)
        dir_j2000 = dir_altaz.transform_to('fk5')
        direction = [dir_j2000.ra.radian, dir_j2000.dec.radian]
        #LOGGER.info("SOURCE: {}, timestamp: {}".format(name, timestamps))
        dask_num_lines = da.full((1,), 1, dtype=np.int32)
        dask_direction = da.asarray(direction)[None, :]
        dask_name = da.asarray(np.asarray([name], dtype=np.object), chunks=1)
        dask_time = da.asarray(np.array([epoch_s]))
        dataset = Dataset({
            "NUM_LINES": (("row",), dask_num_lines),
            "NAME": (("row",), dask_name),
            "TIME": (("row",), dask_time),
            "DIRECTION": (("row", "dir"), dask_direction),
            })
        src_table.append(dataset)

    # Create POLARISATION datasets.
    # Dataset per output row required because column shapes are variable

    for corr_type in corr_types:
        corr_prod = [[i, i] for i in range(len(corr_type))]

        corr_prod = np.array(corr_prod)
        LOGGER.info("Corr Prod {}".format(corr_prod))
        LOGGER.info("Corr Type {}".format(corr_type))

        dask_num_corr = da.full((1,), len(corr_type), dtype=np.int32)
        LOGGER.info("NUM_CORR {}".format(dask_num_corr))
        dask_corr_type = da.from_array(corr_type,
                                       chunks=len(corr_type))[None, :]
        dask_corr_product = da.asarray(corr_prod)[None, :]
        LOGGER.info("Dask Corr Prod {}".format(dask_corr_product.shape))
        LOGGER.info("Dask Corr Type {}".format(dask_corr_type.shape))
        dataset = Dataset({
            "NUM_CORR": (("row",), dask_num_corr),
            "CORR_TYPE": (("row", "corr"), dask_corr_type),
            "CORR_PRODUCT": (("row", "corr", "corrprod_idx"), dask_corr_product),
        })

        pol_table.append(dataset)

    # Create multiple SPECTRAL_WINDOW datasets
    # Dataset per output row required because column shapes are variable

    for num_chan in num_freq_channels:
        dask_num_chan = da.full((1,), num_chan, dtype=np.int32)
        dask_chan_freq = da.asarray([[info['operating_frequency']]])
        dask_chan_width = da.full((1, num_chan), 2.5e6/num_chan)

        dataset = Dataset({
            "NUM_CHAN": (("row",), dask_num_chan),
            "CHAN_FREQ": (("row", "chan"), dask_chan_freq),
            "CHAN_WIDTH": (("row", "chan"), dask_chan_width),
            "EFFECTIVE_BW": (("row", "chan"), dask_chan_width),
            "RESOLUTION": (("row", "chan"), dask_chan_width),
        })

        spw_datasets.append(dataset)

    # For each cartesian product of SPECTRAL_WINDOW and POLARIZATION
    # create a corresponding DATA_DESCRIPTION.
    # Each column has fixed shape so we handle all rows at once
    spw_ids, pol_ids = zip(*product(range(len(num_freq_channels)),
                                    range(len(corr_types))))
    dask_spw_ids = da.asarray(np.asarray(spw_ids, dtype=np.int32))
    dask_pol_ids = da.asarray(np.asarray(pol_ids, dtype=np.int32))
    ddid_datasets.append(Dataset({
        "SPECTRAL_WINDOW_ID": (("row",), dask_spw_ids),
        "POLARIZATION_ID": (("row",), dask_pol_ids),
    }))

    # Now create the associated MS dataset

    #vis_data, baselines = cal_vis.get_all_visibility()
    #vis_array = np.array(vis_data, dtype=np.complex64)
    chunks = {
        "row": (vis_array.shape[0],),
    }
    baselines = np.array(baselines)
    #LOGGER.info(f"baselines {baselines}")
    bl_pos = np.array(ant_pos)[baselines]
    uu_a, vv_a, ww_a = -(bl_pos[:, 1] - bl_pos[:, 0]).T #/constants.L1_WAVELENGTH
    # Use the - sign to get the same orientation as our tart projections.

    uvw_array = np.array([uu_a, vv_a, ww_a]).T

    for ddid, (spw_id, pol_id) in enumerate(zip(spw_ids, pol_ids)):
        # Infer row, chan and correlation shape
        #LOGGER.info("ddid:{} ({}, {})".format(ddid, spw_id, pol_id))
        row = sum(chunks['row'])
        chan = spw_datasets[spw_id].CHAN_FREQ.shape[1]
        corr = pol_table.datasets[pol_id].CORR_TYPE.shape[1]

        # Create some dask vis data
        dims = ("row", "chan", "corr")
        LOGGER.info("Data size %s %s %s" % (row, chan, corr))

        #np_data = vis_array.reshape((row, chan, corr))
        np_data = np.zeros((row, chan, corr), dtype=np.complex128)
        for i in range(corr):
            np_data[:, :, i] = vis_array.reshape((row, chan))
        #np_data = np.array([vis_array.reshape((row, chan, 1)) for i in range(corr)])
        np_uvw = uvw_array.reshape((row, 3))

        data_chunks = tuple((chunks['row'], chan, corr))
        dask_data = da.from_array(np_data, chunks=data_chunks)
        
        flag_categories = da.from_array(0.05*np.ones((row, chan, corr, 1)))
        flag_data = np.zeros((row, chan, corr), dtype=np.bool_)

        uvw_data = da.from_array(np_uvw)
        # Create dask ddid column
        dask_ddid = da.full(row, ddid, chunks=chunks['row'], dtype=np.int32)
        dataset = Dataset({
            'DATA': (dims, dask_data),
            'FLAG': (dims, da.from_array(flag_data)),
            'TIME': (("row", "corr"), da.from_array(epoch_s*np.ones((row, corr)))),
            'TIME_CENTROID': (("row", "corr"), da.from_array(epoch_s*np.ones((row, corr)))),
            'WEIGHT': (("row", "corr"), da.from_array(0.95*np.ones((row, corr)))),
            'WEIGHT_SPECTRUM': (dims, da.from_array(0.95*np.ones_like(np_data, dtype=np.float64))),
            'SIGMA_SPECTRUM': (dims, da.from_array(np.ones_like(np_data, dtype=np.float64)*0.05)),
            'SIGMA': (("row", "corr"), da.from_array(0.05*np.ones((row, corr)))),
            'UVW': (("row", "uvw",), uvw_data),
            'FLAG_CATEGORY': (('row', 'flagcat', 'chan', 'corr'), flag_categories), # {'dims': ('flagcat', 'chan', 'corr')}
            'ANTENNA1': (("row",), da.from_array(baselines[:, 0])),
            'ANTENNA2': (("row",), da.from_array(baselines[:, 1])),
            'FEED1': (("row",), da.from_array(baselines[:, 0])),
            'FEED2': (("row",), da.from_array(baselines[:, 1])),
            'DATA_DESC_ID': (("row",), dask_ddid)
        })
        ms_datasets.append(dataset)

    ms_writes = xds_to_table(ms_datasets, ms_table_name, columns="ALL")
    spw_writes = xds_to_table(spw_datasets, spw_table_name, columns="ALL")
    ddid_writes = xds_to_table(ddid_datasets, ddid_table_name, columns="ALL")

    dask.compute(ms_writes)

    ant_table.write()
    feed_table.write()
    field_table.write()
    pol_table.write()
    obs_table.write()
    src_table.write()

    dask.compute(spw_writes)
    dask.compute(ddid_writes)
Example #7
0
def output_dataset(avg, field_id, data_desc_id, scan_number, group_row_chunks):
    """
    Parameters
    ----------
    avg : namedtuple
        Result of :func:`average`
    field_id : int
        FIELD_ID for this averaged data
    data_desc_id : int
        DATA_DESC_ID for this averaged data
    scan_number : int
        SCAN_NUMBER for this averaged data

    Returns
    -------
    Dataset
        Dataset containing averaged data
    """
    # Create ID columns
    field_id = id_full_like(avg.time, fill_value=field_id)
    data_desc_id = id_full_like(avg.time, fill_value=data_desc_id)
    scan_number = id_full_like(avg.time, fill_value=scan_number)

    # Single flag category, equal to flags
    flag_cats = avg.flag[:, None, :, :]

    out_ds = {
        # Explicitly zero these columns? But this happens anyway
        # "ARRAY_ID": (("row",), zeros),
        # "OBSERVATION_ID": (("row",), zeros),
        # "PROCESSOR_ID": (("row",), zeros),
        # "STATE_ID": (("row",), zeros),
        "ANTENNA1": (("row", ), avg.antenna1),
        "ANTENNA2": (("row", ), avg.antenna2),
        "DATA_DESC_ID": (("row", ), data_desc_id),
        "FIELD_ID": (("row", ), field_id),
        "SCAN_NUMBER": (("row", ), scan_number),
        "FLAG_ROW": (("row", ), avg.flag_row),
        "FLAG_CATEGORY": (("row", "flagcat", "chan", "corr"), flag_cats),
        "TIME": (("row", ), avg.time),
        "INTERVAL": (("row", ), avg.interval),
        "TIME_CENTROID": (("row", ), avg.time_centroid),
        "EXPOSURE": (("row", ), avg.exposure),
        "UVW": (("row", "[uvw]"), avg.uvw),
        "WEIGHT": (("row", "corr"), avg.weight),
        "SIGMA": (("row", "corr"), avg.sigma),
        "DATA": (("row", "chan", "corr"), avg.vis),
        "FLAG": (("row", "chan", "corr"), avg.flag),
    }

    # Add optionally averaged columns columns
    if avg.weight_spectrum is not None:
        out_ds['WEIGHT_SPECTRUM'] = (("row", "chan", "corr"),
                                     avg.weight_spectrum)

    if avg.sigma_spectrum is not None:
        out_ds['SIGMA_SPECTRUM'] = (("row", "chan", "corr"),
                                    avg.sigma_spectrum)

    # Concatenate row chunks together
    if group_row_chunks > 1:
        grc = group_row_chunks
        out_ds = {
            k: (dims, concatenate_row_chunks(data, group_every=grc))
            for k, (dims, data) in out_ds.items()
        }

    return Dataset(out_ds)
Example #8
0
def test_ms_create(Dataset, tmp_path, chunks, num_chans, corr_types, sources):
    # Set up
    rs = np.random.RandomState(42)

    ms_path = tmp_path / "create.ms"

    ms_table_name = str(ms_path)
    ant_table_name = "::".join((ms_table_name, "ANTENNA"))
    ddid_table_name = "::".join((ms_table_name, "DATA_DESCRIPTION"))
    pol_table_name = "::".join((ms_table_name, "POLARIZATION"))
    spw_table_name = "::".join((ms_table_name, "SPECTRAL_WINDOW"))
    # SOURCE is an optional MS sub-table
    src_table_name = "::".join((ms_table_name, "SOURCE"))

    ms_datasets = []
    ant_datasets = []
    ddid_datasets = []
    pol_datasets = []
    spw_datasets = []
    src_datasets = []

    # For comparison
    all_data_desc_id = []
    all_data = []

    # Create ANTENNA dataset of 64 antennas
    # Each column in the ANTENNA has a fixed shape so we
    # can represent all rows with one dataset
    na = 64
    position = da.random.random((na, 3)) * 10000
    offset = da.random.random((na, 3))
    names = np.array(['ANTENNA-%d' % i for i in range(na)], dtype=np.object)
    ds = Dataset({
        'POSITION': (("row", "xyz"), position),
        'OFFSET': (("row", "xyz"), offset),
        'NAME': (("row", ), da.from_array(names, chunks=na)),
    })
    ant_datasets.append(ds)

    # Create SOURCE datasets
    for s, (name, direction, rest_freq) in enumerate(sources):
        dask_num_lines = da.full((1, ), len(rest_freq), dtype=np.int32)
        dask_direction = da.asarray(direction)[None, :]
        dask_rest_freq = da.asarray(rest_freq)[None, :]
        dask_name = da.asarray(np.asarray([name], dtype=np.object))
        ds = Dataset({
            "NUM_LINES": (("row", ), dask_num_lines),
            "NAME": (("row", ), dask_name),
            "REST_FREQUENCY": (("row", "line"), dask_rest_freq),
            "DIRECTION": (("row", "dir"), dask_direction),
        })
        src_datasets.append(ds)

    # Create POLARISATION datasets.
    # Dataset per output row required because column shapes are variable
    for r, corr_type in enumerate(corr_types):
        dask_num_corr = da.full((1, ), len(corr_type), dtype=np.int32)
        dask_corr_type = da.from_array(corr_type,
                                       chunks=len(corr_type))[None, :]
        ds = Dataset({
            "NUM_CORR": (("row", ), dask_num_corr),
            "CORR_TYPE": (("row", "corr"), dask_corr_type),
        })

        pol_datasets.append(ds)

    # Create multiple MeerKAT L-band SPECTRAL_WINDOW datasets
    # Dataset per output row required because column shapes are variable
    for num_chan in num_chans:
        dask_num_chan = da.full((1, ), num_chan, dtype=np.int32)
        dask_chan_freq = da.linspace(.856e9,
                                     2 * .856e9,
                                     num_chan,
                                     chunks=num_chan)[None, :]
        dask_chan_width = da.full((1, num_chan), .856e9 / num_chan)

        ds = Dataset({
            "NUM_CHAN": (("row", ), dask_num_chan),
            "CHAN_FREQ": (("row", "chan"), dask_chan_freq),
            "CHAN_WIDTH": (("row", "chan"), dask_chan_width),
        })

        spw_datasets.append(ds)

    # For each cartesian product of SPECTRAL_WINDOW and POLARIZATION
    # create a corresponding DATA_DESCRIPTION.
    # Each column has fixed shape so we handle all rows at once
    spw_ids, pol_ids = zip(
        *product(range(len(num_chans)), range(len(corr_types))))
    dask_spw_ids = da.asarray(np.asarray(spw_ids, dtype=np.int32))
    dask_pol_ids = da.asarray(np.asarray(pol_ids, dtype=np.int32))
    ddid_datasets.append(
        Dataset({
            "SPECTRAL_WINDOW_ID": (("row", ), dask_spw_ids),
            "POLARIZATION_ID": (("row", ), dask_pol_ids),
        }))

    # Now create the associated MS dataset
    for ddid, (spw_id, pol_id) in enumerate(zip(spw_ids, pol_ids)):
        # Infer row, chan and correlation shape
        row = sum(chunks['row'])
        chan = spw_datasets[spw_id].CHAN_FREQ.shape[1]
        corr = pol_datasets[pol_id].CORR_TYPE.shape[1]

        # Create some dask vis data
        dims = ("row", "chan", "corr")
        np_data = (rs.normal(size=(row, chan, corr)) +
                   1j * rs.normal(size=(row, chan, corr))).astype(np.complex64)

        data_chunks = tuple((chunks['row'], chan, corr))
        dask_data = da.from_array(np_data, chunks=data_chunks)
        # Create dask ddid column
        dask_ddid = da.full(row, ddid, chunks=chunks['row'], dtype=np.int32)
        dataset = Dataset({
            'DATA': (dims, dask_data),
            'DATA_DESC_ID': (("row", ), dask_ddid)
        })
        ms_datasets.append(dataset)
        all_data.append(dask_data)
        all_data_desc_id.append(dask_ddid)

    ms_writes = xds_to_table(ms_datasets, ms_table_name, columns="ALL")
    ant_writes = xds_to_table(ant_datasets, ant_table_name, columns="ALL")
    pol_writes = xds_to_table(pol_datasets, pol_table_name, columns="ALL")
    spw_writes = xds_to_table(spw_datasets, spw_table_name, columns="ALL")
    ddid_writes = xds_to_table(ddid_datasets, ddid_table_name, columns="ALL")
    source_writes = xds_to_table(src_datasets, src_table_name, columns="ALL")

    dask.compute(ms_writes, ant_writes, pol_writes, spw_writes, ddid_writes,
                 source_writes)

    # Check ANTENNA table correctly created
    with pt.table(ant_table_name, ack=False) as A:
        assert_array_equal(A.getcol("NAME"), names)
        assert_array_equal(A.getcol("POSITION"), position)
        assert_array_equal(A.getcol("OFFSET"), offset)

        required_desc = pt.required_ms_desc("ANTENNA")
        required_columns = set(k for k in required_desc.keys()
                               if not k.startswith("_"))

        assert set(A.colnames()) == set(required_columns)

    # Check POLARIZATION table correctly created
    with pt.table(pol_table_name, ack=False) as P:
        for r, corr_type in enumerate(corr_types):
            assert_array_equal(P.getcol("CORR_TYPE", startrow=r, nrow=1),
                               [corr_type])
            assert_array_equal(P.getcol("NUM_CORR", startrow=r, nrow=1),
                               [len(corr_type)])

        required_desc = pt.required_ms_desc("POLARIZATION")
        required_columns = set(k for k in required_desc.keys()
                               if not k.startswith("_"))

        assert set(P.colnames()) == set(required_columns)

    # Check SPECTRAL_WINDOW table correctly created
    with pt.table(spw_table_name, ack=False) as S:
        for r, num_chan in enumerate(num_chans):
            assert_array_equal(
                S.getcol("NUM_CHAN", startrow=r, nrow=1)[0], num_chan)
            assert_array_equal(
                S.getcol("CHAN_FREQ", startrow=r, nrow=1)[0],
                np.linspace(.856e9, 2 * .856e9, num_chan))
            assert_array_equal(
                S.getcol("CHAN_WIDTH", startrow=r, nrow=1)[0],
                np.full(num_chan, .856e9 / num_chan))

        required_desc = pt.required_ms_desc("SPECTRAL_WINDOW")
        required_columns = set(k for k in required_desc.keys()
                               if not k.startswith("_"))

        assert set(S.colnames()) == set(required_columns)

    # We should get a cartesian product out
    with pt.table(ddid_table_name, ack=False) as D:
        spw_id, pol_id = zip(
            *product(range(len(num_chans)), range(len(corr_types))))
        assert_array_equal(pol_id, D.getcol("POLARIZATION_ID"))
        assert_array_equal(spw_id, D.getcol("SPECTRAL_WINDOW_ID"))

        required_desc = pt.required_ms_desc("DATA_DESCRIPTION")
        required_columns = set(k for k in required_desc.keys()
                               if not k.startswith("_"))

        assert set(D.colnames()) == set(required_columns)

    with pt.table(src_table_name, ack=False) as S:
        for r, (name, direction, rest_freq) in enumerate(sources):
            assert_array_equal(S.getcol("NAME", startrow=r, nrow=1)[0], [name])
            assert_array_equal(S.getcol("REST_FREQUENCY", startrow=r, nrow=1),
                               [rest_freq])
            assert_array_equal(S.getcol("DIRECTION", startrow=r, nrow=1),
                               [direction])

    with pt.table(ms_table_name, ack=False) as T:
        # DATA_DESC_ID's are all the same shape
        assert_array_equal(T.getcol("DATA_DESC_ID"),
                           da.concatenate(all_data_desc_id))

        # DATA is variably shaped (on DATA_DESC_ID) so we
        # compared each one separately.
        for ddid, data in enumerate(all_data):
            ms_data = T.getcol("DATA", startrow=ddid * row, nrow=row)
            assert_array_equal(ms_data, data)

        required_desc = pt.required_ms_desc()
        required_columns = set(k for k in required_desc.keys()
                               if not k.startswith("_"))

        # Check we have the required columns
        assert set(T.colnames()) == required_columns.union(
            ["DATA", "DATA_DESC_ID"])
Example #9
0
def output_dataset(avg,
                   field_id,
                   data_desc_id,
                   scan_number,
                   group_row_chunks,
                   viscolumn=""):
    """
    Parameters
    ----------
    avg : namedtuple
        Result of :func:`average`
    field_id : int
        FIELD_ID for this averaged data
    data_desc_id : int
        DATA_DESC_ID for this averaged data
    scan_number : int
        SCAN_NUMBER for this averaged data
    group_row_chunks : int
        Concatenate row chunks group

    Returns
    -------
    Dataset
        Dataset containing averaged data
    """
    # Create ID columns
    fid = field_id
    ddid = data_desc_id
    scn = scan_number
    field_id = id_full_like(avg.time, fill_value=field_id)
    data_desc_id = id_full_like(avg.time, fill_value=data_desc_id)
    scan_number = id_full_like(avg.time, fill_value=scan_number)

    # Single flag category, equal to flags
    flag_cats = avg.flag[:, None, :, :]

    out_ds = {
        "ANTENNA1": (("row", ), avg.antenna1),
        "ANTENNA2": (("row", ), avg.antenna2),
        "DATA_DESC_ID": (("row", ), data_desc_id),
        "FIELD_ID": (("row", ), field_id),
        "SCAN_NUMBER": (("row", ), scan_number),
        "FLAG_ROW": (("row", ), avg.flag_row),
        # "FLAG_CATEGORY": (("row", "flagcat", "chan", "corr"), flag_cats),
        "TIME": (("row", ), avg.time),
        "INTERVAL": (("row", ), avg.interval),
        "TIME_CENTROID": (("row", ), avg.time_centroid),
        "EXPOSURE": (("row", ), avg.exposure),
        "UVW": (("row", "uvw"), avg.uvw),
        "WEIGHT": (("row", "corr"), avg.weight),
        "SIGMA": (("row", "corr"), avg.sigma),
        viscolumn: (("row", "chan", "corr"), avg.vis),
        "FLAG": (("row", "chan", "corr"), avg.flag),
    }

    # Add optionally averaged columns columns
    if avg.weight_spectrum is not None:
        out_ds["WEIGHT_SPECTRUM"] = (("row", "chan", "corr"),
                                     avg.weight_spectrum)

    if avg.sigma_spectrum is not None:
        out_ds["SIGMA_SPECTRUM"] = (("row", "chan", "corr"),
                                    avg.sigma_spectrum)

    # Concatenate row chunks together
    if group_row_chunks > 1:
        grc = group_row_chunks
        # Remove items whose values are None
        out_ds = {
            k: (dims, concatenate_row_chunks(data, group_every=grc))
            for k, (dims, data) in out_ds.items() if data is not None
        }

    return Dataset(out_ds,
                   attrs={
                       "DATA_DESC_ID": ddid,
                       "FIELD_ID": fid,
                       "SCAN_NUMBER": scn
                   })
Example #10
0
def ms_create(ms_table_name, info, ant_pos, cal_vis, timestamps, corr_types,
              sources):
    '''    "info": {
        "info": {
            "L0_frequency": 1571328000.0,
            "bandwidth": 2500000.0,
            "baseband_frequency": 4092000.0,
            "location": {
                "alt": 270.0,
                "lat": -45.85177,
                "lon": 170.5456
            },
            "name": "Signal Hill - Dunedin",
            "num_antenna": 24,
            "operating_frequency": 1575420000.0,
            "sampling_frequency": 16368000.0
        }
    },
    '''

    num_chans = [1]

    rs = np.random.RandomState(42)

    ant_table_name = "::".join((ms_table_name, "ANTENNA"))
    feed_table_name = "::".join((ms_table_name, "FEED"))
    field_table_name = "::".join((ms_table_name, "FIELD"))
    obs_table_name = "::".join((ms_table_name, "OBSERVATION"))
    ddid_table_name = "::".join((ms_table_name, "DATA_DESCRIPTION"))
    pol_table_name = "::".join((ms_table_name, "POLARIZATION"))
    spw_table_name = "::".join((ms_table_name, "SPECTRAL_WINDOW"))
    # SOURCE is an optional MS sub-table
    src_table_name = "::".join((ms_table_name, "SOURCE"))

    ms_datasets = []
    ant_datasets = []
    feed_datasets = []
    field_datasets = []
    obs_datasets = []
    ddid_datasets = []
    pol_datasets = []
    spw_datasets = []
    src_datasets = []

    # Create ANTENNA dataset
    # Each column in the ANTENNA has a fixed shape so we
    # can represent all rows with one dataset
    na = len(ant_pos)
    position = da.asarray(ant_pos)  # da.zeros((na, 3))
    diameter = da.ones(na) * 0.025
    offset = da.zeros((na, 3))
    names = np.array(['ANTENNA-%d' % i for i in range(na)], dtype=np.object)
    stations = np.array([info['name'] for i in range(na)], dtype=np.object)

    ds = Dataset({
        'POSITION': (("row", "xyz"), position),
        'OFFSET': (("row", "xyz"), offset),
        'DISH_DIAMETER': (("row", ), diameter),
        'NAME': (("row", ), da.from_array(names, chunks=na)),
        'STATION': (("row", ), da.from_array(stations, chunks=na)),
    })
    ant_datasets.append(ds)

    #########################################  Create a FEED dataset. #######################################
    # There is one feed per antenna, so this should be quite similar to the ANTENNA
    antenna_ids = da.asarray(range(na))
    feed_ids = da.zeros(na)
    num_receptors = da.ones(na)
    polarization_types = np.array([['XX'] for i in range(na)], dtype=np.object)
    receptor_angles = np.array([[0.0] for i in range(na)])
    pol_response = np.array([[[0.0 + 1.0j, 1.0 - 1.0j]] for i in range(na)])

    beam_offset = np.array([[[0.0, 0.0]] for i in range(na)])

    ds = Dataset({
        'ANTENNA_ID': (("row", ), antenna_ids),
        'FEED_ID': (("row", ), feed_ids),
        'NUM_RECEPTORS': (("row", ), num_receptors),
        'POLARIZATION_TYPE': ((
            "row",
            "receptors",
        ), da.from_array(polarization_types, chunks=na)),
        'RECEPTOR_ANGLE': ((
            "row",
            "receptors",
        ), da.from_array(receptor_angles, chunks=na)),
        'POL_RESPONSE': (("row", "receptors", "receptors-2"),
                         da.from_array(pol_response, chunks=na)),
        'BEAM_OFFSET':
        (("row", "receptors", "radec"), da.from_array(beam_offset, chunks=na)),
    })
    feed_datasets.append(ds)

    ########################################### FIELD dataset ################################################
    direction = [[np.radians(90.0), np.radians(0.0)]]  ## Phase Center in J2000
    field_direction = da.asarray(direction)[None, :]
    field_name = da.asarray(np.asarray(['up'], dtype=np.object))
    field_num_poly = da.zeros(
        1)  # Zero order polynomial in time for phase center.

    dir_dims = (
        "row",
        'field-poly',
        'field-dir',
    )

    ds = Dataset({
        'PHASE_DIR': (dir_dims, field_direction),
        'DELAY_DIR': (dir_dims, field_direction),
        'REFERENCE_DIR': (dir_dims, field_direction),
        'NUM_POLY': (("row", ), field_num_poly),
        'NAME': (("row", ), field_name),
    })
    field_datasets.append(ds)

    ########################################### OBSERVATION dataset ################################################

    ds = Dataset({
        'TELESCOPE_NAME':
        (("row", ), da.asarray(np.asarray(['TART'], dtype=np.object))),
        'OBSERVER': (("row", ), da.asarray(np.asarray(['Tim'],
                                                      dtype=np.object))),
    })
    obs_datasets.append(ds)

    #################################### Create SOURCE datasets #############################################
    for s, src in enumerate(sources):
        name = src['name']
        rest_freq = [info['operating_frequency']]
        direction = [
            np.radians(src['el']),
            np.radians(src['az'])
        ]  ## FIXME these are in elevation and azimuth. Not in J2000.

        #logger.info("SOURCE: {}, timestamp: {}".format(name, timestamps))
        dask_num_lines = da.full((1, ), len(rest_freq), dtype=np.int32)
        dask_direction = da.asarray(direction)[None, :]
        dask_rest_freq = da.asarray(rest_freq)[None, :]
        dask_name = da.asarray(np.asarray([name], dtype=np.object))
        dask_time = da.asarray(np.asarray([timestamps], dtype=np.object))
        ds = Dataset({
            "NUM_LINES": (("row", ), dask_num_lines),
            "NAME": (("row", ), dask_name),
            #"TIME": (("row",), dask_time), # FIXME. Causes an error. Need to sort out TIME data fields
            #"REST_FREQUENCY": (("row", "line"), dask_rest_freq),
            "DIRECTION": (("row", "dir"), dask_direction),
        })
        src_datasets.append(ds)

    # Create POLARISATION datasets.
    # Dataset per output row required because column shapes are variable
    for r, corr_type in enumerate(corr_types):
        dask_num_corr = da.full((1, ), len(corr_type), dtype=np.int32)
        dask_corr_type = da.from_array(corr_type,
                                       chunks=len(corr_type))[None, :]
        dask_corr_type = da.from_array(corr_type,
                                       chunks=len(corr_type))[None, :]
        ds = Dataset({
            "NUM_CORR": (("row", ), dask_num_corr),
            #"CORR_PRODUCT": (("row",), dask_num_corr),
            "CORR_TYPE": (("row", "corr"), dask_corr_type),
        })

        pol_datasets.append(ds)

    # Create multiple SPECTRAL_WINDOW datasets
    # Dataset per output row required because column shapes are variable

    for num_chan in num_chans:
        dask_num_chan = da.full((1, ), num_chan, dtype=np.int32)
        dask_chan_freq = da.asarray([[info['operating_frequency']]])
        dask_chan_width = da.full((1, num_chan), 2.5e6 / num_chan)

        ds = Dataset({
            "NUM_CHAN": (("row", ), dask_num_chan),
            "CHAN_FREQ": (("row", "chan"), dask_chan_freq),
            "CHAN_WIDTH": (("row", "chan"), dask_chan_width),
        })

        spw_datasets.append(ds)

    # For each cartesian product of SPECTRAL_WINDOW and POLARIZATION
    # create a corresponding DATA_DESCRIPTION.
    # Each column has fixed shape so we handle all rows at once
    spw_ids, pol_ids = zip(
        *product(range(len(num_chans)), range(len(corr_types))))
    dask_spw_ids = da.asarray(np.asarray(spw_ids, dtype=np.int32))
    dask_pol_ids = da.asarray(np.asarray(pol_ids, dtype=np.int32))
    ddid_datasets.append(
        Dataset({
            "SPECTRAL_WINDOW_ID": (("row", ), dask_spw_ids),
            "POLARIZATION_ID": (("row", ), dask_pol_ids),
        }))

    # Now create the associated MS dataset

    vis_data, baselines = cal_vis.get_all_visibility()
    vis_array = np.array(vis_data, dtype=np.complex64)
    chunks = {
        "row": (vis_array.shape[0], ),
    }
    baselines = np.array(baselines)
    bl_pos = np.array(ant_pos)[baselines]
    uu_a, vv_a, ww_a = -(
        bl_pos[:, 1] - bl_pos[:, 0]
    ).T / constants.L1_WAVELENGTH  # Use the - sign to get the same orientation as our tart projections.

    uvw_array = np.array([uu_a, vv_a, ww_a]).T

    for ddid, (spw_id, pol_id) in enumerate(zip(spw_ids, pol_ids)):
        # Infer row, chan and correlation shape
        logger.info("ddid:{} ({}, {})".format(ddid, spw_id, pol_id))
        row = sum(chunks['row'])
        chan = spw_datasets[spw_id].CHAN_FREQ.shape[1]
        corr = pol_datasets[pol_id].CORR_TYPE.shape[1]

        # Create some dask vis data
        dims = ("row", "chan", "corr")
        logger.info("Data size {}".format((row, chan, corr)))

        np_data = vis_array.reshape((row, chan, corr))
        np_uvw = uvw_array.reshape((row, 3))

        data_chunks = tuple((chunks['row'], chan, corr))
        dask_data = da.from_array(np_data, chunks=data_chunks)

        uvw_data = da.from_array(np_uvw)
        # Create dask ddid column
        dask_ddid = da.full(row, ddid, chunks=chunks['row'], dtype=np.int32)
        dataset = Dataset({
            'DATA': (dims, dask_data),
            'WEIGHT':
            (("row", "corr"), da.from_array(0.95 * np.ones((row, corr)))),
            'WEIGHT_SPECTRUM':
            (dims,
             da.from_array(0.95 * np.ones_like(np_data, dtype=np.float64))),
            'SIGMA_SPECTRUM':
            (dims,
             da.from_array(np.ones_like(np_data, dtype=np.float64) * 0.05)),
            'UVW': ((
                "row",
                "uvw",
            ), uvw_data),
            'ANTENNA1': (("row", ), da.from_array(baselines[:, 0])),
            'ANTENNA2': (("row", ), da.from_array(baselines[:, 1])),
            'FEED1': (("row", ), da.from_array(baselines[:, 0])),
            'FEED2': (("row", ), da.from_array(baselines[:, 1])),
            'DATA_DESC_ID': (("row", ), dask_ddid)
        })
        ms_datasets.append(dataset)

    ms_writes = xds_to_table(ms_datasets, ms_table_name, columns="ALL")
    ant_writes = xds_to_table(ant_datasets, ant_table_name, columns="ALL")
    feed_writes = xds_to_table(feed_datasets, feed_table_name, columns="ALL")
    field_writes = xds_to_table(field_datasets,
                                field_table_name,
                                columns="ALL")
    obs_writes = xds_to_table(obs_datasets, obs_table_name, columns="ALL")
    pol_writes = xds_to_table(pol_datasets, pol_table_name, columns="ALL")
    spw_writes = xds_to_table(spw_datasets, spw_table_name, columns="ALL")
    ddid_writes = xds_to_table(ddid_datasets, ddid_table_name, columns="ALL")
    source_writes = xds_to_table(src_datasets, src_table_name, columns="ALL")

    dask.compute(ms_writes)
    dask.compute(ant_writes)
    dask.compute(feed_writes)
    dask.compute(field_writes)
    dask.compute(obs_writes)
    dask.compute(pol_writes)
    dask.compute(spw_writes)
    dask.compute(ddid_writes)
    dask.compute(source_writes)
Example #11
0
def output_dataset(avg, field_id, data_desc_id, scan_number,
                   group_row_chunks):
    """
    Parameters
    ----------
    avg : namedtuple
        Result of :func:`average`
    field_id : int
        FIELD_ID for this averaged data
    data_desc_id : int
        DATA_DESC_ID for this averaged data
    scan_number : int
        SCAN_NUMBER for this averaged data

    Returns
    -------
    Dataset
        Dataset containing averaged data
    """
    # Create ID columns
    field_id = id_full_like(avg.time, fill_value=field_id)
    data_desc_id = id_full_like(avg.time, fill_value=data_desc_id)
    scan_number = id_full_like(avg.time, fill_value=scan_number)

    flag_cats = flag_categories(avg.flag)

    zeros = id_full_like(avg.antenna1, fill_value=0)

    out_ds = {
        # Explicitly zero these columns? But this happens anyway
        "ARRAY_ID": (("row",), zeros),
        "OBSERVATION_ID": (("row",), zeros),
        "PROCESSOR_ID": (("row",), zeros),
        "STATE_ID": (("row",), zeros),

        "ANTENNA1": (("row",), avg.antenna1),
        "ANTENNA2": (("row",), avg.antenna2),
        "DATA_DESC_ID": (("row",), data_desc_id),
        "FIELD_ID": (("row",), field_id),
        "SCAN_NUMBER": (("row",), scan_number),

        "FLAG_ROW": (("row",), avg.flag_row),
        "FLAG_CATEGORY": (("row", "flagcat", "chan", "corr"), flag_cats),
        "TIME": (("row",), avg.time),
        "INTERVAL": (("row",), avg.interval),
        "TIME_CENTROID": (("row",), avg.time_centroid),
        "EXPOSURE": (("row",), avg.exposure),
        "UVW": (("row", "[uvw]"), avg.uvw),
        "WEIGHT": (("row", "corr"), avg.weight),
        "SIGMA": (("row", "corr"), avg.sigma),

#        "DATA": (("row", "chan", "corr"), avg.vis),
        "FLAG": (("row", "chan", "corr"), avg.flag),
    }

    if avg.visibilities is not None:
        if type(avg.visibilities) is dict:
            for column, data in avg.visibilities.items():
                out_ds[column] = (("row", "chan", "corr"), data)
        elif isinstance(avg.visibilities, da.Array):
            out_ds["DATA"] = (("row", "chan", "corr"), avg.visibilities)
        else:
            raise TypeError(f"Unknown visibility type {type(avg.visibilities)}")

    if hasattr(avg, "offsets"):
        num_chan = da.map_blocks(np.diff, avg.offsets)
        out_ds['NUM_CHAN'] = (("row",), num_chan)

    if hasattr(avg, "decorr_chan_width"):
        out_ds['DECORR_CHAN_WIDTH'] = (("row",), avg.decorr_chan_width)

    # Add optionally averaged columns columns
    if avg.weight_spectrum is not None:
        out_ds['WEIGHT_SPECTRUM'] = (("row", "chan", "corr"),
                                     avg.weight_spectrum)

    if avg.sigma_spectrum is not None:
        out_ds['SIGMA_SPECTRUM'] = (("row", "chan", "corr"),
                                    avg.sigma_spectrum)

    # Concatenate row chunks together
    if group_row_chunks > 1:
        grc = group_row_chunks
        out_ds = {k: (dims, row_concatenate(data, group_every=grc))
                  for k, (dims, data) in out_ds.items()}

    return Dataset(out_ds)
Example #12
0
def bda_average_spw(out_datasets, ddid_ds, spw_ds):
    """
    Parameters
    ----------
    out_datasets : list of Datasets
        list of Datasets
    ddid_ds : Dataset
        DATA_DESCRIPTION dataset
    spw_ds : list of Datasets
        list of Datasets, each describing a single Spectral Window

    Returns
    -------
    output_ds : list of Datasets
        list of Datasets
    spw_ds : list of Datasets
        list of Datasets, each describing an averaged Spectral Window
    """

    channelisations = []

    # Over the entire set of datasets, determine the complete
    # set of channelisations, per input DDID and
    # reduce down to a single object
    for out_ds in out_datasets:
        transform = da.blockwise(_channelisations, ("row",),
                                 out_ds.DATA_DESC_ID.data, ("row",),
                                 out_ds.NUM_CHAN.data, ("row",),
                                 ddid_ds.SPECTRAL_WINDOW_ID.data, ("ddid",),
                                 ddid_ds.POLARIZATION_ID.data, ("ddid",),
                                 meta=np.empty((0,), dtype=np.object))

        result = da.reduction(transform,
                              chunk=_noop,
                              combine=combine,
                              aggregate=combine,
                              concatenate=False,
                              keepdims=True,
                              meta=np.empty((0,), dtype=np.object),
                              dtype=np.object)

        channelisations.append(result)

    # Final reduction object, note the aggregate method
    # which generates the mapping
    ddid_chan_map = da.reduction(da.concatenate(channelisations),
                                 chunk=_noop,
                                 combine=combine,
                                 aggregate=aggregate,
                                 concatenate=False,
                                 keepdims=False,
                                 meta=np.empty((), dtype=np.object),
                                 dtype=np.object)

    def _squeeze_tuplify(*args):
        return tuple(a.squeeze() for a in args)

    chan_freqs = da.blockwise(_squeeze_tuplify, ("row", "chan"),
                              *(a for spw in spw_ds for a
                                in (spw.CHAN_FREQ.data, ("row", "chan"))),
                              concatenate=False,
                              align_arrays=False,
                              adjust_chunks={"chan": lambda c: np.nan},
                              meta=np.empty((0, 0), dtype=np.object))

    chan_widths = da.blockwise(_squeeze_tuplify, ("row", "chan"),
                               *(a for spw in spw_ds for a
                                 in (spw.CHAN_WIDTH.data, ("row", "chan"))),
                               concatenate=False,
                               align_arrays=False,
                               adjust_chunks={"chan": lambda c: np.nan},
                               meta=np.empty((0, 0), dtype=np.object))

    ref_freqs = da.blockwise(_squeeze_tuplify, ("row",),
                             *(a for spw in spw_ds for a
                               in (spw.REF_FREQUENCY.data, ("row",))),
                             concatenate=False,
                             align_arrays=False,
                             meta=np.empty((0,), dtype=np.object))

    meas_freq_refs = da.blockwise(_squeeze_tuplify, ("row",),
                                  *(a for spw in spw_ds for a
                                    in (spw.REF_FREQUENCY.data, ("row",))),
                                  concatenate=False,
                                  align_arrays=False,
                                  meta=np.empty((0,), dtype=np.object))

    result = da.blockwise(ddid_and_spw_factory, ("row", "chan"),
                          chan_freqs, ("row", "chan"),
                          chan_widths, ("row", "chan"),
                          ref_freqs, ("row",),
                          meas_freq_refs, ("row",),
                          ddid_chan_map, (),
                          meta=np.empty((0, 0), dtype=np.object))

    # There should only be one chunk
    assert result.npartitions == 1

    chan_freq = da.blockwise(getitem, ("row", "chan"),
                             result, ("row", "chan"),
                             0, None,
                             dtype=np.float64)

    chan_width = da.blockwise(getitem, ("row", "chan"),
                              result, ("row", "chan"),
                              1, None,
                              dtype=np.float64)

    num_chan = da.blockwise(lambda d, i: d[0][i], ("row",),
                            result, ("row", "chan"),
                            2, None,
                            dtype=np.int32)

    ref_freq = da.blockwise(lambda d, i: d[0][i], ("row",),
                            result, ("row", "chan"),
                            3, None,
                            dtype=np.float64)

    meas_freq_refs = da.blockwise(lambda d, i: d[0][i], ("row",),
                                  result, ("row", "chan"),
                                  4, None,
                                  dtype=np.float64)

    total_bw = da.blockwise(lambda d, i: d[0][i], ("row",),
                            result, ("row", "chan"),
                            5, None,
                            dtype=np.float64)

    spectral_window_id = da.blockwise(lambda d, i: d[0][i], ("row",),
                                      result, ("row", "chan"),
                                      6, None,
                                      dtype=np.int32)

    polarization_id = da.blockwise(lambda d, i: d[0][i], ("row",),
                                   result, ("row", "chan"),
                                   7, None,
                                   dtype=np.int32)

    ddid_map = da.blockwise(lambda d, i: d[0][i], ("row",),
                            result, ("row", "chan"),
                            8, None,
                            dtype=np.int32)

    for o, out_ds in enumerate(out_datasets):
        data_desc_id = da.blockwise(_new_ddids, ("row",),
                                    out_ds.DATA_DESC_ID.data, ("row",),
                                    out_ds.NUM_CHAN.data, ("row",),
                                    ddid_map, ("ddid",),
                                    dtype=out_ds.DATA_DESC_ID.dtype)

        dv = dict(out_ds.data_vars)
        dv["DATA_DESC_ID"] = (("row",), data_desc_id)
        del dv["NUM_CHAN"]
        del dv["DECORR_CHAN_WIDTH"]

        out_datasets[o] = Dataset(dv, out_ds.coords, out_ds.attrs)

    out_spw_ds = Dataset({
        "CHAN_FREQ": (("row", "chan"), chan_freq),
        "CHAN_WIDTH": (("row", "chan"), chan_width),
        "EFFECTIVE_BW": (("row", "chan"), chan_width),
        "RESOLUTION": (("row", "chan"), chan_width),
        "NUM_CHAN": (("row",), num_chan),
        "REF_FREQUENCY": (("row",), ref_freq),
        "TOTAL_BANDWIDTH": (("row",), total_bw)
    })

    out_ddid_ds = Dataset({
        "SPECTRAL_WINDOW_ID": (("row",), spectral_window_id),
        "POLARIZATION_ID": (("row",), polarization_id),
    })

    return out_datasets, [out_spw_ds], out_ddid_ds