예제 #1
0
def vis_factory(args, source_type, sky_model, ms, ant, field, spw, pol):
    try:
        source = sky_model[source_type]
    except KeyError:
        raise ValueError("Source type '%s' unsupported" % source_type)

    # Select single dataset rows
    corrs = pol.NUM_CORR.data[0]
    frequency = spw.CHAN_FREQ.data[0]
    phase_dir = field.PHASE_DIR.data[0][0]  # row, poly

    lm = radec_to_lm(source.radec, phase_dir)
    uvw = -ms.UVW.data if args.invert_uvw else ms.UVW.data

    # (source, row, frequency)
    phase = phase_delay(lm, uvw, frequency)

    # (source, spi, corrs)
    # Apply spectral mode to stokes parameters
    stokes = spectral_model(source.stokes,
                            source.spi,
                            source.ref_freq,
                            frequency,
                            base=0)

    brightness = convert(stokes, ["I", "Q", "U", "V"], corr_schema(pol))

    bl_jones_args = ["phase_delay", phase]

    # Add any visibility amplitude terms
    if source_type == "gauss":
        bl_jones_args.append("gauss_shape")
        bl_jones_args.append(gaussian_shape(uvw, frequency, source.shape))

    bl_jones_args.extend(["brightness", brightness])

    # Unique times and time index for each row chunk
    # The index is not global
    meta = np.empty((0, ), dtype=tuple)
    utime_inv = ms.TIME.data.map_blocks(np.unique,
                                        return_inverse=True,
                                        meta=meta,
                                        dtype=tuple)

    # Need unique times for parallactic angles
    nan_chunks = (tuple(np.nan for _ in utime_inv.chunks[0]), )
    utime = utime_inv.map_blocks(getitem,
                                 0,
                                 chunks=nan_chunks,
                                 dtype=ms.TIME.dtype)

    time_idx = utime_inv.map_blocks(getitem, 1, dtype=np.int32)

    jones = baseline_jones_multiply(corrs, *bl_jones_args)
    dde = dde_factory(args, ms, ant, field, pol, lm, utime, frequency)

    return predict_vis(time_idx, ms.ANTENNA1.data, ms.ANTENNA2.data, dde,
                       jones, dde, None, None, None)
예제 #2
0
def vis_factory(args, source_type, sky_model, time_index, ms, field, spw, pol):
    try:
        source = sky_model[source_type]
    except KeyError:
        raise ValueError("Source type '%s' unsupported" % source_type)

    # Select single dataset rows
    corrs = pol.NUM_CORR.data[0]
    frequency = spw.CHAN_FREQ.data[0]
    phase_dir = field.PHASE_DIR.data[0][0]  # row, poly

    lm = radec_to_lm(source.radec, phase_dir)
    uvw = -ms.UVW.data if args.invert_uvw else ms.UVW.data

    # (source, row, frequency)
    phase = phase_delay(lm, uvw, frequency)

    # (source, spi, corrs)
    # Apply spectral mode to stokes parameters
    stokes = spectral_model(source.stokes,
                            source.spi,
                            source.ref_freq,
                            frequency,
                            base=[1, 0, 0, 0])

    brightness = convert(stokes, ["I", "Q", "U", "V"], corr_schema(pol))

    args = ["phase_delay", phase]

    # Add any visibility amplitude terms
    if source_type == "gauss":
        args.append("gauss_shape")
        args.append(gaussian_shape(uvw, frequency, source.shape))

    args.extend(["brightness", brightness])

    jones = baseline_jones_multiply(corrs, *args)

    return predict_vis(time_index, ms.ANTENNA1.data, ms.ANTENNA2.data, None,
                       jones, None, None, None, None)
