Exemple #1
0
def _residual(ms, stack, **kw):
    args = OmegaConf.create(kw)
    OmegaConf.set_struct(args, True)
    pyscilog.log_to_file(args.output_filename + '.log')
    pyscilog.enable_memory_logging(level=3)

    # number of threads per worker
    if args.nthreads is None:
        if args.host_address is not None:
            raise ValueError(
                "You have to specify nthreads when using a distributed scheduler"
            )
        import multiprocessing
        nthreads = multiprocessing.cpu_count()
        args.nthreads = nthreads
    else:
        nthreads = args.nthreads

    # configure memory limit
    if args.mem_limit is None:
        if args.host_address is not None:
            raise ValueError(
                "You have to specify mem-limit when using a distributed scheduler"
            )
        import psutil
        mem_limit = int(psutil.virtual_memory()[0] /
                        1e9)  # 100% of memory by default
        args.mem_limit = mem_limit
    else:
        mem_limit = args.mem_limit

    nband = args.nband
    if args.nworkers is None:
        nworkers = nband
        args.nworkers = nworkers
    else:
        nworkers = args.nworkers

    if args.nthreads_per_worker is None:
        nthreads_per_worker = 1
        args.nthreads_per_worker = nthreads_per_worker
    else:
        nthreads_per_worker = args.nthreads_per_worker

    # the number of chunks being read in simultaneously is equal to
    # the number of dask threads
    nthreads_dask = nworkers * nthreads_per_worker

    if args.ngridder_threads is None:
        if args.host_address is not None:
            ngridder_threads = nthreads // nthreads_per_worker
        else:
            ngridder_threads = nthreads // nthreads_dask
        args.ngridder_threads = ngridder_threads
    else:
        ngridder_threads = args.ngridder_threads

    ms = list(ms)
    print('Input Options:', file=log)
    for key in kw.keys():
        print('     %25s = %s' % (key, args[key]), file=log)

    # numpy imports have to happen after this step
    from pfb import set_client
    set_client(nthreads, mem_limit, nworkers, nthreads_per_worker,
               args.host_address, stack, log)

    import numpy as np
    from pfb.utils.misc import chan_to_band_mapping
    import dask
    from dask.graph_manipulation import clone
    from dask.distributed import performance_report
    from daskms import xds_from_storage_ms as xds_from_ms
    from daskms import xds_from_storage_table as xds_from_table
    import dask.array as da
    from africanus.constants import c as lightspeed
    from africanus.gridding.wgridder.dask import residual as im2residim
    from ducc0.fft import good_size
    from pfb.utils.misc import stitch_images, plan_row_chunk
    from pfb.utils.fits import set_wcs, save_fits

    # chan <-> band mapping
    freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping(
        ms, nband=nband)

    # gridder memory budget
    max_chan_chunk = 0
    max_freq = 0
    for ims in ms:
        for spw in freqs[ims]:
            counts = freq_bin_counts[ims][spw].compute()
            freq = freqs[ims][spw].compute()
            max_chan_chunk = np.maximum(max_chan_chunk, counts.max())
            max_freq = np.maximum(max_freq, freq.max())

    # assumes measurement sets have the same columns,
    # number of correlations etc.
    xds = xds_from_ms(ms[0])
    ncorr = xds[0].dims['corr']
    nrow = xds[0].dims['row']
    data_bytes = getattr(xds[0], args.data_column).data.itemsize
    bytes_per_row = max_chan_chunk * ncorr * data_bytes
    memory_per_row = bytes_per_row

    # real valued weights
    wdims = getattr(xds[0], args.weight_column).data.ndim
    if wdims == 2:  # WEIGHT
        memory_per_row += ncorr * data_bytes / 2
    else:  # WEIGHT_SPECTRUM
        memory_per_row += bytes_per_row / 2

    # flags (uint8 or bool)
    memory_per_row += np.dtype(np.uint8).itemsize * max_chan_chunk * ncorr

    # UVW
    memory_per_row += xds[0].UVW.data.itemsize * 3

    # ANTENNA1/2
    memory_per_row += xds[0].ANTENNA1.data.itemsize * 2

    columns = (args.data_column, args.weight_column, args.flag_column, 'UVW',
               'ANTENNA1', 'ANTENNA2')

    # flag row
    if 'FLAG_ROW' in xds[0]:
        columns += ('FLAG_ROW', )
        memory_per_row += xds[0].FLAG_ROW.data.itemsize

    # imaging weights
    if args.imaging_weight_column is not None:
        columns += (args.imaging_weight_column, )
        memory_per_row += bytes_per_row / 2

    # Mueller term (complex valued)
    if args.mueller_column is not None:
        columns += (args.mueller_column, )
        memory_per_row += bytes_per_row

    # get max uv coords over all fields
    uvw = []
    u_max = 0.0
    v_max = 0.0
    for ims in ms:
        xds = xds_from_ms(ims, columns=('UVW'), chunks={'row': -1})

        for ds in xds:
            uvw = ds.UVW.data
            u_max = da.maximum(u_max, abs(uvw[:, 0]).max())
            v_max = da.maximum(v_max, abs(uvw[:, 1]).max())
            uv_max = da.maximum(u_max, v_max)

    uv_max = uv_max.compute()
    del uvw

    # image size
    cell_N = 1.0 / (2 * uv_max * max_freq / lightspeed)

    if args.cell_size is not None:
        cell_size = args.cell_size
        cell_rad = cell_size * np.pi / 60 / 60 / 180
        if cell_N / cell_rad < 1:
            raise ValueError(
                "Requested cell size too small. "
                "Super resolution factor = ", cell_N / cell_rad)
        print("Super resolution factor = %f" % (cell_N / cell_rad), file=log)
    else:
        cell_rad = cell_N / args.super_resolution_factor
        cell_size = cell_rad * 60 * 60 * 180 / np.pi
        print("Cell size set to %5.5e arcseconds" % cell_size, file=log)

    if args.nx is None:
        fov = args.field_of_view * 3600
        npix = int(fov / cell_size)
        if npix % 2:
            npix += 1
        nx = good_size(npix)
        ny = good_size(npix)
    else:
        nx = args.nx
        ny = args.ny if args.ny is not None else nx

    print("Image size set to (%i, %i, %i)" % (nband, nx, ny), file=log)

    # get approx image size
    # this is not a conservative estimate when multiple SPW's map to a single
    # imaging band
    pixel_bytes = np.dtype(args.output_type).itemsize
    band_size = nx * ny * pixel_bytes

    if args.host_address is None:
        # full image on single node
        row_chunk = plan_row_chunk(mem_limit / nworkers, band_size, nrow,
                                   memory_per_row, nthreads_per_worker)

    else:
        # single band per node
        row_chunk = plan_row_chunk(mem_limit, band_size, nrow, memory_per_row,
                                   nthreads_per_worker)

    if args.row_chunks is not None:
        row_chunk = int(args.row_chunks)
        if row_chunk == -1:
            row_chunk = nrow

    print(
        "nrows = %i, row chunks set to %i for a total of %i chunks per node" %
        (nrow, row_chunk, int(np.ceil(nrow / row_chunk))),
        file=log)

    chunks = {}
    for ims in ms:
        chunks[ims] = []  # xds_from_ms expects a list per ds
        for spw in freqs[ims]:
            chunks[ims].append({
                'row': row_chunk,
                'chan': chan_chunks[ims][spw]['chan']
            })

    dirties = []
    radec = None  # assumes we are only imaging field 0 of first MS
    for ims in ms:
        xds = xds_from_ms(ims, chunks=chunks[ims], columns=columns)

        # subtables
        ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
        fields = xds_from_table(ims + "::FIELD")
        spws = xds_from_table(ims + "::SPECTRAL_WINDOW")
        pols = xds_from_table(ims + "::POLARIZATION")

        # subtable data
        ddids = dask.compute(ddids)[0]
        fields = dask.compute(fields)[0]
        spws = dask.compute(spws)[0]
        pols = dask.compute(pols)[0]

        for ds in xds:
            field = fields[ds.FIELD_ID]

            # check fields match
            if radec is None:
                radec = field.PHASE_DIR.data.squeeze()

            if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()):
                continue

            # this is not correct, need to use spw
            spw = ds.DATA_DESC_ID

            uvw = clone(ds.UVW.data)

            data = getattr(ds, args.data_column).data
            dataxx = data[:, :, 0]
            datayy = data[:, :, -1]

            weights = getattr(ds, args.weight_column).data
            if len(weights.shape) < 3:
                weights = da.broadcast_to(weights[:, None, :],
                                          data.shape,
                                          chunks=data.chunks)

            if args.imaging_weight_column is not None:
                imaging_weights = getattr(ds, args.imaging_weight_column).data
                if len(imaging_weights.shape) < 3:
                    imaging_weights = da.broadcast_to(imaging_weights[:,
                                                                      None, :],
                                                      data.shape,
                                                      chunks=data.chunks)

                weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0]
                weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1]
            else:
                weightsxx = weights[:, :, 0]
                weightsyy = weights[:, :, -1]

            # apply adjoint of mueller term.
            # Phases modify data amplitudes modify weights.
            if args.mueller_column is not None:
                mueller = getattr(ds, args.mueller_column).data
                dataxx *= da.exp(-1j * da.angle(mueller[:, :, 0]))
                datayy *= da.exp(-1j * da.angle(mueller[:, :, -1]))
                weightsxx *= da.absolute(mueller[:, :, 0])
                weightsyy *= da.absolute(mueller[:, :, -1])

            # weighted sum corr to Stokes I
            weights = weightsxx + weightsyy
            data = (weightsxx * dataxx + weightsyy * datayy)
            # TODO - turn off this stupid warning
            data = da.where(weights, data / weights, 0.0j)

            # MS may contain auto-correlations
            if 'FLAG_ROW' in xds[0]:
                frow = ds.FLAG_ROW.data | (ds.ANTENNA1.data
                                           == ds.ANTENNA2.data)
            else:
                frow = (ds.ANTENNA1.data == ds.ANTENNA2.data)

            # only keep data where both corrs are unflagged
            flag = getattr(ds, args.flag_column).data
            flagxx = flag[:, :, 0]
            flagyy = flag[:, :, -1]
            # ducc0 uses uint8 mask not flag
            mask = ~da.logical_or((flagxx | flagyy), frow[:, None])

            dirty = vis2im(uvw,
                           freqs[ims][spw],
                           data,
                           freq_bin_idx[ims][spw],
                           freq_bin_counts[ims][spw],
                           nx,
                           ny,
                           cell_rad,
                           weights=weights,
                           flag=mask.astype(np.uint8),
                           nthreads=ngridder_threads,
                           epsilon=args.epsilon,
                           do_wstacking=args.wstack,
                           double_accum=args.double_accum)

            dirties.append(dirty)

    # dask.visualize(dirties, filename=args.output_filename + '_graph.pdf', optimize_graph=False)

    if not args.mock:
        # result = dask.compute(dirties, wsum, optimize_graph=False)
        with performance_report(filename=args.output_filename + '_per.html'):
            result = dask.compute(dirties, optimize_graph=False)

        dirties = result[0]

        dirty = stitch_images(dirties, nband, band_mapping)

        hdr = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec,
                      freq_out)
        save_fits(args.output_filename + '_dirty.fits',
                  dirty,
                  hdr,
                  dtype=args.output_type)

    print("All done here.", file=log)
Exemple #2
0
    def compute_weights(self, robust):
        from pfb.utils.weighting import compute_counts, counts_to_weights
        # compute counts
        counts = []
        for ims in self.ms:
            xds = xds_from_ms(ims,
                              group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                              chunks=self.chunks[ims],
                              columns=('UVW'))

            # subtables
            ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
            fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
            spws = xds_from_table(ims + "::SPECTRAL_WINDOW",
                                  group_cols="__row__")
            pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

            # subtable data
            ddids = dask.compute(ddids)[0]
            fields = dask.compute(fields)[0]
            spws = dask.compute(spws)[0]
            pols = dask.compute(pols)[0]

            for ds in xds:
                field = fields[ds.FIELD_ID]
                radec = field.PHASE_DIR.data.squeeze()
                if not np.array_equal(radec, self.radec):
                    continue

                spw = ds.DATA_DESC_ID  # not optimal, need to use spw

                freq_bin_idx = self.freq_bin_idx[ims][spw]
                freq_bin_counts = self.freq_bin_counts[ims][spw]
                freq = self.freq[ims][spw]

                uvw = ds.UVW.data

                count = compute_counts(uvw, freq, freq_bin_idx,
                                       freq_bin_counts, self.nx, self.ny,
                                       self.cell, self.cell, np.float32)

                counts.append(count)

        counts = dask.compute(counts)[0]

        counts = accumulate_dirty(counts, self.nband, self.band_mapping)

        counts = da.from_array(counts, chunks=(1, -1, -1))

        # convert counts to weights
        writes = []
        for ims in self.ms:
            xds = xds_from_ms(ims,
                              group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                              chunks=self.chunks[ims],
                              columns=self.columns)

            # subtables
            ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
            fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
            spws = xds_from_table(ims + "::SPECTRAL_WINDOW",
                                  group_cols="__row__")
            pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

            # subtable data
            ddids = dask.compute(ddids)[0]
            fields = dask.compute(fields)[0]
            spws = dask.compute(spws)[0]
            pols = dask.compute(pols)[0]

            out_data = []
            for ds in xds:
                field = fields[ds.FIELD_ID]
                radec = field.PHASE_DIR.data.squeeze()
                if not np.array_equal(radec, self.radec):
                    continue

                spw = ds.DATA_DESC_ID  # this is not correct, need to use spw

                freq_bin_idx = self.freq_bin_idx[ims][spw]
                freq_bin_counts = self.freq_bin_counts[ims][spw]
                freq = self.freq[ims][spw]

                uvw = ds.UVW.data

                weights = counts_to_weights(counts, uvw, freq, freq_bin_idx,
                                            freq_bin_counts, self.nx, self.ny,
                                            self.cell, self.cell, np.float32,
                                            robust)

                # hack to get shape and chunking info
                data = getattr(ds, self.data_column).data

                weights = da.broadcast_to(weights[:, :, None],
                                          data.shape,
                                          chunks=data.chunks)
                out_ds = ds.assign(**{
                    self.imaging_weight_column: (("row", "chan", "corr"),
                                                 weights)
                })
                out_data.append(out_ds)
            writes.append(
                xds_to_table(out_data,
                             ims,
                             columns=[self.imaging_weight_column]))
        dask.compute(writes)
Exemple #3
0
    def make_dirty(self):
        print("Making dirty", file=log)
        dirty = da.zeros((self.nband, self.nx, self.ny),
                         dtype=np.float32,
                         chunks=(1, self.nx, self.ny),
                         name=False)
        dirties = []
        for ims in self.ms:
            xds = xds_from_ms(ims,
                              group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                              chunks=self.chunks[ims],
                              columns=self.columns)

            # subtables
            ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
            fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
            spws = xds_from_table(ims + "::SPECTRAL_WINDOW",
                                  group_cols="__row__")
            pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

            # subtable data
            ddids = dask.compute(ddids)[0]
            fields = dask.compute(fields)[0]
            spws = dask.compute(spws)[0]
            pols = dask.compute(pols)[0]

            for ds in xds:
                field = fields[ds.FIELD_ID]
                radec = field.PHASE_DIR.data.squeeze()
                if not np.array_equal(radec, self.radec):
                    continue

                spw = ds.DATA_DESC_ID  # this is not correct, need to use spw

                freq_bin_idx = self.freq_bin_idx[ims][spw]
                freq_bin_counts = self.freq_bin_counts[ims][spw]
                freq = self.freq[ims][spw]
                freq_chunk = freq_bin_counts[0].compute()

                uvw = ds.UVW.data

                data = getattr(ds, self.data_column).data
                dataxx = data[:, :, 0]
                datayy = data[:, :, -1]

                weights = getattr(ds, self.weight_column).data
                if len(weights.shape) < 3:
                    weights = da.broadcast_to(weights[:, None, :],
                                              data.shape,
                                              chunks=data.chunks)

                if self.imaging_weight_column is not None:
                    imaging_weights = getattr(ds,
                                              self.imaging_weight_column).data
                    if len(imaging_weights.shape) < 3:
                        imaging_weights = da.broadcast_to(
                            imaging_weights[:, None, :],
                            data.shape,
                            chunks=data.chunks)

                    weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0]
                    weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1]
                else:
                    weightsxx = weights[:, :, 0]
                    weightsyy = weights[:, :, -1]

                # apply adjoint of mueller term.
                # Phases modify data amplitudes modify weights.
                if self.mueller_column is not None:
                    mueller = getattr(ds, self.mueller_column).data
                    dataxx *= da.exp(-1j * da.angle(mueller[:, :, 0]))
                    datayy *= da.exp(-1j * da.angle(mueller[:, :, -1]))
                    weightsxx *= da.absolute(mueller[:, :, 0])
                    weightsyy *= da.absolute(mueller[:, :, -1])

                # weighted sum corr to Stokes I
                weights = weightsxx + weightsyy
                data = (weightsxx * dataxx + weightsyy * datayy)
                # TODO - turn off this stupid warning
                data = da.where(weights, data / weights, 0.0j)

                # only keep data where both corrs are unflagged
                flag = getattr(ds, self.flag_column).data
                flagxx = flag[:, :, 0]
                flagyy = flag[:, :, -1]
                # ducc0 convention uses uint8 mask not flag
                flag = ~(flagxx | flagyy)

                dirty = vis2im(uvw,
                               freq,
                               data,
                               freq_bin_idx,
                               freq_bin_counts,
                               self.nx,
                               self.ny,
                               self.cell,
                               weights=weights,
                               flag=flag.astype(np.uint8),
                               nthreads=self.nthreads,
                               epsilon=self.epsilon,
                               do_wstacking=self.do_wstacking,
                               double_accum=True)

                dirties.append(dirty)

        dirties = dask.compute(dirties, scheduler='single-threaded')[0]

        return accumulate_dirty(dirties, self.nband,
                                self.band_mapping).astype(self.real_type)
Exemple #4
0
    from dask.diagnostics import Profiler, ProgressBar

    def create_parser():
        parser = argparse.ArgumentParser()
        parser.add_argument("ms")
        parser.add_argument("-c", "--chunks", default=10000, type=int)
        parser.add_argument("-s", "--scheduler", default="threaded")
        return parser

    args = create_parser().parse_args()

    with scheduler_context(args):
        # Create a dataset representing the entire antenna table
        ant_table = '::'.join((args.ms, 'ANTENNA'))

        for ant_ds in xds_from_table(ant_table):
            print(dask.compute(ant_ds.NAME.data,
                               ant_ds.POSITION.data,
                               ant_ds.DISH_DIAMETER.data))

        # Create datasets representing each row of the spw table
        spw_table = '::'.join((args.ms, 'SPECTRAL_WINDOW'))

        for spw_ds in xds_from_table(spw_table, group_cols="__row__"):
            print(spw_ds)
            print(spw_ds.NUM_CHAN.values)
            print(spw_ds.CHAN_FREQ.values)

        # Create datasets from a partioning of the MS
        datasets = list(xds_from_ms(args.ms, chunks={'row': args.chunks}))
