else:
    model_predict = model

# set up coords for DFT
ll, mm = np.meshgrid(l_coord, m_coord)
lm = np.vstack((ll.flatten(), mm.flatten())).T
lm = da.from_array(lm, chunks=(npix_tot, 2))
model_predict = np.transpose(model_predict.reshape(nchan, ncorr, npix_tot),
                             [2, 0, 1])
model_predict = da.from_array(model_predict, chunks=(npix_tot, nchan, ncorr))
ms_freqs = spw_ds.CHAN_FREQ.data

# do the predict
writes = []
for xds in xds_from_ms(args.ms,
                       columns=["UVW", args.colname],
                       chunks={"row": args.row_chunks}):
    uvw = xds.UVW.data
    vis = im_to_vis(model_predict, uvw, lm, ms_freqs)

    data = getattr(xds, args.colname)
    if data.shape != vis.shape:
        print("Assuming only Stokes I passed in")
        if vis.shape[-1] == 1 and data.shape[-1] == 4:
            tmp_zero = da.zeros(vis.shape, chunks=(args.row_chunks, nchan, 1))
            vis = da.concatenate((vis, tmp_zero, tmp_zero, vis), axis=-1)
        elif vis.shape[-1] == 1 and data.shape[-1] == 2:
            vis = da.concatenate((vis, vis), axis=-1)
        else:
            raise ValueError("Incompatible corr axes")
        vis = vis.rechunk((args.row_chunks, nchan, data.shape[-1]))
Esempio n. 2
0
def create_parser():
    p = argparse.ArgumentParser()
    p.add_argument("ms")
    return p


args = create_parser().parse_args()

# Find unique baselines  and Maximum baseline distace in this Measurement Set
# Maximum baselines uses UVW values
xds = list(
    xds_from_ms(
        args.ms,
        # We only need the antenna and uvw columns
        columns=("UVW", "ANTENNA1", "ANTENNA2"),
        group_cols=[],
        index_cols=[],
        chunks={"row": 1e6}))

# Should only have one dataset
assert len(xds) == 1
# The unique baseline for one scan is same for every scan in the Measurement Set
ds = xds[0]

# Calculate Maximum baseline
uvw = ds.UVW.data
bl_max_dist = da.sqrt(da.max(da.sum(uvw**2, axis=1)))