예제 #3
0
def predict(args):
    # get inclusion regions
    include_regions = []
    exclude_regions = []
    if args.within:
        from regions import read_ds9
        import tempfile
        # kludge because regions cries over "FK5", wants lowercase
        with tempfile.NamedTemporaryFile(mode="w") as tmpfile, open(
                args.within) as regfile:
            tmpfile.write(regfile.read().lower())
            tmpfile.flush()
            include_regions = read_ds9(tmpfile.name)
            log.info("read {} inclusion region(s) from {}".format(
                len(include_regions), args.within))

    # Import source data from WSClean component list
    # See https://sourceforge.net/p/wsclean/wiki/ComponentList
    (comp_type, radec, stokes, spec_coeff, ref_freq, log_spec_ind,
     gaussian_shape) = import_from_wsclean(args.sky_model,
                                           include_regions=include_regions,
                                           exclude_regions=exclude_regions,
                                           point_only=args.points_only,
                                           num=args.num_sources or None)

    # Get the support tables
    tables = support_tables(
        args, ["FIELD", "DATA_DESCRIPTION", "SPECTRAL_WINDOW", "POLARIZATION"])

    field_ds = tables["FIELD"]
    ddid_ds = tables["DATA_DESCRIPTION"]
    spw_ds = tables["SPECTRAL_WINDOW"]
    pol_ds = tables["POLARIZATION"]
    frequencies = np.sort(
        [spw_ds[dd].CHAN_FREQ.data.flatten() for dd in range(len(spw_ds))])

    # cluster sources and refit. This only works for delta scale sources
    def __cluster(comp_type, radec, stokes, spec_coeff, ref_freq, log_spec_ind,
                  gaussian_shape, frequencies):
        uniq_radec = np.unique(radec)
        ncomp_type = []
        nradec = []
        nstokes = []
        nspec_coef = []
        nref_freq = []
        nlog_spec_ind = []
        ngaussian_shape = []

        for urd in uniq_radec:
            print comp_type.shape
            print radec.shape
            deltasel = comp_type[radec == urd] == "POINT"
            polyspecsel = np.logical_not(spec_coef[radec == urd])
            sel = deltasel & polyspecsel
            Is = stokes[sel, 0, None] * frequency[None, :]**0
            for jj in range(spec_coeff.shape[1]):
                Is += spec_coeff[sel, jj, None] * (
                    frequency[None, :] / ref_freq[sel, None] - 1)**(jj + 1)
            Is = np.sum(
                Is, axis=0)  # collapse over all the sources at this position
            logpolyspecsel = np.logical_not(log_spec_coef[radec == urd])
            sel = deltasel & logpolyspecsel

            Is = np.log(stokes[sel, 0, None] * frequency[None, :]**0)
            for jj in range(spec_coeff.shape[1]):
                Is += spec_coeff[sel, jj, None] * da.log(
                    (frequency[None, :] / ref_freq[sel, None])**(jj + 1))
            Is = np.exp(Is)
            Islogpoly = np.sum(
                Is, axis=0)  # collapse over all the sources at this position

            popt, pfitvar = curve_fit(
                lambda i, a, b, c, d: i + a *
                (frequency / ref_freq[0, None] - 1) + b *
                (frequency / ref_freq[0, None] - 1)**2 + c *
                (frequency / ref_freq[sel, None] - 1)**3 + d *
                (frequency / ref_freq[0, None] - 1)**3, frequency,
                Ispoly + Islogpoly)
            if not np.all(np.isfinite(pfitvar)):
                popt[0] = np.sum(stokes[sel, 0, None], axis=0)
                popt[1:] = np.inf
                log.warn(
                    "Refitting at position {0:s} failed. Assuming flat spectrum source of {1:.2f} Jy"
                    .format(radec, popt[0]))
            else:
                pcov = np.sqrt(np.diag(pfitvar))
                log.info(
                    "New fitted flux {0:.3f} Jy at position {1:s} with covariance {2:s}"
                    .format(popt[0], radec,
                            ", ".join([str(poptp) for poptp in popt])))

            ncomp_type.append("POINT")
            nradec.append(urd)
            nstokes.append(popt[0])
            nspec_coef.append(popt[1:])
            nref_freq.append(ref_freq[0])
            nlog_spec_ind = 0.0

        # add back all the gaussians
        sel = comp_type[radec] == "GAUSSIAN"
        for rd, stks, spec, ref, lspec, gs in zip(radec[sel], stokes[sel],
                                                  spec_coef[sel],
                                                  ref_freq[sel],
                                                  log_spec_ind[sel],
                                                  gaussian_shape[sel]):
            ncomp_type.append("GAUSSIAN")
            nradec.append(rd)
            nstokes.append(stks)
            nspec_coef.append(spec)
            nref_freq.append(ref)
            nlog_spec_ind.append(lspec)
            ngaussian_shape.append(gs)

        log.info(
            "Reduced {0:d} components to {1:d} components through by refitting"
            .format(len(comp_type), len(ncomp_type)))
        return (np.array(ncomp_type), np.array(nradec), np.array(nstokes),
                np.array(nspec_coeff), np.array(nref_freq),
                np.array(nlog_spec_ind), np.array(ngaussian_shape))

    if not args.dontcluster:
        (comp_type, radec, stokes, spec_coeff, ref_freq, log_spec_ind,
         gaussian_shape) = __cluster(comp_type, radec, stokes, spec_coeff,
                                     ref_freq, log_spec_ind, gaussian_shape,
                                     frequencies)

    # Add output column if it isn't present
    ms_rows, ms_datatype = ms_preprocess(args)

    # sort out resources
    args.row_chunks, args.model_chunks = get_budget(
        comp_type.shape[0], ms_rows, max([ss.NUM_CHAN.data for ss in spw_ds]),
        max([ss.NUM_CORR.data for ss in pol_ds]), ms_datatype, args)

    radec = da.from_array(radec, chunks=(args.model_chunks, 2))
    stokes = da.from_array(stokes, chunks=(args.model_chunks, 4))

    if np.count_nonzero(comp_type == 'GAUSSIAN') > 0:
        gaussian_components = True
        gshape_chunks = (args.model_chunks, 3)
        gaussian_shape = da.from_array(gaussian_shape, chunks=gshape_chunks)
    else:
        gaussian_components = False

    if args.spectra:
        spec_chunks = (args.model_chunks, spec_coeff.shape[1])
        spec_coeff = da.from_array(spec_coeff, chunks=spec_chunks)
        ref_freq = da.from_array(ref_freq, chunks=(args.model_chunks, ))

    # List of write operations
    writes = []

    # Construct a graph for each DATA_DESC_ID
    for xds in xds_from_ms(args.ms,
                           columns=["UVW", "ANTENNA1", "ANTENNA2", "TIME"],
                           group_cols=["FIELD_ID", "DATA_DESC_ID"],
                           chunks={"row": args.row_chunks}):
        if xds.attrs['FIELD_ID'] != args.fieldid:
            continue

        # Extract frequencies from the spectral window associated
        # with this data descriptor id
        field = field_ds[xds.attrs['FIELD_ID']]
        ddid = ddid_ds[xds.attrs['DATA_DESC_ID']]
        spw = spw_ds[ddid.SPECTRAL_WINDOW_ID.values]
        pol = pol_ds[ddid.POLARIZATION_ID.values]
        frequency = spw.CHAN_FREQ.data

        corrs = pol.NUM_CORR.values

        lm = radec_to_lm(radec, field.PHASE_DIR.data)

        if args.exp_sign_convention == 'casa':
            uvw = -xds.UVW.data
        elif args.exp_sign_convention == 'thompson':
            uvw = xds.UVW.data
        else:
            raise ValueError("Invalid sign convention '%s'" % args.sign)

        if args.spectra:
            # flux density at reference frequency ...
            # ... for logarithmic polynomial functions
            if log_spec_ind:
                Is = da.log(stokes[:, 0, None]) * frequency[None, :]**0
                # ... or for ordinary polynomial functions
            else:
                Is = stokes[:, 0, None] * frequency[None, :]**0
            # additional terms of SED ...
            for jj in range(spec_coeff.shape[1]):
                # ... for logarithmic polynomial functions
                if log_spec_ind:
                    Is += spec_coeff[:, jj, None] * da.log(
                        (frequency[None, :] / ref_freq[:, None])**(jj + 1))
                    # ... or for ordinary polynomial functions
                else:
                    Is += spec_coeff[:, jj, None] * (
                        frequency[None, :] / ref_freq[:, None] - 1)**(jj + 1)
            if log_spec_ind: Is = da.exp(Is)
            Qs = da.zeros_like(Is)
            Us = da.zeros_like(Is)
            Vs = da.zeros_like(Is)
            spectrum = da.stack(
                [Is, Qs, Us, Vs], axis=-1
            )  # stack along new axis and make it the last axis of the new array
            spectrum = spectrum.rechunk(spectrum.chunks[:2] +
                                        (spectrum.shape[2], ))

        log.info('-------------------------------------------')
        log.info('Nr sources        = {0:d}'.format(stokes.shape[0]))
        log.info('-------------------------------------------')
        log.info('stokes.shape      = {0:}'.format(stokes.shape))
        log.info('frequency.shape   = {0:}'.format(frequency.shape))
        if args.spectra: log.info('Is.shape          = {0:}'.format(Is.shape))
        if args.spectra:
            log.info('spectrum.shape    = {0:}'.format(spectrum.shape))

        # (source, row, frequency)
        phase = phase_delay(lm, uvw, frequency)
        # If at least one Gaussian component is present in the component list then all
        # sources are modelled as Gaussian components (Delta components have zero width)
        if gaussian_components:
            phase *= gaussian(uvw, frequency, gaussian_shape)
        # (source, frequency, corr_products)
        brightness = convert(spectrum if args.spectra else stokes,
                             ["I", "Q", "U", "V"], corr_schema(pol))

        log.info('brightness.shape  = {0:}'.format(brightness.shape))
        log.info('phase.shape       = {0:}'.format(phase.shape))
        log.info('-------------------------------------------')
        log.info('Attempting phase-brightness einsum with "{0:s}"'.format(
            einsum_schema(pol, args.spectra)))

        # (source, row, frequency, corr_products)
        jones = da.einsum(einsum_schema(pol, args.spectra), phase, brightness)
        log.info('jones.shape       = {0:}'.format(jones.shape))
        log.info('-------------------------------------------')
        if gaussian_components: log.info('Some Gaussian sources found')
        else: log.info('All sources are Delta functions')
        log.info('-------------------------------------------')

        # Identify time indices
        _, time_index = da.unique(xds.TIME.data, return_inverse=True)

        # Predict visibilities
        vis = predict_vis(time_index, xds.ANTENNA1.data, xds.ANTENNA2.data,
                          None, jones, None, None, None, None)

        # Reshape (2, 2) correlation to shape (4,)
        if corrs == 4:
            vis = vis.reshape(vis.shape[:2] + (4, ))

        # Assign visibilities to MODEL_DATA array on the dataset
        model_data = xr.DataArray(vis, dims=["row", "chan", "corr"])
        xds = xds.assign(**{args.output_column: model_data})
        # Create a write to the table
        write = xds_to_table(xds, args.ms, [args.output_column])
        # Add to the list of writes
        writes.append(write)

    # Submit all graph computations in parallel
    if args.num_workers:
        with ProgressBar(), dask.config.set(num_workers=args.num_workers):
            dask.compute(writes)
    else:
        with ProgressBar():
            dask.compute(writes)