Exemple #5
0
def _psf(**kw):
    args = OmegaConf.create(kw)
    from omegaconf import ListConfig
    if not isinstance(args.ms, list) and not isinstance(args.ms, ListConfig):
        args.ms = [args.ms]
    OmegaConf.set_struct(args, True)

    import numpy as np
    from pfb.utils.misc import chan_to_band_mapping
    import dask
    # from dask.distributed import performance_report
    from dask.graph_manipulation import clone
    from daskms import xds_from_storage_ms as xds_from_ms
    from daskms import xds_from_storage_table as xds_from_table
    from daskms import Dataset
    from daskms.experimental.zarr import xds_to_zarr
    import dask.array as da
    from africanus.constants import c as lightspeed
    from africanus.gridding.wgridder.dask import dirty as vis2im
    from ducc0.fft import good_size
    from pfb.utils.misc import stitch_images, plan_row_chunk
    from pfb.utils.fits import set_wcs, save_fits

    # chan <-> band mapping
    ms = args.ms
    nband = args.nband
    freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping(
        ms, nband=nband)

    # gridder memory budget
    max_chan_chunk = 0
    max_freq = 0
    for ims in args.ms:
        for spw in freqs[ims]:
            counts = freq_bin_counts[ims][spw].compute()
            freq = freqs[ims][spw].compute()
            max_chan_chunk = np.maximum(max_chan_chunk, counts.max())
            max_freq = np.maximum(max_freq, freq.max())

    # assumes measurement sets have the same columns,
    # number of correlations etc.
    xds = xds_from_ms(args.ms[0])
    ncorr = xds[0].dims['corr']
    nrow = xds[0].dims['row']
    # we still have to cater for complex valued data because we cast
    # the weights to complex but we not longer need to factor the
    # weight column into our memory budget
    data_bytes = getattr(xds[0], args.data_column).data.itemsize
    bytes_per_row = max_chan_chunk * ncorr * data_bytes
    memory_per_row = bytes_per_row

    # flags (uint8 or bool)
    memory_per_row += bytes_per_row / 8

    # UVW
    memory_per_row += xds[0].UVW.data.itemsize * 3

    # ANTENNA1/2
    memory_per_row += xds[0].ANTENNA1.data.itemsize * 2

    # TIME
    memory_per_row += xds[0].TIME.data.itemsize

    # data column is not actually read into memory just used to infer
    # dtype and chunking
    columns = (args.data_column, args.weight_column, args.flag_column, 'UVW',
               'ANTENNA1', 'ANTENNA2', 'TIME')

    # flag row
    if 'FLAG_ROW' in xds[0]:
        columns += ('FLAG_ROW', )
        memory_per_row += xds[0].FLAG_ROW.data.itemsize

    # imaging weights
    if args.imaging_weight_column is not None:
        columns += (args.imaging_weight_column, )
        memory_per_row += bytes_per_row / 2

    # Mueller term (complex valued)
    if args.mueller_column is not None:
        columns += (args.mueller_column, )
        memory_per_row += bytes_per_row

    # get max uv coords over all fields
    uvw = []
    u_max = 0.0
    v_max = 0.0
    for ims in args.ms:
        xds = xds_from_ms(ims, columns=('UVW'), chunks={'row': -1})

        for ds in xds:
            uvw = ds.UVW.data
            u_max = da.maximum(u_max, abs(uvw[:, 0]).max())
            v_max = da.maximum(v_max, abs(uvw[:, 1]).max())
            uv_max = da.maximum(u_max, v_max)

    uv_max = uv_max.compute()
    del uvw

    # image size
    cell_N = 1.0 / (2 * uv_max * max_freq / lightspeed)

    if args.cell_size is not None:
        cell_size = args.cell_size
        cell_rad = cell_size * np.pi / 60 / 60 / 180
        if cell_N / cell_rad < 1:
            raise ValueError(
                "Requested cell size too small. "
                "Super resolution factor = ", cell_N / cell_rad)
        print("Super resolution factor = %f" % (cell_N / cell_rad), file=log)
    else:
        cell_rad = cell_N / args.super_resolution_factor
        cell_size = cell_rad * 60 * 60 * 180 / np.pi
        print("Cell size set to %5.5e arcseconds" % cell_size, file=log)

    if args.nx is None:
        fov = args.field_of_view * 3600
        npix = int(args.psf_oversize * fov / cell_size)
        if npix % 2:
            npix += 1
        nx = npix
        ny = npix
    else:
        nx = args.nx
        ny = args.ny if args.ny is not None else nx

    print("PSF size set to (%i, %i, %i)" % (nband, nx, ny), file=log)

    # get approx image size
    # this is not a conservative estimate when multiple SPW's map to a single
    # imaging band
    pixel_bytes = np.dtype(args.output_type).itemsize
    band_size = nx * ny * pixel_bytes

    if args.host_address is None:
        # full image on single node
        row_chunk = plan_row_chunk(args.mem_limit / args.nworkers, band_size,
                                   nrow, memory_per_row,
                                   args.nthreads_per_worker)

    else:
        # single band per node
        row_chunk = plan_row_chunk(args.mem_limit, band_size, nrow,
                                   memory_per_row, args.nthreads_per_worker)

    if args.row_chunks is not None:
        row_chunk = int(args.row_chunks)
        if row_chunk == -1:
            row_chunk = nrow

    print(
        "nrows = %i, row chunks set to %i for a total of %i chunks per node" %
        (nrow, row_chunk, int(np.ceil(nrow / row_chunk))),
        file=log)

    chunks = {}
    for ims in args.ms:
        chunks[ims] = []  # xds_from_ms expects a list per ds
        for spw in freqs[ims]:
            chunks[ims].append({
                'row': row_chunk,
                'chan': chan_chunks[ims][spw]['chan']
            })

    psfs = []
    radec = None  # assumes we are only imaging field 0 of first MS
    out_datasets = []
    for ims in args.ms:
        xds = xds_from_ms(ims, chunks=chunks[ims], columns=columns)

        # subtables
        ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
        fields = xds_from_table(ims + "::FIELD")
        spws = xds_from_table(ims + "::SPECTRAL_WINDOW")
        pols = xds_from_table(ims + "::POLARIZATION")

        # subtable data
        ddids = dask.compute(ddids)[0]
        fields = dask.compute(fields)[0]
        spws = dask.compute(spws)[0]
        pols = dask.compute(pols)[0]

        for ds in xds:
            field = fields[ds.FIELD_ID]

            # check fields match
            if radec is None:
                radec = field.PHASE_DIR.data.squeeze()

            if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()):
                continue

            # this is not correct, need to use spw
            spw = ds.DATA_DESC_ID

            uvw = clone(ds.UVW.data)

            data_type = getattr(ds, args.data_column).data.dtype
            data_shape = getattr(ds, args.data_column).data.shape
            data_chunks = getattr(ds, args.data_column).data.chunks

            weights = getattr(ds, args.weight_column).data
            if len(weights.shape) < 3:
                weights = da.broadcast_to(weights[:, None, :],
                                          data_shape,
                                          chunks=data_chunks)

            if args.imaging_weight_column is not None:
                imaging_weights = getattr(ds, args.imaging_weight_column).data
                if len(imaging_weights.shape) < 3:
                    imaging_weights = da.broadcast_to(imaging_weights[:,
                                                                      None, :],
                                                      data_shape,
                                                      chunks=data_chunks)

                weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0]
                weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1]
            else:
                weightsxx = weights[:, :, 0]
                weightsyy = weights[:, :, -1]

            # apply mueller term
            if args.mueller_column is not None:
                mueller = getattr(ds, args.mueller_column).data
                weightsxx *= da.absolute(mueller[:, :, 0])**2
                weightsyy *= da.absolute(mueller[:, :, -1])**2

            # weighted sum corr to Stokes I
            weights = weightsxx + weightsyy

            # MS may contain auto-correlations
            if 'FLAG_ROW' in xds[0]:
                frow = ds.FLAG_ROW.data | (ds.ANTENNA1.data
                                           == ds.ANTENNA2.data)
            else:
                frow = (ds.ANTENNA1.data == ds.ANTENNA2.data)

            # only keep data where both corrs are unflagged
            flag = getattr(ds, args.flag_column).data
            flagxx = flag[:, :, 0]
            flagyy = flag[:, :, -1]
            # ducc0 uses uint8 mask not flag
            mask = ~da.logical_or((flagxx | flagyy), frow[:, None])

            psf = vis2im(uvw,
                         freqs[ims][spw],
                         weights.astype(data_type),
                         freq_bin_idx[ims][spw],
                         freq_bin_counts[ims][spw],
                         nx,
                         ny,
                         cell_rad,
                         flag=mask.astype(np.uint8),
                         nthreads=args.nvthreads,
                         epsilon=args.epsilon,
                         do_wstacking=args.wstack,
                         double_accum=args.double_accum)

            psfs.append(psf)

            data_vars = {
                'FIELD_ID': (('row', ),
                             da.full_like(ds.TIME.data,
                                          ds.FIELD_ID,
                                          chunks=args.row_out_chunk)),
                'DATA_DESC_ID': (('row', ),
                                 da.full_like(ds.TIME.data,
                                              ds.DATA_DESC_ID,
                                              chunks=args.row_out_chunk)),
                'WEIGHT':
                (('row', 'chan'), weights.rechunk({0: args.row_out_chunk
                                                   })),  # why no 'f4'?
                'UVW': (('row', 'uvw'), uvw.rechunk({0: args.row_out_chunk}))
            }

            coords = {'chan': (('chan', ), freqs[ims][spw])}

            out_ds = Dataset(data_vars, coords)

            out_datasets.append(out_ds)

    writes = xds_to_zarr(out_datasets,
                         args.output_filename + '.zarr',
                         columns='ALL')

    # dask.visualize(writes, filename=args.output_filename + '_psf_writes_graph.pdf', optimize_graph=False)
    # dask.visualize(psfs, filename=args.output_filename + '_psf_graph.pdf', optimize_graph=False)

    if not args.mock:
        # psfs = dask.compute(psfs, writes, optimize_graph=False)[0]
        # with performance_report(filename=args.output_filename + '_psf_per.html'):
        psfs = dask.compute(psfs, writes, optimize_graph=False)[0]

        psf = stitch_images(psfs, nband, band_mapping)

        hdr = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec,
                      freq_out)
        save_fits(args.output_filename + '_psf.fits',
                  psf,
                  hdr,
                  dtype=args.output_type)

        psf_mfs = np.sum(psf, axis=0)
        wsum = psf_mfs.max()
        psf_mfs /= wsum

        hdr_mfs = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec,
                          np.mean(freq_out))
        save_fits(args.output_filename + '_psf_mfs.fits',
                  psf_mfs,
                  hdr_mfs,
                  dtype=args.output_type)

    print("All done here.", file=log)