# bl_max_dist = da.stack(ds.UVW.data, my_ds.UVW.data for my_ds in xds, axis=1)
Esempio n. 3
0
def predict(args):
    # get inclusion regions
    include_regions = []
    exclude_regions = []
    if args.within:
        from regions import read_ds9
        import tempfile
        # kludge because regions cries over "FK5", wants lowercase
        with tempfile.NamedTemporaryFile(mode="w") as tmpfile, open(
                args.within) as regfile:
            tmpfile.write(regfile.read().lower())
            tmpfile.flush()
            include_regions = read_ds9(tmpfile.name)
            log.info("read {} inclusion region(s) from {}".format(
                len(include_regions), args.within))

    # Import source data from WSClean component list
    # See https://sourceforge.net/p/wsclean/wiki/ComponentList
    (comp_type, radec, stokes, spec_coeff, ref_freq, log_spec_ind,
     gaussian_shape) = import_from_wsclean(args.sky_model,
                                           include_regions=include_regions,
                                           exclude_regions=exclude_regions,
                                           point_only=args.points_only,
                                           num=args.num_sources or None)

    # Get the support tables
    tables = support_tables(
        args, ["FIELD", "DATA_DESCRIPTION", "SPECTRAL_WINDOW", "POLARIZATION"])

    field_ds = tables["FIELD"]
    ddid_ds = tables["DATA_DESCRIPTION"]
    spw_ds = tables["SPECTRAL_WINDOW"]
    pol_ds = tables["POLARIZATION"]
    frequencies = np.sort(
        [spw_ds[dd].CHAN_FREQ.data.flatten() for dd in range(len(spw_ds))])

    # cluster sources and refit. This only works for delta scale sources
    def __cluster(comp_type, radec, stokes, spec_coeff, ref_freq, log_spec_ind,
                  gaussian_shape, frequencies):
        uniq_radec = np.unique(radec)
        ncomp_type = []
        nradec = []
        nstokes = []
        nspec_coef = []
        nref_freq = []
        nlog_spec_ind = []
        ngaussian_shape = []

        for urd in uniq_radec:
            print comp_type.shape
            print radec.shape
            deltasel = comp_type[radec == urd] == "POINT"
            polyspecsel = np.logical_not(spec_coef[radec == urd])
            sel = deltasel & polyspecsel
            Is = stokes[sel, 0, None] * frequency[None, :]**0
            for jj in range(spec_coeff.shape[1]):
                Is += spec_coeff[sel, jj, None] * (
                    frequency[None, :] / ref_freq[sel, None] - 1)**(jj + 1)
            Is = np.sum(
                Is, axis=0)  # collapse over all the sources at this position
            logpolyspecsel = np.logical_not(log_spec_coef[radec == urd])
            sel = deltasel & logpolyspecsel

            Is = np.log(stokes[sel, 0, None] * frequency[None, :]**0)
            for jj in range(spec_coeff.shape[1]):
                Is += spec_coeff[sel, jj, None] * da.log(
                    (frequency[None, :] / ref_freq[sel, None])**(jj + 1))
            Is = np.exp(Is)
            Islogpoly = np.sum(
                Is, axis=0)  # collapse over all the sources at this position

            popt, pfitvar = curve_fit(
                lambda i, a, b, c, d: i + a *
                (frequency / ref_freq[0, None] - 1) + b *
                (frequency / ref_freq[0, None] - 1)**2 + c *
                (frequency / ref_freq[sel, None] - 1)**3 + d *
                (frequency / ref_freq[0, None] - 1)**3, frequency,
                Ispoly + Islogpoly)
            if not np.all(np.isfinite(pfitvar)):
                popt[0] = np.sum(stokes[sel, 0, None], axis=0)
                popt[1:] = np.inf
                log.warn(
                    "Refitting at position {0:s} failed. Assuming flat spectrum source of {1:.2f} Jy"
                    .format(radec, popt[0]))
            else:
                pcov = np.sqrt(np.diag(pfitvar))
                log.info(
                    "New fitted flux {0:.3f} Jy at position {1:s} with covariance {2:s}"
                    .format(popt[0], radec,
                            ", ".join([str(poptp) for poptp in popt])))

            ncomp_type.append("POINT")
            nradec.append(urd)
            nstokes.append(popt[0])
            nspec_coef.append(popt[1:])
            nref_freq.append(ref_freq[0])
            nlog_spec_ind = 0.0

        # add back all the gaussians
        sel = comp_type[radec] == "GAUSSIAN"
        for rd, stks, spec, ref, lspec, gs in zip(radec[sel], stokes[sel],
                                                  spec_coef[sel],
                                                  ref_freq[sel],
                                                  log_spec_ind[sel],
                                                  gaussian_shape[sel]):
            ncomp_type.append("GAUSSIAN")
            nradec.append(rd)
            nstokes.append(stks)
            nspec_coef.append(spec)
            nref_freq.append(ref)
            nlog_spec_ind.append(lspec)
            ngaussian_shape.append(gs)

        log.info(
            "Reduced {0:d} components to {1:d} components through by refitting"
            .format(len(comp_type), len(ncomp_type)))
        return (np.array(ncomp_type), np.array(nradec), np.array(nstokes),
                np.array(nspec_coeff), np.array(nref_freq),
                np.array(nlog_spec_ind), np.array(ngaussian_shape))

    if not args.dontcluster:
        (comp_type, radec, stokes, spec_coeff, ref_freq, log_spec_ind,
         gaussian_shape) = __cluster(comp_type, radec, stokes, spec_coeff,
                                     ref_freq, log_spec_ind, gaussian_shape,
                                     frequencies)

    # Add output column if it isn't present
    ms_rows, ms_datatype = ms_preprocess(args)

    # sort out resources
    args.row_chunks, args.model_chunks = get_budget(
        comp_type.shape[0], ms_rows, max([ss.NUM_CHAN.data for ss in spw_ds]),
        max([ss.NUM_CORR.data for ss in pol_ds]), ms_datatype, args)

    radec = da.from_array(radec, chunks=(args.model_chunks, 2))
    stokes = da.from_array(stokes, chunks=(args.model_chunks, 4))

    if np.count_nonzero(comp_type == 'GAUSSIAN') > 0:
        gaussian_components = True
        gshape_chunks = (args.model_chunks, 3)
        gaussian_shape = da.from_array(gaussian_shape, chunks=gshape_chunks)
    else:
        gaussian_components = False

    if args.spectra:
        spec_chunks = (args.model_chunks, spec_coeff.shape[1])
        spec_coeff = da.from_array(spec_coeff, chunks=spec_chunks)
        ref_freq = da.from_array(ref_freq, chunks=(args.model_chunks, ))

    # List of write operations
    writes = []

    # Construct a graph for each DATA_DESC_ID
    for xds in xds_from_ms(args.ms,
                           columns=["UVW", "ANTENNA1", "ANTENNA2", "TIME"],
                           group_cols=["FIELD_ID", "DATA_DESC_ID"],
                           chunks={"row": args.row_chunks}):
        if xds.attrs['FIELD_ID'] != args.fieldid:
            continue

        # Extract frequencies from the spectral window associated
        # with this data descriptor id
        field = field_ds[xds.attrs['FIELD_ID']]
        ddid = ddid_ds[xds.attrs['DATA_DESC_ID']]
        spw = spw_ds[ddid.SPECTRAL_WINDOW_ID.values]
        pol = pol_ds[ddid.POLARIZATION_ID.values]
        frequency = spw.CHAN_FREQ.data

        corrs = pol.NUM_CORR.values

        lm = radec_to_lm(radec, field.PHASE_DIR.data)

        if args.exp_sign_convention == 'casa':
            uvw = -xds.UVW.data
        elif args.exp_sign_convention == 'thompson':
            uvw = xds.UVW.data
        else:
            raise ValueError("Invalid sign convention '%s'" % args.sign)

        if args.spectra:
            # flux density at reference frequency ...
            # ... for logarithmic polynomial functions
            if log_spec_ind:
                Is = da.log(stokes[:, 0, None]) * frequency[None, :]**0
                # ... or for ordinary polynomial functions
            else:
                Is = stokes[:, 0, None] * frequency[None, :]**0
            # additional terms of SED ...
            for jj in range(spec_coeff.shape[1]):
                # ... for logarithmic polynomial functions
                if log_spec_ind:
                    Is += spec_coeff[:, jj, None] * da.log(
                        (frequency[None, :] / ref_freq[:, None])**(jj + 1))
                    # ... or for ordinary polynomial functions
                else:
                    Is += spec_coeff[:, jj, None] * (
                        frequency[None, :] / ref_freq[:, None] - 1)**(jj + 1)
            if log_spec_ind: Is = da.exp(Is)
            Qs = da.zeros_like(Is)
            Us = da.zeros_like(Is)
            Vs = da.zeros_like(Is)
            spectrum = da.stack(
                [Is, Qs, Us, Vs], axis=-1
            )  # stack along new axis and make it the last axis of the new array
            spectrum = spectrum.rechunk(spectrum.chunks[:2] +
                                        (spectrum.shape[2], ))

        log.info('-------------------------------------------')
        log.info('Nr sources        = {0:d}'.format(stokes.shape[0]))
        log.info('-------------------------------------------')
        log.info('stokes.shape      = {0:}'.format(stokes.shape))
        log.info('frequency.shape   = {0:}'.format(frequency.shape))
        if args.spectra: log.info('Is.shape          = {0:}'.format(Is.shape))
        if args.spectra:
            log.info('spectrum.shape    = {0:}'.format(spectrum.shape))

        # (source, row, frequency)
        phase = phase_delay(lm, uvw, frequency)
        # If at least one Gaussian component is present in the component list then all
        # sources are modelled as Gaussian components (Delta components have zero width)
        if gaussian_components:
            phase *= gaussian(uvw, frequency, gaussian_shape)
        # (source, frequency, corr_products)
        brightness = convert(spectrum if args.spectra else stokes,
                             ["I", "Q", "U", "V"], corr_schema(pol))

        log.info('brightness.shape  = {0:}'.format(brightness.shape))
        log.info('phase.shape       = {0:}'.format(phase.shape))
        log.info('-------------------------------------------')
        log.info('Attempting phase-brightness einsum with "{0:s}"'.format(
            einsum_schema(pol, args.spectra)))

        # (source, row, frequency, corr_products)
        jones = da.einsum(einsum_schema(pol, args.spectra), phase, brightness)
        log.info('jones.shape       = {0:}'.format(jones.shape))
        log.info('-------------------------------------------')
        if gaussian_components: log.info('Some Gaussian sources found')
        else: log.info('All sources are Delta functions')
        log.info('-------------------------------------------')

        # Identify time indices
        _, time_index = da.unique(xds.TIME.data, return_inverse=True)

        # Predict visibilities
        vis = predict_vis(time_index, xds.ANTENNA1.data, xds.ANTENNA2.data,
                          None, jones, None, None, None, None)

        # Reshape (2, 2) correlation to shape (4,)
        if corrs == 4:
            vis = vis.reshape(vis.shape[:2] + (4, ))

        # Assign visibilities to MODEL_DATA array on the dataset
        model_data = xr.DataArray(vis, dims=["row", "chan", "corr"])
        xds = xds.assign(**{args.output_column: model_data})
        # Create a write to the table
        write = xds_to_table(xds, args.ms, [args.output_column])
        # Add to the list of writes
        writes.append(write)

    # Submit all graph computations in parallel
    if args.num_workers:
        with ProgressBar(), dask.config.set(num_workers=args.num_workers):
            dask.compute(writes)
    else:
        with ProgressBar():
            dask.compute(writes)