예제 #4
0
def predict(args):
    # Numpy arrays

    # Convert source data into dask arrays
    radec, stokes = parse_sky_model(args.sky_model)
    radec = da.from_array(radec, chunks=(SOURCE_CHUNKS, 2))
    stokes = da.from_array(stokes, chunks=(SOURCE_CHUNKS, 4))

    # Get the support tables
    tables = support_tables(args, ["FIELD", "DATA_DESCRIPTION",
                                   "SPECTRAL_WINDOW", "POLARIZATION"])

    field_ds = tables["FIELD"]
    ddid_ds = tables["DATA_DESCRIPTION"]
    spw_ds = tables["SPECTRAL_WINDOW"]
    pol_ds = tables["POLARIZATION"]

    # List of write operations
    writes = []

    # Construct a graph for each DATA_DESC_ID
    for xds in xds_from_ms(args.ms,
                           columns=["UVW", "ANTENNA1", "ANTENNA2", "TIME"],
                           group_cols=["FIELD_ID", "DATA_DESC_ID"],
                           chunks={"row": args.row_chunks}):

        # Extract frequencies from the spectral window associated
        # with this data descriptor id
        field = field_ds[xds.attrs['FIELD_ID']]
        ddid = ddid_ds[xds.attrs['DATA_DESC_ID']]
        spw = spw_ds[ddid.SPECTRAL_WINDOW_ID.values]
        pol = pol_ds[ddid.POLARIZATION_ID.values]
        frequency = spw.CHAN_FREQ.data

        corrs = pol.NUM_CORR.values

        lm = radec_to_lm(radec, field.PHASE_DIR.data)
        uvw = -xds.UVW.data if args.invert_uvw else xds.UVW.data

        # (source, row, frequency)
        phase = phase_delay(lm, uvw, frequency)

        brightness = convert(stokes, ["I", "Q", "U", "V"],
                             corr_schema(pol))

        # (source, row, frequency, corr1, corr2)
        jones = da.einsum(einsum_schema(pol), phase, brightness)

        # Identify time indices
        _, time_index = da.unique(xds.TIME.data, return_inverse=True)

        # Predict visibilities
        vis = predict_vis(time_index, xds.ANTENNA1.data, xds.ANTENNA2.data,
                          None, jones, None, None, None, None)

        # Reshape (2, 2) correlation to shape (4,)
        if corrs == 4:
            vis = vis.reshape(vis.shape[:2] + (4,))

        # Assign visibilities to MODEL_DATA array on the dataset
        model_data = xr.DataArray(vis, dims=["row", "chan", "corr"])
        xds = xds.assign(MODEL_DATA=model_data)
        # Create a write to the table
        write = xds_to_table(xds, args.ms, ['MODEL_DATA'])
        # Add to the list of writes
        writes.append(write)

    # Submit all graph computations in parallel
    with ProgressBar():
        dask.compute(writes)