Exemple #6
0
def main(args):
    # get max uv coords over all fields
    uvw = []
    u_max = 0.0
    v_max = 0.0
    all_freqs = []
    for ims in args.ms:
        xds = xds_from_ms(ims,
                          group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                          columns=('UVW'),
                          chunks={'row': args.row_chunks})

        spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__")
        spws = dask.compute(spws)[0]

        for ds in xds:
            uvw = ds.UVW.data
            u_max = da.maximum(u_max, abs(uvw[:, 0]).max())
            v_max = da.maximum(v_max, abs(uvw[:, 1]).max())
            uv_max = da.maximum(u_max, v_max)

            spw = spws[ds.DATA_DESC_ID]
            tmp_freq = spw.CHAN_FREQ.data.squeeze()
            all_freqs.append(list(tmp_freq))

    uv_max = u_max.compute()
    del uvw

    # get Nyquist cell size
    from africanus.constants import c as lightspeed
    all_freqs = dask.compute(all_freqs)
    freq = np.unique(all_freqs)
    cell_N = 1.0 / (2 * uv_max * freq.max() / lightspeed)

    if args.cell_size is not None:
        cell_rad = args.cell_size * np.pi / 60 / 60 / 180
        print("Super resolution factor = ", cell_N / cell_rad)
    else:
        cell_rad = cell_N / args.super_resolution_factor
        args.cell_size = cell_rad * 60 * 60 * 180 / np.pi
        print("Cell size set to %5.5e arcseconds" % args.cell_size)

    if args.nx is None or args.ny is None:
        fov = args.fov * 3600
        npix = int(fov / args.cell_size)
        if npix % 2:
            npix += 1
        args.nx = npix
        args.ny = npix

    if args.nband is None:
        args.nband = freq.size

    print("Image size set to (%i, %i, %i)" % (args.nband, args.nx, args.ny))

    # init gridder
    R = Gridder(args.ms,
                args.nx,
                args.ny,
                args.cell_size,
                nband=args.nband,
                nthreads=args.nthreads,
                do_wstacking=args.do_wstacking,
                row_chunks=args.row_chunks,
                data_column=args.data_column,
                weight_column=args.weight_column,
                epsilon=args.epsilon,
                imaging_weight_column=args.imaging_weight_column,
                model_column=args.model_column,
                flag_column=args.flag_column)
    freq_out = R.freq_out
    radec = R.radec

    # get headers
    hdr = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx,
                  args.ny, radec, freq_out)
    hdr_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx,
                      args.ny, radec, np.mean(freq_out))
    hdr_psf = set_wcs(args.cell_size / 3600, args.cell_size / 3600,
                      2 * args.nx, 2 * args.ny, radec, freq_out)
    hdr_psf_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600,
                          2 * args.nx, 2 * args.ny, radec, np.mean(freq_out))

    # psf
    if args.psf is not None:
        try:
            compare_headers(hdr_psf, fits.getheader(args.psf))
            psf_array = load_fits(args.psf)
        except:
            psf_array = R.make_psf()
            save_fits(args.outfile + '_psf.fits', psf_array, hdr_psf)
    else:
        psf_array = R.make_psf()
        save_fits(args.outfile + '_psf.fits', psf_array, hdr_psf)

    psf_max = np.amax(psf_array.reshape(args.nband, 4 * args.nx * args.ny),
                      axis=1)
    wsum = np.sum(psf_max)
    counts = np.sum(psf_max > 0)
    psf_max_mean = wsum / counts  # normalissation for more intuitive sig_21 values
    psf_array /= psf_max_mean
    psf = PSF(psf_array, args.nthreads)
    psf_max = np.amax(psf_array.reshape(args.nband, 4 * args.nx * args.ny),
                      axis=1)
    wsum = np.sum(psf_max)
    psf_max[psf_max < 1e-15] = 1e-15  # LB - is this the right thing to do?

    psf_mfs = np.sum(psf_array, axis=0) / wsum
    save_fits(
        args.outfile + '_psf_mfs.fits', psf_mfs[args.nx // 2:3 * args.nx // 2,
                                                args.ny // 2:3 * args.ny // 2],
        hdr_mfs)

    # dirty
    if args.dirty is not None:
        try:
            compare_headers(hdr, fits.getheader(args.dirty))
            dirty = load_fits(args.dirty)
        except:
            dirty = R.make_dirty()
            save_fits(args.outfile + '_dirty.fits', dirty, hdr)
    else:
        dirty = R.make_dirty()
        save_fits(args.outfile + '_dirty.fits', dirty, hdr)

    dirty_mfs = np.sum(dirty / psf_max_mean, axis=0) / wsum
    save_fits(args.outfile + '_dirty_mfs.fits', dirty_mfs, hdr_mfs)

    residual = dirty.copy()

    model = np.zeros((2, args.nband, args.nx, args.ny))
    recompute_residual = False
    if args.beta0 is not None:
        compare_headers(hdr, fits.getheader(args.beta0))
        model[0] = load_fits(args.beta0).squeeze()
        recompute_residual = True

    if args.alpha0 is not None:
        compare_headers(hdr, fits.getheader(args.alpha0))
        model[1] = load_fits(args.alpha0).squeeze()
        recompute_residual = True

    # normalise for more intuitive hypers
    residual /= psf_max_mean
    residual_mfs = np.sum(residual, axis=0) / wsum
    save_fits(args.outfile + '_first_residual_mfs.fits', residual_mfs, hdr_mfs)

    # mask
    if args.mask is not None:
        mask = load_fits(args.mask, dtype=np.int64)[None, :, :]
        if mask.shape != (1, args.nx, args.ny):
            raise ValueError("Mask has incorrect shape")
    else:
        mask = np.ones((1, args.nx, args.ny), dtype=np.int64)

    # point mask
    pmask = load_fits(args.point_mask, dtype=np.bool)[None, :, :]
    if pmask.shape != (1, args.nx, args.ny):
        raise ValueError("Mask has incorrect shape")

    # set up splitting operator
    phi = lambda x: x[0] * pmask + x[1] * mask
    phih = lambda x: np.concatenate(
        ((pmask * x)[None], (mask * x)[None]), axis=0)

    if recompute_residual:
        image = phi(model)
        residual = R.make_residual(image) / psf_max_mean
        residual_mfs = np.sum(residual, axis=0) / wsum

    # Gaussian "prior" used for preconditioning extended emission
    A = Gauss(args.sig_l2a, args.nband, args.nx, args.ny, args.nthreads)

    #  preconditioning matrix
    def hess(x):
        return phih(psf.convolve(phi(x))) + np.concatenate(
            (x[0:1] / args.sig_l2b**2, A.idot(x[1])[None]), axis=0)
        # return  phih(psf.convolve(phi(x))) + np.concatenate((x[0:1]/args.sig_l2b**2, x[1::]/args.sig_l2a**2), axis=0)

    # M_func = lambda x: np.concatenate((x[0:1] * args.sig_l2b**2, x[1::] * args.sig_l2a**2), axis=0)
    M_func = lambda x: np.concatenate(
        (x[0:1] * args.sig_l2b**2, A.convolve(x[1])[None]), axis=0)

    par_shape = phih(dirty).shape
    if args.beta is None:
        print("Getting spectral norm of update operator")
        beta = power_method(hess,
                            par_shape,
                            tol=args.pmtol,
                            maxit=args.pmmaxit)
    else:
        beta = args.beta
    print(" beta = %f " % beta)

    # set up wavelet basis
    theta = DaskTheta(args.nband, args.nx, args.ny, nthreads=args.nthreads)
    nbasis = theta.nbasis
    weights_21 = np.ones((theta.nbasis + 1, theta.nmax), dtype=np.float64)
    tmp = np.pad(pmask.ravel(), (0, theta.nmax - args.nx * args.ny),
                 mode='constant')
    weights_21[0] = np.where(tmp, args.sig_21b / args.sig_21a, 1e15)
    dual = np.zeros((theta.nbasis + 1, args.nband, theta.nmax),
                    dtype=np.float64)

    # Reporting
    report_iters = list(np.arange(0, args.maxit, args.report_freq))
    if report_iters[-1] != args.maxit - 1:
        report_iters.append(args.maxit - 1)

    # deconvolve
    eps = 1.0
    i = 0
    rmax = np.abs(residual_mfs).max()
    rms = np.std(residual_mfs)
    print("Peak of initial residual is %f and rms is %f" % (rmax, rms))
    for i in range(1, args.maxit):
        x = pcg(hess,
                phih(residual),
                np.zeros(par_shape, dtype=np.float64),
                M=M_func,
                tol=args.cgtol,
                maxit=args.cgmaxit,
                verbosity=args.cgverbose)

        if i in report_iters:
            save_fits(args.outfile + str(i) + '_point_update.fits', x[0], hdr)
            save_fits(args.outfile + str(i) + '_fluff_update.fits', x[1], hdr)

        # update model
        modelp = model
        model = modelp + args.gamma * x
        model, dual = primal_dual(hess,
                                  model,
                                  modelp,
                                  dual,
                                  args.sig_21a,
                                  theta,
                                  weights_21,
                                  beta,
                                  tol=args.pdtol,
                                  maxit=args.pdmaxit,
                                  report_freq=100,
                                  mask=mask,
                                  positivity=args.positivity,
                                  gamma=args.gamma)

        # get residual
        image = phi(model)
        residual = R.make_residual(image) / psf_max_mean

        # check stopping criteria
        residual_mfs = np.sum(residual, axis=0) / wsum
        rmax = np.abs(residual_mfs).max()
        rms = np.std(residual_mfs)
        eps = np.linalg.norm(model - modelp) / np.linalg.norm(model)

        if i in report_iters:
            # save current iteration
            save_fits(args.outfile + str(i) + '_model.fits', image, hdr)

            save_fits(args.outfile + str(i) + '_point.fits', model[0], hdr)
            save_fits(args.outfile + str(i) + '_fluff.fits', model[1], hdr)

            model_mfs = np.mean(image, axis=0)
            save_fits(args.outfile + str(i) + '_model_mfs.fits', model_mfs,
                      hdr_mfs)

            save_fits(args.outfile + str(i) + '_residual.fits', residual, hdr)

            save_fits(args.outfile + str(i) + '_residual_mfs.fits',
                      residual_mfs, hdr_mfs)

        print(
            "At iteration %i peak of residual is %f, rms is %f, current eps is %f"
            % (i, rmax, rms, eps))

        if eps < args.tol:
            break

    if args.interp_model:
        nband = args.nband
        order = args.spectral_poly_order
        mask = np.where(model_mfs > 1e-10, 1, 0)
        I = np.argwhere(mask).squeeze()
        Ix = I[:, 0]
        Iy = I[:, 1]
        npix = I.shape[0]

        # get components
        beta = image[:, Ix, Iy]

        # fit integrated polynomial to model components
        # we are given frequencies at bin centers, convert to bin edges
        ref_freq = np.mean(freq_out)
        delta_freq = freq_out[1] - freq_out[0]
        wlow = (freq_out - delta_freq / 2.0) / ref_freq
        whigh = (freq_out + delta_freq / 2.0) / ref_freq
        wdiff = whigh - wlow

        # set design matrix for each component
        Xdesign = np.zeros([freq_out.size, args.spectral_poly_order])
        for i in range(1, args.spectral_poly_order + 1):
            Xdesign[:, i - 1] = (whigh**i - wlow**i) / (i * wdiff)

        weights = psf_max[:, None]
        dirty_comps = Xdesign.T.dot(weights * beta)

        hess_comps = Xdesign.T.dot(weights * Xdesign)

        comps = np.linalg.solve(hess_comps, dirty_comps)

        np.savez(args.outfile + "spectral_comps",
                 comps=comps,
                 ref_freq=ref_freq,
                 mask=np.any(model, axis=0))

    if args.write_model:
        if args.interp_model:
            R.write_component_model(comps, ref_freq, mask, args.row_chunks,
                                    args.chan_chunks)
        else:
            R.write_model(model)
Exemple #7
0
def read_ms(ms, num_vis, res_arcmin, chunks=50000, channel=0, field_id=0):
    """
    Use dask-ms to load the necessary data to create a telescope operator
    (will use uvw positions, and antenna positions)

    -- res_arcmin: Used to calculate the maximum baselines to consider.
                   We want two pixels per smallest fringe
                   pix_res > fringe / 2

                   u sin(theta) = n (for nth fringe)
                   at small angles: theta = 1/u, or bl_max = 1 / theta

                   d sin(theta) = lambda / 2
                   d / lambda = 1 / (2 sin(theta))
                   bl_max = lambda / 2sin(theta)


    """

    # local_cluster = distributed.LocalCluster(processes=False)
    # address = local_cluster.scheduler_address
    # logging.info("Using distributed scheduler "
    # "with address '{}'".format(address))
    # client = distributed.Client()

    try:
        # Create a dataset representing the entire antenna table
        ant_table = "::".join((ms, "ANTENNA"))

        for ant_ds in xds_from_table(ant_table):
            # print(ant_ds)
            # print(dask.compute(ant_ds.NAME.data,
            # ant_ds.POSITION.data,
            # ant_ds.DISH_DIAMETER.data))
            ant_p = np.array(ant_ds.POSITION.data)
        logger.info("Antenna Positions {}".format(ant_p.shape))

        # Create a dataset representing the field
        field_table = "::".join((ms, "FIELD"))
        for field_ds in xds_from_table(field_table):
            phase_dir = np.array(field_ds.PHASE_DIR.data)[0].flatten()
            name = field_ds.NAME.data.compute()
            logger.info("Field {}: Phase Dir {}".format(
                name, np.degrees(phase_dir)))

        # Create datasets representing each row of the spw table
        spw_table = "::".join((ms, "SPECTRAL_WINDOW"))

        for spw_ds in xds_from_table(spw_table, group_cols="__row__"):
            logger.info("CHAN_FREQ.values: {}".format(
                spw_ds.CHAN_FREQ.values.shape))
            frequencies = dask.compute(spw_ds.CHAN_FREQ.values)[0].flatten()
            frequency = frequencies[channel]
            logger.info("Frequencies = {}".format(frequencies))
            logger.info("Frequency = {}".format(frequency))
            logger.info("NUM_CHAN = %f" % np.array(spw_ds.NUM_CHAN.values)[0])

        # Create datasets from a partioning of the MS
        datasets = list(xds_from_ms(ms, chunks={"row": chunks}))
        logger.info("DataSets: N={}".format(len(datasets)))

        pol = 0

        def read_np_array(da, title, dtype=np.float32):
            tic = time.perf_counter()
            logger.info("Reading {}...".format(title))
            ret = np.array(da, dtype=dtype)
            toc = time.perf_counter()
            logger.info("Elapsed {:04f} seconds".format(toc - tic))
            return ret

        for i, ds in enumerate(datasets):
            logger.info("DATASET field_id={} shape: {}".format(
                ds.FIELD_ID, ds.DATA.data.shape))
            logger.info("UVW shape: {}".format(ds.UVW.data.shape))
            logger.info("SIGMA shape: {}".format(ds.SIGMA.data.shape))
            if int(field_id) == int(ds.FIELD_ID):
                uvw = read_np_array(ds.UVW.data, "UVW")
                flags = read_np_array(ds.FLAG.data[:, channel, pol],
                                      "FLAGS",
                                      dtype=np.int32)

                #
                #
                #   Now calculate which indices we should use to get the required number of
                #   visibilities.
                #
                bl_max = get_resolution_max_baseline(res_arcmin, frequency)

                logger.info("Resolution Max UVW: {:g} meters".format(bl_max))
                logger.info("Flags: {}".format(flags.shape))

                # Now report the recommended resolution from the data.
                # 1.0 / 2*np.sin(theta) = limit_u
                limit_uvw = np.max(np.abs(uvw), 0)

                res_limit = get_baseline_resolution(limit_uvw[0], frequency)
                logger.info("Nyquist resolution: {:g} arcmin".format(
                    np.degrees(res_limit) * 60.0))

                if True:
                    bl = np.sqrt(uvw[:, 0]**2 + uvw[:, 1]**2 + uvw[:, 2]**2)
                    # good_data = np.array(np.where((flags == 0) & (np.max(np.abs(uvw), 1) < bl_max))).T.reshape((-1,))
                    good_data = np.array(np.where((flags == 0)
                                                  & (bl < bl_max))).T.reshape(
                                                      (-1, ))
                else:
                    good_data = np.array(np.where(flags == 0)).T.reshape(
                        (-1, ))
                logger.info("Good Data {}".format(good_data.shape))

                logger.info("Maximum UVW: {}".format(limit_uvw))
                logger.info("Minimum UVW: {}".format(np.min(np.abs(uvw), 0)))

                for i in range(3):
                    p05, p50, p95 = np.percentile(np.abs(uvw[:, i]),
                                                  [5, 50, 95])
                    logger.info("       U[{}]: {:5.2f} {:5.2f} {:5.2f}".format(
                        i, p05, p50, p95))

                n_ant = len(ant_p)

                n_max = len(good_data)

                if n_max <= num_vis:
                    indices = np.arange(n_max)
                else:
                    indices = np.random.choice(good_data,
                                               min(num_vis, n_max),
                                               replace=False)

                # sort the indices to keep them in order (speeds up IO)
                indices = np.sort(indices)
                #
                #
                #   Now read the remaining data
                #
                sigma = read_np_array(ds.SIGMA.data[indices, pol], "SIGMA")
                # ant1   = read_np_array(ds.ANTENNA1.data[indices], "ANTENNA1")
                # ant12  = read_np_array(ds.ANTENNA1.data[indices], "ANTENNA2")
                cv_vis = read_np_array(ds.DATA.data[indices, channel, pol],
                                       "DATA",
                                       dtype=np.complex64)

                epoch_seconds = np.array(ds.TIME.data)[0]

        if "uvw" not in locals():
            raise RuntimeError("FIELD_ID ({}) is invalid".format(field_id))

        hdr = {
            "CTYPE1": ("RA---SIN", "Right ascension angle cosine"),
            "CRVAL1": np.degrees(phase_dir)[0],
            "CUNIT1": "deg     ",
            "CTYPE2": ("DEC--SIN", "Declination angle cosine "),
            "CRVAL2": np.degrees(phase_dir)[1],
            "CUNIT2": "deg     ",
            "CTYPE3": "FREQ    ",  #           / Central frequency  ",
            "CRPIX3": 1.0,
            "CRVAL3": "{}".format(frequency),
            "CDELT3": 10026896.158854,
            "CUNIT3": "Hz      ",
            "EQUINOX": "2000.",
            "DATE-OBS": "{}".format(epoch_seconds),
            "BTYPE": "Intensity",
        }

        # from astropy.wcs.utils import celestial_frame_to_wcs
        # from astropy.coordinates import FK5
        # frame = FK5(equinox='J2010')
        # wcs = celestial_frame_to_wcs(frame)
        # wcs.to_header()

        u_arr = uvw[indices, 0].T
        v_arr = uvw[indices, 1].T
        w_arr = uvw[indices, 2].T

        rms_arr = sigma.T

        logger.info("Max vis {}".format(np.max(np.abs(cv_vis))))

        # Convert from reduced Julian Date to timestamp.
        timestamp = datetime.datetime(
            1858, 11, 17, 0, 0, 0,
            tzinfo=datetime.timezone.utc) + datetime.timedelta(
                seconds=epoch_seconds)

    except Exception as e:
        logger.info("Exception {}".format(e))

    # finally:
    # client.close()
    # local_cluster.close()

    return u_arr, v_arr, w_arr, frequency, cv_vis, hdr, timestamp, rms_arr
Exemple #8
0
def get_field_names(myms):
    field_tab = xms.xds_from_table(
        myms+'::FIELD', columns=['NAME', 'SOURCE_ID'])
    field_ids = field_tab[0].SOURCE_ID.values
    field_names = field_tab[0].NAME.values
    return field_ids, field_names
Exemple #9
0
    def __init__(self, ms_name, nx, ny, cell_size, nband=None, nthreads=8, do_wstacking=1, Stokes='I',
                 row_chunks=100000, optimise_chunks=True, epsilon=1e-5,
                 data_column='CORRECTED_DATA', weight_column='WEIGHT_SPECTRUM',
                 model_column="MODEL_DATA", flag_column='FLAG', imaging_weight_column=None):
        if Stokes != 'I':
            raise NotImplementedError("Only Stokes I currently supported")
        self.nx = nx
        self.ny = ny
        self.cell = cell_size * np.pi/60/60/180
        self.nthreads = nthreads
        self.do_wstacking = do_wstacking
        self.epsilon = epsilon

        self.data_column = data_column
        self.weight_column = weight_column
        self.model_column = model_column
        self.flag_column = flag_column
        if isinstance(ms_name, list):
            self.ms = ms_name
        else:
            self.ms = [ms_name]

        # first pass through data to determine freq_mapping
        self.radec = None
        self.freq = {}
        all_freqs = []
        for ims in self.ms:
            xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                              chunks={"row":-1},
                              columns=('TIME'))

            # subtables
            ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
            fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
            spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__")
            pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

            # subtable data
            ddids = dask.compute(ddids)[0]
            fields = dask.compute(fields)[0]
            spws = dask.compute(spws)[0]
            pols = dask.compute(pols)[0]
            

            self.freq[ims] = {}
            for ds in xds:
                field = fields[ds.FIELD_ID]
                radec = field.PHASE_DIR.data.squeeze()

                # check fields match
                if self.radec is None:
                    self.radec = radec

                if not np.array_equal(radec, self.radec):
                    continue

                
                spw = spws[ds.DATA_DESC_ID]
                tmp_freq = spw.CHAN_FREQ.data.squeeze()
                self.freq[ims][ds.DATA_DESC_ID] = tmp_freq
                all_freqs.append(list([tmp_freq]))


        # freq mapping
        all_freqs = dask.compute(all_freqs)
        ufreqs = np.unique(all_freqs)  # returns ascending sorted
        self.nchan = ufreqs.size
        if nband is None:
            self.nband = self.nchan
        else:
            self.nband = nband
        
        # bin edges
        fmin = ufreqs[0]
        fmax = ufreqs[-1]
        fbins = np.linspace(fmin, fmax, self.nband+1)
        self.freq_out = np.zeros(self.nband)
        for band in range(self.nband):
            indl = ufreqs >= fbins[band]
            indu = ufreqs < fbins[band + 1] + 1e-6
            self.freq_out[band] = np.mean(ufreqs[indl & indu])
        
        # chan <-> band mapping
        self.band_mapping = {}
        self.chunks = {}
        self.freq_bin_idx = {}
        self.freq_bin_counts = {}
        for ims in self.freq:
            self.freq_bin_idx[ims] = {}
            self.freq_bin_counts[ims] = {}
            self.band_mapping[ims] = {}
            self.chunks[ims] = []
            for spw in self.freq[ims]:
                freq = np.atleast_1d(dask.compute(self.freq[ims][spw])[0])
                band_map = np.zeros(freq.size, dtype=np.int32)
                for band in range(self.nband):
                    indl = freq >= fbins[band]
                    indu = freq < fbins[band + 1] + 1e-6
                    band_map = np.where(indl & indu, band, band_map)
                # to dask arrays
                bands, bin_counts = np.unique(band_map, return_counts=True)
                self.band_mapping[ims][spw] = tuple(bands)
                self.chunks[ims].append({'row':(-1,), 'chan':tuple(bin_counts)})
                self.freq[ims][spw] = da.from_array(freq, chunks=tuple(bin_counts))
                bin_idx = np.append(np.array([0]), np.cumsum(bin_counts))[0:-1]
                self.freq_bin_idx[ims][spw] = da.from_array(bin_idx, chunks=1)
                self.freq_bin_counts[ims][spw] = da.from_array(bin_counts, chunks=1)

        self.imaging_weight_column = imaging_weight_column
        if imaging_weight_column is not None:
            self.columns = (self.data_column, self.weight_column,
                            self.imaging_weight_column, self.flag_column, 'UVW')
        else:
            self.columns = (self.data_column, self.weight_column, self.flag_column, 'UVW')
Exemple #10
0
def both(args):
    """Generate model data, corrupted visibilities and 
    gains (phase-only or normal)"""
        
    # Set thread count to cpu count
    if args.ncpu:
        from multiprocessing.pool import ThreadPool
        import dask
        dask.config.set(pool=ThreadPool(args.ncpu))
    else:
        import multiprocessing
        args.ncpu = multiprocessing.cpu_count()

    # Get full time column and compute row chunks
    ms = xds_from_table(args.ms)[0]
   
    row_chunks, tbin_idx, tbin_counts = chunkify_rows(
        ms.TIME, args.utimes_per_chunk)
    
    # Convert time rows to dask arrays
    tbin_idx = da.from_array(tbin_idx, 
                 chunks=(args.utimes_per_chunk))
    tbin_counts = da.from_array(tbin_counts, 
                    chunks=(args.utimes_per_chunk))

    # Time axis
    n_time = tbin_idx.size

    # Get antenna columns
    ant1 = ms.ANTENNA1.data
    ant2 = ms.ANTENNA2.data

    # No. of antennas axis
    n_ant = (np.maximum(ant1.max(), ant2.max()) + 1).compute()

    # Get flag column
    flag = ms.FLAG.data

    # Get convention
    if args.phase_convention == 'CASA':
        uvw = -(ms.UVW.data.astype(np.float64))
    elif args.phase_convention == 'CODEX':
        uvw = ms.UVW.data.astype(np.float64)
    else:
        raise ValueError("Unknown sign convention for phase")
        
    # Get rest of dimensions
    n_row, n_freq, n_corr = flag.shape

    # Raise error if correlation axis too small
    if n_corr != 4:
        raise NotImplementedError("Only 4 correlations "\
            + "currently supported")

    # Get phase direction
    radec0_table = xds_from_table(args.ms+'::FIELD')[0]
    radec0 = radec0_table.PHASE_DIR.data.squeeze().compute()

    # Get frequency column
    freq_table = xds_from_table(args.ms+'::SPECTRAL_WINDOW')[0]
    freq = freq_table.CHAN_FREQ.data.astype(np.float64)[0]

    # Check dimension
    assert freq.size == n_freq

    # Check for sky-model
    if args.sky_model == 'MODEL-1.txt':
        args.sky_model = MODEL_1
    elif args.sky_model == 'MODEL-4.txt':
        args.sky_model = MODEL_4
    elif args.sky_model == 'MODEL-50.txt':
        args.sky_model = MODEL_50
    else:
        raise NotImplemented(f"Sky-model {args.sky_model} not in "\
            + "kalcal/datasets/sky_model/")

    # Build source model from lsm
    lsm = Tigger.load(args.sky_model)

    # Direction axis
    n_dir = len(lsm.sources)

    # Create initial model array
    model = np.zeros((n_dir, n_freq, n_corr), dtype=np.float64)

    # Create initial coordinate array and source names
    lm = np.zeros((n_dir, 2), dtype=np.float64)
    source_names = []

    # Cycle coordinates creating a source with flux
    for d, source in enumerate(lsm.sources):
        # Extract name
        source_names.append(source.name)

        # Extract position
        radec_s = np.array([[source.pos.ra, source.pos.dec]])
        lm[d] =  radec_to_lm(radec_s, radec0)

        # Get flux - Stokes I
        if source.flux.I:
            I0 = source.flux.I

            # Get spectrum (only spi currently supported)
            tmp_spec = source.spectrum
            spi = [tmp_spec.spi if tmp_spec is not None else 0.0]
            ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0]

            # Generate model flux
            model[d, :, 0] = I0 * (freq/ref_freq)**spi

        # Get flux - Stokes Q
        if source.flux.Q:
            Q0 = source.flux.Q

            # Get spectrum
            tmp_spec = source.spectrum
            spi = [tmp_spec.spi if tmp_spec is not None else 0.0]
            ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0]

            # Generate model flux
            model[d, :, 1] = Q0 * (freq/ref_freq)**spi

        # Get flux - Stokes U
        if source.flux.U:
            U0 = source.flux.U

            # Get spectrum
            tmp_spec = source.spectrum
            spi = [tmp_spec.spi if tmp_spec is not None else 0.0]
            ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0]

            # Generate model flux
            model[d, :, 2] = U0 * (freq/ref_freq)**spi

        # Get flux - Stokes V
        if source.flux.V:
            V0 = source.flux.V

            # Get spectrum
            tmp_spec = source.spectrum
            spi = [tmp_spec.spi if tmp_spec is not None else 0.0]
            ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0]

            # Generate model flux
            model[d, :, 3] = V0 * (freq/ref_freq)**spi

    # Generate gains
    jones = None
    jones_shape = None

    # Dask to NP
    t = tbin_idx.compute()
    nu = freq.compute()

    print('==> Both-mode')
    if args.mode == "phase":
        jones = phase_gains(lm, nu, n_time, n_ant, args.alpha_std)

    elif args.mode == "normal":
        jones = normal_gains(t, nu, lm, n_ant, n_corr, 
                    args.sigma_f, args.lt, args.lnu, args.ls)
    else:
        raise ValueError("Only normal and phase modes available.")
    
    print()
    # Reduce jones to diagonals only
    jones = jones[:, :, :, :, (0, -1)]

    # Jones to complex
    jones = jones.astype(np.complex128)

    # Jones shape
    jones_shape = jones.shape

    # Generate filename
    if args.out == "":
        args.out = f"{args.mode}.npy"

    # Save gains and settings to file
    with open(args.out, 'wb') as file:        
        np.save(file, jones)

    # Build dask graph
    lm = da.from_array(lm, chunks=lm.shape)
    model = da.from_array(model, chunks=model.shape)
    jones_da = da.from_array(jones, chunks=(args.utimes_per_chunk,)
                            + jones_shape[1::])    

    # Append antenna columns
    cols = []
    cols.append('ANTENNA1')
    cols.append('ANTENNA2')
    cols.append('UVW')

    # Load data in in chunks and apply gains to each chunk
    xds = xds_from_ms(args.ms, columns=cols, 
            chunks={"row": row_chunks})[0]
    ant1 = xds.ANTENNA1.data
    ant2 = xds.ANTENNA2.data

    # Adjust UVW based on phase-convention
    if args.phase_convention == 'CASA':
        uvw = -xds.UVW.data.astype(np.float64)       
    elif args.phase_convention == 'CODEX':
        uvw = xds.UVW.data.astype(np.float64)
    else:
        raise ValueError("Unknown sign convention for phase")
    

    # Get model visibilities
    model_vis = np.zeros((n_row, n_freq, n_dir, n_corr), 
                            dtype=np.complex128)    

    for s in range(n_dir):
        model_vis[:, :, s] = im_to_vis(
            model[s].reshape((1, n_freq, n_corr)),
            uvw, 
            lm[s].reshape((1, 2)), 
            freq, 
            dtype=np.complex64, convention='fourier')

    # NP to Dask
    model_vis = da.from_array(model_vis, chunks=(row_chunks, 
                                n_freq, n_dir, n_corr))

    # Convert Stokes to corr
    in_schema = ['I', 'Q', 'U', 'V']
    out_schema = [['RR', 'RL'], ['LR', 'LL']]
    model_vis = convert(model_vis, in_schema, out_schema)

    # Apply gains
    data = corrupt_vis(tbin_idx, tbin_counts, ant1, ant2,
                    jones_da, model_vis).reshape(
                        (n_row, n_freq, n_corr))
    
    # Assign model visibilities
    out_names = []
    for d in range(n_dir):
        xds = xds.assign(**{source_names[d]: 
                (("row", "chan", "corr"), 
                model_vis[:, :, d].reshape(
                    n_row, n_freq, n_corr).astype(np.complex64))})

        out_names += [source_names[d]]

    # Assign noise free visibilities to 'CLEAN_DATA'
    xds = xds.assign(**{'CLEAN_DATA': (("row", "chan", "corr"), 
            data.astype(np.complex64))})

    out_names += ['CLEAN_DATA']
    
    # Get noise realisation
    if args.sigma_n > 0.0:

        # Noise matrix
        noise = (da.random.normal(loc=0.0, scale=args.sigma_n, 
                    size=(n_row, n_freq, n_corr), 
                    chunks=(row_chunks, n_freq, n_corr)) \

                + 1.0j*da.random.normal(loc=0.0, scale=args.sigma_n, 
                    size=(n_row, n_freq, n_corr), 
                    chunks=(row_chunks, n_freq, n_corr)))/np.sqrt(2.0)

        # Zero matrix for off-diagonals
        zero = da.zeros_like(noise[:, :, 0])

        # Dask to NP
        noise = noise.compute()
        zero = zero.compute()

        # Remove noise on off-diagonals
        noise[:, :, 1] = zero[:, :]
        noise[:, :, 2] = zero[:, :]

        # NP to Dask
        noise = da.from_array(noise, chunks=(row_chunks, n_freq, n_corr))
        
        # Assign noise to 'NOISE'
        xds = xds.assign(**{'NOISE': (("row", "chan", "corr"), 
                noise.astype(np.complex64))})

        out_names += ['NOISE']

        # Add noise to data and assign to 'DATA'
        noisy_data = data + noise
        xds = xds.assign(**{'DATA': (("row", "chan", "corr"), 
                noisy_data.astype(np.complex64))})

        out_names += ['DATA']
        
    # Create a write to the table
    write = xds_to_table(xds, args.ms, out_names)

    # Submit all graph computations in parallel
    with ProgressBar():
        write.compute()

    print(f"==> Applied Jones to MS: {args.ms} <--> {args.out}")
