Exemple #1
def test_keyword_write(ms):
    datasets = xds_from_ms(ms)

    # Add to table keywords
    writes = xds_to_table([], ms, [], table_keywords={'bob': 'qux'})

    with pt.table(ms, ack=False, readonly=True) as T:
        assert T.getkeywords()['bob'] == 'qux'

    # Add to column keywords
    writes = xds_to_table(datasets,
                          ms, [],
                          column_keywords={'STATE_ID': {
                              'bob': 'qux'

    with pt.table(ms, ack=False, readonly=True) as T:
        assert T.getcolkeywords("STATE_ID")['bob'] == 'qux'

    # Remove from column and table keywords
    from daskms.writes import DELKW
    writes = xds_to_table(datasets,
                          ms, [],
                          table_keywords={'bob': DELKW},
                          column_keywords={'STATE_ID': {
                              'bob': DELKW

    with pt.table(ms, ack=False, readonly=True) as T:
        assert 'bob' not in T.getkeywords()
        assert 'bob' not in T.getcolkeywords("STATE_ID")
Exemple #2
    def execute(self):
        """ Execute the application """
        logger.info("xova {args}", args=" ".join(self.cmdline_args))

        self.args = args = parse_args(self.cmdline_args)

        row_chunks, time_chunks, interval_chunks = self._derive_row_chunking(

        (main_ds, spw_ds, ddid_ds, field_ds,
         subtables) = self._input_datasets(args, row_chunks)

        # Set up Main MS data averaging
        main_ds = average_main(main_ds,

        main_writes = xds_to_table(main_ds, args.output, "ALL")

        # Set up SPW data averaging
        spw_ds = average_spw(spw_ds, args.chan_bin_size)
        spw_table = "::".join((args.output, "SPECTRAL_WINDOW"))
        spw_writes = xds_to_table(spw_ds, spw_table, "ALL")

        copy_subtables(args.ms, args.output, subtables)

        self._execute_graph(main_writes, spw_writes)
Exemple #3
def test_ms_create_and_update(Dataset, tmp_path, chunks):
    """ Test that we can update and append at the same time """
    filename = str(tmp_path / "create-and-update.ms")

    rs = np.random.RandomState(42)

    # Create a dataset of 10 rows with DATA and DATA_DESC_ID
    dims = ("row", "chan", "corr")
    row, chan, corr = tuple(sum(chunks[d]) for d in dims)
    ms_datasets = []
    np_data = (rs.normal(size=(row, chan, corr)) +
               1j * rs.normal(size=(row, chan, corr))).astype(np.complex64)

    data_chunks = tuple((chunks['row'], chan, corr))
    dask_data = da.from_array(np_data, chunks=data_chunks)
    # Create dask ddid column
    dask_ddid = da.full(row, 0, chunks=chunks['row'], dtype=np.int32)
    dataset = Dataset({
        'DATA': (dims, dask_data),
        'DATA_DESC_ID': (("row", ), dask_ddid),

    # Write it
    writes = xds_to_table(ms_datasets, filename, ["DATA", "DATA_DESC_ID"])

    ms_datasets = xds_from_ms(filename)

    # Now add another dataset (different DDID), with no ROWID
    np_data = (rs.normal(size=(row, chan, corr)) +
               1j * rs.normal(size=(row, chan, corr))).astype(np.complex64)
    data_chunks = tuple((chunks['row'], chan, corr))
    dask_data = da.from_array(np_data, chunks=data_chunks)
    # Create dask ddid column
    dask_ddid = da.full(row, 1, chunks=chunks['row'], dtype=np.int32)
    dataset = Dataset({
        'DATA': (dims, dask_data),
        'DATA_DESC_ID': (("row", ), dask_ddid),

    # Write it
    writes = xds_to_table(ms_datasets, filename, ["DATA", "DATA_DESC_ID"])

    # Rows have been added and additional data is present
    with pt.table(filename, ack=False, readonly=True) as T:
        first_data_desc_id = da.full(row,
        ds_data = da.concatenate(
            [ms_datasets[0].DATA.data, ms_datasets[1].DATA.data])
        ds_ddid = da.concatenate(
            [first_data_desc_id, ms_datasets[1].DATA_DESC_ID.data])
        assert_array_equal(T.getcol("DATA"), ds_data)
        assert_array_equal(T.getcol("DATA_DESC_ID"), ds_ddid)
def test_write_table_proxy_keyword(ms):
    datasets = xds_from_ms(ms)

    # Test that we get a TableProxy if requested
    writes, tp = xds_to_table(datasets, ms, [], table_proxy=True)
    assert isinstance(writes, list) and isinstance(writes[0], Dataset)
    assert isinstance(tp, TableProxy)
    assert tp.nrows().result() == 10

    writes = xds_to_table(datasets, ms, [], table_proxy=False)
    assert isinstance(writes, list) and isinstance(writes[0], Dataset)
Exemple #5
def predict(args):
    # Convert source data into dask arrays
    sky_model = parse_sky_model(args.sky_model, args.model_chunks)

    # Get the support tables
    tables = support_tables(

    field_ds = tables["FIELD"]
    ddid_ds = tables["DATA_DESCRIPTION"]
    spw_ds = tables["SPECTRAL_WINDOW"]
    pol_ds = tables["POLARIZATION"]

    # List of write operations
    writes = []

    # Construct a graph for each DATA_DESC_ID
    for xds in xds_from_ms(args.ms,
                           columns=["UVW", "ANTENNA1", "ANTENNA2", "TIME"],
                           group_cols=["FIELD_ID", "DATA_DESC_ID"],
                           chunks={"row": args.row_chunks}):

        # Extract frequencies from the spectral window associated
        # with this data descriptor id
        field = field_ds[xds.attrs['FIELD_ID']]
        ddid = ddid_ds[xds.attrs['DATA_DESC_ID']]
        spw = spw_ds[ddid.SPECTRAL_WINDOW_ID.data[0]]
        pol = pol_ds[ddid.POLARIZATION_ID.data[0]]

        # Select single dataset row out
        corrs = pol.NUM_CORR.data[0]

        _, time_index = da.unique(xds.TIME.data, return_inverse=True)

        # Generate visibility expressions for each source type
        source_vis = [
            vis_factory(args, stype, sky_model, time_index, xds, field, spw,
                        pol) for stype in sky_model.keys()

        # Sum visibilities together
        vis = sum(source_vis)

        # Reshape (2, 2) correlation to shape (4,)
        if corrs == 4:
            vis = vis.reshape(vis.shape[:2] + (4, ))

        # Assign visibilities to MODEL_DATA array on the dataset
        xds = xds.assign(MODEL_DATA=(("row", "chan", "corr"), vis))
        # Create a write to the table
        write = xds_to_table(xds, args.ms, ['MODEL_DATA'])
        # Add to the list of writes

    # Submit all graph computations in parallel
    with ProgressBar():
Exemple #6
    def write_component_model(self, comps, ref_freq, mask, row_chunks, chan_chunks):
        print("Writing model data at full freq resolution")
        order, npix = comps.shape
        comps = da.from_array(comps, chunks=(-1, -1))
        mask = da.from_array(mask.squeeze(), chunks=(-1, -1))
        writes  = []
        for ims in self.ms:
            xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                              chunks={'row':(row_chunks,), 'chan':(chan_chunks,)},
                              columns=('MODEL_DATA', 'UVW'))

            # subtables
            ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
            fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
            spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__")
            pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

            # subtable data
            ddids = dask.compute(ddids)[0]
            fields = dask.compute(fields)[0]
            pols = dask.compute(pols)[0]

            out_data = []
            for ds in xds:
                field = fields[ds.FIELD_ID]
                radec = field.PHASE_DIR.data.squeeze()
                if not np.array_equal(radec, self.radec):

                spw = spws[ds.DATA_DESC_ID]
                freq = spw.CHAN_FREQ.data.squeeze()
                freq_bin_idx = da.arange(0, freq.size, 1, chunks=freq.chunks, dtype=np.int64)
                freq_bin_counts = da.ones(freq.size, chunks=freq.chunks, dtype=np.int64)

                uvw = ds.UVW.data

                model_vis = getattr(ds, 'MODEL_DATA').data

                model = model_from_comps(comps, freq, mask, ref_freq)
                vis = im2vis(uvw, freq, model, freq_bin_idx, freq_bin_counts,
                             self.cell, nthreads=self.nthreads, epsilon=self.epsilon,

                model_vis = populate_model(vis, model_vis)
                out_ds = ds.assign(**{self.model_column: (("row", "chan", "corr"),
            writes.append(xds_to_table(out_data, ims, columns=[self.model_column]))
        dask.compute(writes, scheduler='single-threaded')
Exemple #7
    def write_model(self, x):
        print("Writing model data")
        x = da.from_array(x.astype(np.float32), chunks=(1, self.nx, self.ny))
        writes  = []
        for ims in self.ms:
            xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                              columns=('MODEL_DATA', 'UVW'))

            # subtables
            ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
            fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
            spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__")
            pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

            # subtable data
            ddids = dask.compute(ddids)[0]
            fields = dask.compute(fields)[0]
            spws = dask.compute(spws)[0]
            pols = dask.compute(pols)[0]

            out_data = []
            for ds in xds:
                field = fields[ds.FIELD_ID]
                radec = field.PHASE_DIR.data.squeeze()
                if not np.array_equal(radec, self.radec):

                spw = ds.DATA_DESC_ID  # this is not correct, need to use spw

                freq_bin_idx = self.freq_bin_idx[ims][spw]
                freq_bin_counts = self.freq_bin_counts[ims][spw]
                freq = self.freq[ims][spw]

                uvw = ds.UVW.data

                model_vis = getattr(ds, 'MODEL_DATA').data

                bands = self.band_mapping[ims][spw]
                model = x[list(bands), :, :]
                vis = im2vis(uvw, freq, model, freq_bin_idx, freq_bin_counts,
                             self.cell, nthreads=self.nthreads, epsilon=self.epsilon,

                model_vis = populate_model(vis, model_vis)
                out_ds = ds.assign(**{self.model_column: (("row", "chan", "corr"),
            writes.append(xds_to_table(out_data, ims, columns=[self.model_column]))
        dask.compute(writes, scheduler='single-threaded')
Exemple #8
    def execute(self):
        """ Execute the application """
        logger.info("xova version {v}", v=xova_version)
        logger.info("xova {args}", args=" ".join(self.cmdline_args))

        self.args = args = parse_args(self.cmdline_args)

        if args.command == "check":
            logger.info("{ms} is conformant".format(ms=args.ms))


        row_chunks, time_chunks, interval_chunks = self._derive_row_chunking(

        (main_ds, spw_ds, ddid_ds, field_ds,
         subtables) = self._input_datasets(args, row_chunks)

        # Set up Main MS data averaging
        if args.command == "timechannel":
            output_ds = average_main(main_ds,

            spw_ds = average_spw(spw_ds, args.chan_bin_size)
        elif args.command == "bda":
            output_ds = bda_average_main(main_ds, field_ds, ddid_ds, spw_ds,

            output_ds, spw_ds, out_ddid_ds = bda_average_spw(
                output_ds, ddid_ds, spw_ds)
            raise ValueError("Invalid command %s" % args.command)

        main_writes = xds_to_table(output_ds,

        spw_table = "::".join((args.output, "SPECTRAL_WINDOW"))
        spw_writes = xds_to_table(spw_ds, spw_table, "ALL")

        if args.command == "bda":
            ddid_table = "::".join((args.output, "DATA_DESCRIPTION"))
            ddid_writes = xds_to_table(out_ddid_ds, ddid_table, "ALL")
            ddid_writes = None

        copy_subtables(args.ms, args.output, subtables)

        self._execute_graph(main_writes, spw_writes, ddid_writes)
        if not args.average_uvw_coordinates:
                "Applying approximation to uvw coordinates as you requested - "
                "the spatial frequencies of your long baseline data may be "
                "serverely affected!")
def main(args):
    Flags outliers in data given a model and rescale weights so that whitened residuals have a
    mean amplitude of sqrt(2). 
    Flags and weights are computed per chunk of data
    radec_ref = None
    writes = []
    for ims in args.ms:
        xds = xds_from_ms(ims,
                          group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                              "row": args.row_chunks,
                              "chan": args.chan_chunks
                          columns=('UVW', args.data_column, args.weight_column,
                                   args.model_column, args.flag_column,

        # subtables
        ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
        fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
        spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__")
        pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

        # subtable data
        ddids = dask.compute(ddids)[0]
        fields = dask.compute(fields)[0]
        spws = dask.compute(spws)[0]
        pols = dask.compute(pols)[0]

        out_data = []
        for ds in xds:
            field = fields[ds.FIELD_ID]
            radec = field.PHASE_DIR.data.squeeze()

            # check fields match
            if radec_ref is None:
                radec_ref = radec

            if not np.array_equal(radec, radec_ref):

            # load in data and compute whitened residuals
            data = getattr(ds, args.data_column).data
            model = getattr(ds, args.model_column).data
            flag = getattr(ds, args.flag_column).data
            flag = da.logical_or(flag, ds.FLAG_ROW.data[:, None, None])
            weights = getattr(ds, args.weight_column).data
            if len(weights.shape) < 3:
                weights = da.broadcast_to(weights[:, None, :],

            if args.trim_channels:
                flag = trim_chans(flag, args.trim_channels)

            # Stokes I vis
            weights = (~flag) * weights
            resid_vis = (data - model) * weights
            wsums = (weights[:, :, 0] + weights[:, :, -1])
            resid_vis_I = da.where(
                wsums, (resid_vis[:, :, 0] + resid_vis[:, :, -1]) / wsums,

            # whiten and take abs
            white_resid = resid_vis_I * da.sqrt(wsums)
            abs_resid_vis_I = (white_resid).__abs__()

            # mean amp
            sum_amp = da.sum(abs_resid_vis_I)
            count = da.sum(wsums > 0)
            mean_amp = sum_amp / count

            flag_legacy = flag[:, :, 0] | flag[:, :, -1]
            flag_I = da.logical_or(abs_resid_vis_I > args.sigma_cut * mean_amp,

            # new flags
            updated_flag = da.broadcast_to(flag_I[:, :, None],

            # scale weights (whitened residuals should have mean amplitude of 1/sqrt(2))
            if args.scale_weights:
                # recompute mean amp with new flags
                weights = (~updated_flag) * weights
                resid_vis = (data - model) * weights
                wsums = (weights[:, :, 0] + weights[:, :, -1])
                resid_vis_I = da.where(
                    wsums, (resid_vis[:, :, 0] + resid_vis[:, :, -1]) / wsums,
                white_resid = resid_vis_I * da.sqrt(wsums)
                abs_resid_vis_I = (white_resid).__abs__()
                sum_amp = da.sum(abs_resid_vis_I)
                count = da.sum(wsums > 0)
                mean_amp = sum_amp / count
                updated_weight = 2**0.5 * weights / mean_amp**2
                updated_weight = weights

            ds = ds.assign(**{
                args.weight_out_column: (("row", "chan", "corr"),
            ds = ds.assign(**{
                args.flag_out_column: (("row", "chan", "corr"), updated_flag)

                columns=[args.flag_out_column, args.weight_out_column]))

    with ProgressBar():

    # report new mean amp
    if args.report_means:
        radec_ref = None
        mean_amps = []
        for ims in args.ms:
            xds = xds_from_ms(
                group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                    "row": args.row_chunks,
                    "chan": args.chan_chunks
                columns=('UVW', args.data_column, args.weight_out_column,
                         args.model_column, args.flag_out_column, 'FLAG_ROW'))

            # subtables
            ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
            fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
            spws = xds_from_table(ims + "::SPECTRAL_WINDOW",
            pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

            # subtable data
            ddids = dask.compute(ddids)[0]
            fields = dask.compute(fields)[0]
            spws = dask.compute(spws)[0]
            pols = dask.compute(pols)[0]

            for ds in xds:
                field = fields[ds.FIELD_ID]
                radec = field.PHASE_DIR.data.squeeze()

                # check fields match
                if radec_ref is None:
                    radec_ref = radec

                if not np.array_equal(radec, radec_ref):

                # load in data and compute whitened residuals
                data = getattr(ds, args.data_column).data
                model = getattr(ds, args.model_column).data
                flag = getattr(ds, args.flag_out_column).data
                flag = da.logical_or(flag, ds.FLAG_ROW.data[:, None, None])
                weights = getattr(ds, args.weight_out_column).data
                if len(weights.shape) < 3:
                    weights = da.broadcast_to(weights[:, None, :],

                # Stokes I vis
                weights = (~flag) * weights
                resid_vis = (data - model) * weights
                wsums = (weights[:, :, 0] + weights[:, :, -1])
                resid_vis_I = da.where(
                    wsums, (resid_vis[:, :, 0] + resid_vis[:, :, -1]) / wsums,

                # whiten and take abs
                white_resid = resid_vis_I * da.sqrt(wsums)
                abs_resid_vis_I = (white_resid).__abs__()

                # mean amp
                sum_amp = da.sum(abs_resid_vis_I)
                count = da.sum(wsums > 0)
                mean_amps.append(sum_amp / count)

        mean_amps = dask.compute(mean_amps)[0]

Exemple #10
def ms_create(ms_table_name, info, ant_pos, vis_array, baselines, timestamps, pol_feeds, sources):
    ''' Create a Measurement Set from some TART observations
    ms_table_name : string
        The name of the MS top level directory. I think this only workds in 
        the local directory.
    info : JSON
        "info": {
            "info": {
                "L0_frequency": 1571328000.0,
                "bandwidth": 2500000.0,
                "baseband_frequency": 4092000.0,
                "location": {
                    "alt": 270.0,
                    "lat": -45.85177,
                    "lon": 170.5456
                "name": "Signal Hill - Dunedin",
                "num_antenna": 24,
                "operating_frequency": 1575420000.0,
                "sampling_frequency": 16368000.0


    epoch_s = timestamp_to_ms_epoch(timestamps)
    LOGGER.info("Time {}".format(epoch_s))

        loc = info['location']
        loc = info
    # Sort out the coordinate frames using astropy
    # https://casa.nrao.edu/casadocs/casa-5.4.1/reference-material/coordinate-frames
    iers.conf.iers_auto_url = 'https://astroconda.org/aux/astropy_mirror/iers_a_1/finals2000A.all' 
    iers.conf.auto_max_age = None 

    location = EarthLocation.from_geodetic(lon=loc['lon']*u.deg,
    obstime = Time(timestamps)
    local_frame = AltAz(obstime=obstime, location=location)

    phase_altaz = SkyCoord(alt=90.0*u.deg, az=0.0*u.deg, obstime = obstime, frame = 'altaz', location = location)
    phase_j2000 = phase_altaz.transform_to('fk5')

    # Get the stokes enums for the polarization types
    corr_types = [[MS_STOKES_ENUMS[p_f] for p_f in pol_feeds]]

    LOGGER.info("Pol Feeds {}".format(pol_feeds))
    LOGGER.info("Correlation Types {}".format(corr_types))
    num_freq_channels = [1]

    ant_table = MSTable(ms_table_name, 'ANTENNA')
    feed_table = MSTable(ms_table_name, 'FEED')
    field_table = MSTable(ms_table_name, 'FIELD')
    pol_table = MSTable(ms_table_name, 'POLARIZATION')
    obs_table = MSTable(ms_table_name, 'OBSERVATION')
    # SOURCE is an optional MS sub-table
    src_table = MSTable(ms_table_name, 'SOURCE')
    ddid_table_name = "::".join((ms_table_name, "DATA_DESCRIPTION"))
    spw_table_name = "::".join((ms_table_name, "SPECTRAL_WINDOW"))

    ms_datasets = []
    ddid_datasets = []
    spw_datasets = []

    # Create ANTENNA dataset
    # Each column in the ANTENNA has a fixed shape so we
    # can represent all rows with one dataset
    num_ant = len(ant_pos)
    position = da.asarray(ant_pos)
    diameter = da.ones(num_ant) * 0.025
    offset = da.zeros((num_ant, 3))
    names = np.array(['ANTENNA-%d' % i for i in range(num_ant)], dtype=np.object)
    stations = np.array([info['name'] for i in range(num_ant)], dtype=np.object)

    dataset = Dataset({
        'POSITION': (("row", "xyz"), position),
        'OFFSET': (("row", "xyz"), offset),
        'DISH_DIAMETER': (("row",), diameter),
        'NAME': (("row",), da.from_array(names, chunks=num_ant)),
        'STATION': (("row",), da.from_array(stations, chunks=num_ant)),

    ###################  Create a FEED dataset. ###################################
    # There is one feed per antenna, so this should be quite similar to the ANTENNA
    num_pols = len(pol_feeds)
    pol_types = pol_feeds
    pol_responses = [POL_RESPONSES[ct] for ct in pol_feeds]

    LOGGER.info("Pol Types {}".format(pol_types))
    LOGGER.info("Pol Responses {}".format(pol_responses))

    antenna_ids = da.asarray(range(num_ant))
    feed_ids = da.zeros(num_ant)
    num_receptors = da.zeros(num_ant) + num_pols
    polarization_types = np.array([pol_types for i in range(num_ant)], dtype=np.object)
    receptor_angles = np.array([[0.0] for i in range(num_ant)])
    pol_response = np.array([pol_responses for i in range(num_ant)])

    beam_offset = np.array([[[0.0, 0.0]] for i in range(num_ant)])

    dataset = Dataset({
        'ANTENNA_ID': (("row",), antenna_ids),
        'FEED_ID': (("row",), feed_ids),
        'NUM_RECEPTORS': (("row",), num_receptors),
        'POLARIZATION_TYPE': (("row", "receptors",),
                              da.from_array(polarization_types, chunks=num_ant)),
        'RECEPTOR_ANGLE': (("row", "receptors",),
                           da.from_array(receptor_angles, chunks=num_ant)),
        'POL_RESPONSE': (("row", "receptors", "receptors-2"),
                         da.from_array(pol_response, chunks=num_ant)),
        'BEAM_OFFSET': (("row", "receptors", "radec"),
                        da.from_array(beam_offset, chunks=num_ant)),

    ####################### FIELD dataset #########################################
    direction = [[phase_j2000.ra.radian, phase_j2000.dec.radian]]
    field_direction = da.asarray(direction)[None, :]
    field_name = da.asarray(np.asarray(['up'], dtype=np.object), chunks=1)
    field_num_poly = da.zeros(1) # Zero order polynomial in time for phase center.

    dir_dims = ("row", 'field-poly', 'field-dir',)

    dataset = Dataset({
        'PHASE_DIR': (dir_dims, field_direction),
        'DELAY_DIR': (dir_dims, field_direction),
        'REFERENCE_DIR': (dir_dims, field_direction),
        'NUM_POLY': (("row", ), field_num_poly),
        'NAME': (("row", ), field_name),

   ######################### OBSERVATION dataset #####################################

    dataset = Dataset({
        'TELESCOPE_NAME': (("row",), da.asarray(np.asarray(['TART'], dtype=np.object), chunks=1)),
        'OBSERVER': (("row",), da.asarray(np.asarray(['Tim'], dtype=np.object), chunks=1)),
        "TIME_RANGE": (("row","obs-exts"), da.asarray(np.array([[epoch_s, epoch_s+1]]), chunks=1)),

    ######################## SOURCE datasets ########################################
    for src in sources:
        name = src['name']
        # Convert to J2000 
        dir_altaz = SkyCoord(alt=src['el']*u.deg, az=src['az']*u.deg, obstime = obstime,
                             frame = 'altaz', location = location)
        dir_j2000 = dir_altaz.transform_to('fk5')
        direction = [dir_j2000.ra.radian, dir_j2000.dec.radian]
        #LOGGER.info("SOURCE: {}, timestamp: {}".format(name, timestamps))
        dask_num_lines = da.full((1,), 1, dtype=np.int32)
        dask_direction = da.asarray(direction)[None, :]
        dask_name = da.asarray(np.asarray([name], dtype=np.object), chunks=1)
        dask_time = da.asarray(np.array([epoch_s]))
        dataset = Dataset({
            "NUM_LINES": (("row",), dask_num_lines),
            "NAME": (("row",), dask_name),
            "TIME": (("row",), dask_time),
            "DIRECTION": (("row", "dir"), dask_direction),

    # Create POLARISATION datasets.
    # Dataset per output row required because column shapes are variable

    for corr_type in corr_types:
        corr_prod = [[i, i] for i in range(len(corr_type))]

        corr_prod = np.array(corr_prod)
        LOGGER.info("Corr Prod {}".format(corr_prod))
        LOGGER.info("Corr Type {}".format(corr_type))

        dask_num_corr = da.full((1,), len(corr_type), dtype=np.int32)
        LOGGER.info("NUM_CORR {}".format(dask_num_corr))
        dask_corr_type = da.from_array(corr_type,
                                       chunks=len(corr_type))[None, :]
        dask_corr_product = da.asarray(corr_prod)[None, :]
        LOGGER.info("Dask Corr Prod {}".format(dask_corr_product.shape))
        LOGGER.info("Dask Corr Type {}".format(dask_corr_type.shape))
        dataset = Dataset({
            "NUM_CORR": (("row",), dask_num_corr),
            "CORR_TYPE": (("row", "corr"), dask_corr_type),
            "CORR_PRODUCT": (("row", "corr", "corrprod_idx"), dask_corr_product),


    # Create multiple SPECTRAL_WINDOW datasets
    # Dataset per output row required because column shapes are variable

    for num_chan in num_freq_channels:
        dask_num_chan = da.full((1,), num_chan, dtype=np.int32)
        dask_chan_freq = da.asarray([[info['operating_frequency']]])
        dask_chan_width = da.full((1, num_chan), 2.5e6/num_chan)

        dataset = Dataset({
            "NUM_CHAN": (("row",), dask_num_chan),
            "CHAN_FREQ": (("row", "chan"), dask_chan_freq),
            "CHAN_WIDTH": (("row", "chan"), dask_chan_width),
            "EFFECTIVE_BW": (("row", "chan"), dask_chan_width),
            "RESOLUTION": (("row", "chan"), dask_chan_width),


    # For each cartesian product of SPECTRAL_WINDOW and POLARIZATION
    # create a corresponding DATA_DESCRIPTION.
    # Each column has fixed shape so we handle all rows at once
    spw_ids, pol_ids = zip(*product(range(len(num_freq_channels)),
    dask_spw_ids = da.asarray(np.asarray(spw_ids, dtype=np.int32))
    dask_pol_ids = da.asarray(np.asarray(pol_ids, dtype=np.int32))
        "SPECTRAL_WINDOW_ID": (("row",), dask_spw_ids),
        "POLARIZATION_ID": (("row",), dask_pol_ids),

    # Now create the associated MS dataset

    #vis_data, baselines = cal_vis.get_all_visibility()
    #vis_array = np.array(vis_data, dtype=np.complex64)
    chunks = {
        "row": (vis_array.shape[0],),
    baselines = np.array(baselines)
    #LOGGER.info(f"baselines {baselines}")
    bl_pos = np.array(ant_pos)[baselines]
    uu_a, vv_a, ww_a = -(bl_pos[:, 1] - bl_pos[:, 0]).T #/constants.L1_WAVELENGTH
    # Use the - sign to get the same orientation as our tart projections.

    uvw_array = np.array([uu_a, vv_a, ww_a]).T

    for ddid, (spw_id, pol_id) in enumerate(zip(spw_ids, pol_ids)):
        # Infer row, chan and correlation shape
        #LOGGER.info("ddid:{} ({}, {})".format(ddid, spw_id, pol_id))
        row = sum(chunks['row'])
        chan = spw_datasets[spw_id].CHAN_FREQ.shape[1]
        corr = pol_table.datasets[pol_id].CORR_TYPE.shape[1]

        # Create some dask vis data
        dims = ("row", "chan", "corr")
        LOGGER.info("Data size %s %s %s" % (row, chan, corr))

        #np_data = vis_array.reshape((row, chan, corr))
        np_data = np.zeros((row, chan, corr), dtype=np.complex128)
        for i in range(corr):
            np_data[:, :, i] = vis_array.reshape((row, chan))
        #np_data = np.array([vis_array.reshape((row, chan, 1)) for i in range(corr)])
        np_uvw = uvw_array.reshape((row, 3))

        data_chunks = tuple((chunks['row'], chan, corr))
        dask_data = da.from_array(np_data, chunks=data_chunks)
        flag_categories = da.from_array(0.05*np.ones((row, chan, corr, 1)))
        flag_data = np.zeros((row, chan, corr), dtype=np.bool_)

        uvw_data = da.from_array(np_uvw)
        # Create dask ddid column
        dask_ddid = da.full(row, ddid, chunks=chunks['row'], dtype=np.int32)
        dataset = Dataset({
            'DATA': (dims, dask_data),
            'FLAG': (dims, da.from_array(flag_data)),
            'TIME': (("row", "corr"), da.from_array(epoch_s*np.ones((row, corr)))),
            'TIME_CENTROID': (("row", "corr"), da.from_array(epoch_s*np.ones((row, corr)))),
            'WEIGHT': (("row", "corr"), da.from_array(0.95*np.ones((row, corr)))),
            'WEIGHT_SPECTRUM': (dims, da.from_array(0.95*np.ones_like(np_data, dtype=np.float64))),
            'SIGMA_SPECTRUM': (dims, da.from_array(np.ones_like(np_data, dtype=np.float64)*0.05)),
            'SIGMA': (("row", "corr"), da.from_array(0.05*np.ones((row, corr)))),
            'UVW': (("row", "uvw",), uvw_data),
            'FLAG_CATEGORY': (('row', 'flagcat', 'chan', 'corr'), flag_categories), # {'dims': ('flagcat', 'chan', 'corr')}
            'ANTENNA1': (("row",), da.from_array(baselines[:, 0])),
            'ANTENNA2': (("row",), da.from_array(baselines[:, 1])),
            'FEED1': (("row",), da.from_array(baselines[:, 0])),
            'FEED2': (("row",), da.from_array(baselines[:, 1])),
            'DATA_DESC_ID': (("row",), dask_ddid)

    ms_writes = xds_to_table(ms_datasets, ms_table_name, columns="ALL")
    spw_writes = xds_to_table(spw_datasets, spw_table_name, columns="ALL")
    ddid_writes = xds_to_table(ddid_datasets, ddid_table_name, columns="ALL")



Exemple #11
 def write(self):
     writes = xds_to_table(self.datasets, self.table_name, columns="ALL")
Exemple #12
ncomps = idx_nz.size
model_predict = da.from_array(model_predict, chunks=(ncomps, nchan, ncorr))
lm = da.from_array(lm[idx_nz, :], chunks=(ncomps, 2))
ms_freqs = spw_ds.CHAN_FREQ.data

xds = xds_from_ms(args.ms, columns=["UVW", args.colname],
                  chunks={"row": args.row_chunks})[0]
uvw = xds.UVW.data
vis = im_to_vis(model_predict, uvw, lm, ms_freqs)

data = getattr(xds, args.colname)
if data.shape != vis.shape:
    print("Assuming only Stokes I passed in")
    if vis.shape[-1] == 1 and data.shape[-1] == 4:
        tmp_zero = da.zeros(vis.shape, chunks=(args.row_chunks, nchan, 1))
        vis = da.concatenate((vis, tmp_zero, tmp_zero, vis), axis=-1)
    elif vis.shape[-1] == 1 and data.shape[-1] == 2:
        vis = da.concatenate((vis, vis), axis=-1)
        raise ValueError("Incompatible corr axes")
    vis = vis.rechunk((args.row_chunks, nchan, data.shape[-1]))

# Assign visibilities to MODEL_DATA array on the dataset
xds = xds.assign(**{args.colname: (("row", "chan", "corr"), vis)})
# Create a write to the table
write = xds_to_table(xds, args.ms, [args.colname])

# Submit all graph computations in parallel
with ProgressBar():
Exemple #13
def _predict(ms, stack, **kw):
    args = OmegaConf.create(kw)
    OmegaConf.set_struct(args, True)
    pyscilog.log_to_file(args.output_filename + '.log')

    # number of threads per worker
    if args.nthreads is None:
        if args.host_address is not None:
            raise ValueError(
                "You have to specify nthreads when using a distributed scheduler"
        import multiprocessing
        nthreads = multiprocessing.cpu_count()
        args.nthreads = nthreads
        nthreads = args.nthreads

    if args.mem_limit is None:
        if args.host_address is not None:
            raise ValueError(
                "You have to specify mem-limit when using a distributed scheduler"
        import psutil
        mem_limit = int(psutil.virtual_memory()[0] /
                        1e9)  # 100% of memory by default
        args.mem_limit = mem_limit
        mem_limit = args.mem_limit

    nband = args.nband
    if args.nworkers is None:
        nworkers = nband
        args.nworkers = nworkers
        nworkers = args.nworkers

    if args.nthreads_per_worker is None:
        nthreads_per_worker = 1
        args.nthreads_per_worker = nthreads_per_worker
        nthreads_per_worker = args.nthreads_per_worker

    # the number of chunks being read in simultaneously is equal to
    # the number of dask threads
    nthreads_dask = nworkers * nthreads_per_worker

    if args.ngridder_threads is None:
        if args.host_address is not None:
            ngridder_threads = nthreads // nthreads_per_worker
            ngridder_threads = nthreads // nthreads_dask
        args.ngridder_threads = ngridder_threads
        ngridder_threads = args.ngridder_threads

    ms = list(ms)
    print('Input Options:', file=log)
    for key in kw.keys():
        print('     %25s = %s' % (key, args[key]), file=log)

    # numpy imports have to happen after this step
    from pfb import set_client
    set_client(nthreads, mem_limit, nworkers, nthreads_per_worker,
               args.host_address, stack, log)

    import numpy as np
    from pfb.utils.misc import chan_to_band_mapping
    import dask
    from dask.distributed import performance_report
    from dask.graph_manipulation import clone
    from daskms import xds_from_storage_ms as xds_from_ms
    from daskms import xds_from_storage_table as xds_from_table
    from daskms.utils import dataset_type
    mstype = dataset_type(ms[0])
    if mstype == 'casa':
        from daskms import xds_to_table
    elif mstype == 'zarr':
        from daskms.experimental.zarr import xds_to_zarr as xds_to_table
    import dask.array as da
    from africanus.constants import c as lightspeed
    from africanus.gridding.wgridder.dask import model as im2vis
    from pfb.utils.fits import load_fits
    from pfb.utils.misc import restore_corrs, plan_row_chunk
    from astropy.io import fits

    # always returns 4D
    # gridder expects freq axis
    model = np.atleast_3d(load_fits(args.model).squeeze())
    nband, nx, ny = model.shape
    hdr = fits.getheader(args.model)
    cell_d = np.abs(hdr['CDELT1'])
    cell_rad = np.deg2rad(cell_d)

    # chan <-> band mapping
    freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping(
        ms, nband=nband)

    # degridder memory budget
    max_chan_chunk = 0
    for ims in ms:
        for spw in freqs[ims]:
            counts = freq_bin_counts[ims][spw].compute()
            max_chan_chunk = np.maximum(max_chan_chunk, counts.max())

    # assumes number of correlations are the same across MS/SPW
    xds = xds_from_ms(ms[0])
    ncorr = xds[0].dims['corr']
    nrow = xds[0].dims['row']
    if args.output_type is not None:
        output_type = np.dtype(args.output_type)
        output_type = np.result_type(np.dtype(args.real_type), np.complex64)
    data_bytes = output_type.itemsize
    bytes_per_row = max_chan_chunk * ncorr * data_bytes
    memory_per_row = bytes_per_row  # model
    memory_per_row += 3 * 8  # uvw

    if mstype == 'zarr':
        if args.model_column in xds[0].keys():
            model_chunks = getattr(xds[0], args.model_column).data.chunks
            model_chunks = xds[0].DATA.data.chunks
            print('Chunking model same as data')

    # get approx image size
    # this is not a conservative estimate when multiple SPW's map to a single
    # imaging band
    pixel_bytes = np.dtype(args.output_type).itemsize
    band_size = nx * ny * pixel_bytes

    if args.host_address is None:
        # full image on single node
        row_chunk = plan_row_chunk(mem_limit / nworkers, band_size, nrow,
                                   memory_per_row, nthreads_per_worker)

        # single band per node
        row_chunk = plan_row_chunk(mem_limit, band_size, nrow, memory_per_row,

    if args.row_chunks is not None:
        row_chunk = int(args.row_chunks)
        if row_chunk == -1:
            row_chunk = nrow

        "nrows = %i, row chunks set to %i for a total of %i chunks per node" %
        (nrow, row_chunk, int(np.ceil(nrow / row_chunk))),

    chunks = {}
    for ims in ms:
        chunks[ims] = []  # xds_from_ms expects a list per ds
        for spw in freqs[ims]:
                'row': row_chunk,
                'chan': chan_chunks[ims][spw]['chan']

    model = da.from_array(model.astype(args.real_type),
                          chunks=(1, nx, ny),
    writes = []
    radec = None  # assumes we are only imaging field 0 of first MS
    for ims in ms:
        xds = xds_from_ms(ims, chunks=chunks[ims], columns=('UVW'))

        # subtables
        ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
        fields = xds_from_table(ims + "::FIELD")
        spws = xds_from_table(ims + "::SPECTRAL_WINDOW")
        pols = xds_from_table(ims + "::POLARIZATION")

        # subtable data
        ddids = dask.compute(ddids)[0]
        fields = dask.compute(fields)[0]
        spws = dask.compute(spws)[0]
        pols = dask.compute(pols)[0]

        out_data = []
        for ds in xds:
            field = fields[ds.FIELD_ID]
            radec = field.PHASE_DIR.data.squeeze()

            # check fields match
            if radec is None:
                radec = field.PHASE_DIR.data.squeeze()

            if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()):

            spw = ds.DATA_DESC_ID  # this is not correct, need to use spw

            uvw = clone(ds.UVW.data)

            bands = band_mapping[ims][spw]
            model = model[list(bands), :, :]
            vis = im2vis(uvw,

            model_vis = restore_corrs(vis, ncorr)
            if mstype == 'zarr':
                model_vis = model_vis.rechunk(model_chunks)
                uvw = uvw.rechunk((model_chunks[0], 3))

            out_ds = ds.assign(
                    args.model_column: (("row", "chan", "corr"), model_vis),
                    'UVW': (("row", "three"), uvw)
            # out_ds = ds.assign(**{args.model_column: (("row", "chan", "corr"), model_vis)})

        writes.append(xds_to_table(out_data, ims, columns=[args.model_column]))

                   filename=args.output_filename + '_predict_graph.pdf',

    if not args.mock:
        with performance_report(filename=args.output_filename +
            dask.compute(writes, optimize_graph=False)

    print("All done here.", file=log)
def predict(args):
    # Convert source data into dask arrays
    sky_model = parse_sky_model(args.sky_model, args.model_chunks)

    # Get the support tables
    tables = support_tables(args)

    ant_ds = tables["ANTENNA"]
    field_ds = tables["FIELD"]
    ddid_ds = tables["DATA_DESCRIPTION"]
    spw_ds = tables["SPECTRAL_WINDOW"]
    pol_ds = tables["POLARIZATION"]

    # List of write operations
    writes = []

    # Construct a graph for each DATA_DESC_ID
    for xds in xds_from_ms(
            columns=["UVW", "ANTENNA1", "ANTENNA2", "TIME"],
            group_cols=["FIELD_ID", "DATA_DESC_ID"],
            chunks={"row": args.row_chunks},

        # Perform subtable joins
        ant = ant_ds[0]
        field = field_ds[xds.attrs["FIELD_ID"]]
        ddid = ddid_ds[xds.attrs["DATA_DESC_ID"]]
        spw = spw_ds[ddid.SPECTRAL_WINDOW_ID.data[0]]
        pol = pol_ds[ddid.POLARIZATION_ID.data[0]]

        # Select single dataset row out
        corrs = pol.NUM_CORR.data[0]

        # Generate visibility expressions for each source type
        source_vis = [
            vis_factory(args, stype, sky_model, xds, ant, field, spw, pol)
            for stype in sky_model.keys()

        # Sum visibilities together
        vis = sum(source_vis)

        # Reshape (2, 2) correlation to shape (4,)
        if corrs == 4:
            vis = vis.reshape(vis.shape[:2] + (4, ))

        # Assign visibilities to MODEL_DATA array on the dataset
        xds = (xds.assign(MODEL_DATA=(("row", "chan", "corr"),
                                      vis)) if args.data_column == "MODEL_DATA"
               else xds.assign(CORRECTED_DATA=(("row", "chan", "corr"), vis)))

        # Create a write to the table
        write = xds_to_table(xds, args.ms, [args.data_column])

        # Add to the list of writes

    # Submit all graph computations in parallel
    with ProgressBar():
Exemple #15
def _predict(args):
    # get inclusion regions
    include_regions = load_regions(args.within) if args.within else []

    # Import source data from WSClean component list
    # See https://sourceforge.net/p/wsclean/wiki/ComponentList
    (comp_type, radec, stokes, spec_coeff, ref_freq, log_spec_ind,
     gaussian_shape) = import_from_wsclean(args.sky_model,
                                           num=args.num_sources or None)

    # Add output column if it isn't present
    ms_rows, ms_datatype = ms_preprocess(args)

    # Get the support tables
    tables = support_tables(

    field_ds = tables["FIELD"]
    ddid_ds = tables["DATA_DESCRIPTION"]
    spw_ds = tables["SPECTRAL_WINDOW"]
    pol_ds = tables["POLARIZATION"]

    max_num_chan = max([ss.NUM_CHAN.data[0] for ss in spw_ds])
    max_num_corr = max([ss.NUM_CORR.data[0] for ss in pol_ds])

    # Perform resource budgeting
    args.row_chunks, args.model_chunks = get_budget(comp_type.shape[0],
                                                    ms_rows, max_num_chan,
                                                    max_num_corr, ms_datatype,

    radec = da.from_array(radec, chunks=(args.model_chunks, 2))
    stokes = da.from_array(stokes, chunks=(args.model_chunks, 4))

    if np.count_nonzero(comp_type == 'GAUSSIAN') > 0:
        gaussian_components = True
        gshape_chunks = (args.model_chunks, 3)
        gaussian_shape = da.from_array(gaussian_shape, chunks=gshape_chunks)
        gaussian_components = False

    if args.spectra:
        spec_chunks = (args.model_chunks, spec_coeff.shape[1])
        spec_coeff = da.from_array(spec_coeff, chunks=spec_chunks)
        ref_freq = da.from_array(ref_freq, chunks=(args.model_chunks, ))

    # List of write operations
    writes = []

    # Construct a graph for each FIELD and DATA DESCRIPTOR
    datasets = xds_from_ms(args.ms,
                           columns=["UVW", "ANTENNA1", "ANTENNA2", "TIME"],
                           group_cols=["FIELD_ID", "DATA_DESC_ID"],
                           chunks={"row": args.row_chunks})

    select_fields = valid_field_ids(field_ds, args.fields)

    for xds in filter_datasets(datasets, select_fields):
        # Extract frequencies from the spectral window associated
        # with this data descriptor id
        field = field_ds[xds.attrs['FIELD_ID']]
        ddid = ddid_ds[xds.attrs['DATA_DESC_ID']]
        spw = spw_ds[ddid.SPECTRAL_WINDOW_ID.data[0]]
        pol = pol_ds[ddid.POLARIZATION_ID.data[0]]
        frequency = spw.CHAN_FREQ.data[0]

        corrs = pol.NUM_CORR.values

        lm = radec_to_lm(radec, field.PHASE_DIR.data[0][0])

        if args.exp_sign_convention == 'casa':
            uvw = -xds.UVW.data
        elif args.exp_sign_convention == 'thompson':
            uvw = xds.UVW.data
            raise ValueError("Invalid sign convention '%s'" % args.sign)

        if args.spectra:
            # flux density at reference frequency ...
            # ... for logarithmic polynomial functions
            if log_spec_ind:
                Is = da.log(stokes[:, 0, None]) * frequency[None, :]**0
            # ... or for ordinary polynomial functions
                Is = stokes[:, 0, None] * frequency[None, :]**0
            # additional terms of SED ...
            for jj in range(spec_coeff.shape[1]):
                # ... for logarithmic polynomial functions
                if log_spec_ind:
                    Is += spec_coeff[:, jj, None] * \
                        da.log((frequency[None, :]/ref_freq[:, None])**(jj+1))
                # ... or for ordinary polynomial functions
                    Is += spec_coeff[:, jj, None] * \
                        (frequency[None, :]/ref_freq[:, None]-1)**(jj+1)
            if log_spec_ind:
                Is = da.exp(Is)
            Qs = da.zeros_like(Is)
            Us = da.zeros_like(Is)
            Vs = da.zeros_like(Is)
            # stack along new axis and make it the last axis of the new array
            spectrum = da.stack([Is, Qs, Us, Vs], axis=-1)
            spectrum = spectrum.rechunk(spectrum.chunks[:2] +
                                        (spectrum.shape[2], ))

        print('Nr sources        = {0:d}'.format(stokes.shape[0]))
        print('stokes.shape      = {0:}'.format(stokes.shape))
        print('frequency.shape   = {0:}'.format(frequency.shape))
        if args.spectra:
            print('Is.shape          = {0:}'.format(Is.shape))
        if args.spectra:
            print('spectrum.shape    = {0:}'.format(spectrum.shape))

        # (source, row, frequency)
        phase = phase_delay(lm, uvw, frequency)
        # If at least one Gaussian component is present in the component
        # list then all sources are modelled as Gaussian components
        # (Delta components have zero width)
        if gaussian_components:
            phase *= gaussian(uvw, frequency, gaussian_shape)
        # (source, frequency, corr_products)
        brightness = convert(spectrum if args.spectra else stokes,
                             ["I", "Q", "U", "V"], corr_schema(pol))

        print('brightness.shape  = {0:}'.format(brightness.shape))
        print('phase.shape       = {0:}'.format(phase.shape))
        print('Attempting phase-brightness einsum with "{0:s}"'.format(
            einsum_schema(pol, args.spectra)))

        # (source, row, frequency, corr_products)
        jones = da.einsum(einsum_schema(pol, args.spectra), phase, brightness)
        print('jones.shape       = {0:}'.format(jones.shape))
        if gaussian_components:
            print('Some Gaussian sources found')
            print('All sources are Delta functions')

        # Identify time indices
        _, time_index = da.unique(xds.TIME.data, return_inverse=True)

        # Predict visibilities
        vis = predict_vis(time_index, xds.ANTENNA1.data, xds.ANTENNA2.data,
                          None, jones, None, None, None, None)

        # Reshape (2, 2) correlation to shape (4,)
        if corrs == 4:
            vis = vis.reshape(vis.shape[:2] + (4, ))

        # Assign visibilities to MODEL_DATA array on the dataset
        xds = xds.assign(
            **{args.output_column: (("row", "chan", "corr"), vis)})
        # Create a write to the table
        write = xds_to_table(xds, args.ms, [args.output_column])
        # Add to the list of writes

    with ExitStack() as stack:
        if sys.stdout.isatty():
            # Default progress bar in user terminal
            # Log progress every 5 minutes
            stack.enter_context(ProgressBar(minimum=2 * 60, dt=5))

        # Submit all graph computations in parallel
def simulate(args):
    # get full time column and compute row chunks
    ms = table(args.ms)
    time = ms.getcol('TIME')
    row_chunks, tbin_idx, tbin_counts = chunkify_rows(time,
    # convert to dask arrays
    tbin_idx = da.from_array(tbin_idx, chunks=(args.utimes_per_chunk))
    tbin_counts = da.from_array(tbin_counts, chunks=(args.utimes_per_chunk))
    n_time = tbin_idx.size
    ant1 = ms.getcol('ANTENNA1')
    ant2 = ms.getcol('ANTENNA2')
    n_ant = np.maximum(ant1.max(), ant2.max()) + 1
    flag = ms.getcol("FLAG")
    n_row, n_freq, n_corr = flag.shape
    if n_corr == 4:
        model_corr = (2, 2)
        jones_corr = (2, )
    elif n_corr == 2:
        model_corr = (2, )
        jones_corr = (2, )
    elif n_corr == 1:
        model_corr = (1, )
        jones_corr = (1, )
        raise RuntimeError("Invalid number of correlations")

    # get phase dir
    radec0 = table(args.ms + '::FIELD').getcol('PHASE_DIR').squeeze()

    # get freqs
    freq = table(args.ms + '::SPECTRAL_WINDOW').getcol('CHAN_FREQ')[0].astype(
    assert freq.size == n_freq

    # get source coordinates from lsm
    lsm = Tigger.load(args.sky_model)
    radec = []
    stokes = []
    spi = []
    ref_freqs = []

    for source in lsm.sources:
        radec.append([source.pos.ra, source.pos.dec])
        tmp_spec = source.spectrum
        spi.append([tmp_spec.spi if tmp_spec is not None else 0.0])
        ref_freqs.append([tmp_spec.freq0 if tmp_spec is not None else 1.0])

    n_dir = len(stokes)
    radec = np.asarray(radec)
    lm = radec_to_lm(radec, radec0)

    # load in the model file
    model = np.zeros((n_freq, n_dir) + model_corr)
    stokes = np.asarray(stokes)
    ref_freqs = np.asarray(ref_freqs)
    spi = np.asarray(spi)
    for d in range(n_dir):
        Stokes_I = stokes[d] * (freq / ref_freqs[d])**spi[d]
        if n_corr == 4:
            model[:, d, 0, 0] = Stokes_I
            model[:, d, 1, 1] = Stokes_I
        elif n_corr == 2:
            model[:, d, 0] = Stokes_I
            model[:, d, 1] = Stokes_I
            model[:, d, 0] = Stokes_I

    # append antenna columns
    cols = []

    # load in gains
    jones, alphas = make_screen(lm, freq, n_time, n_ant, jones_corr[0])
    jones = jones.astype(np.complex128)
    jones_shape = jones.shape
    jones_da = da.from_array(jones,
                             chunks=(args.utimes_per_chunk, ) +

    freqs = da.from_array(freq, chunks=(n_freq))
    lm = da.from_array(np.tile(lm[None], (n_time, 1, 1)),
                       chunks=(args.utimes_per_chunk, n_dir, 2))
    # change model to dask array
    tmp_shape = (n_time, )
    for i in range(len(model.shape)):
        tmp_shape += (1, )
    model = da.from_array(np.tile(model[None], tmp_shape),
                          chunks=(args.utimes_per_chunk, ) + model.shape)

    # load data in in chunks and apply gains to each chunk
    xds = xds_from_ms(args.ms, columns=cols, chunks={"row": row_chunks})[0]
    ant1 = xds.ANTENNA1.data
    ant2 = xds.ANTENNA2.data
    uvw = xds.UVW.data

    # apply gains
    data = compute_and_corrupt_vis(tbin_idx, tbin_counts, ant1, ant2, jones_da,
                                   model, uvw, freqs, lm)

    # Assign visibilities to args.out_col and write to ms
    xds = xds.assign(
            args.out_col: (("row", "chan", "corr"),
                           data.reshape(n_row, n_freq, n_corr))
    # Create a write to the table
    write = xds_to_table(xds, args.ms, [args.out_col])

    # Submit all graph computations in parallel
    with ProgressBar():

    return jones, alphas
ant2 = xds.ANTENNA2.data

model = []
for col in model_cols:
    model.append(getattr(xds, col).data)
model = da.stack(model, axis=2).rechunk({2: 3})

# reshape the correlation axis
if model.shape[-1] > 2:
    n_row, n_chan, n_dir, n_corr = model.shape
    model = model.reshape(n_row, n_chan, n_dir, 2, 2)
    reshape_vis = True
    reshape_vis = False

# apply gains
corrupted_data = corrupt_vis(tbin_idx, tbin_counts, ant1, ant2,
                             jones, model)

if reshape_vis:
    corrupted_data = corrupted_data.reshape(n_row, n_chan, n_corr)

# Assign visibilities to args.out_col and write to ms
xds = xds.assign(**{args.out_col: (("row", "chan", "corr"), corrupted_data)})
# Create a write to the table
write = xds_to_table(xds, args.ms, [args.out_col])

# Submit all graph computations in parallel
with ProgressBar():
Exemple #18
def both(args):
    """Generate model data, corrupted visibilities and 
    gains (phase-only or normal)"""
    # Set thread count to cpu count
    if args.ncpu:
        from multiprocessing.pool import ThreadPool
        import dask
        import multiprocessing
        args.ncpu = multiprocessing.cpu_count()

    # Get full time column and compute row chunks
    ms = xds_from_table(args.ms)[0]
    row_chunks, tbin_idx, tbin_counts = chunkify_rows(
        ms.TIME, args.utimes_per_chunk)
    # Convert time rows to dask arrays
    tbin_idx = da.from_array(tbin_idx, 
    tbin_counts = da.from_array(tbin_counts, 

    # Time axis
    n_time = tbin_idx.size

    # Get antenna columns
    ant1 = ms.ANTENNA1.data
    ant2 = ms.ANTENNA2.data

    # No. of antennas axis
    n_ant = (np.maximum(ant1.max(), ant2.max()) + 1).compute()

    # Get flag column
    flag = ms.FLAG.data

    # Get convention
    if args.phase_convention == 'CASA':
        uvw = -(ms.UVW.data.astype(np.float64))
    elif args.phase_convention == 'CODEX':
        uvw = ms.UVW.data.astype(np.float64)
        raise ValueError("Unknown sign convention for phase")
    # Get rest of dimensions
    n_row, n_freq, n_corr = flag.shape

    # Raise error if correlation axis too small
    if n_corr != 4:
        raise NotImplementedError("Only 4 correlations "\
            + "currently supported")

    # Get phase direction
    radec0_table = xds_from_table(args.ms+'::FIELD')[0]
    radec0 = radec0_table.PHASE_DIR.data.squeeze().compute()

    # Get frequency column
    freq_table = xds_from_table(args.ms+'::SPECTRAL_WINDOW')[0]
    freq = freq_table.CHAN_FREQ.data.astype(np.float64)[0]

    # Check dimension
    assert freq.size == n_freq

    # Check for sky-model
    if args.sky_model == 'MODEL-1.txt':
        args.sky_model = MODEL_1
    elif args.sky_model == 'MODEL-4.txt':
        args.sky_model = MODEL_4
    elif args.sky_model == 'MODEL-50.txt':
        args.sky_model = MODEL_50
        raise NotImplemented(f"Sky-model {args.sky_model} not in "\
            + "kalcal/datasets/sky_model/")

    # Build source model from lsm
    lsm = Tigger.load(args.sky_model)

    # Direction axis
    n_dir = len(lsm.sources)

    # Create initial model array
    model = np.zeros((n_dir, n_freq, n_corr), dtype=np.float64)

    # Create initial coordinate array and source names
    lm = np.zeros((n_dir, 2), dtype=np.float64)
    source_names = []

    # Cycle coordinates creating a source with flux
    for d, source in enumerate(lsm.sources):
        # Extract name

        # Extract position
        radec_s = np.array([[source.pos.ra, source.pos.dec]])
        lm[d] =  radec_to_lm(radec_s, radec0)

        # Get flux - Stokes I
        if source.flux.I:
            I0 = source.flux.I

            # Get spectrum (only spi currently supported)
            tmp_spec = source.spectrum
            spi = [tmp_spec.spi if tmp_spec is not None else 0.0]
            ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0]

            # Generate model flux
            model[d, :, 0] = I0 * (freq/ref_freq)**spi

        # Get flux - Stokes Q
        if source.flux.Q:
            Q0 = source.flux.Q

            # Get spectrum
            tmp_spec = source.spectrum
            spi = [tmp_spec.spi if tmp_spec is not None else 0.0]
            ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0]

            # Generate model flux
            model[d, :, 1] = Q0 * (freq/ref_freq)**spi

        # Get flux - Stokes U
        if source.flux.U:
            U0 = source.flux.U

            # Get spectrum
            tmp_spec = source.spectrum
            spi = [tmp_spec.spi if tmp_spec is not None else 0.0]
            ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0]

            # Generate model flux
            model[d, :, 2] = U0 * (freq/ref_freq)**spi

        # Get flux - Stokes V
        if source.flux.V:
            V0 = source.flux.V

            # Get spectrum
            tmp_spec = source.spectrum
            spi = [tmp_spec.spi if tmp_spec is not None else 0.0]
            ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0]

            # Generate model flux
            model[d, :, 3] = V0 * (freq/ref_freq)**spi

    # Generate gains
    jones = None
    jones_shape = None

    # Dask to NP
    t = tbin_idx.compute()
    nu = freq.compute()

    print('==> Both-mode')
    if args.mode == "phase":
        jones = phase_gains(lm, nu, n_time, n_ant, args.alpha_std)

    elif args.mode == "normal":
        jones = normal_gains(t, nu, lm, n_ant, n_corr, 
                    args.sigma_f, args.lt, args.lnu, args.ls)
        raise ValueError("Only normal and phase modes available.")
    # Reduce jones to diagonals only
    jones = jones[:, :, :, :, (0, -1)]

    # Jones to complex
    jones = jones.astype(np.complex128)

    # Jones shape
    jones_shape = jones.shape

    # Generate filename
    if args.out == "":
        args.out = f"{args.mode}.npy"

    # Save gains and settings to file
    with open(args.out, 'wb') as file:        
        np.save(file, jones)

    # Build dask graph
    lm = da.from_array(lm, chunks=lm.shape)
    model = da.from_array(model, chunks=model.shape)
    jones_da = da.from_array(jones, chunks=(args.utimes_per_chunk,)
                            + jones_shape[1::])    

    # Append antenna columns
    cols = []

    # Load data in in chunks and apply gains to each chunk
    xds = xds_from_ms(args.ms, columns=cols, 
            chunks={"row": row_chunks})[0]
    ant1 = xds.ANTENNA1.data
    ant2 = xds.ANTENNA2.data

    # Adjust UVW based on phase-convention
    if args.phase_convention == 'CASA':
        uvw = -xds.UVW.data.astype(np.float64)       
    elif args.phase_convention == 'CODEX':
        uvw = xds.UVW.data.astype(np.float64)
        raise ValueError("Unknown sign convention for phase")

    # Get model visibilities
    model_vis = np.zeros((n_row, n_freq, n_dir, n_corr), 

    for s in range(n_dir):
        model_vis[:, :, s] = im_to_vis(
            model[s].reshape((1, n_freq, n_corr)),
            lm[s].reshape((1, 2)), 
            dtype=np.complex64, convention='fourier')

    # NP to Dask
    model_vis = da.from_array(model_vis, chunks=(row_chunks, 
                                n_freq, n_dir, n_corr))

    # Convert Stokes to corr
    in_schema = ['I', 'Q', 'U', 'V']
    out_schema = [['RR', 'RL'], ['LR', 'LL']]
    model_vis = convert(model_vis, in_schema, out_schema)

    # Apply gains
    data = corrupt_vis(tbin_idx, tbin_counts, ant1, ant2,
                    jones_da, model_vis).reshape(
                        (n_row, n_freq, n_corr))
    # Assign model visibilities
    out_names = []
    for d in range(n_dir):
        xds = xds.assign(**{source_names[d]: 
                (("row", "chan", "corr"), 
                model_vis[:, :, d].reshape(
                    n_row, n_freq, n_corr).astype(np.complex64))})

        out_names += [source_names[d]]

    # Assign noise free visibilities to 'CLEAN_DATA'
    xds = xds.assign(**{'CLEAN_DATA': (("row", "chan", "corr"), 

    out_names += ['CLEAN_DATA']
    # Get noise realisation
    if args.sigma_n > 0.0:

        # Noise matrix
        noise = (da.random.normal(loc=0.0, scale=args.sigma_n, 
                    size=(n_row, n_freq, n_corr), 
                    chunks=(row_chunks, n_freq, n_corr)) \

                + 1.0j*da.random.normal(loc=0.0, scale=args.sigma_n, 
                    size=(n_row, n_freq, n_corr), 
                    chunks=(row_chunks, n_freq, n_corr)))/np.sqrt(2.0)

        # Zero matrix for off-diagonals
        zero = da.zeros_like(noise[:, :, 0])

        # Dask to NP
        noise = noise.compute()
        zero = zero.compute()

        # Remove noise on off-diagonals
        noise[:, :, 1] = zero[:, :]
        noise[:, :, 2] = zero[:, :]

        # NP to Dask
        noise = da.from_array(noise, chunks=(row_chunks, n_freq, n_corr))
        # Assign noise to 'NOISE'
        xds = xds.assign(**{'NOISE': (("row", "chan", "corr"), 

        out_names += ['NOISE']

        # Add noise to data and assign to 'DATA'
        noisy_data = data + noise
        xds = xds.assign(**{'DATA': (("row", "chan", "corr"), 

        out_names += ['DATA']
    # Create a write to the table
    write = xds_to_table(xds, args.ms, out_names)

    # Submit all graph computations in parallel
    with ProgressBar():

    print(f"==> Applied Jones to MS: {args.ms} <--> {args.out}")
Exemple #19
def _main(args):
    tic = time.time()


    if args.disable_post_mortem:
        log.warn("Disabling crash debugging with the "
                 "Interactive Python Debugger, as per user request")

    log.info("Flagging on the {0:s} column".format(args.data_column))
    data_column = args.data_column
    masked_channels = [
        load_mask(fn, dilate=args.dilate_masks) for fn in collect_masks()
    GD = args.config


    # Group datasets by these columns
    group_cols = ["FIELD_ID", "DATA_DESC_ID", "SCAN_NUMBER"]
    # Index datasets by these columns
    index_cols = ['TIME']

    # Reopen the datasets using the aggregated row ordering
    columns = [data_column, "FLAG", "TIME", "ANTENNA1", "ANTENNA2"]

    if args.subtract_model_column is not None:

    xds = list(
                    chunks={"row": args.row_chunks}))

    # Get support tables
    st = support_tables(args.ms)
    ddid_ds = st["DATA_DESCRIPTION"]
    field_ds = st["FIELD"]
    pol_ds = st["POLARIZATION"]
    spw_ds = st["SPECTRAL_WINDOW"]
    ant_ds = st["ANTENNA"]

    assert len(ant_ds) == 1
    assert len(ddid_ds) == 1

    antspos = ant_ds[0].POSITION.data
    antsnames = ant_ds[0].NAME.data
    fieldnames = [fds.NAME.data[0] for fds in field_ds]

    avail_scans = [ds.SCAN_NUMBER for ds in xds]
    args.scan_numbers = list(
        set(avail_scans).intersection(args.scan_numbers if args.scan_numbers
                                      is not None else avail_scans))

    if args.scan_numbers != []:
        log.info("Only considering scans '{0:s}' as "
                 "per user selection criterion".format(", ".join(
                     map(str, map(int, args.scan_numbers)))))

    if args.field_names != []:
        flatten_field_names = []
        for f in args.field_names:
            # accept comma lists per specification
            flatten_field_names += [x.strip() for x in f.split(",")]
        for f in flatten_field_names:
            if re.match(r"^\d+$", f) and int(f) < len(fieldnames):
        flatten_field_names = list(
                filter(lambda x: not re.match(r"^\d+$", x),
        log.info("Only considering fields '{0:s}' for flagging per "
                 "user "
                 "selection criterion.".format(", ".join(flatten_field_names)))
        if not set(flatten_field_names) <= set(fieldnames):
            raise ValueError("One or more fields cannot be "
                             "found in dataset '{0:s}' "
                             "You specified {1:s}, but "
                             "only {2:s} are available".format(
                                 args.ms, ",".join(flatten_field_names),

        field_dict = {fieldnames.index(fn): fn for fn in flatten_field_names}
        field_dict = {i: fn for i, fn in enumerate(fieldnames)}

    # List which hold our dask compute graphs for each dataset
    write_computes = []
    original_stats = []
    final_stats = []

    # Iterate through each dataset
    for ds in xds:
        if ds.FIELD_ID not in field_dict:

        if (args.scan_numbers is not None
                and ds.SCAN_NUMBER not in args.scan_numbers):

        log.info("Adding field '{0:s}' scan {1:d} to "
                 "compute graph for processing".format(field_dict[ds.FIELD_ID],

        ddid = ddid_ds[ds.attrs['DATA_DESC_ID']]
        spw_info = spw_ds[ddid.SPECTRAL_WINDOW_ID.data[0]]
        pol_info = pol_ds[ddid.POLARIZATION_ID.data[0]]

        nrow, nchan, ncorr = getattr(ds, data_column).data.shape

        # Visibilities from the dataset
        vis = getattr(ds, data_column).data
        if args.subtract_model_column is not None:
            log.info("Forming residual data between '{0:s}' and "
                     "'{1:s}' for flagging.".format(
                         data_column, args.subtract_model_column))
            vismod = getattr(ds, args.subtract_model_column).data
            vis = vis - vismod

        antenna1 = ds.ANTENNA1.data
        antenna2 = ds.ANTENNA2.data
        chan_freq = spw_info.CHAN_FREQ.data[0]
        chan_width = spw_info.CHAN_WIDTH.data[0]

        # Generate unflagged defaults if we should ignore existing flags
        # otherwise take flags from the dataset
        if args.ignore_flags is True:
            flags = da.full_like(vis, False, dtype=np.bool)
            log.critical("Completely ignoring measurement set "
                         "flags as per '-if' request. "
                         "Strategy WILL NOT or with original flags, even if "
            flags = ds.FLAG.data

        # If we're flagging on polarised intensity,
        # we convert visibilities to polarised intensity
        # and any flagged correlation will flag the entire visibility
        if args.flagging_strategy == "polarisation":
            corr_type = pol_info.CORR_TYPE.data[0].tolist()
            stokes_map = stokes_corr_map(corr_type)
            stokes_pol = tuple(v for k, v in stokes_map.items() if k != "I")
            vis = polarised_intensity(vis, stokes_pol)
            flags = da.any(flags, axis=2, keepdims=True)
        elif args.flagging_strategy == "total_power":
            if args.subtract_model_column is None:
                log.critical("You requested to flag total quadrature "
                             "power, but not on residuals. "
                             "This is not advisable and the flagger "
                             "may mistake fringes of "
                             "off-axis sources for broadband RFI.")
            corr_type = pol_info.CORR_TYPE.data[0].tolist()
            stokes_map = stokes_corr_map(corr_type)
            stokes_pol = tuple(v for k, v in stokes_map.items())
            vis = polarised_intensity(vis, stokes_pol)
            flags = da.any(flags, axis=2, keepdims=True)
        elif args.flagging_strategy == "standard":
            if args.subtract_model_column is None:
                log.critical("You requested to flag per correlation, "
                             "but not on residuals. "
                             "This is not advisable and the flagger "
                             "may mistake fringes of off-axis sources "
                             "for broadband RFI.")
            raise ValueError("Invalid flagging strategy '%s'" %

        ubl = unique_baselines(antenna1, antenna2)
        utime, time_inv = da.unique(ds.TIME.data, return_inverse=True)
        utime, ubl = dask.compute(utime, ubl)
        ubl = ubl.view(np.int32).reshape(-1, 2)
        # Stack the baseline index with the unique baselines
        bl_range = np.arange(ubl.shape[0], dtype=ubl.dtype)[:, None]
        ubl = np.concatenate([bl_range, ubl], axis=1)
        ubl = da.from_array(ubl, chunks=(args.baseline_chunks, 3))

        vis_windows, flag_windows = pack_data(time_inv,

            window_stats(flag_windows, ubl, chan_freq, antsnames,
                         ds.SCAN_NUMBER, field_dict[ds.FIELD_ID],

        with StrategyExecutor(antspos, ubl, chan_freq, chan_width,
                              masked_channels, GD['strategies']) as se:

            flag_windows = se.apply_strategies(flag_windows, vis_windows)

            window_stats(flag_windows, ubl, chan_freq, antsnames,
                         ds.SCAN_NUMBER, field_dict[ds.FIELD_ID],

        # Unpack window data for writing back to the MS
        unpacked_flags = unpack_data(antenna1, antenna2, time_inv, ubl,

        # Flag entire visibility if any correlations are flagged
        equalized_flags = da.sum(unpacked_flags, axis=2, keepdims=True) > 0
        corr_flags = da.broadcast_to(equalized_flags, (nrow, nchan, ncorr))

        if corr_flags.chunks != ds.FLAG.data.chunks:
            raise ValueError("Output flag chunking does not "
                             "match input flag chunking")

        # Create new dataset containing new flags
        new_ds = ds.assign(FLAG=(("row", "chan", "corr"), corr_flags))

        # Write back to original dataset
        writes = xds_to_table(new_ds, args.ms, "FLAG")
        # original should also have .compute called because we need stats

    if len(write_computes) > 0:
        # Combine stats from all datasets
        original_stats = combine_window_stats(original_stats)
        final_stats = combine_window_stats(final_stats)

        with contextlib.ExitStack() as stack:
            # Create dask profiling contexts
            profilers = []

            if can_profile:

            if sys.stdout.isatty():
                # Interactive terminal, default ProgressBar
                # Non-interactive, emit a bar every 5 minutes so
                # as not to spam the log
                stack.enter_context(ProgressBar(minimum=1, dt=5 * 60))

            _, original_stats, final_stats = dask.compute(
                write_computes, original_stats, final_stats)
        if can_profile:

        toc = time.time()

        # Log each summary line
        for line in summarise_stats(final_stats, original_stats):

        elapsed = toc - tic
        log.info("Data flagged successfully in "
                 "{0:02.0f}h{1:02.0f}m{2:02.0f}s".format((elapsed // 60) // 60,
                                                         (elapsed // 60) % 60,
                                                         elapsed % 60))
        log.info("User data selection criteria resulted in empty dataset. "
                 "Nothing to be done. Bye!")
Exemple #20
def _jones2col(**kw):
    args = OmegaConf.create(kw)
    from omegaconf import ListConfig
    if not isinstance(args.ms, list) and not isinstance(args.ms, ListConfig):
        args.ms = [args.ms]
    OmegaConf.set_struct(args, True)

    import numpy as np
    from daskms.experimental.zarr import xds_from_zarr
    from daskms import xds_from_ms, xds_to_table
    import dask.array as da
    import dask
    from africanus.calibration.utils import chunkify_rows
    from africanus.calibration.utils.dask import corrupt_vis

    # get net gains
    G = xds_from_zarr(args.gain_table + '::G')

    # chunking info
    t_chunks = G[0].t_chunk.data
    if len(t_chunks) > 1:
        t_chunks = G[0].t_chunk.data[1:-1] - G[0].t_chunk.data[0:-2]
        assert (t_chunks == t_chunks[0]).all()
        utpc = t_chunks[0]
        utpc = t_chunks[0]
    times = xds_from_ms(args.ms[0], columns=['TIME'])[0].get('TIME').data.compute()
    row_chunks, tbin_idx, tbin_counts = chunkify_rows(times, utimes_per_chunk=utpc, daskify_idx=True)

    f_chunks = G[0].f_chunk.data
    if len(f_chunks) > 1:
        f_chunks = G[0].f_chunk.data[1:-1] - G[0].f_chunk.data[0:-2]
        assert (f_chunks == f_chunks[0]).all()
        chan_chunks = f_chunks[0]
        if f_chunks[0]:
            chan_chunks = f_chunks[0]
            chan_chunks = -1

    columns = ('DATA', 'FLAG', 'FLAG_ROW', 'ANTENNA1', 'ANTENNA2')
    if args.acol is not None:
        columns += (args.acol,)

    # open MS
    xds = xds_from_ms(args.ms[0], chunks={'row': row_chunks, 'chan': chan_chunks},
                      group_cols=('FIELD_ID', 'DATA_DESC_ID', 'SCAN_NUMBER'))

    # Current hack probably only works for single field and DDID
        assert len(xds) == len(G)
    except Exception as e:
        raise ValueError("Number of datasets in gains do not "
                            "match those in MS")

    # assuming scans are aligned
    out_data = []
    for g, ds in zip(G, xds):
            assert g.SCAN_NUMBER == ds.SCAN_NUMBER
        except Exception as e:
            raise ValueError("Scans not aligned")

        nrow = ds.dims['row']
        nchan = ds.dims['chan']
        ncorr = ds.dims['corr']

        # need to swap axes for africanus
        jones = da.swapaxes(g.gains.data, 1, 2)
        flag = ds.FLAG.data
        frow = ds.FLAG_ROW.data
        ant1 = ds.ANTENNA1.data
        ant2 = ds.ANTENNA2.data

        frow = (frow | (ant1 == ant2))
        flag = (flag[:, :, 0] | flag[:, :, -1])
        flag = da.logical_or(flag, frow[:, None])

        if args.acol is not None:
            acol = ds.get(args.acol).data.reshape(nrow, nchan, 1, ncorr)
            acol = da.ones((nrow, nchan, 1, ncorr),
                           chunks=(row_chunks, chan_chunks, 1, -1),

        cvis = corrupt_vis(tbin_idx, tbin_counts, ant1, ant2, jones, acol)

        # compare where unflagged
        if args.compareto is not None:
            flag = flag.compute()
            vis = ds.get(args.compareto).values[~flag]
            print("Max abs difference = ", np.abs(cvis.compute()[~flag] - vis).max())

        out_ds = ds.assign(**{args.mueller_column: (("row", "chan", "corr"), cvis)})

    writes = xds_to_table(out_data, args.ms[0], columns=[args.mueller_column])
Exemple #21
def new(ms, sky_model, gains, **kwargs):
    """Generate model visibilties per source (as direction axis)
    for stokes I and Q and generate relevant visibilities."""

    # Options to attributed dictionary
    if kwargs["yaml"] is not None:
        options = ocf.load(kwargs["yaml"])
        options = ocf.create(kwargs)

    # Set to struct
    ocf.set_struct(options, True)

    # Change path to sky model if chosen
        sky_model = sky_models[sky_model.lower()]
        # Own sky model reference

    # Set thread count to cpu count
    if options.ncpu:
        from multiprocessing.pool import ThreadPool
        import dask
        import multiprocessing
        options.ncpu = multiprocessing.cpu_count()

    # Load gains to corrupt with
    with open(gains, "rb") as file:
        jones = np.load(file)

    # Load dimensions
    n_time, n_ant, n_chan, n_dir, n_corr = jones.shape
    n_row = n_time * (n_ant * (n_ant - 1) // 2)

    # Load ms
    MS = xds_from_ms(ms)[0]

    # Get time-bin indices and counts
    row_chunks, tbin_indices, tbin_counts = chunkify_rows(
        MS.TIME, options.utime)

    # Close and reopen with chunked rows
    MS = xds_from_ms(ms, chunks={"row": row_chunks})[0]

    # Get antenna arrays (dask ignored for now)
    ant1 = MS.ANTENNA1.data
    ant2 = MS.ANTENNA2.data

    # Adjust UVW based on phase-convention
    if options.phase_convention.upper() == 'CASA':
        uvw = -MS.UVW.data.astype(np.float64)
    elif options.phase_convention.upper() == 'CODEX':
        uvw = MS.UVW.data.astype(np.float64)
        raise ValueError("Unknown sign convention for phase.")

    # MS dimensions
    dims = ocf.create(dict(MS.sizes))

    # Close MS

    # Build source model from lsm
    lsm = Tigger.load(sky_model)

    # Check if dimensions match jones
    assert n_time * (n_ant * (n_ant - 1) // 2) == dims.row
    assert n_time == len(tbin_indices)
    assert n_ant == np.max((np.max(ant1), np.max(ant2))) + 1
    assert n_chan == dims.chan
    assert n_corr == dims.corr

    # If gains are DIE
    if options.die:
        assert n_dir == 1
        n_dir = len(lsm.sources)
        assert n_dir == len(lsm.sources)

    # Get phase direction
    radec0_table = xds_from_table(ms + '::FIELD')[0]
    radec0 = radec0_table.PHASE_DIR.data.squeeze().compute()

    # Get frequency column
    freq_table = xds_from_table(ms + '::SPECTRAL_WINDOW')[0]
    freq = freq_table.CHAN_FREQ.data.astype(np.float64)[0]

    # Get feed orientation
    feed_table = xds_from_table(ms + '::FEED')[0]
    feeds = feed_table.POLARIZATION_TYPE.data[0].compute()

    # Create initial model array
    model = np.zeros((n_dir, n_chan, n_corr), dtype=np.float64)

    # Create initial coordinate array and source names
    lm = np.zeros((n_dir, 2), dtype=np.float64)
    source_names = []

    # Cycle coordinates creating a source with flux
    print("==> Building model visibilities")
    for d, source in enumerate(lsm.sources):
        # Extract name

        # Extract position
        radec_s = np.array([[source.pos.ra, source.pos.dec]])
        lm[d] = radec_to_lm(radec_s, radec0)

        # Get flux - Stokes I
        if source.flux.I:
            I0 = source.flux.I

            # Get spectrum (only spi currently supported)
            tmp_spec = source.spectrum
            spi = [tmp_spec.spi if tmp_spec is not None else 0.0]
            ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0]

            # Generate model flux
            model[d, :, 0] = I0 * (freq / ref_freq)**spi

        # Get flux - Stokes Q
        if source.flux.Q:
            Q0 = source.flux.Q

            # Get spectrum
            tmp_spec = source.spectrum
            spi = [tmp_spec.spi if tmp_spec is not None else 0.0]
            ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0]

            # Generate model flux
            model[d, :, 1] = Q0 * (freq / ref_freq)**spi

        # Get flux - Stokes U
        if source.flux.U:
            U0 = source.flux.U

            # Get spectrum
            tmp_spec = source.spectrum
            spi = [tmp_spec.spi if tmp_spec is not None else 0.0]
            ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0]

            # Generate model flux
            model[d, :, 2] = U0 * (freq / ref_freq)**spi

        # Get flux - Stokes V
        if source.flux.V:
            V0 = source.flux.V

            # Get spectrum
            tmp_spec = source.spectrum
            spi = [tmp_spec.spi if tmp_spec is not None else 0.0]
            ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0]

            # Generate model flux
            model[d, :, 3] = V0 * (freq / ref_freq)**spi

    # Close sky-model
    del lsm

    # Build dask graph
    tbin_indices = da.from_array(tbin_indices, chunks=(options.utime))
    tbin_counts = da.from_array(tbin_counts, chunks=(options.utime))
    lm = da.from_array(lm, chunks=lm.shape)
    model = da.from_array(model, chunks=model.shape)
    jones = da.from_array(jones, chunks=(options.utime, ) + jones.shape[1::])

    # Apply image to visibility for each source
    sources = []
    for s in range(n_dir):
        source_vis = im_to_vis(model[s].reshape((1, n_chan, n_corr)),
                               lm[s].reshape((1, 2)),

    model_vis = da.stack(sources, axis=2)

    # Sum over direction?
    if options.die:
        model_vis = da.sum(model_vis, axis=2, keepdims=True)
        n_dir = 1
        source_names = [options.mname]

    # Select schema based on feed orientation
    if (feeds == ["X", "Y"]).all():
        out_schema = [["XX", "XY"], ["YX", "YY"]]
    elif (feeds == ["R", "L"]).all():
        out_schema = [['RR', 'RL'], ['LR', 'LL']]
        raise ValueError("Unknown feed orientation implementation.")

    # Convert Stokes to Correlations
    in_schema = ['I', 'Q', 'U', 'V']
    model_vis = convert(model_vis, in_schema, out_schema).reshape(
        (n_row, n_chan, n_dir, n_corr))

    # Apply gains to model_vis
    print("==> Corrupting visibilities")

    data = corrupt_vis(tbin_indices, tbin_counts, ant1, ant2, jones, model_vis)

    # Reopen MS
    MS = xds_from_ms(ms, chunks={"row": row_chunks})[0]

    # Assign model visibilities
    out_names = []
    for d in range(n_dir):
        MS = MS.assign(
                source_names[d]: (("row", "chan", "corr"),
                                  model_vis[:, :, d].astype(np.complex64))

        out_names += [source_names[d]]

    # Assign noise free visibilities to 'CLEAN_DATA'
    MS = MS.assign(
            'CLEAN_' + options.dname: (("row", "chan", "corr"),

    out_names += ['CLEAN_' + options.dname]

    # Get noise realisation
    if options.std > 0.0:

        # Noise matrix
        print(f"==> Applying noise (std={options.std}) to visibilities")
        noise = []
        for i in range(2):
            real = da.random.normal(loc=0.0,
                                    size=(n_row, n_chan),
                                    chunks=(row_chunks, n_chan))
            imag = 1.0j * (da.random.normal(loc=0.0,
                                            size=(n_row, n_chan),
                                            chunks=(row_chunks, n_chan)))
            noise.append(real + imag)

        # Zero matrix for off-diagonals
        zero = da.zeros((n_row, n_chan), chunks=(row_chunks, n_chan))

        noise.insert(1, zero)
        noise.insert(2, zero)

        # NP to Dask
        noise = da.stack(noise, axis=2).rechunk((row_chunks, n_chan, n_corr))

        # Assign noise to 'NOISE'
        MS = MS.assign(
            **{'NOISE': (("row", "chan", "corr"), noise.astype(np.complex64))})

        out_names += ['NOISE']

        # Add noise to data and assign to 'DATA'
        noisy_data = data + noise

        MS = MS.assign(
                options.dname: (("row", "chan", "corr"),

        out_names += [options.dname]

    # Create a write to the table
    write = xds_to_table(MS, ms, out_names)

    # Submit all graph computations in parallel
    print(f"==> Executing `dask-ms` write to `{ms}` for the following columns: "\
            + f"{', '.join(out_names)}")

    with ProgressBar():

    print(f"==> Completed.")
Exemple #22
def read_ms(ms, num_vis, res_arcmin, chunks=50000, channel=0):
        Use dask-ms to load the necessary data to create a telescope operator
        (will use uvw positions, and antenna positions)
        -- res_arcmin: Used to calculate the maximum baselines to consider.
                       We want two pixels per smallest fringe
                       pix_res > fringe / 2
                       u sin(theta) = n (for nth fringe)
                       at small angles: theta = 1/u, or u_max = 1 / theta
                       d sin(theta) = lambda / 2
                       d / lambda = 1 / (2 sin(theta))
                       u_max = lambda / 2sin(theta)
    with scheduler_context():
        # Create a dataset representing the entire antenna table
        ant_table = '::'.join((ms, 'ANTENNA'))

        for ant_ds in xds_from_table(ant_table):
            ant_p = np.array(ant_ds.POSITION.data)
        logger.info("Antenna Positions {}".format(ant_p.shape))

        # Create a dataset representing the field
        field_table = '::'.join((ms, 'FIELD'))
        for field_ds in xds_from_table(field_table):
            phase_dir = np.array(field_ds.PHASE_DIR.data)[0].flatten()
        logger.info("Phase Dir {}".format(np.degrees(phase_dir)))

        # Create datasets representing each row of the spw table
        spw_table = '::'.join((ms, 'SPECTRAL_WINDOW'))

        for spw_ds in xds_from_table(spw_table, group_cols="__row__"):
            logger.info("CHAN_FREQ.values: {}".format(
            frequencies = dask.compute(spw_ds.CHAN_FREQ.values)[0].flatten()
            frequency = frequencies[channel]
            logger.info("Frequencies = {}".format(frequencies))
            logger.info("Frequency = {}".format(frequency))
            logger.info("NUM_CHAN = %f" % np.array(spw_ds.NUM_CHAN.values)[0])

        # Create datasets from a partioning of the MS
        datasets = list(xds_from_ms(ms, chunks={'row': chunks}))

        pol = 0

        for ds in datasets:
            logger.info("DATA shape: {}".format(ds.DATA.data.shape))
            logger.info("UVW shape: {}".format(ds.UVW.data.shape))

            uvw = np.array(ds.UVW.data)  # UVW is stored in meters!
            ant1 = np.array(ds.ANTENNA1.data)
            ant2 = np.array(ds.ANTENNA2.data)
            flags = np.array(ds.FLAG.data)
            cv_vis = np.array(ds.DATA.data)[:, channel, pol]
            epoch_seconds = np.array(ds.TIME.data)[0]

            # Try write the STATE_ID column back
            write = xds_to_table(ds, ms, 'STATE_ID')
            with ProgressBar(), Profiler() as prof:

            # Profile


        u_max = get_resolution_max_baseline(res_arcmin, frequency)

        logger.info("Resolution Max UVW: {:g}".format(u_max))
        logger.info("Flags: {}".format(flags.shape))

        # Now report the recommended resolution from the data.
        # 1.0 / 2*np.sin(theta) = limit_u
        limit_uvw = np.max(np.abs(uvw), 0)
        res_limit = get_baseline_resolution(limit_uvw[0], frequency)
        logger.info("Nyquist resolution: {:g} arcmin".format(
            np.degrees(res_limit) * 60.0))

        #maxuvw = np.max(np.abs(uvw), 1)
        #logger.info(np.random.choice(maxuvw, 100))

        if False:
            good_data = np.array(np.where(flags[:, channel,
                                                pol] == 0)).T.reshape((-1, ))
            good_data = np.array(
                np.where((flags[:, channel, pol] == 0)
                         & (np.max(np.abs(uvw), 1) < u_max))).T.reshape((-1, ))
        logger.info("Good Data {}".format(good_data.shape))

        logger.info("Maximum UVW: {}".format(limit_uvw))
        logger.info("Minimum UVW: {}".format(np.min(np.abs(uvw), 0)))

        n_ant = len(ant_p)

        good_vis = cv_vis[good_data]

        n_max = len(good_vis)

        indices = np.random.choice(good_data, min(num_vis, n_max))

        hdr = {
            'CTYPE1': ('RA---SIN', "Right ascension angle cosine"),
            'CRVAL1': np.degrees(phase_dir)[0],
            'CUNIT1': 'deg     ',
            'CTYPE2': ('DEC--SIN', "Declination angle cosine "),
            'CRVAL2': np.degrees(phase_dir)[1],
            'CUNIT2': 'deg     ',
            'CTYPE3': 'FREQ    ',  #           / Central frequency  ",
            'CRPIX3': 1.,
            'CRVAL3': "{}".format(frequency),
            'CDELT3': 10026896.158854,
            'CUNIT3': 'Hz      ',
            'EQUINOX': '2000.',
            'DATE-OBS': "{}".format(epoch_seconds),
            'BTYPE': 'Intensity'

        #from astropy.wcs.utils import celestial_frame_to_wcs
        #from astropy.coordinates import FK5
        #frame = FK5(equinox='J2010')
        #wcs = celestial_frame_to_wcs(frame)

        u_arr = uvw[indices, 0]
        v_arr = uvw[indices, 1]
        w_arr = uvw[indices, 2]

        cv_vis = cv_vis[indices]

        # Convert from reduced Julian Date to timestamp.
        timestamp = datetime.datetime(
            1858, 11, 17, 0, 0, 0,
            tzinfo=datetime.timezone.utc) + datetime.timedelta(

        return u_arr, v_arr, w_arr, frequency, cv_vis, hdr, timestamp
Exemple #23
    def compute_weights(self, robust):
        from pfb.utils.weighting import compute_counts, counts_to_weights
        # compute counts
        counts = []
        for ims in self.ms:
            xds = xds_from_ms(ims,
                              group_cols=('FIELD_ID', 'DATA_DESC_ID'),

            # subtables
            ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
            fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
            spws = xds_from_table(ims + "::SPECTRAL_WINDOW",
            pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

            # subtable data
            ddids = dask.compute(ddids)[0]
            fields = dask.compute(fields)[0]
            spws = dask.compute(spws)[0]
            pols = dask.compute(pols)[0]

            for ds in xds:
                field = fields[ds.FIELD_ID]
                radec = field.PHASE_DIR.data.squeeze()
                if not np.array_equal(radec, self.radec):

                spw = ds.DATA_DESC_ID  # not optimal, need to use spw

                freq_bin_idx = self.freq_bin_idx[ims][spw]
                freq_bin_counts = self.freq_bin_counts[ims][spw]
                freq = self.freq[ims][spw]

                uvw = ds.UVW.data

                count = compute_counts(uvw, freq, freq_bin_idx,
                                       freq_bin_counts, self.nx, self.ny,
                                       self.cell, self.cell, np.float32)


        counts = dask.compute(counts)[0]

        counts = accumulate_dirty(counts, self.nband, self.band_mapping)

        counts = da.from_array(counts, chunks=(1, -1, -1))

        # convert counts to weights
        writes = []
        for ims in self.ms:
            xds = xds_from_ms(ims,
                              group_cols=('FIELD_ID', 'DATA_DESC_ID'),

            # subtables
            ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
            fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
            spws = xds_from_table(ims + "::SPECTRAL_WINDOW",
            pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

            # subtable data
            ddids = dask.compute(ddids)[0]
            fields = dask.compute(fields)[0]
            spws = dask.compute(spws)[0]
            pols = dask.compute(pols)[0]

            out_data = []
            for ds in xds:
                field = fields[ds.FIELD_ID]
                radec = field.PHASE_DIR.data.squeeze()
                if not np.array_equal(radec, self.radec):

                spw = ds.DATA_DESC_ID  # this is not correct, need to use spw

                freq_bin_idx = self.freq_bin_idx[ims][spw]
                freq_bin_counts = self.freq_bin_counts[ims][spw]
                freq = self.freq[ims][spw]

                uvw = ds.UVW.data

                weights = counts_to_weights(counts, uvw, freq, freq_bin_idx,
                                            freq_bin_counts, self.nx, self.ny,
                                            self.cell, self.cell, np.float32,

                # hack to get shape and chunking info
                data = getattr(ds, self.data_column).data

                weights = da.broadcast_to(weights[:, :, None],
                out_ds = ds.assign(**{
                    self.imaging_weight_column: (("row", "chan", "corr"),
Exemple #24
def _predict(args):
    import pkg_resources
    version = pkg_resources.get_distribution("crystalball").version
    log.info("Crystalball version {0}", version)

    # get inclusion regions
    include_regions = load_regions(args.within) if args.within else []

    # Import source data from WSClean component list
    # See https://sourceforge.net/p/wsclean/wiki/ComponentList
    source_model = import_from_wsclean(args.sky_model,
                                       num=args.num_sources or None)

    # Add output column if it isn't present
    ms_rows, ms_datatype = ms_preprocess(args)

    # Get the support tables
    tables = support_tables(

    field_ds = tables["FIELD"]
    ddid_ds = tables["DATA_DESCRIPTION"]
    spw_ds = tables["SPECTRAL_WINDOW"]
    pol_ds = tables["POLARIZATION"]

    max_num_chan = max([ss.NUM_CHAN.data[0] for ss in spw_ds])
    max_num_corr = max([ss.NUM_CORR.data[0] for ss in pol_ds])

    # Perform resource budgeting
    nsources = source_model.source_type.shape[0]
    args.row_chunks, args.model_chunks = get_budget(nsources, ms_rows,
                                                    max_num_chan, max_num_corr,
                                                    ms_datatype, args)

    source_model = source_model_to_dask(source_model, args.model_chunks)

    # List of write operations
    writes = []

    datasets = xds_from_ms(args.ms,
                           columns=["UVW", "ANTENNA1", "ANTENNA2", "TIME"],
                           group_cols=["FIELD_ID", "DATA_DESC_ID"],
                           chunks={"row": args.row_chunks})

    field_id = select_field_id(field_ds, args.field)

    for xds in filter_datasets(datasets, field_id):
        # Extract frequencies from the spectral window associated
        # with this data descriptor id
        field = field_ds[xds.attrs['FIELD_ID']]
        ddid = ddid_ds[xds.attrs['DATA_DESC_ID']]
        spw = spw_ds[ddid.SPECTRAL_WINDOW_ID.data[0]]
        pol = pol_ds[ddid.POLARIZATION_ID.data[0]]
        frequency = spw.CHAN_FREQ.data[0]

        lm = radec_to_lm(source_model.radec, field.PHASE_DIR.data[0][0])

        with warnings.catch_warnings():
            # Ignore dask chunk warnings emitted when going from 1D
            # inputs to a 2D space of chunks
            warnings.simplefilter('ignore', category=PerformanceWarning)
            vis = wsclean_predict(xds.UVW.data, lm, source_model.source_type,
                                  source_model.flux, source_model.spi,
                                  source_model.log_poly, source_model.ref_freq,
                                  source_model.gauss_shape, frequency)

        vis = fill_correlations(vis, pol)

        log.info('Field {0} DDID {1:d} rows {2} chans {3} corrs {4}',
                 field.NAME.values[0], xds.DATA_DESC_ID, vis.shape[0],
                 vis.shape[1], vis.shape[2])

        # Assign visibilities to MODEL_DATA array on the dataset
        xds = xds.assign(
            **{args.output_column: (("row", "chan", "corr"), vis)})
        # Create a write to the table
        write = xds_to_table(xds, args.ms, [args.output_column])
        # Add to the list of writes

    with ExitStack() as stack:
        if sys.stdout.isatty():
            # Default progress bar in user terminal
            # Log progress every 5 minutes
            stack.enter_context(EstimatingProgressBar(minimum=2 * 60, dt=5))

        # Submit all graph computations in parallel

Exemple #25
def test_ms_create(Dataset, tmp_path, chunks, num_chans, corr_types, sources):
    # Set up
    rs = np.random.RandomState(42)

    ms_path = tmp_path / "create.ms"

    ms_table_name = str(ms_path)
    ant_table_name = "::".join((ms_table_name, "ANTENNA"))
    ddid_table_name = "::".join((ms_table_name, "DATA_DESCRIPTION"))
    pol_table_name = "::".join((ms_table_name, "POLARIZATION"))
    spw_table_name = "::".join((ms_table_name, "SPECTRAL_WINDOW"))
    # SOURCE is an optional MS sub-table
    src_table_name = "::".join((ms_table_name, "SOURCE"))

    ms_datasets = []
    ant_datasets = []
    ddid_datasets = []
    pol_datasets = []
    spw_datasets = []
    src_datasets = []

    # For comparison
    all_data_desc_id = []
    all_data = []

    # Create ANTENNA dataset of 64 antennas
    # Each column in the ANTENNA has a fixed shape so we
    # can represent all rows with one dataset
    na = 64
    position = da.random.random((na, 3)) * 10000
    offset = da.random.random((na, 3))
    names = np.array(['ANTENNA-%d' % i for i in range(na)], dtype=np.object)
    ds = Dataset({
        'POSITION': (("row", "xyz"), position),
        'OFFSET': (("row", "xyz"), offset),
        'NAME': (("row", ), da.from_array(names, chunks=na)),

    # Create SOURCE datasets
    for s, (name, direction, rest_freq) in enumerate(sources):
        dask_num_lines = da.full((1, ), len(rest_freq), dtype=np.int32)
        dask_direction = da.asarray(direction)[None, :]
        dask_rest_freq = da.asarray(rest_freq)[None, :]
        dask_name = da.asarray(np.asarray([name], dtype=np.object))
        ds = Dataset({
            "NUM_LINES": (("row", ), dask_num_lines),
            "NAME": (("row", ), dask_name),
            "REST_FREQUENCY": (("row", "line"), dask_rest_freq),
            "DIRECTION": (("row", "dir"), dask_direction),

    # Create POLARISATION datasets.
    # Dataset per output row required because column shapes are variable
    for r, corr_type in enumerate(corr_types):
        dask_num_corr = da.full((1, ), len(corr_type), dtype=np.int32)
        dask_corr_type = da.from_array(corr_type,
                                       chunks=len(corr_type))[None, :]
        ds = Dataset({
            "NUM_CORR": (("row", ), dask_num_corr),
            "CORR_TYPE": (("row", "corr"), dask_corr_type),


    # Create multiple MeerKAT L-band SPECTRAL_WINDOW datasets
    # Dataset per output row required because column shapes are variable
    for num_chan in num_chans:
        dask_num_chan = da.full((1, ), num_chan, dtype=np.int32)
        dask_chan_freq = da.linspace(.856e9,
                                     2 * .856e9,
                                     chunks=num_chan)[None, :]
        dask_chan_width = da.full((1, num_chan), .856e9 / num_chan)

        ds = Dataset({
            "NUM_CHAN": (("row", ), dask_num_chan),
            "CHAN_FREQ": (("row", "chan"), dask_chan_freq),
            "CHAN_WIDTH": (("row", "chan"), dask_chan_width),


    # For each cartesian product of SPECTRAL_WINDOW and POLARIZATION
    # create a corresponding DATA_DESCRIPTION.
    # Each column has fixed shape so we handle all rows at once
    spw_ids, pol_ids = zip(
        *product(range(len(num_chans)), range(len(corr_types))))
    dask_spw_ids = da.asarray(np.asarray(spw_ids, dtype=np.int32))
    dask_pol_ids = da.asarray(np.asarray(pol_ids, dtype=np.int32))
            "SPECTRAL_WINDOW_ID": (("row", ), dask_spw_ids),
            "POLARIZATION_ID": (("row", ), dask_pol_ids),

    # Now create the associated MS dataset
    for ddid, (spw_id, pol_id) in enumerate(zip(spw_ids, pol_ids)):
        # Infer row, chan and correlation shape
        row = sum(chunks['row'])
        chan = spw_datasets[spw_id].CHAN_FREQ.shape[1]
        corr = pol_datasets[pol_id].CORR_TYPE.shape[1]

        # Create some dask vis data
        dims = ("row", "chan", "corr")
        np_data = (rs.normal(size=(row, chan, corr)) +
                   1j * rs.normal(size=(row, chan, corr))).astype(np.complex64)

        data_chunks = tuple((chunks['row'], chan, corr))
        dask_data = da.from_array(np_data, chunks=data_chunks)
        # Create dask ddid column
        dask_ddid = da.full(row, ddid, chunks=chunks['row'], dtype=np.int32)
        dataset = Dataset({
            'DATA': (dims, dask_data),
            'DATA_DESC_ID': (("row", ), dask_ddid)

    ms_writes = xds_to_table(ms_datasets, ms_table_name, columns="ALL")
    ant_writes = xds_to_table(ant_datasets, ant_table_name, columns="ALL")
    pol_writes = xds_to_table(pol_datasets, pol_table_name, columns="ALL")
    spw_writes = xds_to_table(spw_datasets, spw_table_name, columns="ALL")
    ddid_writes = xds_to_table(ddid_datasets, ddid_table_name, columns="ALL")
    source_writes = xds_to_table(src_datasets, src_table_name, columns="ALL")

    dask.compute(ms_writes, ant_writes, pol_writes, spw_writes, ddid_writes,

    # Check ANTENNA table correctly created
    with pt.table(ant_table_name, ack=False) as A:
        assert_array_equal(A.getcol("NAME"), names)
        assert_array_equal(A.getcol("POSITION"), position)
        assert_array_equal(A.getcol("OFFSET"), offset)

        required_desc = pt.required_ms_desc("ANTENNA")
        required_columns = set(k for k in required_desc.keys()
                               if not k.startswith("_"))

        assert set(A.colnames()) == set(required_columns)

    # Check POLARIZATION table correctly created
    with pt.table(pol_table_name, ack=False) as P:
        for r, corr_type in enumerate(corr_types):
            assert_array_equal(P.getcol("CORR_TYPE", startrow=r, nrow=1),
            assert_array_equal(P.getcol("NUM_CORR", startrow=r, nrow=1),

        required_desc = pt.required_ms_desc("POLARIZATION")
        required_columns = set(k for k in required_desc.keys()
                               if not k.startswith("_"))

        assert set(P.colnames()) == set(required_columns)

    # Check SPECTRAL_WINDOW table correctly created
    with pt.table(spw_table_name, ack=False) as S:
        for r, num_chan in enumerate(num_chans):
                S.getcol("NUM_CHAN", startrow=r, nrow=1)[0], num_chan)
                S.getcol("CHAN_FREQ", startrow=r, nrow=1)[0],
                np.linspace(.856e9, 2 * .856e9, num_chan))
                S.getcol("CHAN_WIDTH", startrow=r, nrow=1)[0],
                np.full(num_chan, .856e9 / num_chan))

        required_desc = pt.required_ms_desc("SPECTRAL_WINDOW")
        required_columns = set(k for k in required_desc.keys()
                               if not k.startswith("_"))

        assert set(S.colnames()) == set(required_columns)

    # We should get a cartesian product out
    with pt.table(ddid_table_name, ack=False) as D:
        spw_id, pol_id = zip(
            *product(range(len(num_chans)), range(len(corr_types))))
        assert_array_equal(pol_id, D.getcol("POLARIZATION_ID"))
        assert_array_equal(spw_id, D.getcol("SPECTRAL_WINDOW_ID"))

        required_desc = pt.required_ms_desc("DATA_DESCRIPTION")
        required_columns = set(k for k in required_desc.keys()
                               if not k.startswith("_"))

        assert set(D.colnames()) == set(required_columns)

    with pt.table(src_table_name, ack=False) as S:
        for r, (name, direction, rest_freq) in enumerate(sources):
            assert_array_equal(S.getcol("NAME", startrow=r, nrow=1)[0], [name])
            assert_array_equal(S.getcol("REST_FREQUENCY", startrow=r, nrow=1),
            assert_array_equal(S.getcol("DIRECTION", startrow=r, nrow=1),

    with pt.table(ms_table_name, ack=False) as T:
        # DATA_DESC_ID's are all the same shape

        # DATA is variably shaped (on DATA_DESC_ID) so we
        # compared each one separately.
        for ddid, data in enumerate(all_data):
            ms_data = T.getcol("DATA", startrow=ddid * row, nrow=row)
            assert_array_equal(ms_data, data)

        required_desc = pt.required_ms_desc()
        required_columns = set(k for k in required_desc.keys()
                               if not k.startswith("_"))

        # Check we have the required columns
        assert set(T.colnames()) == required_columns.union(
            ["DATA", "DATA_DESC_ID"])
Exemple #26
def ms_create(ms_table_name, info, ant_pos, cal_vis, timestamps, corr_types,
    '''    "info": {
        "info": {
            "L0_frequency": 1571328000.0,
            "bandwidth": 2500000.0,
            "baseband_frequency": 4092000.0,
            "location": {
                "alt": 270.0,
                "lat": -45.85177,
                "lon": 170.5456
            "name": "Signal Hill - Dunedin",
            "num_antenna": 24,
            "operating_frequency": 1575420000.0,
            "sampling_frequency": 16368000.0

    num_chans = [1]

    rs = np.random.RandomState(42)

    ant_table_name = "::".join((ms_table_name, "ANTENNA"))
    feed_table_name = "::".join((ms_table_name, "FEED"))
    field_table_name = "::".join((ms_table_name, "FIELD"))
    obs_table_name = "::".join((ms_table_name, "OBSERVATION"))
    ddid_table_name = "::".join((ms_table_name, "DATA_DESCRIPTION"))
    pol_table_name = "::".join((ms_table_name, "POLARIZATION"))
    spw_table_name = "::".join((ms_table_name, "SPECTRAL_WINDOW"))
    # SOURCE is an optional MS sub-table
    src_table_name = "::".join((ms_table_name, "SOURCE"))

    ms_datasets = []
    ant_datasets = []
    feed_datasets = []
    field_datasets = []
    obs_datasets = []
    ddid_datasets = []
    pol_datasets = []
    spw_datasets = []
    src_datasets = []

    # Create ANTENNA dataset
    # Each column in the ANTENNA has a fixed shape so we
    # can represent all rows with one dataset
    na = len(ant_pos)
    position = da.asarray(ant_pos)  # da.zeros((na, 3))
    diameter = da.ones(na) * 0.025
    offset = da.zeros((na, 3))
    names = np.array(['ANTENNA-%d' % i for i in range(na)], dtype=np.object)
    stations = np.array([info['name'] for i in range(na)], dtype=np.object)

    ds = Dataset({
        'POSITION': (("row", "xyz"), position),
        'OFFSET': (("row", "xyz"), offset),
        'DISH_DIAMETER': (("row", ), diameter),
        'NAME': (("row", ), da.from_array(names, chunks=na)),
        'STATION': (("row", ), da.from_array(stations, chunks=na)),

    #########################################  Create a FEED dataset. #######################################
    # There is one feed per antenna, so this should be quite similar to the ANTENNA
    antenna_ids = da.asarray(range(na))
    feed_ids = da.zeros(na)
    num_receptors = da.ones(na)
    polarization_types = np.array([['XX'] for i in range(na)], dtype=np.object)
    receptor_angles = np.array([[0.0] for i in range(na)])
    pol_response = np.array([[[0.0 + 1.0j, 1.0 - 1.0j]] for i in range(na)])

    beam_offset = np.array([[[0.0, 0.0]] for i in range(na)])

    ds = Dataset({
        'ANTENNA_ID': (("row", ), antenna_ids),
        'FEED_ID': (("row", ), feed_ids),
        'NUM_RECEPTORS': (("row", ), num_receptors),
        ), da.from_array(polarization_types, chunks=na)),
        'RECEPTOR_ANGLE': ((
        ), da.from_array(receptor_angles, chunks=na)),
        'POL_RESPONSE': (("row", "receptors", "receptors-2"),
                         da.from_array(pol_response, chunks=na)),
        (("row", "receptors", "radec"), da.from_array(beam_offset, chunks=na)),

    ########################################### FIELD dataset ################################################
    direction = [[np.radians(90.0), np.radians(0.0)]]  ## Phase Center in J2000
    field_direction = da.asarray(direction)[None, :]
    field_name = da.asarray(np.asarray(['up'], dtype=np.object))
    field_num_poly = da.zeros(
        1)  # Zero order polynomial in time for phase center.

    dir_dims = (

    ds = Dataset({
        'PHASE_DIR': (dir_dims, field_direction),
        'DELAY_DIR': (dir_dims, field_direction),
        'REFERENCE_DIR': (dir_dims, field_direction),
        'NUM_POLY': (("row", ), field_num_poly),
        'NAME': (("row", ), field_name),

    ########################################### OBSERVATION dataset ################################################

    ds = Dataset({
        (("row", ), da.asarray(np.asarray(['TART'], dtype=np.object))),
        'OBSERVER': (("row", ), da.asarray(np.asarray(['Tim'],

    #################################### Create SOURCE datasets #############################################
    for s, src in enumerate(sources):
        name = src['name']
        rest_freq = [info['operating_frequency']]
        direction = [
        ]  ## FIXME these are in elevation and azimuth. Not in J2000.

        #logger.info("SOURCE: {}, timestamp: {}".format(name, timestamps))
        dask_num_lines = da.full((1, ), len(rest_freq), dtype=np.int32)
        dask_direction = da.asarray(direction)[None, :]
        dask_rest_freq = da.asarray(rest_freq)[None, :]
        dask_name = da.asarray(np.asarray([name], dtype=np.object))
        dask_time = da.asarray(np.asarray([timestamps], dtype=np.object))
        ds = Dataset({
            "NUM_LINES": (("row", ), dask_num_lines),
            "NAME": (("row", ), dask_name),
            #"TIME": (("row",), dask_time), # FIXME. Causes an error. Need to sort out TIME data fields
            #"REST_FREQUENCY": (("row", "line"), dask_rest_freq),
            "DIRECTION": (("row", "dir"), dask_direction),

    # Create POLARISATION datasets.
    # Dataset per output row required because column shapes are variable
    for r, corr_type in enumerate(corr_types):
        dask_num_corr = da.full((1, ), len(corr_type), dtype=np.int32)
        dask_corr_type = da.from_array(corr_type,
                                       chunks=len(corr_type))[None, :]
        dask_corr_type = da.from_array(corr_type,
                                       chunks=len(corr_type))[None, :]
        ds = Dataset({
            "NUM_CORR": (("row", ), dask_num_corr),
            #"CORR_PRODUCT": (("row",), dask_num_corr),
            "CORR_TYPE": (("row", "corr"), dask_corr_type),


    # Create multiple SPECTRAL_WINDOW datasets
    # Dataset per output row required because column shapes are variable

    for num_chan in num_chans:
        dask_num_chan = da.full((1, ), num_chan, dtype=np.int32)
        dask_chan_freq = da.asarray([[info['operating_frequency']]])
        dask_chan_width = da.full((1, num_chan), 2.5e6 / num_chan)

        ds = Dataset({
            "NUM_CHAN": (("row", ), dask_num_chan),
            "CHAN_FREQ": (("row", "chan"), dask_chan_freq),
            "CHAN_WIDTH": (("row", "chan"), dask_chan_width),


    # For each cartesian product of SPECTRAL_WINDOW and POLARIZATION
    # create a corresponding DATA_DESCRIPTION.
    # Each column has fixed shape so we handle all rows at once
    spw_ids, pol_ids = zip(
        *product(range(len(num_chans)), range(len(corr_types))))
    dask_spw_ids = da.asarray(np.asarray(spw_ids, dtype=np.int32))
    dask_pol_ids = da.asarray(np.asarray(pol_ids, dtype=np.int32))
            "SPECTRAL_WINDOW_ID": (("row", ), dask_spw_ids),
            "POLARIZATION_ID": (("row", ), dask_pol_ids),

    # Now create the associated MS dataset

    vis_data, baselines = cal_vis.get_all_visibility()
    vis_array = np.array(vis_data, dtype=np.complex64)
    chunks = {
        "row": (vis_array.shape[0], ),
    baselines = np.array(baselines)
    bl_pos = np.array(ant_pos)[baselines]
    uu_a, vv_a, ww_a = -(
        bl_pos[:, 1] - bl_pos[:, 0]
    ).T / constants.L1_WAVELENGTH  # Use the - sign to get the same orientation as our tart projections.

    uvw_array = np.array([uu_a, vv_a, ww_a]).T

    for ddid, (spw_id, pol_id) in enumerate(zip(spw_ids, pol_ids)):
        # Infer row, chan and correlation shape
        logger.info("ddid:{} ({}, {})".format(ddid, spw_id, pol_id))
        row = sum(chunks['row'])
        chan = spw_datasets[spw_id].CHAN_FREQ.shape[1]
        corr = pol_datasets[pol_id].CORR_TYPE.shape[1]

        # Create some dask vis data
        dims = ("row", "chan", "corr")
        logger.info("Data size {}".format((row, chan, corr)))

        np_data = vis_array.reshape((row, chan, corr))
        np_uvw = uvw_array.reshape((row, 3))

        data_chunks = tuple((chunks['row'], chan, corr))
        dask_data = da.from_array(np_data, chunks=data_chunks)

        uvw_data = da.from_array(np_uvw)
        # Create dask ddid column
        dask_ddid = da.full(row, ddid, chunks=chunks['row'], dtype=np.int32)
        dataset = Dataset({
            'DATA': (dims, dask_data),
            (("row", "corr"), da.from_array(0.95 * np.ones((row, corr)))),
             da.from_array(0.95 * np.ones_like(np_data, dtype=np.float64))),
             da.from_array(np.ones_like(np_data, dtype=np.float64) * 0.05)),
            'UVW': ((
            ), uvw_data),
            'ANTENNA1': (("row", ), da.from_array(baselines[:, 0])),
            'ANTENNA2': (("row", ), da.from_array(baselines[:, 1])),
            'FEED1': (("row", ), da.from_array(baselines[:, 0])),
            'FEED2': (("row", ), da.from_array(baselines[:, 1])),
            'DATA_DESC_ID': (("row", ), dask_ddid)

    ms_writes = xds_to_table(ms_datasets, ms_table_name, columns="ALL")
    ant_writes = xds_to_table(ant_datasets, ant_table_name, columns="ALL")
    feed_writes = xds_to_table(feed_datasets, feed_table_name, columns="ALL")
    field_writes = xds_to_table(field_datasets,
    obs_writes = xds_to_table(obs_datasets, obs_table_name, columns="ALL")
    pol_writes = xds_to_table(pol_datasets, pol_table_name, columns="ALL")
    spw_writes = xds_to_table(spw_datasets, spw_table_name, columns="ALL")
    ddid_writes = xds_to_table(ddid_datasets, ddid_table_name, columns="ALL")
    source_writes = xds_to_table(src_datasets, src_table_name, columns="ALL")

Exemple #27
        # Create a dataset representing the entire antenna table
        ant_table = '::'.join((args.ms, 'ANTENNA'))

        for ant_ds in xds_from_table(ant_table):
                dask.compute(ant_ds.NAME.data, ant_ds.POSITION.data,

        # Create datasets representing each row of the spw table
        spw_table = '::'.join((args.ms, 'SPECTRAL_WINDOW'))

        for spw_ds in xds_from_table(spw_table, group_cols="__row__"):

        # Create datasets from a partioning of the MS
        datasets = list(xds_from_ms(args.ms, chunks={'row': args.chunks}))

        for ds in datasets:

            # Try write the STATE_ID column back
            write = xds_to_table(ds, args.ms, 'STATE_ID')
            with ProgressBar(), Profiler() as prof:

            # Profile
Exemple #28
def main(args):
    # get full time column and compute row chunks
    ms = table(args.ms)
    time = ms.getcol('TIME')
    row_chunks, tbin_idx, tbin_counts = chunkify_rows(
        time, args.utimes_per_chunk)
    # convert to dask arrays
    tbin_idx = da.from_array(tbin_idx, chunks=(args.utimes_per_chunk))
    tbin_counts = da.from_array(tbin_counts, chunks=(args.utimes_per_chunk))
    n_time = tbin_idx.size

    # get phase dir
    fld = table(args.ms+'::FIELD')
    radec0 = fld.getcol('PHASE_DIR').squeeze().reshape(1, 2)
    radec0 = np.tile(radec0, (n_time, 1))

    # get freqs
    freqs = table(
    n_freq = freqs.size
    freqs = da.from_array(freqs, chunks=(n_freq))

    # get source coordinates from lsm
    lsm = Tigger.load(args.sky_model)
    radec = []
    stokes = []
    spi = []
    ref_freqs = []

    for source in lsm.sources:
        radec.append([source.pos.ra, source.pos.dec])

    n_dir = len(stokes)
    radec = np.asarray(radec)
    lm = np.zeros((n_time,) + radec.shape)
    for t in range(n_time):
        lm[t] = radec_to_lm(radec, radec0[t])

    lm = da.from_array(lm, chunks=(args.utimes_per_chunk, n_dir, 2))

    # load in the model file
    n_corr = 1
    model = np.zeros((n_time, n_freq, n_dir, n_corr))
    stokes = np.asarray(stokes)
    ref_freqs = np.asarray(ref_freqs)
    spi = np.asarray(spi)
    for t in range(n_time):
        for d in range(n_dir):
            model[t, :, d, 0] = stokes[d] * (freqs/ref_freqs[d])**spi[d]

    # append antenna columns
    cols = []

    # load in gains
    jones = np.load(args.gain_file)
    jones = jones.astype(np.complex128)
    jones_shape = jones.shape
    jones = da.from_array(jones, chunks=(args.utimes_per_chunk,)
                          + jones_shape[1::])

    # change model to dask array
    model = da.from_array(model, chunks=(args.utimes_per_chunk,)
                          + model.shape[1::])

    # load data in in chunks and apply gains to each chunk
    xds = xds_from_ms(args.ms, columns=cols, chunks={"row": row_chunks})[0]
    ant1 = xds.ANTENNA1.data
    ant2 = xds.ANTENNA2.data
    uvw = xds.UVW.data

    # apply gains
    data = compute_and_corrupt_vis(tbin_idx, tbin_counts, ant1, ant2,
                                   jones, model, uvw, freqs, lm)

    # Assign visibilities to args.out_col and write to ms
    xds = xds.assign(**{args.out_col: (("row", "chan", "corr"), data)})
    # Create a write to the table
    write = xds_to_table(xds, args.ms, [args.out_col])

    # Submit all graph computations in parallel
    with ProgressBar():