예제 #5
0
def _predict(args):
    import pkg_resources
    version = pkg_resources.get_distribution("crystalball").version
    log.info("Crystalball version {0}", version)

    # get inclusion regions
    include_regions = load_regions(args.within) if args.within else []

    # Import source data from WSClean component list
    # See https://sourceforge.net/p/wsclean/wiki/ComponentList
    source_model = import_from_wsclean(args.sky_model,
                                       include_regions=include_regions,
                                       point_only=args.points_only,
                                       num=args.num_sources or None)

    # Add output column if it isn't present
    ms_rows, ms_datatype = ms_preprocess(args)

    # Get the support tables
    tables = support_tables(
        args, ["FIELD", "DATA_DESCRIPTION", "SPECTRAL_WINDOW", "POLARIZATION"])

    field_ds = tables["FIELD"]
    ddid_ds = tables["DATA_DESCRIPTION"]
    spw_ds = tables["SPECTRAL_WINDOW"]
    pol_ds = tables["POLARIZATION"]

    max_num_chan = max([ss.NUM_CHAN.data[0] for ss in spw_ds])
    max_num_corr = max([ss.NUM_CORR.data[0] for ss in pol_ds])

    # Perform resource budgeting
    nsources = source_model.source_type.shape[0]
    args.row_chunks, args.model_chunks = get_budget(nsources, ms_rows,
                                                    max_num_chan, max_num_corr,
                                                    ms_datatype, args)

    source_model = source_model_to_dask(source_model, args.model_chunks)

    # List of write operations
    writes = []

    datasets = xds_from_ms(args.ms,
                           columns=["UVW", "ANTENNA1", "ANTENNA2", "TIME"],
                           group_cols=["FIELD_ID", "DATA_DESC_ID"],
                           chunks={"row": args.row_chunks})

    field_id = select_field_id(field_ds, args.field)

    for xds in filter_datasets(datasets, field_id):
        # Extract frequencies from the spectral window associated
        # with this data descriptor id
        field = field_ds[xds.attrs['FIELD_ID']]
        ddid = ddid_ds[xds.attrs['DATA_DESC_ID']]
        spw = spw_ds[ddid.SPECTRAL_WINDOW_ID.data[0]]
        pol = pol_ds[ddid.POLARIZATION_ID.data[0]]
        frequency = spw.CHAN_FREQ.data[0]

        lm = radec_to_lm(source_model.radec, field.PHASE_DIR.data[0][0])

        with warnings.catch_warnings():
            # Ignore dask chunk warnings emitted when going from 1D
            # inputs to a 2D space of chunks
            warnings.simplefilter('ignore', category=PerformanceWarning)
            vis = wsclean_predict(xds.UVW.data, lm, source_model.source_type,
                                  source_model.flux, source_model.spi,
                                  source_model.log_poly, source_model.ref_freq,
                                  source_model.gauss_shape, frequency)

        vis = fill_correlations(vis, pol)

        log.info('Field {0} DDID {1:d} rows {2} chans {3} corrs {4}',
                 field.NAME.values[0], xds.DATA_DESC_ID, vis.shape[0],
                 vis.shape[1], vis.shape[2])

        # Assign visibilities to MODEL_DATA array on the dataset
        xds = xds.assign(
            **{args.output_column: (("row", "chan", "corr"), vis)})
        # Create a write to the table
        write = xds_to_table(xds, args.ms, [args.output_column])
        # Add to the list of writes
        writes.append(write)

    with ExitStack() as stack:
        if sys.stdout.isatty():
            # Default progress bar in user terminal
            stack.enter_context(EstimatingProgressBar())
        else:
            # Log progress every 5 minutes
            stack.enter_context(EstimatingProgressBar(minimum=2 * 60, dt=5))

        # Submit all graph computations in parallel
        dask.compute(writes)

    log.info("Finished")