Esempio n. 4
0
def predict(args):
    # Numpy arrays

    # Convert source data into dask arrays
    radec, stokes = parse_sky_model(args.sky_model)
    radec = da.from_array(radec, chunks=(SOURCE_CHUNKS, 2))
    stokes = da.from_array(stokes, chunks=(SOURCE_CHUNKS, 4))

    # Get the support tables
    tables = support_tables(args, ["FIELD", "DATA_DESCRIPTION",
                                   "SPECTRAL_WINDOW", "POLARIZATION"])

    field_ds = tables["FIELD"]
    ddid_ds = tables["DATA_DESCRIPTION"]
    spw_ds = tables["SPECTRAL_WINDOW"]
    pol_ds = tables["POLARIZATION"]

    # List of write operations
    writes = []

    # Construct a graph for each DATA_DESC_ID
    for xds in xds_from_ms(args.ms,
                           columns=["UVW", "ANTENNA1", "ANTENNA2", "TIME"],
                           group_cols=["FIELD_ID", "DATA_DESC_ID"],
                           chunks={"row": args.row_chunks}):

        # Extract frequencies from the spectral window associated
        # with this data descriptor id
        field = field_ds[xds.attrs['FIELD_ID']]
        ddid = ddid_ds[xds.attrs['DATA_DESC_ID']]
        spw = spw_ds[ddid.SPECTRAL_WINDOW_ID.values]
        pol = pol_ds[ddid.POLARIZATION_ID.values]
        frequency = spw.CHAN_FREQ.data

        corrs = pol.NUM_CORR.values

        lm = radec_to_lm(radec, field.PHASE_DIR.data)
        uvw = -xds.UVW.data if args.invert_uvw else xds.UVW.data

        # (source, row, frequency)
        phase = phase_delay(lm, uvw, frequency)

        brightness = convert(stokes, ["I", "Q", "U", "V"],
                             corr_schema(pol))

        # (source, row, frequency, corr1, corr2)
        jones = da.einsum(einsum_schema(pol), phase, brightness)

        # Identify time indices
        _, time_index = da.unique(xds.TIME.data, return_inverse=True)

        # Predict visibilities
        vis = predict_vis(time_index, xds.ANTENNA1.data, xds.ANTENNA2.data,
                          None, jones, None, None, None, None)

        # Reshape (2, 2) correlation to shape (4,)
        if corrs == 4:
            vis = vis.reshape(vis.shape[:2] + (4,))

        # Assign visibilities to MODEL_DATA array on the dataset
        model_data = xr.DataArray(vis, dims=["row", "chan", "corr"])
        xds = xds.assign(MODEL_DATA=model_data)
        # Create a write to the table
        write = xds_to_table(xds, args.ms, ['MODEL_DATA'])
        # Add to the list of writes
        writes.append(write)

    # Submit all graph computations in parallel
    with ProgressBar():
        dask.compute(writes)