Exemple #11
0
def get_chan_freqs(myms):
    spw_tab = xms.xds_from_table(
        myms+'::SPECTRAL_WINDOW', columns=['CHAN_FREQ'])
    chan_freqs = spw_tab[0].CHAN_FREQ
    return chan_freqs
Exemple #12
0
def jones(args):
    """Generate jones matrix only, but based off
    of a measurement set."""
    
    # Set thread count to cpu count
    if args.ncpu:
        from multiprocessing.pool import ThreadPool
        import dask
        dask.config.set(pool=ThreadPool(args.ncpu))
    else:
        import multiprocessing
        args.ncpu = multiprocessing.cpu_count()

    # Get full time column and compute row chunks
    ms = xds_from_table(args.ms)[0]
   
    _, tbin_idx, tbin_counts = chunkify_rows(
        ms.TIME, args.utimes_per_chunk)
    
    # Convert time rows to dask arrays
    tbin_idx = da.from_array(tbin_idx, 
                 chunks=(args.utimes_per_chunk))
    tbin_counts = da.from_array(tbin_counts, 
                    chunks=(args.utimes_per_chunk))

    # Time axis
    n_time = tbin_idx.size

    # Get antenna columns
    ant1 = ms.ANTENNA1.data
    ant2 = ms.ANTENNA2.data

    # No. of antennas axis
    n_ant = (np.maximum(ant1.max(), ant2.max()) + 1).compute()

    # Get flag column
    flag = ms.FLAG.data

    # Get convention
    if args.phase_convention == 'CASA':
        uvw = -(ms.UVW.data.astype(np.float64))
    elif args.phase_convention == 'CODEX':
        uvw = ms.UVW.data.astype(np.float64)
    else:
        raise ValueError("Unknown sign convention for phase")

    # Get rest of dimensions
    n_row, n_freq, n_corr = flag.shape

    # Raise error if correlation axis too small
    if n_corr != 4:
        raise NotImplementedError("Only 4 correlations "\
            + "currently supported")

    # Get phase direction
    radec0_table = xds_from_table(args.ms+'::FIELD')[0]
    radec0 = radec0_table.PHASE_DIR.data.squeeze().compute()
    
    # Get frequency column
    freq_table = xds_from_table(args.ms+'::SPECTRAL_WINDOW')[0]
    freq = freq_table.CHAN_FREQ.data.astype(np.float64)[0]

    # Check dimension
    assert freq.size == n_freq

    # Check for sky-model
    if args.sky_model == 'MODEL-1.txt':
        args.sky_model = MODEL_1
    elif args.sky_model == 'MODEL-4.txt':
        args.sky_model = MODEL_4
    elif args.sky_model == 'MODEL-50.txt':
        args.sky_model = MODEL_50
    else:
        raise ValueError(f"Sky-model {args.sky_model} not in "\
            + "kalcal/datasets/sky_model/")

    # Build source model from lsm
    lsm = Tigger.load(args.sky_model)

    # Direction axis
    n_dir = len(lsm.sources)

    # Create initial coordinate array and source names
    lm = np.zeros((n_dir, 2), dtype=np.float64)

    # Cycle coordinates creating a source with flux
    for d, source in enumerate(lsm.sources):

        # Extract position
        radec_s = np.array([[source.pos.ra, source.pos.dec]])
        lm[d] =  radec_to_lm(radec_s, radec0)

    # Generate gains
    jones = None
    print('==> Jones-only mode')
    if args.mode == "phase":
        jones = phase_gains(lm, freq, n_time, n_ant, args.alpha_std)

    elif args.mode == "normal":
        jones = normal_gains(tbin_idx, freq, lm, n_ant, n_corr, 
                    args.sigma_f, args.lt, args.lnu, args.ls)
    else:
        raise ValueError("Only normal and phase modes available.")

    # Reduce jones to diagonals only
    jones = jones[:, :, :, :, (0, -1)]

    # Jones to complex
    jones = jones.astype(np.complex128)

    # Generate filename
    if args.out == "":
        args.out = f"{args.mode}.npy"
  
    # Save gains and settings to file
    with open(args.out, 'wb') as file:        
        np.save(file, jones)

    print(f"==> Created Jones data: {args.out}")
Exemple #13
0
def new(ms, sky_model, **kwargs):
    """Generate a jones matrix based on a given sky-model
    either as phase-only or normal gains, as an .npy file."""

    # Options to attributed dictionary
    if kwargs["yaml"] is not None:
        options = ocf.load(kwargs["yaml"])
    else:
        options = ocf.create(kwargs)

    # Set to struct
    ocf.set_struct(options, True)

    # Change path to sky model if chosen
    try:
        sky_model = sky_models[sky_model.lower()]
    except:
        # Own sky model reference
        pass

    # Load ms
    MS = xds_from_ms(ms)[0]

    # Get dimensions (correlations need to be adapted)
    dims = ocf.create(dict(MS.sizes))
    n_chan = dims.chan
    n_corr = dims.corr

    # Get time-bin indices and counts
    _, tbin_indices, _ = np.unique(MS.TIME,
                                   return_index=True,
                                   return_counts=True)

    # Set time dimension
    n_time = len(tbin_indices)

    # Get antenna arrays (dask ignored for now)
    ant1 = MS.ANTENNA1.data.compute()
    ant2 = MS.ANTENNA2.data.compute()

    # Set antenna dimension
    n_ant = np.max((np.max(ant1), np.max(ant2))) + 1

    # Build source model from lsm
    lsm = Tigger.load(sky_model)

    # Set direction axis as per source
    n_dir = len(lsm.sources)

    # Get phase direction
    radec0_table = xds_from_table(ms + '::FIELD')[0]
    radec0 = radec0_table.PHASE_DIR.data.squeeze().compute()

    # Get frequency column
    freq_table = xds_from_table(ms + '::SPECTRAL_WINDOW')[0]
    freq = freq_table.CHAN_FREQ.data.astype(np.float64)[0]

    # Check dimension
    assert freq.size == n_chan

    # Create initial coordinate array and source names
    lm = np.zeros((n_dir, 2), dtype=np.float64)

    # Cycle coordinates creating a source with flux
    for d, source in enumerate(lsm.sources):
        # Extract position
        radec_s = np.array([[source.pos.ra, source.pos.dec]])
        lm[d] = radec_to_lm(radec_s, radec0)

    # Direction independent gains
    if options.die:
        lm = np.array(lm[0]).reshape((1, -1))
        n_dir = 1

    # Choose between phase-only or normal
    if options.type == "phase":
        # Run phase-only
        print("==> Simulating `phase-only` gains, with dimensions ("\
            + f"n_time={n_time}, n_ant={n_ant}, n_chan={n_chan}, "\
            + f"n_dir={n_dir}, n_corr={n_corr})")

        jones = phase_gains(lm, freq, n_time, n_ant, n_chan, n_dir, n_corr,
                            options.std)

    elif options.type == "normal":
        # With normal selected, get differentials
        lt, lnu, ls = options.diffs

        # Run normal
        print("==> Simulating `normal` gains, with dimensions ("\
            + f"n_time={n_time}, n_ant={n_ant}, n_chan={n_chan}, "\
            + f"n_dir={n_dir}, n_corr={n_corr})")
        jones = normal_gains(tbin_indices, freq, lm, n_time, n_ant, n_chan,
                             n_dir, n_corr, options.std, lt, lnu, ls)

    # Output to jones to .npy file
    gains_file = (options.type + ".npy") if options.out_file is None\
                    else options.out_file

    with open(gains_file, 'wb') as file:
        np.save(file, jones)
    print(f"==> Completed and gains saved to: {gains_file}")
Exemple #14
0
def read_ms(ms, num_vis, res_arcmin, chunks=50000, channel=0):
    '''
        Use dask-ms to load the necessary data to create a telescope operator
        (will use uvw positions, and antenna positions)
        
        -- res_arcmin: Used to calculate the maximum baselines to consider.
                       We want two pixels per smallest fringe
                       pix_res > fringe / 2
                       
                       u sin(theta) = n (for nth fringe)
                       at small angles: theta = 1/u, or u_max = 1 / theta
                       
                       d sin(theta) = lambda / 2
                       d / lambda = 1 / (2 sin(theta))
                       u_max = lambda / 2sin(theta)
                       
                       
    '''
    with scheduler_context():
        # Create a dataset representing the entire antenna table
        ant_table = '::'.join((ms, 'ANTENNA'))

        for ant_ds in xds_from_table(ant_table):
            #print(ant_ds)
            #print(dask.compute(ant_ds.NAME.data,
            #ant_ds.POSITION.data,
            #ant_ds.DISH_DIAMETER.data))
            ant_p = np.array(ant_ds.POSITION.data)
        logger.info("Antenna Positions {}".format(ant_p.shape))

        # Create a dataset representing the field
        field_table = '::'.join((ms, 'FIELD'))
        for field_ds in xds_from_table(field_table):
            #print(ant_ds)
            #print(dask.compute(ant_ds.NAME.data,
            #ant_ds.POSITION.data,
            #ant_ds.DISH_DIAMETER.data))
            phase_dir = np.array(field_ds.PHASE_DIR.data)[0].flatten()
        logger.info("Phase Dir {}".format(np.degrees(phase_dir)))

        # Create datasets representing each row of the spw table
        spw_table = '::'.join((ms, 'SPECTRAL_WINDOW'))

        for spw_ds in xds_from_table(spw_table, group_cols="__row__"):
            #print(spw_ds)
            #print(spw_ds.NUM_CHAN.values)
            logger.info("CHAN_FREQ.values: {}".format(
                spw_ds.CHAN_FREQ.values.shape))
            frequencies = dask.compute(spw_ds.CHAN_FREQ.values)[0].flatten()
            frequency = frequencies[channel]
            logger.info("Frequencies = {}".format(frequencies))
            logger.info("Frequency = {}".format(frequency))
            logger.info("NUM_CHAN = %f" % np.array(spw_ds.NUM_CHAN.values)[0])

        # Create datasets from a partioning of the MS
        datasets = list(xds_from_ms(ms, chunks={'row': chunks}))

        pol = 0

        for ds in datasets:
            logger.info("DATA shape: {}".format(ds.DATA.data.shape))
            logger.info("UVW shape: {}".format(ds.UVW.data.shape))

            uvw = np.array(ds.UVW.data)  # UVW is stored in meters!
            ant1 = np.array(ds.ANTENNA1.data)
            ant2 = np.array(ds.ANTENNA2.data)
            flags = np.array(ds.FLAG.data)
            cv_vis = np.array(ds.DATA.data)[:, channel, pol]
            epoch_seconds = np.array(ds.TIME.data)[0]

            # Try write the STATE_ID column back
            write = xds_to_table(ds, ms, 'STATE_ID')
            with ProgressBar(), Profiler() as prof:
                write.compute()

            # Profile
            #prof.visualize(file_path="chunked.html")

        ### NOW REMOVE DATA THAT DOESN'T FIT THE IMAGE RESOLUTION

        u_max = get_resolution_max_baseline(res_arcmin, frequency)

        logger.info("Resolution Max UVW: {:g}".format(u_max))
        logger.info("Flags: {}".format(flags.shape))

        # Now report the recommended resolution from the data.
        # 1.0 / 2*np.sin(theta) = limit_u
        limit_uvw = np.max(np.abs(uvw), 0)
        res_limit = get_baseline_resolution(limit_uvw[0], frequency)
        logger.info("Nyquist resolution: {:g} arcmin".format(
            np.degrees(res_limit) * 60.0))

        #maxuvw = np.max(np.abs(uvw), 1)
        #logger.info(np.random.choice(maxuvw, 100))

        if False:
            good_data = np.array(np.where(flags[:, channel,
                                                pol] == 0)).T.reshape((-1, ))
        else:
            good_data = np.array(
                np.where((flags[:, channel, pol] == 0)
                         & (np.max(np.abs(uvw), 1) < u_max))).T.reshape((-1, ))
        logger.info("Good Data {}".format(good_data.shape))

        logger.info("Maximum UVW: {}".format(limit_uvw))
        logger.info("Minimum UVW: {}".format(np.min(np.abs(uvw), 0)))

        n_ant = len(ant_p)

        good_vis = cv_vis[good_data]

        n_max = len(good_vis)

        indices = np.random.choice(good_data, min(num_vis, n_max))

        hdr = {
            'CTYPE1': ('RA---SIN', "Right ascension angle cosine"),
            'CRVAL1': np.degrees(phase_dir)[0],
            'CUNIT1': 'deg     ',
            'CTYPE2': ('DEC--SIN', "Declination angle cosine "),
            'CRVAL2': np.degrees(phase_dir)[1],
            'CUNIT2': 'deg     ',
            'CTYPE3': 'FREQ    ',  #           / Central frequency  ",
            'CRPIX3': 1.,
            'CRVAL3': "{}".format(frequency),
            'CDELT3': 10026896.158854,
            'CUNIT3': 'Hz      ',
            'EQUINOX': '2000.',
            'DATE-OBS': "{}".format(epoch_seconds),
            'BTYPE': 'Intensity'
        }

        #from astropy.wcs.utils import celestial_frame_to_wcs
        #from astropy.coordinates import FK5
        #frame = FK5(equinox='J2010')
        #wcs = celestial_frame_to_wcs(frame)
        #wcs.to_header()

        u_arr = uvw[indices, 0]
        v_arr = uvw[indices, 1]
        w_arr = uvw[indices, 2]

        cv_vis = cv_vis[indices]

        # Convert from reduced Julian Date to timestamp.
        timestamp = datetime.datetime(
            1858, 11, 17, 0, 0, 0,
            tzinfo=datetime.timezone.utc) + datetime.timedelta(
                seconds=epoch_seconds)

        return u_arr, v_arr, w_arr, frequency, cv_vis, hdr, timestamp
Exemple #15
0
def main(args):
    if args.precision > 1e-6:
        real_type = np.float32
        complex_type = np.complex64
    else:
        real_type = np.float64
        complex_type = np.complex128

    # get max uv coords over all fields
    uvw = []
    xds = xds_from_table(args.table_name,
                         group_cols=('FIELD_ID'),
                         columns=('UVW'),
                         chunks={'row': -1})
    for ds in xds:
        uvw.append(ds.UVW.data.compute())
    uvw = np.concatenate(uvw)
    from africanus.constants import c as lightspeed
    u_max = np.abs(uvw[:, 0]).max()
    v_max = np.abs(uvw[:, 1]).max()
    # del uvw

    # get Nyquist cell size
    freq = xds_from_table(args.table_name +
                          "::FREQ")[0].FREQ.data.compute().squeeze()
    uv_max = np.maximum(u_max, v_max)
    cell_N = 1.0 / (2 * uv_max * freq.max() / lightspeed)

    if args.cell_size is not None:
        cell_rad = args.cell_size * np.pi / 60 / 60 / 180
        print("Super resolution factor = ", cell_N / cell_rad)
    else:
        cell_rad = cell_N / args.super_resolution_factor
        args.cell_size = cell_rad * 60 * 60 * 180 / np.pi
        print("Cell size set to %5.5e arcseconds" % args.cell_size)

    if args.nx is None:
        fov = args.fov * 3600
        nx = int(fov / args.cell_size)
        from scipy.fftpack import next_fast_len
        args.nx = next_fast_len(nx)

    if args.ny is None:
        fov = args.fov * 3600
        ny = int(fov / args.cell_size)
        from scipy.fftpack import next_fast_len
        args.ny = next_fast_len(ny)

    if args.channels_out is None:
        args.channels_out = freq.size

    print("Image size set to (%i, %i, %i)" %
          (args.channels_out, args.nx, args.ny))

    # init gridder
    R = OutMemGridder(args.table_name,
                      args.nx,
                      args.ny,
                      args.cell_size,
                      freq,
                      nband=args.channels_out,
                      field=args.field,
                      precision=args.precision,
                      ncpu=args.ncpu,
                      do_wstacking=args.do_wstacking,
                      data_column=args.data_column,
                      weight_column=args.weight_column)
    freq_out = R.freq_out

    # get headers
    radec = xds_from_table(args.table_name +
                           "::RADEC")[0].RADEC.data.compute().squeeze()
    hdr = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx,
                  args.ny, radec, freq_out)
    hdr_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx,
                      args.ny, radec, np.mean(freq_out))
    hdr_psf = set_wcs(args.cell_size / 3600, args.cell_size / 3600,
                      2 * args.nx, 2 * args.ny, radec, freq_out)
    hdr_psf_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600,
                          2 * args.nx, 2 * args.ny, radec, np.mean(freq_out))

    # make psf  LB - TODO: undersized psfs
    psf = R.make_psf()
    nband = R.nband
    psf_max = np.amax(psf.reshape(nband, 4 * args.nx * args.ny), axis=1)

    # make dirty
    dirty = R.make_dirty()

    # save dirty and psf images
    save_fits(args.outfile + '_dirty.fits', dirty, hdr, dtype=real_type)
    save_fits(args.outfile + '_psf.fits', psf, hdr_psf, dtype=real_type)

    # MFS images
    wsum = np.sum(psf_max)
    dirty_mfs = np.sum(dirty, axis=0) / wsum
    save_fits(args.outfile + '_dirty_mfs.fits', dirty_mfs, hdr_mfs)

    psf_mfs = np.sum(psf, axis=0) / wsum
    save_fits(args.outfile + '_psf_mfs.fits', psf_mfs, hdr_psf_mfs)

    rmax = np.abs(dirty_mfs).max()
    rms = np.std(dirty_mfs)
    print("Peak of dirty is %f and rms is %f" % (rmax, rms))