예제 #6
0
def _predict(args):
    # get inclusion regions
    include_regions = load_regions(args.within) if args.within else []

    # Import source data from WSClean component list
    # See https://sourceforge.net/p/wsclean/wiki/ComponentList
    (comp_type, radec, stokes, spec_coeff, ref_freq, log_spec_ind,
     gaussian_shape) = import_from_wsclean(args.sky_model,
                                           include_regions=include_regions,
                                           point_only=args.points_only,
                                           num=args.num_sources or None)

    # Add output column if it isn't present
    ms_rows, ms_datatype = ms_preprocess(args)

    # Get the support tables
    tables = support_tables(
        args, ["FIELD", "DATA_DESCRIPTION", "SPECTRAL_WINDOW", "POLARIZATION"])

    field_ds = tables["FIELD"]
    ddid_ds = tables["DATA_DESCRIPTION"]
    spw_ds = tables["SPECTRAL_WINDOW"]
    pol_ds = tables["POLARIZATION"]

    max_num_chan = max([ss.NUM_CHAN.data[0] for ss in spw_ds])
    max_num_corr = max([ss.NUM_CORR.data[0] for ss in pol_ds])

    # Perform resource budgeting
    args.row_chunks, args.model_chunks = get_budget(comp_type.shape[0],
                                                    ms_rows, max_num_chan,
                                                    max_num_corr, ms_datatype,
                                                    args)

    radec = da.from_array(radec, chunks=(args.model_chunks, 2))
    stokes = da.from_array(stokes, chunks=(args.model_chunks, 4))

    if np.count_nonzero(comp_type == 'GAUSSIAN') > 0:
        gaussian_components = True
        gshape_chunks = (args.model_chunks, 3)
        gaussian_shape = da.from_array(gaussian_shape, chunks=gshape_chunks)
    else:
        gaussian_components = False

    if args.spectra:
        spec_chunks = (args.model_chunks, spec_coeff.shape[1])
        spec_coeff = da.from_array(spec_coeff, chunks=spec_chunks)
        ref_freq = da.from_array(ref_freq, chunks=(args.model_chunks, ))

    # List of write operations
    writes = []

    # Construct a graph for each FIELD and DATA DESCRIPTOR
    datasets = xds_from_ms(args.ms,
                           columns=["UVW", "ANTENNA1", "ANTENNA2", "TIME"],
                           group_cols=["FIELD_ID", "DATA_DESC_ID"],
                           chunks={"row": args.row_chunks})

    select_fields = valid_field_ids(field_ds, args.fields)

    for xds in filter_datasets(datasets, select_fields):
        # Extract frequencies from the spectral window associated
        # with this data descriptor id
        field = field_ds[xds.attrs['FIELD_ID']]
        ddid = ddid_ds[xds.attrs['DATA_DESC_ID']]
        spw = spw_ds[ddid.SPECTRAL_WINDOW_ID.data[0]]
        pol = pol_ds[ddid.POLARIZATION_ID.data[0]]
        frequency = spw.CHAN_FREQ.data[0]

        corrs = pol.NUM_CORR.values

        lm = radec_to_lm(radec, field.PHASE_DIR.data[0][0])

        if args.exp_sign_convention == 'casa':
            uvw = -xds.UVW.data
        elif args.exp_sign_convention == 'thompson':
            uvw = xds.UVW.data
        else:
            raise ValueError("Invalid sign convention '%s'" % args.sign)

        if args.spectra:
            # flux density at reference frequency ...
            # ... for logarithmic polynomial functions
            if log_spec_ind:
                Is = da.log(stokes[:, 0, None]) * frequency[None, :]**0
            # ... or for ordinary polynomial functions
            else:
                Is = stokes[:, 0, None] * frequency[None, :]**0
            # additional terms of SED ...
            for jj in range(spec_coeff.shape[1]):
                # ... for logarithmic polynomial functions
                if log_spec_ind:
                    Is += spec_coeff[:, jj, None] * \
                        da.log((frequency[None, :]/ref_freq[:, None])**(jj+1))
                # ... or for ordinary polynomial functions
                else:
                    Is += spec_coeff[:, jj, None] * \
                        (frequency[None, :]/ref_freq[:, None]-1)**(jj+1)
            if log_spec_ind:
                Is = da.exp(Is)
            Qs = da.zeros_like(Is)
            Us = da.zeros_like(Is)
            Vs = da.zeros_like(Is)
            # stack along new axis and make it the last axis of the new array
            spectrum = da.stack([Is, Qs, Us, Vs], axis=-1)
            spectrum = spectrum.rechunk(spectrum.chunks[:2] +
                                        (spectrum.shape[2], ))

        print('-------------------------------------------')
        print('Nr sources        = {0:d}'.format(stokes.shape[0]))
        print('-------------------------------------------')
        print('stokes.shape      = {0:}'.format(stokes.shape))
        print('frequency.shape   = {0:}'.format(frequency.shape))
        if args.spectra:
            print('Is.shape          = {0:}'.format(Is.shape))
        if args.spectra:
            print('spectrum.shape    = {0:}'.format(spectrum.shape))

        # (source, row, frequency)
        phase = phase_delay(lm, uvw, frequency)
        # If at least one Gaussian component is present in the component
        # list then all sources are modelled as Gaussian components
        # (Delta components have zero width)
        if gaussian_components:
            phase *= gaussian(uvw, frequency, gaussian_shape)
        # (source, frequency, corr_products)
        brightness = convert(spectrum if args.spectra else stokes,
                             ["I", "Q", "U", "V"], corr_schema(pol))

        print('brightness.shape  = {0:}'.format(brightness.shape))
        print('phase.shape       = {0:}'.format(phase.shape))
        print('-------------------------------------------')
        print('Attempting phase-brightness einsum with "{0:s}"'.format(
            einsum_schema(pol, args.spectra)))

        # (source, row, frequency, corr_products)
        jones = da.einsum(einsum_schema(pol, args.spectra), phase, brightness)
        print('jones.shape       = {0:}'.format(jones.shape))
        print('-------------------------------------------')
        if gaussian_components:
            print('Some Gaussian sources found')
        else:
            print('All sources are Delta functions')
        print('-------------------------------------------')

        # Identify time indices
        _, time_index = da.unique(xds.TIME.data, return_inverse=True)

        # Predict visibilities
        vis = predict_vis(time_index, xds.ANTENNA1.data, xds.ANTENNA2.data,
                          None, jones, None, None, None, None)

        # Reshape (2, 2) correlation to shape (4,)
        if corrs == 4:
            vis = vis.reshape(vis.shape[:2] + (4, ))

        # Assign visibilities to MODEL_DATA array on the dataset
        xds = xds.assign(
            **{args.output_column: (("row", "chan", "corr"), vis)})
        # Create a write to the table
        write = xds_to_table(xds, args.ms, [args.output_column])
        # Add to the list of writes
        writes.append(write)

    with ExitStack() as stack:
        if sys.stdout.isatty():
            # Default progress bar in user terminal
            stack.enter_context(ProgressBar())
        else:
            # Log progress every 5 minutes
            stack.enter_context(ProgressBar(minimum=2 * 60, dt=5))

        # Submit all graph computations in parallel
        dask.compute(writes)