Esempio n. 5
0
        # Create a dataset representing the entire antenna table
        ant_table = '::'.join((args.ms, 'ANTENNA'))

        for ant_ds in xds_from_table(ant_table):
            print(ant_ds)
            print(
                dask.compute(ant_ds.NAME.data, ant_ds.POSITION.data,
                             ant_ds.DISH_DIAMETER.data))

        # Create datasets representing each row of the spw table
        spw_table = '::'.join((args.ms, 'SPECTRAL_WINDOW'))

        for spw_ds in xds_from_table(spw_table, group_cols="__row__"):
            print(spw_ds)
            print(spw_ds.NUM_CHAN.values)
            print(spw_ds.CHAN_FREQ.values)

        # Create datasets from a partioning of the MS
        datasets = list(xds_from_ms(args.ms, chunks={'row': args.chunks}))

        for ds in datasets:
            print(ds)

            # Try write the STATE_ID column back
            write = xds_to_table(ds, args.ms, 'STATE_ID')
            with ProgressBar(), Profiler() as prof:
                write.compute()

            # Profile
            prof.visualize(file_path="chunked.html")
Esempio n. 6
0
from africanus.filters import convolution_filter
from xarrayms import xds_from_ms, xds_from_table


def create_parser():
    p = argparse.ArgumentParser()
    p.add_argument("ms")
    p.add_argument("-np", "--npix", default=1024, type=int)
    p.add_argument("-nc", "--chunks", default=10000, type=int)
    p.add_argument("-sc", "--cell-size", type=float)
    return p