Exemple #16
0
    def make_psf(self):
        print("Making PSF")
        psfs = []
        for ims in self.ms:
            xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                              chunks=self.chunks[ims],
                              columns=self.columns)

            # subtables
            ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
            fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
            spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__")
            pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

            # subtable data
            ddids = dask.compute(ddids)[0]
            fields = dask.compute(fields)[0]
            spws = dask.compute(spws)[0]
            pols = dask.compute(pols)[0]

            for ds in xds:
                field = fields[ds.FIELD_ID]
                radec = field.PHASE_DIR.data.squeeze()
                if not np.array_equal(radec, self.radec):
                    continue

                spw = ds.DATA_DESC_ID  # this is not correct, need to use spw
                
                freq_bin_idx = self.freq_bin_idx[ims][spw]
                freq_bin_counts = self.freq_bin_counts[ims][spw]
                freq = self.freq[ims][spw]

                uvw = ds.UVW.data

                flag = getattr(ds, self.flag_column).data

                weights = getattr(ds, self.weight_column).data
                if len(weights.shape) < 3:
                    weights = da.broadcast_to(weights[:, None, :], flag.shape, chunks=flag.chunks)
                
                if self.imaging_weight_column is not None:
                    imaging_weights = getattr(ds, self.imaging_weight_column).data
                    if len(imaging_weights.shape) < 3:
                        imaging_weights = da.broadcast_to(imaging_weights[:, None, :], data.shape, chunks=data.chunks)

                    weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0]
                    weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1]
                else:
                    weightsxx = weights[:, :, 0]
                    weightsyy = weights[:, :, -1]

                # weighted sum corr to Stokes I
                weights = weightsxx + weightsyy
                data = weights.astype(np.complex64)
                
                # only keep data where both corrs are unflagged
                flagxx = flag[:, :, 0]
                flagyy = flag[:, :, -1]
                flag = ~ (flagxx | flagyy)  # ducc0 convention

                psf = vis2im(uvw, freq, data, freq_bin_idx, freq_bin_counts,
                             2*self.nx, 2*self.ny, self.cell, flag=flag.astype(np.uint8),
                             nthreads=self.nthreads, epsilon=self.epsilon, do_wstacking=self.do_wstacking)

                psfs.append(psf)

        psfs = dask.compute(psfs)[0]
                
        return accumulate_dirty(psfs, self.nband, self.band_mapping).astype(np.float64)
Exemple #17
0
def main(args):
    """
    Flags outliers in data given a model and rescale weights so that whitened residuals have a
    mean amplitude of sqrt(2). 
    
    Flags and weights are computed per chunk of data
    """
    radec_ref = None
    writes = []
    for ims in args.ms:
        xds = xds_from_ms(ims,
                          group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                          chunks={
                              "row": args.row_chunks,
                              "chan": args.chan_chunks
                          },
                          columns=('UVW', args.data_column, args.weight_column,
                                   args.model_column, args.flag_column,
                                   'FLAG_ROW'))

        # subtables
        ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
        fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
        spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__")
        pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

        # subtable data
        ddids = dask.compute(ddids)[0]
        fields = dask.compute(fields)[0]
        spws = dask.compute(spws)[0]
        pols = dask.compute(pols)[0]

        out_data = []
        for ds in xds:
            field = fields[ds.FIELD_ID]
            radec = field.PHASE_DIR.data.squeeze()

            # check fields match
            if radec_ref is None:
                radec_ref = radec

            if not np.array_equal(radec, radec_ref):
                continue

            # load in data and compute whitened residuals
            data = getattr(ds, args.data_column).data
            model = getattr(ds, args.model_column).data
            flag = getattr(ds, args.flag_column).data
            flag = da.logical_or(flag, ds.FLAG_ROW.data[:, None, None])
            weights = getattr(ds, args.weight_column).data
            if len(weights.shape) < 3:
                weights = da.broadcast_to(weights[:, None, :],
                                          data.shape,
                                          chunks=data.chunks)

            if args.trim_channels:
                flag = trim_chans(flag, args.trim_channels)

            # Stokes I vis
            weights = (~flag) * weights
            resid_vis = (data - model) * weights
            wsums = (weights[:, :, 0] + weights[:, :, -1])
            resid_vis_I = da.where(
                wsums, (resid_vis[:, :, 0] + resid_vis[:, :, -1]) / wsums,
                0.0j)

            # whiten and take abs
            white_resid = resid_vis_I * da.sqrt(wsums)
            abs_resid_vis_I = (white_resid).__abs__()

            # mean amp
            sum_amp = da.sum(abs_resid_vis_I)
            count = da.sum(wsums > 0)
            mean_amp = sum_amp / count

            flag_legacy = flag[:, :, 0] | flag[:, :, -1]
            flag_I = da.logical_or(abs_resid_vis_I > args.sigma_cut * mean_amp,
                                   flag_legacy)

            # new flags
            updated_flag = da.broadcast_to(flag_I[:, :, None],
                                           flag.shape,
                                           chunks=flag.chunks)

            # scale weights (whitened residuals should have mean amplitude of 1/sqrt(2))
            if args.scale_weights:
                # recompute mean amp with new flags
                weights = (~updated_flag) * weights
                resid_vis = (data - model) * weights
                wsums = (weights[:, :, 0] + weights[:, :, -1])
                resid_vis_I = da.where(
                    wsums, (resid_vis[:, :, 0] + resid_vis[:, :, -1]) / wsums,
                    0.0j)
                white_resid = resid_vis_I * da.sqrt(wsums)
                abs_resid_vis_I = (white_resid).__abs__()
                sum_amp = da.sum(abs_resid_vis_I)
                count = da.sum(wsums > 0)
                mean_amp = sum_amp / count
                updated_weight = 2**0.5 * weights / mean_amp**2
            else:
                updated_weight = weights

            ds = ds.assign(**{
                args.weight_out_column: (("row", "chan", "corr"),
                                         updated_weight)
            })
            ds = ds.assign(**{
                args.flag_out_column: (("row", "chan", "corr"), updated_flag)
            })

            out_data.append(ds)
        writes.append(
            xds_to_table(
                out_data,
                ims,
                columns=[args.flag_out_column, args.weight_out_column]))

    with ProgressBar():
        dask.compute(writes)

    # report new mean amp
    if args.report_means:
        radec_ref = None
        mean_amps = []
        for ims in args.ms:
            xds = xds_from_ms(
                ims,
                group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                chunks={
                    "row": args.row_chunks,
                    "chan": args.chan_chunks
                },
                columns=('UVW', args.data_column, args.weight_out_column,
                         args.model_column, args.flag_out_column, 'FLAG_ROW'))

            # subtables
            ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
            fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
            spws = xds_from_table(ims + "::SPECTRAL_WINDOW",
                                  group_cols="__row__")
            pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

            # subtable data
            ddids = dask.compute(ddids)[0]
            fields = dask.compute(fields)[0]
            spws = dask.compute(spws)[0]
            pols = dask.compute(pols)[0]

            for ds in xds:
                field = fields[ds.FIELD_ID]
                radec = field.PHASE_DIR.data.squeeze()

                # check fields match
                if radec_ref is None:
                    radec_ref = radec

                if not np.array_equal(radec, radec_ref):
                    continue

                # load in data and compute whitened residuals
                data = getattr(ds, args.data_column).data
                model = getattr(ds, args.model_column).data
                flag = getattr(ds, args.flag_out_column).data
                flag = da.logical_or(flag, ds.FLAG_ROW.data[:, None, None])
                weights = getattr(ds, args.weight_out_column).data
                if len(weights.shape) < 3:
                    weights = da.broadcast_to(weights[:, None, :],
                                              data.shape,
                                              chunks=data.chunks)

                # Stokes I vis
                weights = (~flag) * weights
                resid_vis = (data - model) * weights
                wsums = (weights[:, :, 0] + weights[:, :, -1])
                resid_vis_I = da.where(
                    wsums, (resid_vis[:, :, 0] + resid_vis[:, :, -1]) / wsums,
                    0.0j)

                # whiten and take abs
                white_resid = resid_vis_I * da.sqrt(wsums)
                abs_resid_vis_I = (white_resid).__abs__()

                # mean amp
                sum_amp = da.sum(abs_resid_vis_I)
                count = da.sum(wsums > 0)
                mean_amps.append(sum_amp / count)

        mean_amps = dask.compute(mean_amps)[0]

        print(mean_amps)