args = create_parser().parse_args()

xds = list(xds_from_ms(args.ms, chunks={"row": args.chunks}))[0]
spw_ds = list(
    xds_from_table("::".join((args.ms, "SPECTRAL_WINDOW")),
                   group_cols="__row__"))[0]
wavelength = (lightspeed / spw_ds.CHAN_FREQ.data).compute()

# Determine UVW Coordinate extents
query = """
SELECT
MIN([SELECT ABS(UVW[0]) FROM {ms}]) AS ABS_UMIN,
MAX([SELECT ABS(UVW[0]) FROM {ms}]) as ABS_UMAX,
MIN([SELECT ABS(UVW[1]) FROM {ms}]) AS ABS_VMIN,
MAX([SELECT ABS(UVW[1]) FROM {ms}]) as ABS_VMAX,
MIN([SELECT UVW[2] FROM {ms}]) AS WMIN,
MAX([SELECT UVW[2] FROM {ms}]) AS WMAX
""".format(ms=args.ms)
Esempio n. 7
0
        assert bl_uvw.shape[0] == in_rows

        tot_rows = out_rows if rem == 0 else out_rows + 1

        avg_uvw = np.empty((tot_rows, 3), dtype=uvw.dtype)
        avg_uvw[:out_rows, :] = bl_uvw[:out_rows * bins, :].reshape(
            out_rows, bins, 3).mean(axis=1)

        if rem > 0:
            avg_uvw[out_rows:, :] = bl_uvw[out_rows * bins:, :].mean(axis=0)

    return data


# Main Method
# Read the MS
xds = list(
    xds_from_ms(args.ms,
                columns=["TIME", "ANTENNA1", "ANTENNA2", "UVW", "FLAG_ROW"],
                group_cols=[],
                index_cols=[],
                chunks={"row": 1e9}))

ds = xds[0]
max_uvw = np.sqrt(np.max(np.sum(ds.UVW.data**2, axis=1))).compute()

# Call the baseline_average_scan function
# baseline_average_scan(time, ant1, ant2, uvw, data, flag_row, max_uvw):
baseline_average_scan(ds.TIME.data, ds.ANTENNA1.data, ds.ANTENNA2.data,
                      ds.UVW.data, ds.FLAG_ROW, max_uvw)
Esempio n. 8
0
    # Create short names mapped to the full table path
    table_name = {
        short: "::".join((args.ms, full))
        for short, full in [('antenna', "ANTENNA"), (
            'ddid', "DATA_DESCRIPTION"), (
                'spw', "SPECTRAL_WINDOW"), ('pol',
                                            "POLARIZATION"), ('field',
                                                              "FIELD")]
    }

    with scheduler_context(args):
        # Get datasets from the main MS
        # partition by FIELD_ID and DATA_DESC_ID
        # and sorted by TIME
        datasets = xds_from_ms(args.ms,
                               group_cols=("FIELD_ID", "DATA_DESC_ID"),
                               index_cols="TIME")

        # Get the antenna dataset
        ant_ds = list(xds_from_table(table_name['antenna']))
        assert len(ant_ds) == 1
        ant_ds = ant_ds[0].rename({'row': 'antenna'}).drop('table_row')

        # Get datasets for DATA_DESCRIPTION, SPECTRAL_WINDOW
        # POLARIZATION and FIELD, partitioned by row
        ddid_ds = list(xds_from_table(table_name['ddid'],
                                      group_cols="__row__"))
        spwds = list(xds_from_table(table_name['spw'], group_cols="__row__"))
        pds = list(xds_from_table(table_name['pol'], group_cols="__row__"))
        field_ds = list(
            xds_from_table(table_name['field'], group_cols="__row__"))