Exemple #18
0
MAX([SELECT ABS(UVW[1]) FROM {ms}]) as ABS_VMAX,
MIN([SELECT UVW[2] FROM {ms}]) AS WMIN,
MAX([SELECT UVW[2] FROM {ms}]) AS WMAX
""".format(ms=args.ms)

with pt.taql(query) as Q:
    umin = Q.getcol("ABS_UMIN").item()
    umax = Q.getcol("ABS_UMAX").item()
    vmin = Q.getcol("ABS_VMIN").item()
    vmax = Q.getcol("ABS_VMAX").item()
    wmin = Q.getcol("WMIN").item()
    wmax = Q.getcol("WMAX").item()

xds = list(xds_from_ms(args.ms, chunks={"row": args.chunks}))[0]
spw_ds = list(
    xds_from_table("::".join((args.ms, "SPECTRAL_WINDOW")),
                   group_cols="__row__"))[0]
wavelength = (lightspeed / spw_ds.CHAN_FREQ.data[0]).compute()

if args.cell_size:
    cell_size = args.cell_size
else:
    cell_size = estimate_cell_size(umax,
                                   vmax,
                                   wavelength,
                                   factor=3,
                                   ny=args.npix,
                                   nx=args.npix).max()

# Convolution Filter
conv_filter = convolution_filter(3, 63, "kaiser-bessel")
Exemple #19
0
def new(ms, sky_model, gains, **kwargs):
    """Generate model visibilties per source (as direction axis)
    for stokes I and Q and generate relevant visibilities."""

    # Options to attributed dictionary
    if kwargs["yaml"] is not None:
        options = ocf.load(kwargs["yaml"])
    else:
        options = ocf.create(kwargs)

    # Set to struct
    ocf.set_struct(options, True)

    # Change path to sky model if chosen
    try:
        sky_model = sky_models[sky_model.lower()]
    except:
        # Own sky model reference
        pass

    # Set thread count to cpu count
    if options.ncpu:
        from multiprocessing.pool import ThreadPool
        import dask
        dask.config.set(pool=ThreadPool(options.ncpu))
    else:
        import multiprocessing
        options.ncpu = multiprocessing.cpu_count()

    # Load gains to corrupt with
    with open(gains, "rb") as file:
        jones = np.load(file)

    # Load dimensions
    n_time, n_ant, n_chan, n_dir, n_corr = jones.shape
    n_row = n_time * (n_ant * (n_ant - 1) // 2)

    # Load ms
    MS = xds_from_ms(ms)[0]

    # Get time-bin indices and counts
    row_chunks, tbin_indices, tbin_counts = chunkify_rows(
        MS.TIME, options.utime)

    # Close and reopen with chunked rows
    MS.close()
    MS = xds_from_ms(ms, chunks={"row": row_chunks})[0]

    # Get antenna arrays (dask ignored for now)
    ant1 = MS.ANTENNA1.data
    ant2 = MS.ANTENNA2.data

    # Adjust UVW based on phase-convention
    if options.phase_convention.upper() == 'CASA':
        uvw = -MS.UVW.data.astype(np.float64)
    elif options.phase_convention.upper() == 'CODEX':
        uvw = MS.UVW.data.astype(np.float64)
    else:
        raise ValueError("Unknown sign convention for phase.")

    # MS dimensions
    dims = ocf.create(dict(MS.sizes))

    # Close MS
    MS.close()

    # Build source model from lsm
    lsm = Tigger.load(sky_model)

    # Check if dimensions match jones
    assert n_time * (n_ant * (n_ant - 1) // 2) == dims.row
    assert n_time == len(tbin_indices)
    assert n_ant == np.max((np.max(ant1), np.max(ant2))) + 1
    assert n_chan == dims.chan
    assert n_corr == dims.corr

    # If gains are DIE
    if options.die:
        assert n_dir == 1
        n_dir = len(lsm.sources)
    else:
        assert n_dir == len(lsm.sources)

    # Get phase direction
    radec0_table = xds_from_table(ms + '::FIELD')[0]
    radec0 = radec0_table.PHASE_DIR.data.squeeze().compute()
    radec0_table.close()

    # Get frequency column
    freq_table = xds_from_table(ms + '::SPECTRAL_WINDOW')[0]
    freq = freq_table.CHAN_FREQ.data.astype(np.float64)[0]
    freq_table.close()

    # Get feed orientation
    feed_table = xds_from_table(ms + '::FEED')[0]
    feeds = feed_table.POLARIZATION_TYPE.data[0].compute()

    # Create initial model array
    model = np.zeros((n_dir, n_chan, n_corr), dtype=np.float64)

    # Create initial coordinate array and source names
    lm = np.zeros((n_dir, 2), dtype=np.float64)
    source_names = []

    # Cycle coordinates creating a source with flux
    print("==> Building model visibilities")
    for d, source in enumerate(lsm.sources):
        # Extract name
        source_names.append(source.name)

        # Extract position
        radec_s = np.array([[source.pos.ra, source.pos.dec]])
        lm[d] = radec_to_lm(radec_s, radec0)

        # Get flux - Stokes I
        if source.flux.I:
            I0 = source.flux.I

            # Get spectrum (only spi currently supported)
            tmp_spec = source.spectrum
            spi = [tmp_spec.spi if tmp_spec is not None else 0.0]
            ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0]

            # Generate model flux
            model[d, :, 0] = I0 * (freq / ref_freq)**spi

        # Get flux - Stokes Q
        if source.flux.Q:
            Q0 = source.flux.Q

            # Get spectrum
            tmp_spec = source.spectrum
            spi = [tmp_spec.spi if tmp_spec is not None else 0.0]
            ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0]

            # Generate model flux
            model[d, :, 1] = Q0 * (freq / ref_freq)**spi

        # Get flux - Stokes U
        if source.flux.U:
            U0 = source.flux.U

            # Get spectrum
            tmp_spec = source.spectrum
            spi = [tmp_spec.spi if tmp_spec is not None else 0.0]
            ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0]

            # Generate model flux
            model[d, :, 2] = U0 * (freq / ref_freq)**spi

        # Get flux - Stokes V
        if source.flux.V:
            V0 = source.flux.V

            # Get spectrum
            tmp_spec = source.spectrum
            spi = [tmp_spec.spi if tmp_spec is not None else 0.0]
            ref_freq = [tmp_spec.freq0 if tmp_spec is not None else 1.0]

            # Generate model flux
            model[d, :, 3] = V0 * (freq / ref_freq)**spi

    # Close sky-model
    del lsm

    # Build dask graph
    tbin_indices = da.from_array(tbin_indices, chunks=(options.utime))
    tbin_counts = da.from_array(tbin_counts, chunks=(options.utime))
    lm = da.from_array(lm, chunks=lm.shape)
    model = da.from_array(model, chunks=model.shape)
    jones = da.from_array(jones, chunks=(options.utime, ) + jones.shape[1::])

    # Apply image to visibility for each source
    sources = []
    for s in range(n_dir):
        source_vis = im_to_vis(model[s].reshape((1, n_chan, n_corr)),
                               uvw,
                               lm[s].reshape((1, 2)),
                               freq,
                               dtype=np.complex64,
                               convention='fourier')

        sources.append(source_vis)
    model_vis = da.stack(sources, axis=2)

    # Sum over direction?
    if options.die:
        model_vis = da.sum(model_vis, axis=2, keepdims=True)
        n_dir = 1
        source_names = [options.mname]

    # Select schema based on feed orientation
    if (feeds == ["X", "Y"]).all():
        out_schema = [["XX", "XY"], ["YX", "YY"]]
    elif (feeds == ["R", "L"]).all():
        out_schema = [['RR', 'RL'], ['LR', 'LL']]
    else:
        raise ValueError("Unknown feed orientation implementation.")

    # Convert Stokes to Correlations
    in_schema = ['I', 'Q', 'U', 'V']
    model_vis = convert(model_vis, in_schema, out_schema).reshape(
        (n_row, n_chan, n_dir, n_corr))

    # Apply gains to model_vis
    print("==> Corrupting visibilities")

    data = corrupt_vis(tbin_indices, tbin_counts, ant1, ant2, jones, model_vis)

    # Reopen MS
    MS = xds_from_ms(ms, chunks={"row": row_chunks})[0]

    # Assign model visibilities
    out_names = []
    for d in range(n_dir):
        MS = MS.assign(
            **{
                source_names[d]: (("row", "chan", "corr"),
                                  model_vis[:, :, d].astype(np.complex64))
            })

        out_names += [source_names[d]]

    # Assign noise free visibilities to 'CLEAN_DATA'
    MS = MS.assign(
        **{
            'CLEAN_' + options.dname: (("row", "chan", "corr"),
                                       data.astype(np.complex64))
        })

    out_names += ['CLEAN_' + options.dname]

    # Get noise realisation
    if options.std > 0.0:

        # Noise matrix
        print(f"==> Applying noise (std={options.std}) to visibilities")
        noise = []
        for i in range(2):
            real = da.random.normal(loc=0.0,
                                    scale=options.std,
                                    size=(n_row, n_chan),
                                    chunks=(row_chunks, n_chan))
            imag = 1.0j * (da.random.normal(loc=0.0,
                                            scale=options.std,
                                            size=(n_row, n_chan),
                                            chunks=(row_chunks, n_chan)))
            noise.append(real + imag)

        # Zero matrix for off-diagonals
        zero = da.zeros((n_row, n_chan), chunks=(row_chunks, n_chan))

        noise.insert(1, zero)
        noise.insert(2, zero)

        # NP to Dask
        noise = da.stack(noise, axis=2).rechunk((row_chunks, n_chan, n_corr))

        # Assign noise to 'NOISE'
        MS = MS.assign(
            **{'NOISE': (("row", "chan", "corr"), noise.astype(np.complex64))})

        out_names += ['NOISE']

        # Add noise to data and assign to 'DATA'
        noisy_data = data + noise

        MS = MS.assign(
            **{
                options.dname: (("row", "chan", "corr"),
                                noisy_data.astype(np.complex64))
            })

        out_names += [options.dname]

    # Create a write to the table
    write = xds_to_table(MS, ms, out_names)

    # Submit all graph computations in parallel
    print(f"==> Executing `dask-ms` write to `{ms}` for the following columns: "\
            + f"{', '.join(out_names)}")

    with ProgressBar():
        write.compute()

    print(f"==> Completed.")
Exemple #20
0
def main(args):
    # get max uv coords over all fields
    uvw = []
    u_max = 0.0
    v_max = 0.0
    all_freqs = []
    for ims in args.ms:
        xds = xds_from_ms(ims,
                          group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                          columns=('UVW'),
                          chunks={'row': args.row_chunks})

        spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__")
        spws = dask.compute(spws)[0]

        for ds in xds:
            uvw = ds.UVW.data
            u_max = da.maximum(u_max, abs(uvw[:, 0]).max())
            v_max = da.maximum(v_max, abs(uvw[:, 1]).max())
            uv_max = da.maximum(u_max, v_max)

            spw = spws[ds.DATA_DESC_ID]
            tmp_freq = spw.CHAN_FREQ.data.squeeze()
            all_freqs.append(list([tmp_freq]))

    uv_max = u_max.compute()
    del uvw

    # get Nyquist cell size
    from africanus.constants import c as lightspeed
    all_freqs = dask.compute(all_freqs)
    freq = np.unique(all_freqs)
    cell_N = 1.0 / (2 * uv_max * freq.max() / lightspeed)

    if args.cell_size is not None:
        cell_rad = args.cell_size * np.pi / 60 / 60 / 180
        print("Super resolution factor = ", cell_N / cell_rad)
    else:
        cell_rad = cell_N / args.super_resolution_factor
        args.cell_size = cell_rad * 60 * 60 * 180 / np.pi
        print("Cell size set to %5.5e arcseconds" % args.cell_size)

    if args.nx is None or args.ny is None:
        fov = args.fov * 3600
        npix = int(fov / args.cell_size)
        if npix % 2:
            npix += 1
        args.nx = npix
        args.ny = npix

    if args.nband is None:
        args.nband = freq.size

    print("Image size set to (%i, %i, %i)" % (args.nband, args.nx, args.ny))

    # init gridder
    R = Gridder(args.ms,
                args.nx,
                args.ny,
                args.cell_size,
                nband=args.nband,
                nthreads=args.nthreads,
                do_wstacking=args.do_wstacking,
                row_chunks=args.row_chunks,
                data_column=args.data_column,
                weight_column=args.weight_column,
                epsilon=args.epsilon,
                imaging_weight_column=args.imaging_weight_column,
                model_column=args.model_column,
                flag_column=args.flag_column)
    freq_out = R.freq_out
    radec = R.radec

    # get headers
    hdr = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx,
                  args.ny, radec, freq_out)
    hdr_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx,
                      args.ny, radec, np.mean(freq_out))
    hdr_psf = set_wcs(args.cell_size / 3600, args.cell_size / 3600,
                      2 * args.nx, 2 * args.ny, radec, freq_out)
    hdr_psf_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600,
                          2 * args.nx, 2 * args.ny, radec, np.mean(freq_out))

    # psf
    if args.psf is not None:
        try:
            compare_headers(hdr_psf, fits.getheader(args.psf))
            psf_array = load_fits(args.psf)
        except:
            psf_array = R.make_psf()
            save_fits(args.outfile + '_psf.fits', psf_array, hdr_psf)
    else:
        psf_array = R.make_psf()
        save_fits(args.outfile + '_psf.fits', psf_array, hdr_psf)

    psf_max = np.amax(psf_array.reshape(args.nband, 4 * args.nx * args.ny),
                      axis=1)
    wsum = np.sum(psf_max)
    counts = np.sum(psf_max > 0)
    psf_max_mean = wsum / counts  # normalissation for more intuitive sig_21 values
    psf_array /= psf_max_mean
    psf = PSF(psf_array, args.nthreads)
    psf_max = np.amax(psf_array.reshape(args.nband, 4 * args.nx * args.ny),
                      axis=1)
    wsum = np.sum(psf_max)
    psf_max[psf_max < 1e-15] = 1e-15  # LB - is this the right thing to do?

    psf_mfs = np.sum(psf_array, axis=0) / wsum
    save_fits(
        args.outfile + '_psf_mfs.fits', psf_mfs[args.nx // 2:3 * args.nx // 2,
                                                args.ny // 2:3 * args.ny // 2],
        hdr_mfs)

    # dirty
    if args.dirty is not None:
        try:
            compare_headers(hdr, fits.getheader(args.dirty))
            dirty = load_fits(args.dirty)
        except:
            dirty = R.make_dirty()
            save_fits(args.outfile + '_dirty.fits', dirty, hdr)
    else:
        dirty = R.make_dirty()
        save_fits(args.outfile + '_dirty.fits', dirty, hdr)

    dirty_mfs = np.sum(dirty / psf_max_mean, axis=0) / wsum
    save_fits(args.outfile + '_dirty_mfs.fits', dirty_mfs, hdr_mfs)

    if args.x0 is not None:
        try:
            compare_headers(hdr, fits.getheader(args.x0))
            model = load_fits(args.x0, dtype=np.float64)
            if args.first_residual is not None:
                try:
                    compare_headers(hdr, fits.getheader(args.first_residual))
                    residual = load_fits(args.first_residual, dtype=np.float64)
                except:
                    residual = R.make_residual(model)
                    save_fits(args.outfile + '_first_residual.fits', residual,
                              hdr)
            else:
                residual = R.make_residual(model)
                save_fits(args.outfile + '_first_residual.fits', residual, hdr)
        except:
            model = np.zeros((args.nband, args.nx, args.ny))
            residual = dirty.copy()
    else:
        model = np.zeros((args.nband, args.nx, args.ny))
        residual = dirty.copy()

    # normalise for more intuitive hypers
    residual /= psf_max_mean
    residual_mfs = np.sum(residual, axis=0) / wsum
    save_fits(args.outfile + '_first_residual_mfs.fits', residual_mfs, hdr_mfs)

    # mask
    if args.mask is not None:
        mask = load_fits(args.mask, dtype=np.int64)[None, :, :]
        if mask.shape != (1, args.nx, args.ny):
            raise ValueError("Mask has incorrect shape")
    else:
        mask = np.ones((1, args.nx, args.ny), dtype=np.int64)

    #  preconditioning matrix
    def hess(x):
        return mask * psf.convolve(mask * x) + x / args.sig_l2**2

    if args.beta is None:
        print("Getting spectral norm of update operator")
        beta = power_method(hess,
                            dirty.shape,
                            tol=args.pmtol,
                            maxit=args.pmmaxit)
    else:
        beta = args.beta
    print(" beta = %f " % beta)

    # set up wavelet basis
    if args.psi_basis is None:
        print("Using Dirac + db1-4 dictionary")
        psi = DaskPSI(args.nband,
                      args.nx,
                      args.ny,
                      nlevels=args.psi_levels,
                      nthreads=args.nthreads)
        # psi = PSI(args.nband, args.nx, args.ny, nlevels=args.psi_levels)
    else:
        if not isinstance(args.psi_basis, list):
            args.psi_basis = list(args.psi_basis)
        print("Using ", args.psi_basis, " dictionary")
        psi = DaskPSI(args.nband,
                      args.nx,
                      args.ny,
                      nlevels=args.psi_levels,
                      nthreads=args.nthreads,
                      bases=args.psi_basis)
        # psi = PSI(args.nband, args.nx, args.ny, nlevels=args.psi_levels, bases=args.psi_basis)
    nbasis = psi.nbasis
    weights_21 = np.ones((psi.nbasis, psi.nmax), dtype=np.float64)
    dual = np.zeros((psi.nbasis, args.nband, psi.nmax), dtype=np.float64)

    # Reweighting
    if args.reweight_iters is not None:
        if not isinstance(args.reweight_iters, list):
            reweight_iters = [args.reweight_iters]
        else:
            reweight_iters = list(args.reweight_iters)
    else:
        reweight_iters = list(
            np.arange(args.reweight_start, args.reweight_end,
                      args.reweight_freq))
        reweight_iters.append(args.reweight_end)

    # Reporting
    report_iters = list(np.arange(0, args.maxit, args.report_freq))
    if report_iters[-1] != args.maxit - 1:
        report_iters.append(args.maxit - 1)

    # deconvolve
    eps = 1.0
    i = 0
    rmax = np.abs(residual_mfs).max()
    rms = np.std(residual_mfs)
    M = lambda x: x * args.sig_l2**2  # preconditioner
    print("Peak of initial residual is %f and rms is %f" % (rmax, rms))
    for i in range(1, args.maxit):
        x = pcg(hess,
                mask * residual,
                np.zeros(dirty.shape, dtype=np.float64),
                M=M,
                tol=args.cgtol,
                maxit=args.cgmaxit,
                minit=args.cgminit,
                verbosity=args.cgverbose)

        if i in report_iters:
            save_fits(args.outfile + str(i) + '_update.fits', x, hdr)

        # update model
        modelp = model
        model = modelp + args.gamma * x
        model, dual = primal_dual(hess,
                                  model,
                                  modelp,
                                  dual,
                                  args.sig_21,
                                  psi,
                                  weights_21,
                                  beta,
                                  tol=args.pdtol,
                                  maxit=args.pdmaxit,
                                  report_freq=100,
                                  mask=mask,
                                  positivity=args.positivity)

        # reweighting
        if i in reweight_iters:
            v = psi.hdot(model)
            l2_norm = norm(v, axis=1)
            l2_norm = np.where(l2_norm < args.sig_21 * weights_21, 0.0,
                               l2_norm)
            for m in range(psi.nbasis):
                indnz = l2_norm[m].nonzero()
                alpha = np.percentile(l2_norm[m, indnz].flatten(),
                                      args.reweight_alpha_percent)
                alpha = np.maximum(alpha, args.reweight_alpha_min)
                print("Reweighting - ", m, alpha)
                weights_21[m] = alpha / (l2_norm[m] + alpha)
            args.reweight_alpha_percent *= args.reweight_alpha_ff
            # print(" reweight alpha percent = ", args.reweight_alpha_percent)

        # get residual
        residual = R.make_residual(model) / psf_max_mean

        # check stopping criteria
        residual_mfs = np.sum(residual, axis=0) / wsum
        rmax = np.abs(residual_mfs).max()
        rms = np.std(residual_mfs)
        eps = np.linalg.norm(model - modelp) / np.linalg.norm(model)

        if i in report_iters:
            # save current iteration
            save_fits(args.outfile + str(i) + '_model.fits', model, hdr)

            model_mfs = np.mean(model, axis=0)
            save_fits(args.outfile + str(i) + '_model_mfs.fits', model_mfs,
                      hdr_mfs)

            save_fits(args.outfile + str(i) + '_residual.fits', residual, hdr)

            save_fits(args.outfile + str(i) + '_residual_mfs.fits',
                      residual_mfs, hdr_mfs)

        print(
            "At iteration %i peak of residual is %f, rms is %f, current eps is %f"
            % (i, rmax, rms, eps))

    if args.write_model:
        R.write_model(model)

    if args.make_restored:
        x = pcg(hess,
                residual,
                np.zeros(dirty.shape, dtype=np.float64),
                M=M,
                tol=args.cgtol,
                maxit=args.cgmaxit)
        restored = model + x

        # get residual
        residual = R.make_residual(restored) / psf_max_mean
        residual_mfs = np.sum(residual, axis=0) / wsum
        rmax = np.abs(residual_mfs).max()
        rms = np.std(residual_mfs)

        print("After restoring peak of residual is %f and rms is %f" %
              (rmax, rms))

        # save current iteration
        save_fits(args.outfile + '_restored.fits', restored, hdr)

        restored_mfs = np.mean(restored, axis=0)
        save_fits(args.outfile + '_restored_mfs.fits', restored_mfs, hdr_mfs)

        save_fits(args.outfile + '_restored_residual.fits', residual, hdr)

        save_fits(args.outfile + '_restored_residual_mfs.fits', residual_mfs,
                  hdr_mfs)
Exemple #21
0
def extract_dde_info(args, freqs):
    """
    Computes paralactic angles, antenna scaling and pointing information
    required for beam interpolation. 
    """
    # get ms info required to compute paralactic angles and weighted sum
    nband = freqs.size
    if args.ms is not None:
        utimes = []
        unflag_counts = []
        ant_pos = None
        phase_dir = None
        for ms_name in args.ms:
            # get antenna positions
            ant = xds_from_table(ms_name + '::ANTENNA')[0].compute()
            if ant_pos is None:
                ant_pos = ant['POSITION'].data
            else: # check all are the same
                tmp = ant['POSITION']
                if not np.array_equal(ant_pos, tmp):
                    raise ValueError("Antenna positions not the same across measurement sets")
            
            # get phase center for field
            field = xds_from_table(ms_name + '::FIELD')[0].compute()
            if phase_dir is None:
                phase_dir = field['PHASE_DIR'][args.field].data.squeeze()
            else:
                tmp = field['PHASE_DIR'][args.field].data.squeeze()
                if not np.array_equal(phase_dir, tmp):
                    raise ValueError('Phase direction not the same across measurement sets')

            # get unique times and count flags
            xds = xds_from_ms(ms_name, columns=["TIME", "FLAG_ROW"], group_cols=["FIELD_ID"])[args.field]
            utime, time_idx = np.unique(xds.TIME.data.compute(), return_index=True)
            ntime = utime.size
            # extract subset of times
            if args.sparsify_time > 1:
                I = np.arange(0, ntime, args.sparsify_time)
                utime = utime[I]
                time_idx = time_idx[I]
                ntime = utime.size
            
            utimes.append(utime)
        
            flags = xds.FLAG_ROW.data.compute()
            unflag_count = _unflagged_counts(flags.astype(np.int32), time_idx, np.zeros(ntime, dtype=np.int32))
            unflag_counts.append(unflag_count)

        utimes = np.concatenate(utimes)
        unflag_counts = np.concatenate(unflag_counts)
        ntimes = utimes.size
        
        # compute paralactic angles
        parangles = parallactic_angles(utimes, ant_pos, phase_dir)

        # mean over antanna nant -> 1
        parangles = np.mean(parangles, axis=1, keepdims=True)
        nant = 1

        # beam_cube_dde requirements
        ant_scale = np.ones((nant, nband, 2), dtype=np.float64)
        point_errs = np.zeros((ntimes, nant, nband, 2), dtype=np.float64)

        return (parangles,
                da.from_array(ant_scale, chunks=ant_scale.shape),
                point_errs,
                unflag_counts,
                True)
    else:
        ntimes = 1
        nant = 1
        parangles = np.zeros((ntimes, nant,), dtype=np.float64)    
        ant_scale = np.ones((nant, nband, 2), dtype=np.float64)
        point_errs = np.zeros((ntimes, nant, nband, 2), dtype=np.float64)
        unflag_counts = np.array([1])
        
        return (parangles, ant_scale, point_errs, unflag_counts, False)
def antenna_flags_field(msname, fields=None, antennas=None):
    ds_ant = xds_from_table(msname+"::ANTENNA")[0]
    ds_field = xds_from_table(msname+"::FIELD")[0]
    ds_obs = xds_from_table(msname+"::OBSERVATION")[0]

    ant_names = ds_ant.NAME.data.compute()
    field_names = ds_field.NAME.data.compute()
    ant_positions = ds_ant.POSITION.data.compute()

    try:
        # Get observatory name and centre of array
        obs_name = ds_obs.TELESCOPE_NAME.data.compute()[0]
        me = casacore.measures.measures()
        obs_cofa = me.observatory(obs_name)
        lon, lat, alt = (obs_cofa['m0']['value'],
                         obs_cofa['m1']['value'],
                         obs_cofa['m2']['value'])
        cofa = wgs84_to_ecef(lon, lat, alt)
    except:
        # Otherwise use the first id antenna
        cofa = ant_positions[0]

    if fields:
        if isinstance(fields[0], str):
            field_ids = list(map(fields.index, fields))
        else:
            field_ids = fields
    else:
        field_ids = list(range(len(field_names)))

    if antennas:
        if isinstance(antennas[0], str):
            ant_ids = list(map(antennas.index, antennas))
        else:
            ant_ids = antennas
    else:
        ant_ids = list(range(len(ant_names)))

    nant = len(ant_ids)
    nfield = len(field_ids)
    
    fields_str = ", ".join(map(str, field_ids))
    ds_mss = xds_from_ms(msname, group_cols=["FIELD_ID", "DATA_DESC_ID"], 
            chunks={'row': 100000}, taql_where="FIELD_ID IN [%s]" % fields_str)
    flag_sum_computes = []
    for ds in ds_mss:
        flag_sums = da.blockwise(_get_flags, ("row",),
                                    ant_ids, ("ant",),
                                    ds.ANTENNA1.data, ("row",),
                                    ds.ANTENNA2.data, ("row",),
                                    ds.FLAG.data, ("row","chan", "corr"),
                                    adjust_chunks={"row": nant },
                                    dtype=numpy.ndarray)
    
        flags_redux = da.reduction(flag_sums,
                                 chunk=_chunk,
                                 combine=_combine,
                                 aggregate=_aggregate,
                                 concatenate=False,
                                 dtype=numpy.float64)
        flag_sum_computes.append(flags_redux)

    #flag_sum_computes[0].visualize("graph.pdf")
    sum_per_field_spw = dask.compute(flag_sum_computes)[0]
    sum_all = sum(sum_per_field_spw)
    fractions = sum_all[:,0]/sum_all[:,1]
    stats = {}
    for i,aid in enumerate(ant_ids):
        ant_stats = {}
        ant_pos = list(ant_positions[i])
        ant_stats["name"] = ant_names[aid]
        ant_stats["position"] = ant_pos
        ant_stats["array_centre_dist"] = _distance(cofa, ant_pos)
        ant_stats["frac"] = fractions[i]
        ant_stats["sum"] = sum_all[i][0]
        ant_stats["counts"] = sum_all[i][1]
        stats[aid] = ant_stats

    return stats
Exemple #23
0
def _main(dest=sys.stdout):
    from pfb.parser import create_parser
    args = create_parser().parse_args()

    if not args.nthreads:
        import multiprocessing
        args.nthreads = multiprocessing.cpu_count()

    if not args.mem_limit:
        import psutil
        args.mem_limit = int(psutil.virtual_memory()[0] /
                             1e9)  # 100% of memory by default

    import numpy as np
    import numba
    import numexpr
    import dask
    import dask.array as da
    from daskms import xds_from_ms, xds_from_table
    from astropy.io import fits
    from pfb.utils.fits import (set_wcs, load_fits, save_fits, compare_headers,
                                data_from_header)
    from pfb.utils.restoration import fitcleanbeam
    from pfb.utils.misc import Gaussian2D
    from pfb.operators.gridder import Gridder
    from pfb.operators.psf import PSF
    from pfb.deconv.sara import sara
    from pfb.deconv.clean import clean
    from pfb.deconv.spotless import spotless
    from pfb.deconv.nnls import nnls
    from pfb.opt.pcg import pcg

    if not isinstance(args.ms, list):
        args.ms = [args.ms]

    pyscilog.log_to_file(args.outfile + '.log')
    pyscilog.enable_memory_logging(level=3)

    GD = vars(args)
    print('Input Options:', file=log)
    for key in GD.keys():
        print('     %25s = %s' % (key, GD[key]), file=log)

    # get max uv coords over all fields
    uvw = []
    u_max = 0.0
    v_max = 0.0
    all_freqs = []
    for ims in args.ms:
        xds = xds_from_ms(ims,
                          group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                          columns=('UVW'),
                          chunks={'row': args.row_chunks})

        spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__")
        spws = dask.compute(spws)[0]

        for ds in xds:
            uvw = ds.UVW.data
            u_max = da.maximum(u_max, abs(uvw[:, 0]).max())
            v_max = da.maximum(v_max, abs(uvw[:, 1]).max())
            uv_max = da.maximum(u_max, v_max)

            spw = spws[ds.DATA_DESC_ID]
            tmp_freq = spw.CHAN_FREQ.data.squeeze()
            all_freqs.append(list([tmp_freq]))

    uv_max = u_max.compute()
    del uvw

    # get Nyquist cell size
    from africanus.constants import c as lightspeed
    all_freqs = dask.compute(all_freqs)
    freq = np.unique(all_freqs)
    cell_N = 1.0 / (2 * uv_max * freq.max() / lightspeed)

    if args.cell_size is not None:
        cell_rad = args.cell_size * np.pi / 60 / 60 / 180
        if cell_N / cell_rad < 1:
            raise ValueError(
                "Requested cell size too small. "
                "Super resolution factor = ", cell_N / cell_rad)
        print("Super resolution factor = %f" % (cell_N / cell_rad), file=dest)
    else:
        cell_rad = cell_N / args.super_resolution_factor
        args.cell_size = cell_rad * 60 * 60 * 180 / np.pi
        print("Cell size set to %5.5e arcseconds" % args.cell_size, file=dest)

    if args.nx is None or args.ny is None:
        from ducc0.fft import good_size
        fov = args.fov * 3600
        npix = int(fov / args.cell_size)
        if npix % 2:
            npix += 1
        args.nx = good_size(npix)
        args.ny = good_size(npix)

    if args.nband is None:
        args.nband = freq.size

    print("Image size set to (%i, %i, %i)" % (args.nband, args.nx, args.ny),
          file=dest)

    # mask
    if args.mask is not None:
        mask_array = load_fits(args.mask, dtype=args.real_type).squeeze()
        if mask_array.shape != (args.nx, args.ny):
            raise ValueError("Mask has incorrect shape.")
        # add freq axis
        mask_array = mask_array[None]

        def mask(x):
            return mask_array * x
    else:
        mask_array = None

        def mask(x):
            return x

    # init gridder
    R = Gridder(
        args.ms,
        args.nx,
        args.ny,
        args.cell_size,
        nband=args.nband,
        nthreads=args.nthreads,
        do_wstacking=args.do_wstacking,
        row_chunks=args.row_chunks,
        psf_oversize=args.psf_oversize,
        data_column=args.data_column,
        epsilon=args.epsilon,
        weight_column=args.weight_column,
        imaging_weight_column=args.imaging_weight_column,
        model_column=args.model_column,
        flag_column=args.flag_column,
        weighting=args.weighting,
        robust=args.robust,
        mem_limit=int(
            0.8 * args.mem_limit))  # assumes gridding accounts for 80% memory
    freq_out = R.freq_out
    radec = R.radec

    print("PSF size set to (%i, %i, %i)" % (args.nband, R.nx_psf, R.ny_psf),
          file=dest)

    # get headers
    hdr = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx,
                  args.ny, radec, freq_out)
    hdr_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600, args.nx,
                      args.ny, radec, np.mean(freq_out))
    hdr_psf = set_wcs(args.cell_size / 3600, args.cell_size / 3600, R.nx_psf,
                      R.ny_psf, radec, freq_out)
    hdr_psf_mfs = set_wcs(args.cell_size / 3600, args.cell_size / 3600,
                          R.nx_psf, R.ny_psf, radec, np.mean(freq_out))

    # psf
    if args.psf is not None:
        try:
            compare_headers(hdr_psf, fits.getheader(args.psf))
            psf = load_fits(args.psf, dtype=args.real_type).squeeze()
        except BaseException:
            raise
            psf = R.make_psf()
            save_fits(args.outfile + '_psf.fits', psf, hdr_psf)
    else:
        psf = R.make_psf()
        save_fits(args.outfile + '_psf.fits', psf, hdr_psf)

    # Normalising by wsum (so that the PSF always sums to 1) results in the
    # most intuitive sig_21 values and by far the least bookkeeping.
    # However, we won't save the cubes that way as it destroys information
    # about the noise in image space. Note only the MFS images will have the
    # usual units of Jy/beam.
    wsums = np.amax(psf.reshape(args.nband, R.nx_psf * R.ny_psf), axis=1)
    wsum = np.sum(wsums)
    psf /= wsum
    psf_mfs = np.sum(psf, axis=0)

    # fit restoring psf
    GaussPar = fitcleanbeam(psf_mfs[None], level=0.5, pixsize=1.0)
    GaussPars = fitcleanbeam(psf, level=0.5, pixsize=1.0)

    cpsf_mfs = np.zeros(psf_mfs.shape, dtype=args.real_type)
    cpsf = np.zeros(psf.shape, dtype=args.real_type)

    lpsf = np.arange(-R.nx_psf / 2, R.nx_psf / 2)
    mpsf = np.arange(-R.ny_psf / 2, R.ny_psf / 2)
    xx, yy = np.meshgrid(lpsf, mpsf, indexing='ij')

    cpsf_mfs = Gaussian2D(xx, yy, GaussPar[0], normalise=False)

    for v in range(args.nband):
        cpsf[v] = Gaussian2D(xx, yy, GaussPars[v], normalise=False)

    from pfb.utils.fits import add_beampars
    GaussPar = list(GaussPar[0])
    GaussPar[0] *= args.cell_size / 3600
    GaussPar[1] *= args.cell_size / 3600
    GaussPar = tuple(GaussPar)
    hdr_psf_mfs = add_beampars(hdr_psf_mfs, GaussPar)

    save_fits(args.outfile + '_cpsf_mfs.fits', cpsf_mfs, hdr_psf_mfs)
    save_fits(args.outfile + '_psf_mfs.fits', psf_mfs, hdr_psf_mfs)

    GaussPars = list(GaussPars)
    for b in range(args.nband):
        GaussPars[b] = list(GaussPars[b])
        GaussPars[b][0] *= args.cell_size / 3600
        GaussPars[b][1] *= args.cell_size / 3600
        GaussPars[b] = tuple(GaussPars[b])
    GaussPars = tuple(GaussPars)
    hdr_psf = add_beampars(hdr_psf, GaussPar, GaussPars)

    save_fits(args.outfile + '_cpsf.fits', cpsf, hdr_psf)

    # dirty
    if args.dirty is not None:
        try:
            compare_headers(hdr, fits.getheader(args.dirty))
            dirty = load_fits(args.dirty).squeeze()
        except BaseException:
            raise
            dirty = R.make_dirty()
            save_fits(args.outfile + '_dirty.fits', dirty, hdr)
    else:
        dirty = R.make_dirty()
        save_fits(args.outfile + '_dirty.fits', dirty, hdr)

    dirty /= wsum
    dirty_mfs = np.sum(dirty, axis=0)
    save_fits(args.outfile + '_dirty_mfs.fits', dirty_mfs, hdr_mfs)

    quit()
    # initial model and residual
    if args.x0 is not None:
        try:
            compare_headers(hdr, fits.getheader(args.x0))
            model = load_fits(args.x0, dtype=args.real_type).squeeze()
            if args.first_residual is not None:
                try:
                    compare_headers(hdr, fits.getheader(args.first_residual))
                    residual = load_fits(args.first_residual,
                                         dtype=args.real_type).squeeze()
                except BaseException:
                    residual = R.make_residual(model)
                    save_fits(args.outfile + '_first_residual.fits', residual,
                              hdr)
            else:
                residual = R.make_residual(model)
                save_fits(args.outfile + '_first_residual.fits', residual, hdr)
            residual /= wsum
        except BaseException:
            model = np.zeros((args.nband, args.nx, args.ny))
            residual = dirty.copy()
    else:
        model = np.zeros((args.nband, args.nx, args.ny))
        residual = dirty.copy()

    residual_mfs = np.sum(residual, axis=0)
    save_fits(args.outfile + '_first_residual_mfs.fits', residual_mfs, hdr_mfs)

    # smooth beam
    if args.beam_model is not None:
        if args.beam_model[-5:] == '.fits':
            beam_image = load_fits(args.beam_model,
                                   dtype=args.real_type).squeeze()
            if beam_image.shape != (args.nband, args.nx, args.ny):
                raise ValueError("Beam has incorrect shape")

        elif args.beam_model == "JimBeam":
            from katbeam import JimBeam
            if args.band.lower() == 'l':
                beam = JimBeam('MKAT-AA-L-JIM-2020')
            else:
                beam = JimBeam('MKAT-AA-UHF-JIM-2020')
            beam_image = np.zeros((args.nband, args.nx, args.ny),
                                  dtype=args.real_type)

            l_coord, ref_l = data_from_header(hdr, axis=1)
            l_coord -= ref_l
            m_coord, ref_m = data_from_header(hdr, axis=2)
            m_coord -= ref_m
            xx, yy = np.meshgrid(l_coord, m_coord, indexing='ij')

            for v in range(args.nband):
                beam_image[v] = beam.I(xx, yy, freq_out[v])

        def beam(x):
            return beam_image * x
    else:
        beam_image = None

        def beam(x):
            return x

    if args.init_nnls:
        print("Initialising with NNLS", file=log)
        model = nnls(psf,
                     model,
                     residual,
                     mask=mask_array,
                     beam_image=beam_image,
                     hdr=hdr,
                     hdr_mfs=hdr_mfs,
                     outfile=args.outfile,
                     maxit=1,
                     nthreads=args.nthreads)

        residual = R.make_residual(beam(mask(model))) / wsum
        residual_mfs = np.sum(residual, axis=0)

    # deconvolve
    rmax = np.abs(residual_mfs).max()
    rms = np.std(residual_mfs)
    redo_dirty = False
    print("Peak of initial residual is %f and rms is %f" % (rmax, rms),
          file=dest)
    for i in range(0, args.maxit):
        # run minor cycle of choice
        modelp = model.copy()
        if args.deconv_mode == 'sara':
            model = sara(psf,
                         model,
                         residual,
                         mask=mask_array,
                         beam_image=beam_image,
                         hessian=R.convolve,
                         wsum=wsum,
                         adapt_sig21=args.adapt_sig21,
                         hdr=hdr,
                         hdr_mfs=hdr_mfs,
                         outfile=args.outfile,
                         cpsf=cpsf,
                         nthreads=args.nthreads,
                         sig_21=args.sig_21,
                         sigma_frac=args.sigma_frac,
                         maxit=args.minormaxit,
                         tol=args.minortol,
                         gamma=args.gamma,
                         psi_levels=args.psi_levels,
                         psi_basis=args.psi_basis,
                         pdtol=args.pdtol,
                         pdmaxit=args.pdmaxit,
                         pdverbose=args.pdverbose,
                         positivity=args.positivity,
                         cgtol=args.cgtol,
                         cgminit=args.cgminit,
                         cgmaxit=args.cgmaxit,
                         cgverbose=args.cgverbose,
                         pmtol=args.pmtol,
                         pmmaxit=args.pmmaxit,
                         pmverbose=args.pmverbose)

        elif args.deconv_mode == 'clean':
            model = clean(psf,
                          model,
                          residual,
                          mask=mask_array,
                          beam=beam_image,
                          nthreads=args.nthreads,
                          maxit=args.minormaxit,
                          gamma=args.gamma,
                          peak_factor=args.peak_factor,
                          threshold=args.threshold,
                          hbgamma=args.hbgamma,
                          hbpf=args.hbpf,
                          hbmaxit=args.hbmaxit,
                          hbverbose=args.hbverbose)
        elif args.deconv_mode == 'spotless':
            model = spotless(psf,
                             model,
                             residual,
                             mask=mask_array,
                             beam_image=beam_image,
                             hessian=R.convolve,
                             wsum=wsum,
                             adapt_sig21=args.adapt_sig21,
                             cpsf=cpsf_mfs,
                             hdr=hdr,
                             hdr_mfs=hdr_mfs,
                             outfile=args.outfile,
                             sig_21=args.sig_21,
                             sigma_frac=args.sigma_frac,
                             nthreads=args.nthreads,
                             gamma=args.gamma,
                             peak_factor=args.peak_factor,
                             maxit=args.minormaxit,
                             tol=args.minortol,
                             threshold=args.threshold,
                             positivity=args.positivity,
                             hbgamma=args.hbgamma,
                             hbpf=args.hbpf,
                             hbmaxit=args.hbmaxit,
                             hbverbose=args.hbverbose,
                             pdtol=args.pdtol,
                             pdmaxit=args.pdmaxit,
                             pdverbose=args.pdverbose,
                             cgtol=args.cgtol,
                             cgminit=args.cgminit,
                             cgmaxit=args.cgmaxit,
                             cgverbose=args.cgverbose,
                             pmtol=args.pmtol,
                             pmmaxit=args.pmmaxit,
                             pmverbose=args.pmverbose)
        else:
            raise ValueError("Unknown deconvolution mode ", args.deconv_mode)

        # get residual
        if redo_dirty:
            # Need to do this if weights or Jones has changed
            # (eg. if we change robustness factor, reweight or calibrate)
            psf = R.make_psf()
            wsums = np.amax(psf.reshape(args.nband, R.nx_psf * R.ny_psf),
                            axis=1)
            wsum = np.sum(wsums)
            psf /= wsum
            dirty = R.make_dirty() / wsum

        # compute in image space
        # residual = dirty - R.convolve(beam(mask(model))) / wsum
        residual = R.make_residual(beam(mask(model))) / wsum

        residual_mfs = np.sum(residual, axis=0)

        # save current iteration
        model_mfs = np.mean(model, axis=0)
        save_fits(args.outfile + '_major' + str(i + 1) + '_model_mfs.fits',
                  model_mfs, hdr_mfs)

        save_fits(args.outfile + '_major' + str(i + 1) + '_model.fits', model,
                  hdr)

        save_fits(args.outfile + '_major' + str(i + 1) + '_residual_mfs.fits',
                  residual_mfs, hdr_mfs)

        save_fits(args.outfile + '_major' + str(i + 1) + '_residual.fits',
                  residual * wsum, hdr)

        # check stopping criteria
        rmax = np.abs(residual_mfs).max()
        rms = np.std(residual_mfs)
        eps = np.linalg.norm(model - modelp) / np.linalg.norm(model)

        print("At iteration %i peak of residual is %f, rms is %f, current "
              "eps is %f" % (i + 1, rmax, rms, eps),
              file=dest)

        if eps < args.tol:
            break

    if args.mop_flux:
        print("Mopping flux", file=dest)

        # vague Gaussian prior on x
        def hess(x):
            return mask(beam(R.convolve(mask(beam(x))))) / wsum + 1e-6 * x

        def M(x):
            return x / 1e-6  # preconditioner

        x = pcg(hess,
                mask(beam(residual)),
                np.zeros(residual.shape, dtype=residual.dtype),
                M=M,
                tol=0.1 * args.cgtol,
                maxit=args.cgmaxit,
                minit=args.cgminit,
                verbosity=args.cgverbose)

        model += x
        # residual = dirty - R.convolve(beam(mask(model))) / wsum
        residual = R.make_residual(beam(mask(model))) / wsum

        save_fits(args.outfile + '_mopped_model.fits', model, hdr)
        save_fits(args.outfile + '_mopped_residual.fits', residual, hdr)
        model_mfs = np.mean(model, axis=0)
        save_fits(args.outfile + '_mopped_model_mfs.fits', model_mfs, hdr_mfs)
        residual_mfs = np.sum(residual, axis=0)
        save_fits(args.outfile + '_mopped_residual_mfs.fits', residual_mfs,
                  hdr_mfs)

        rmax = np.abs(residual_mfs).max()
        rms = np.std(residual_mfs)

        print("After mopping flux peak of residual is %f, rms is %f" %
              (rmax, rms),
              file=dest)

    # if args.interp_model:
    #     nband = args.nband
    #     order = args.spectral_poly_order
    #     phi.trim_fat(model)
    #     I = np.argwhere(phi.mask).squeeze()
    #     Ix = I[:, 0]
    #     Iy = I[:, 1]
    #     npix = I.shape[0]

    #     # get components
    #     beta = model[:, Ix, Iy]

    #     # fit integrated polynomial to model components
    #     # we are given frequencies at bin centers, convert to bin edges
    #     ref_freq = np.mean(freq_out)
    #     delta_freq = freq_out[1] - freq_out[0]
    #     wlow = (freq_out - delta_freq/2.0)/ref_freq
    #     whigh = (freq_out + delta_freq/2.0)/ref_freq
    #     wdiff = whigh - wlow

    #     # set design matrix for each component
    #     Xdesign = np.zeros([freq_out.size, args.spectral_poly_order])
    #     for i in range(1, args.spectral_poly_order+1):
    #         Xdesign[:, i-1] = (whigh**i - wlow**i)/(i*wdiff)

    #     weights = psf_max[:, None]
    #     dirty_comps = Xdesign.T.dot(weights*beta)

    #     hess_comps = Xdesign.T.dot(weights*Xdesign)

    #     comps = np.linalg.solve(hess_comps, dirty_comps)

    #     np.savez(args.outfile + "spectral_comps", comps=comps, ref_freq=ref_freq, mask=np.any(model, axis=0))

    if args.write_model:
        print("Writing model", file=dest)
        R.write_model(model)

    if args.make_restored:
        print("Making restored", file=dest)
        cpsfo = PSF(cpsf, residual.shape, nthreads=args.nthreads)
        restored = cpsfo.convolve(model)

        # residual needs to be in Jy/beam before adding to convolved model
        wsums = np.amax(psf.reshape(-1, R.nx_psf * R.ny_psf), axis=1)
        restored += residual / wsums[:, None, None]

        save_fits(args.outfile + '_restored.fits', restored, hdr)
        restored_mfs = np.mean(restored, axis=0)
        save_fits(args.outfile + '_restored_mfs.fits', restored_mfs, hdr_mfs)
        residual_mfs = np.sum(residual, axis=0)
Exemple #24
0
def chan_to_band_mapping(ms_name, nband=None):
    '''
    Construct dictionaries containing per MS and SPW channel to band mapping.
    Currently assumes we are only imaging field 0 of the first MS.

    Input:
    ms_name     - list of ms names
    nband       - number of imaging bands

    Output:
    freqs           - dict[MS][SPW] chunked dask arrays of the freq to band mapping
    freq_bin_idx    - dict[MS][SPW] chunked dask arrays of bin starting indices
    freq_bin_counts - dict[MS][SPW] chunked dask arrays of counts in each bin
    freq_out        - frequencies of average (LB - should a weighted sum rather be computed?)
    band_mapping    - dict[MS][SPW] identifying imaging bands going into degridder
    chan_chunks     - dict[MS][SPW] specifying dask chunking scheme over channel
    '''
    from daskms import xds_from_storage_ms as xds_from_ms
    from daskms import xds_from_storage_table as xds_from_table
    import dask
    import dask.array as da

    from omegaconf import ListConfig
    if not isinstance(ms_name, list) and not isinstance(ms_name, ListConfig):
        ms_name = [ms_name]

    # first pass through data to determine freq_mapping
    radec = None
    freqs = {}
    all_freqs = []
    spws = {}
    for ims in ms_name:
        xds = xds_from_ms(ims, chunks={"row": -1}, columns=('TIME', ))

        # subtables
        ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
        fields = xds_from_table(ims + "::FIELD")
        spws_table = xds_from_table(ims + "::SPECTRAL_WINDOW")
        pols = xds_from_table(ims + "::POLARIZATION")

        # subtable data
        ddids = dask.compute(ddids)[0]
        fields = dask.compute(fields)[0]
        spws_table = dask.compute(spws_table)[0]
        pols = dask.compute(pols)[0]

        freqs[ims] = {}
        spws[ims] = []
        for ds in xds:
            field = fields[ds.FIELD_ID]

            # check fields match
            if radec is None:
                radec = field.PHASE_DIR.data.squeeze()

            if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()):
                continue

            spw = spws_table[ds.DATA_DESC_ID]
            tmp_freq = spw.CHAN_FREQ.data.squeeze()
            freqs[ims][ds.DATA_DESC_ID] = tmp_freq
            all_freqs.append(list([tmp_freq]))
            spws[ims].append(ds.DATA_DESC_ID)

    # freq mapping
    all_freqs = dask.compute(all_freqs)
    ufreqs = np.unique(all_freqs)  # sorted ascending
    nchan = ufreqs.size
    if nband is None:
        nband = nchan
    else:
        nband = nband

    # bin edges
    fmin = ufreqs[0]
    fmax = ufreqs[-1]
    fbins = np.linspace(fmin, fmax, nband + 1)
    freq_out = np.zeros(nband)
    for band in range(nband):
        indl = ufreqs >= fbins[band]
        # inclusive except for the last one
        indu = ufreqs < fbins[band + 1] + 1e-6
        freq_out[band] = np.mean(ufreqs[indl & indu])

    # chan <-> band mapping
    band_mapping = {}
    chan_chunks = {}
    freq_bin_idx = {}
    freq_bin_counts = {}
    for ims in freqs:
        freq_bin_idx[ims] = {}
        freq_bin_counts[ims] = {}
        band_mapping[ims] = {}
        chan_chunks[ims] = []
        for spw in freqs[ims]:
            freq = np.atleast_1d(dask.compute(freqs[ims][spw])[0])
            band_map = np.zeros(freq.size, dtype=np.int32)
            for band in range(nband):
                indl = freq >= fbins[band]
                indu = freq < fbins[band + 1] + 1e-6
                band_map = np.where(indl & indu, band, band_map)
            # to dask arrays
            bands, bin_counts = np.unique(band_map, return_counts=True)
            band_mapping[ims][spw] = tuple(bands)
            chan_chunks[ims].append({'chan': tuple(bin_counts)})
            freqs[ims][spw] = da.from_array(freq, chunks=tuple(bin_counts))
            bin_idx = np.append(np.array([0]), np.cumsum(bin_counts))[0:-1]
            freq_bin_idx[ims][spw] = da.from_array(bin_idx, chunks=1)
            freq_bin_counts[ims][spw] = da.from_array(bin_counts, chunks=1)

    return freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks
Exemple #25
0
    def __init__(self,
                 ms_name,
                 nx,
                 ny,
                 cell_size,
                 nband=None,
                 nthreads=8,
                 do_wstacking=1,
                 Stokes='I',
                 row_chunks=-1,
                 chan_chunks=32,
                 optimise_chunks=True,
                 epsilon=1e-5,
                 psf_oversize=2.0,
                 weighting=None,
                 robust=None,
                 data_column='CORRECTED_DATA',
                 weight_column='WEIGHT_SPECTRUM',
                 mueller_column=None,
                 model_column="MODEL_DATA",
                 flag_column='FLAG',
                 imaging_weight_column=None,
                 real_type='f4',
                 cdir=None,
                 mem_limit=None):
        '''
        TODO - currently row_chunks and chan_chunks are only used for the
        compute_weights() and write_component_model() methods. All other
        methods assume that the data for a single imaging band per ms and
        spw fit into memory. The optimise_chunks argument is a promise to
        improve this in the future.

        TODO - current IO can probably be massively reduced if we optimize
        for specific Stokes outputs and we optimise the chunking strategy.
        In particular, we can write out the weights for Stokes I imaging in
        advance and then only load precomputed scalar weights in the convolve
        function. Since we currently load in weights, imaging weights and a
        complex "Mueller" term for all 4 correlations, we can in principle
        reduce IO and memory footprint by about a factor of 16.

        # of GB for 8 hr 8 sec 32k observation
        64*(64-1) //2 * 8 * 60 * 60 // 8 * 2**15 * 4 * 8 / 1e9 = 7610 GB

        # of GB for 8 hr 8 sec 4k observation
        64*(64-1) //2 * 8 * 60 * 60 // 8 * 2**15 * 4 * 8 / 1e9 = 951 GB
        '''
        if Stokes != 'I':
            raise NotImplementedError("Only Stokes I currently supported")
        self.nx = nx
        self.ny = ny
        self.cell = cell_size * np.pi / 60 / 60 / 180
        self.nthreads = nthreads
        self.do_wstacking = do_wstacking
        self.epsilon = epsilon
        self.row_chunks = row_chunks
        self.chan_chunks = chan_chunks
        self.psf_oversize = psf_oversize
        self.nx_psf = int(self.psf_oversize * self.nx)
        self.nx_psf += self.nx_psf % 2
        self.ny_psf = int(self.psf_oversize * self.ny)
        self.ny_psf += self.ny_psf % 2
        self.real_type = real_type

        if isinstance(ms_name, list):
            self.ms = ms_name
        else:
            self.ms = [ms_name]

        # first pass through data to determine freq_mapping
        self.radec = None
        self.freq = {}
        self.freq_np = {}
        all_freqs = []
        self.spws = {}
        for ims in self.ms:
            xds = xds_from_ms(ims,
                              group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                              chunks={"row": -1},
                              columns=('TIME'))

            # subtables
            ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
            fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
            spws = xds_from_table(ims + "::SPECTRAL_WINDOW",
                                  group_cols="__row__")
            pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

            # subtable data
            ddids = dask.compute(ddids)[0]
            fields = dask.compute(fields)[0]
            spws = dask.compute(spws)[0]
            pols = dask.compute(pols)[0]

            self.freq[ims] = {}
            self.freq_np[ims] = {}
            self.spws[ims] = []
            maxchans = 0
            ncorr = 4  # TODO - get ncorr from ds
            for ds in xds:
                field = fields[ds.FIELD_ID]
                radec = field.PHASE_DIR.data.squeeze()

                # check fields match
                if self.radec is None:
                    self.radec = radec

                if not np.array_equal(radec, self.radec):
                    continue

                spw = spws[ds.DATA_DESC_ID]
                tmp_freq = spw.CHAN_FREQ.data.squeeze()
                maxchans = np.maximum(maxchans, tmp_freq.size)
                self.freq[ims][ds.DATA_DESC_ID] = tmp_freq
                self.freq_np[ims][ds.DATA_DESC_ID] = dask.compute(tmp_freq)[0]
                all_freqs.append(list([tmp_freq]))
                self.spws[ims].append(ds.DATA_DESC_ID)

        self.data_column = data_column
        self.weight_column = weight_column
        self.model_column = model_column
        self.flag_column = flag_column

        self.columns = (self.data_column, self.weight_column, self.flag_column,
                        'UVW')

        # TODO - write jones2col if column does not exist
        self.mueller_column = mueller_column
        if mueller_column is not None:
            self.columns += (self.mueller_column, )

        # check that all measurement sets contain the required columns
        for ims in self.ms:
            xds = xds_from_ms(ims)

            for ds in xds:
                for column in self.columns:
                    try:
                        getattr(ds, column)
                    except BaseException:
                        raise ValueError("No column named %s in %s" %
                                         (column, ims))

        # freq mapping
        all_freqs = dask.compute(all_freqs)
        ufreqs = np.unique(all_freqs)  # sorted ascending
        self.nchan = ufreqs.size
        if nband is None:
            self.nband = self.nchan
        else:
            self.nband = nband

        # bin edges
        fmin = ufreqs[0]
        fmax = ufreqs[-1]
        fbins = np.linspace(fmin, fmax, self.nband + 1)
        self.freq_out = np.zeros(self.nband)
        for band in range(self.nband):
            indl = ufreqs >= fbins[band]
            # inclusive except for the last one
            indu = ufreqs < fbins[band + 1] + 1e-6
            self.freq_out[band] = np.mean(ufreqs[indl & indu])

        # chan <-> band mapping
        self.band_mapping = {}
        self.chunks = {}
        self.freq_bin_idx = {}
        self.freq_bin_counts = {}
        self.freq_bin_idx_np = {}
        self.freq_bin_counts_np = {}
        for ims in self.freq:
            self.freq_bin_idx[ims] = {}
            self.freq_bin_counts[ims] = {}
            self.freq_bin_idx_np[ims] = {}
            self.freq_bin_counts_np[ims] = {}
            self.band_mapping[ims] = {}
            self.chunks[ims] = []
            for spw in self.freq[ims]:
                freq = np.atleast_1d(dask.compute(self.freq[ims][spw])[0])
                band_map = np.zeros(freq.size, dtype=np.int32)
                for band in range(self.nband):
                    indl = freq >= fbins[band]
                    indu = freq < fbins[band + 1] + 1e-6
                    band_map = np.where(indl & indu, band, band_map)
                # to dask arrays
                bands, bin_counts = np.unique(band_map, return_counts=True)
                self.band_mapping[ims][spw] = tuple(bands)
                self.chunks[ims].append({'row': -1, 'chan': tuple(bin_counts)})
                self.freq[ims][spw] = da.from_array(freq,
                                                    chunks=tuple(bin_counts))
                bin_idx = np.append(np.array([0]), np.cumsum(bin_counts))[0:-1]
                self.freq_bin_idx[ims][spw] = da.from_array(bin_idx, chunks=1)
                self.freq_bin_counts[ims][spw] = da.from_array(bin_counts,
                                                               chunks=1)
                self.freq_bin_idx_np[ims][spw] = bin_idx
                self.freq_bin_counts_np[ims][spw] = bin_counts

        # compute imaging weights
        if weighting is not None:
            if imaging_weight_column is None:
                self.imaging_weight_column = "IMAGING_WEIGHT_SPECTRUM"
            else:  # this column is always created if asked
                self.imaging_weight_column = imaging_weight_column
            print("Computing weights", file=log)
            self.compute_weights(robust)
            self.columns += (self.imaging_weight_column, )
        else:
            self.imaging_weight_column = None
Exemple #26
0
    def __init__(self, msname=None, log=None):
        if not msname:
            return

        self.msname = msname
        self.log = log

        tab = table(msname, ack=False)
        log and log.info(f": MS {msname} contains {tab.nrows()} rows")

        self.valid_columns = set(tab.colnames())

        spw_tab = daskms.xds_from_table(msname + '::SPECTRAL_WINDOW',
                                        columns=['CHAN_FREQ'])
        self.chan_freqs = spw_tab[
            0].CHAN_FREQ  # important for this to be an xarray
        self.nspw = self.chan_freqs.shape[0]
        self.spw = NamedList("spw", list(map(str, range(self.nspw))))

        log and log.info(
            f":   {self.chan_freqs.shape} spectral windows and channels")

        self.field = NamedList(
            "field",
            table(msname + '::FIELD', ack=False).getcol("NAME"))
        log and log.info(
            f":   {len(self.field)} fields: {' '.join(self.field.names)}")

        scan_numbers = sorted(set(tab.getcol("SCAN_NUMBER")))
        log and log.info(
            f":   {len(scan_numbers)} scans, first #{scan_numbers[0]}, last #{scan_numbers[-1]}"
        )
        all_scans = NamedList("scan",
                              list(map(str, range(scan_numbers[-1] + 1))))
        self.scan = all_scans.get_subset(scan_numbers)

        self.all_antenna = NamedList(
            "antenna",
            table(msname + '::ANTENNA', ack=False).getcol("NAME"))

        self.antenna = self.all_antenna.get_subset(
            list(set(tab.getcol("ANTENNA1")) | set(tab.getcol("ANTENNA2"))))

        baselines = [(p, q) for p in self.antenna.numbers
                     for q in self.antenna.numbers if p <= q]
        self.baseline_numbering = {(p, q): i
                                   for i, (p, q) in enumerate(baselines)}
        self.baseline_numbering.update({(q, p): i
                                        for i, (p, q) in enumerate(baselines)})

        log and log.info(
            f":   {len(self.antenna)} antennas: {self.antenna.str_list()}")

        pol_tab = table(msname + '::POLARIZATION', ack=False)

        all_corr_labels = [
            STOKES_TYPES[icorr]
            for icorr in pol_tab.getcol("CORR_TYPE", 0, 1).ravel()
        ]
        self.corr = NamedList("correlation", all_corr_labels.copy())

        # Maps correlation -> callable that extracts that correlation from visibility data
        # By default, populated with slicing functions for 0...3,
        # but can also be extended with "I", "Q", etx.
        self.corr_data_mappers = OrderedDict({
            i: lambda x, icorr=i: x[..., icorr]
            for i in range(len(all_corr_labels))
        })

        # Maps correlation -> callable that extracts that correlation from flag data
        self.corr_flag_mappers = self.corr_data_mappers.copy()

        # add mappings and labels for Stokes parameters
        xx, xy, yx, yy = [
            self.corr.map.get(c) for c in ("XX", "XY", "YX", "YY")
        ]
        rr, rl, lr, ll = [
            self.corr.map.get(c) for c in ("RR", "RL", "LR", "LL")
        ]

        def add_stokes(a, b, I, J, imag=False):
            """Adds mappers for Stokes A and B as the sum/difference of components I and J, divided by 2 or 2j"""
            def _sum(x):
                return (x[..., I] + x[..., J]) / 2

            def _diff(x):
                return (x[..., I] - x[..., J]) / (2j if imag else 2)

            def _or(x):
                return (x[..., I] | x[..., J])

            nonlocal all_corr_labels
            if a not in self.corr_data_mappers:
                self.corr_data_mappers[len(all_corr_labels)] = _sum
                self.corr_flag_mappers[len(all_corr_labels)] = _or
                all_corr_labels.append(a)
            if b not in self.corr_data_mappers:
                self.corr_data_mappers[len(all_corr_labels)] = _diff
                self.corr_flag_mappers[len(all_corr_labels)] = _or
                all_corr_labels.append(b)

        if xx is not None and yy is not None:
            add_stokes("I", "Q", xx, yy)
        if rr is not None and ll is not None:
            add_stokes("I", "V", rr, ll)
        if xy is not None and yx is not None:
            add_stokes("U", "V", xy, yx, True)
        if rl is not None and lr is not None:
            add_stokes("Q", "U", rl, lr, True)

        self.all_corr = NamedList("correlation", all_corr_labels)

        log and log.info(f":   corrs/Stokes {' '.join(self.all_corr.names)}")
Exemple #27
0
    def make_residual(self, x):
        # Note deprecated (does not support Jones terms)
        print("Making residual", file=log)
        x = da.from_array(x.astype(self.real_type),
                          chunks=(1, self.nx, self.ny),
                          name=False)
        residuals = []
        for ims in self.ms:
            xds = xds_from_ms(ims,
                              group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                              chunks=self.chunks[ims],
                              columns=self.columns)

            # subtables
            ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
            fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
            spws = xds_from_table(ims + "::SPECTRAL_WINDOW",
                                  group_cols="__row__")
            pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

            # subtable data
            ddids = dask.compute(ddids)[0]
            fields = dask.compute(fields)[0]
            spws = dask.compute(spws)[0]
            pols = dask.compute(pols)[0]

            for ds in xds:
                field = fields[ds.FIELD_ID]
                radec = field.PHASE_DIR.data.squeeze()
                if not np.array_equal(radec, self.radec):
                    continue

                spw = ds.DATA_DESC_ID  # this is not correct, need to use spw

                freq_bin_idx = self.freq_bin_idx[ims][spw]
                freq_bin_counts = self.freq_bin_counts[ims][spw]
                freq = self.freq[ims][spw]

                uvw = ds.UVW.data

                data = getattr(ds, self.data_column).data
                dataxx = data[:, :, 0]
                datayy = data[:, :, -1]

                weights = getattr(ds, self.weight_column).data
                if len(weights.shape) < 3:
                    weights = da.broadcast_to(weights[:, None, :],
                                              data.shape,
                                              chunks=data.chunks)

                if self.imaging_weight_column is not None:
                    imaging_weights = getattr(ds,
                                              self.imaging_weight_column).data
                    if len(imaging_weights.shape) < 3:
                        imaging_weights = da.broadcast_to(
                            imaging_weights[:, None, :],
                            data.shape,
                            chunks=data.chunks)

                    weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0]
                    weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1]
                else:
                    weightsxx = weights[:, :, 0]
                    weightsyy = weights[:, :, -1]

                # weighted sum corr to Stokes I
                weights = weightsxx + weightsyy
                data = (weightsxx * dataxx + weightsyy * datayy)
                data = da.where(weights, data / weights, 0.0j)

                # only keep data where both corrs are unflagged
                flag = getattr(ds, self.flag_column).data
                flagxx = flag[:, :, 0]
                flagyy = flag[:, :, -1]
                flag = ~(flagxx | flagyy)  # ducc0 convention

                bands = self.band_mapping[ims][spw]
                model = x[list(bands), :, :]
                residual = im2residim(uvw,
                                      freq,
                                      model,
                                      data,
                                      freq_bin_idx,
                                      freq_bin_counts,
                                      self.cell,
                                      weights=weights,
                                      flag=flag.astype(np.uint8),
                                      nthreads=self.nthreads,
                                      epsilon=self.epsilon,
                                      do_wstacking=self.do_wstacking,
                                      double_accum=True)

                residuals.append(residual)

        residuals = dask.compute(residuals)[0]

        return accumulate_dirty(residuals, self.nband,
                                self.band_mapping).astype(self.real_type)
Exemple #28
0
                  in [('antenna', "ANTENNA"),
                      ('ddid', "DATA_DESCRIPTION"),
                      ('spw', "SPECTRAL_WINDOW"),
                      ('pol', "POLARIZATION"),
                      ('field', "FIELD")]}

    with scheduler_context(args):
        # Get datasets from the main MS
        # partition by FIELD_ID and DATA_DESC_ID
        # and sorted by TIME
        datasets = xds_from_ms(args.ms,
                               group_cols=("FIELD_ID", "DATA_DESC_ID"),
                               index_cols="TIME")

        # Get the antenna dataset
        ant_ds = list(xds_from_table(table_name['antenna']))
        assert len(ant_ds) == 1
        ant_ds = ant_ds[0].rename({'row': 'antenna'})

        # Get datasets for DATA_DESCRIPTION, SPECTRAL_WINDOW
        # POLARIZATION and FIELD, partitioned by row
        ddid_ds = list(xds_from_table(table_name['ddid'],
                                      group_cols="__row__"))
        spwds = list(xds_from_table(table_name['spw'],
                                    group_cols="__row__"))
        pds = list(xds_from_table(table_name['pol'],
                                  group_cols="__row__"))
        field_ds = list(xds_from_table(table_name['field'],
                                       group_cols="__row__"))

        # For each partitioned dataset from the main MS,
Exemple #29
0
    def make_psf(self):
        print("Making PSF", file=log)
        psfs = []
        self.stokes_weights = {}
        self.uvws = {}
        for ims in self.ms:
            xds = xds_from_ms(ims,
                              group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                              chunks=self.chunks[ims],
                              columns=self.columns)

            # subtables
            ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
            fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
            spws = xds_from_table(ims + "::SPECTRAL_WINDOW",
                                  group_cols="__row__")
            pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

            # subtable data
            ddids = dask.compute(ddids)[0]
            fields = dask.compute(fields)[0]
            spws = dask.compute(spws)[0]
            pols = dask.compute(pols)[0]
            self.stokes_weights[ims] = {}
            self.uvws[ims] = {}

            for ds in xds:
                field = fields[ds.FIELD_ID]
                radec = field.PHASE_DIR.data.squeeze()
                if not np.array_equal(radec, self.radec):
                    continue

                # this is not correct, need to use spw
                spw = ds.DATA_DESC_ID

                freq_bin_idx = self.freq_bin_idx[ims][spw]
                freq_bin_counts = self.freq_bin_counts[ims][spw]
                freq = self.freq[ims][spw]

                uvw = ds.UVW.data

                flag = getattr(ds, self.flag_column).data

                weights = getattr(ds, self.weight_column).data
                if len(weights.shape) < 3:
                    weights = da.broadcast_to(weights[:, None, :],
                                              flag.shape,
                                              chunks=flag.chunks)

                if self.imaging_weight_column is not None:
                    imaging_weights = getattr(ds,
                                              self.imaging_weight_column).data
                    if len(imaging_weights.shape) < 3:
                        imaging_weights = da.broadcast_to(
                            imaging_weights[:, None, :],
                            flag.shape,
                            chunks=flag.chunks)

                    weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0]
                    weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1]
                else:
                    weightsxx = weights[:, :, 0]
                    weightsyy = weights[:, :, -1]

                # for the PSF we need to scale the weights by the
                # Mueller amplitudes squared
                if self.mueller_column is not None:
                    mueller = getattr(ds, self.mueller_column).data
                    weightsxx *= da.absolute(mueller[:, :, 0])**2
                    weightsyy *= da.absolute(mueller[:, :, -1])**2

                # weighted sum corr to Stokes I
                weights = weightsxx + weightsyy

                # only keep data where both corrs are unflagged
                flagxx = flag[:, :, 0]
                flagyy = flag[:, :, -1]
                flag = ~(flagxx | flagyy)  # ducc0 convention

                weights *= flag

                data = weights.astype(np.complex64)

                psf = vis2im(uvw,
                             freq,
                             data,
                             freq_bin_idx,
                             freq_bin_counts,
                             self.nx_psf,
                             self.ny_psf,
                             self.cell,
                             flag=flag.astype(np.uint8),
                             nthreads=self.nthreads,
                             epsilon=self.epsilon,
                             do_wstacking=self.do_wstacking,
                             double_accum=True)

                psfs.append(psf)

                # assumes that stokes weights and uvw fit into memory
                # self.stokes_weights[ims][spw] = dask.persist(weights.rechunk({0:-1}))[0]
                # self.uvws[ims][spw] = dask.persist(uvw.rechunk({0:-1}))[0]

                # for comparison with numpy implementation
                # self.stokes_weights[ims][spw] = dask.compute(weights)[0]
                # self.uvws[ims][spw] = dask.compute(uvw)[0]

        # import pdb
        # pdb.set_trace()

        psfs = dask.compute(psfs, scheduler='single-threaded')[0]
        return accumulate_dirty(psfs, self.nband,
                                self.band_mapping).astype(self.real_type)
Exemple #30
0
def full_ms_data(ms_name):
    """Load full ms into memory for pytest."""

    return xds_from_table(ms_name)[0]