def broadcast_and_rechunk(array_tuple, chunks=None):
    dest_array_id = max_array_id(array_tuple)
    dest_shape = array_tuple[dest_array_id].shape

    if chunks is None:
        chunks = array_tuple[dest_array_id].chunks

    broadcast_list = [None] * len(array_tuple)
    for array_id, array in enumerate(array_tuple):
        try:
            arr_chunks = array.chunks
            if arr_chunks != chunks:
                broadcast_list[array_id] = (da.broadcast_to(
                    add_dims_to_broadcast(array, dest_shape),
                    dest_shape).rechunk(chunks))
            elif dest_shape != array.shape:
                broadcast_list[array_id] = (da.broadcast_to(
                    add_dims_to_broadcast(array, dest_shape), dest_shape))
            else:
                broadcast_list[array_id] = array
        except AttributeError:
            broadcast_list[array_id] = da.from_array(
                array,
                chunks=remove_chunks_from_shape(array, dest_shape, chunks))
            if dest_shape != array.shape:
                broadcast_list[array_id] = (da.broadcast_to(
                    add_dims_to_broadcast(broadcast_list[array_id],
                                          dest_shape),
                    dest_shape).rechunk(chunks))
    return broadcast_list
def _get_zhai_data_frame(datasets, lat_constraint):
    """Get :class:`pandas.DataFrame` including the data for ``zhai``."""
    cl_cube = _get_cube(datasets, 'cl')
    wap_cube = _get_cube(datasets, 'wap')
    tos_cube = _get_cube(datasets, 'tos')

    # Add air_pressure coordinate if necessary
    if not cl_cube.coords('air_pressure'):
        if cl_cube.coords('altitude'):
            add_plev_from_altitude(cl_cube)
        elif cl_cube.coords('atmosphere_sigma_coordinate'):
            add_sigma_factory(cl_cube)
        else:
            raise ValueError(f"No 'air_pressure' coord available in cube "
                             f"{cl_cube.summary(shorten=True)}")

    # Apply common mask (only ocean)
    mask_2d = da.ma.getmaskarray(tos_cube.core_data())
    mask_3d = mask_2d[:, np.newaxis, ...]
    mask_3d = da.broadcast_to(mask_3d, cl_cube.shape)
    wap_cube.data = da.ma.masked_array(wap_cube.core_data(), mask=mask_2d)
    cl_cube.data = da.ma.masked_array(cl_cube.core_data(), mask=mask_3d)

    # Calculate SST mean and MBLC fraction
    tos_cube = _get_mean_over_subsidence(tos_cube, wap_cube, lat_constraint)
    mblc_cube = _get_seasonal_mblc_fraction(cl_cube, wap_cube, lat_constraint)
    return pd.DataFrame(
        {
            'tos': tos_cube.data,
            'mblc_fraction': mblc_cube.data
        },
        index=pd.Index(np.arange(12) + 1, name='month'),
    )
Exemplo n.º 3
0
 def compute(fieldset):
     # Calculating vertical weighted average
     for f in [fieldset.U, fieldset.V]:
         for tind in f.loaded_time_indices:
             data = da.sum(f.data[tind, :] * DZ, axis=0) / sum(dz)
             data = da.broadcast_to(data, (1, f.grid.zdim, f.grid.ydim, f.grid.xdim))
             f.data = f.data_concatenate(f.data, data, tind)
def tsLab(series, seg, label):
    segnp = sitk.GetArrayFromImage(seg)
    # The mask indicates which entries to ignore.
    m3 = da.broadcast_to(segnp != label, series.shape)
    smk = da.ma.masked_array(series, m3)
    p = smk.mean(axis=[2, 3])
    l = p.compute()
    return l
Exemplo n.º 5
0
def _mask_cube(cube, selections):
    cubelist = iris.cube.CubeList()
    for id_, select in selections.items():
        _cube = cube.copy()
        _cube.add_aux_coord(
            iris.coords.AuxCoord(id_, units='no_unit', long_name="shape_id"))
        select = da.broadcast_to(select, _cube.shape)
        _cube.data = da.ma.masked_where(~select, _cube.core_data())
        cubelist.append(_cube)
    return fix_coordinate_ordering(cubelist.merge_cube())
Exemplo n.º 6
0
def add_cell_measure(cube, fx_cube, measure):
    """
    Broadcast fx_cube and add it as a cell_measure in
    the cube containing the data.

    Parameters
    ----------
    cube: iris.cube.Cube
        Iris cube with input data.
    fx_cube: iris.cube.Cube
        Iris cube with fx data.
    measure: str
        Name of the measure, can be 'area' or 'volume'.

    Returns
    -------
    iris.cube.Cube
        Cube with added ancillary variables

    Raises
    ------
    ValueError
        If measure name is not 'area' or 'volume'.
    ValueError
        If fx_cube cannot be broadcast to cube.
    """
    if measure not in ['area', 'volume']:
        raise ValueError(f"measure name must be 'area' or 'volume', "
                         f"got {measure} instead")
    try:
        fx_data = da.broadcast_to(fx_cube.core_data(), cube.shape)
    except ValueError as exc:
        raise ValueError(f"Dimensions of {cube.var_name} and "
                         f"{fx_cube.var_name} cubes do not match. "
                         "Cannot broadcast cubes.") from exc
    measure = iris.coords.CellMeasure(fx_data,
                                      standard_name=fx_cube.standard_name,
                                      units=fx_cube.units,
                                      measure=measure,
                                      var_name=fx_cube.var_name,
                                      attributes=fx_cube.attributes)
    cube.add_cell_measure(measure, range(0, measure.ndim))
    logger.debug('Added %s as cell measure in cube of %s.', fx_cube.var_name,
                 cube.var_name)
Exemplo n.º 7
0
def diag_dot(diag_array, x, return_diag=False):
    """Computes dot product between diag_array and x

    Parameters
    ----------
    diag_array : array_like, shape (K, ) or (K, 1)
                Diagonal entries of a diagonal maitrx
                [d1, d2, d3, ..., dk] -> [d1*e1, d2*e2, d2*e3, ..., dk*ek], where ei is the ith unit column vector
    x : array_like, shape (K, ...)
    return_diag : boolean
        If return_diag is True, return broadcasted array prepped for operation

    Returns
    -------
    out : array_like, shape of x
    """
    if len(x.shape) not in [1, 2]:
        raise ValueError(
            "x must have (M, K) or (K, ). Current Shape = {}".format(x.shape))
    if diag_array.shape[0] != x.shape[0]:
        raise ValueError(
            'shapes {} and {} not aligned: {} (dim 0 and 1) != {} (dim 0)'.
            format(diag_array.shape, x.shape, diag_array.shape[0], x.shape[0]))
    if len(diag_array.shape) not in [1, 2]:
        raise ValueError(
            'diag_array must have dimension (K, ) or (K, 1). Current shape = {}'
            .format(diag_array.shape))

    if len(x.shape) == 1:
        if len(diag_array.shape) == 2:
            d = np.squeeze(diag_array)
        else:
            d = diag_array
    else:
        if len(diag_array.shape) == 1:
            d = diag_array[:, np.newaxis]
        else:
            d = diag_array
        d = broadcast_to(d, x.shape)
    if return_diag:
        return d
    else:
        return np.multiply(d, x)
Exemplo n.º 8
0
    def coriolis(self, lats, ndim):
        """Compute Coriolis force.

        Parameters
        ----------
        lats: iris.coord.Coord
            Latitude coordinate.
        ndim: int
            Number of dimension.

        Returns
        -------
        fcor: da.array
            Array containing Coriolis force.
        """
        fcor = 2.0 * self.omega * np.sin(np.radians(lats.points))
        fcor = fcor[np.newaxis, np.newaxis, :, np.newaxis]
        fcor = da.broadcast_to(fcor, ndim)

        return fcor
Exemplo n.º 9
0
def get_time_weights(cube):
    """Compute the weighting of the time axis.

    Parameters
    ----------
    cube: iris.cube.Cube
        input cube.

    Returns
    -------
    numpy.array
        Array of time weights for averaging.
    """
    time = cube.coord('time')
    time_weights = time.bounds[..., 1] - time.bounds[..., 0]
    time_weights = time_weights.squeeze()
    if time_weights.shape == ():
        time_weights = da.broadcast_to(time_weights, cube.shape)
    else:
        time_weights = iris.util.broadcast_to_shape(time_weights, cube.shape,
                                                    cube.coord_dims('time'))
    return time_weights
Exemplo n.º 10
0
    def apply(X: Array, YP: Array, BX: Array, BYP: Array) -> Array:
        # Collapse selected variant blocks and alphas into single
        # new covariate dimension
        assert YP.shape[2] == BYP.shape[2]
        n_group_covar = n_covar + BYP.shape[2] * n_alpha_1

        BYP = BYP.reshape((n_outcome, n_sample_block, -1))
        BG = da.concatenate((BX, BYP), axis=-1)
        BG = BG.rechunk((-1, None, -1))
        assert_block_shape(BG, 1, n_sample_block, 1)
        assert_chunk_shape(BG, n_outcome, 1, n_group_covar)
        assert_array_shape(BG, n_outcome, n_sample_block, n_group_covar)

        YP = YP.reshape((n_outcome, n_sample, -1))
        XYP = da.broadcast_to(X, (n_outcome, n_sample, n_covar))
        XG = da.concatenate((XYP, YP), axis=-1)
        XG = XG.rechunk((-1, None, -1))
        assert_block_shape(XG, 1, n_sample_block, 1)
        assert_chunk_shape(XG, n_outcome, sample_chunks[0], n_group_covar)
        assert_array_shape(XG, n_outcome, n_sample, n_group_covar)

        YG = da.map_blocks(
            # Block chunks:
            # (n_outcome, sample_chunks[0], n_group_covar) @
            # (n_outcome, n_group_covar, 1) [after transpose]
            lambda x, b: x @ b.transpose((0, 2, 1)),
            XG,
            BG,
            chunks=(n_outcome, sample_chunks, 1),
        )
        assert_block_shape(YG, 1, n_sample_block, 1)
        assert_chunk_shape(YG, n_outcome, sample_chunks[0], 1)
        assert_array_shape(YG, n_outcome, n_sample, 1)
        YG = da.squeeze(YG, axis=-1).T
        assert_block_shape(YG, n_sample_block, 1)
        assert_chunk_shape(YG, sample_chunks[0], n_outcome)
        assert_array_shape(YG, n_sample, n_outcome)
        return YG
Exemplo n.º 11
0
def push(array, n, axis):
    """
    Dask-aware bottleneck.push
    """
    import bottleneck
    import dask.array as da
    import numpy as np

    def _fill_with_last_one(a, b):
        # cumreduction apply the push func over all the blocks first so, the only missing part is filling
        # the missing values using the last data of the previous chunk
        return np.where(~np.isnan(b), b, a)

    if n is not None and 0 < n < array.shape[axis] - 1:
        arange = da.broadcast_to(
            da.arange(array.shape[axis],
                      chunks=array.chunks[axis],
                      dtype=array.dtype).reshape(
                          tuple(size if i == axis else 1
                                for i, size in enumerate(array.shape))),
            array.shape,
            array.chunks,
        )
        valid_arange = da.where(da.notnull(array), arange, np.nan)
        valid_limits = (arange - push(valid_arange, None, axis)) <= n
        # omit the forward fill that violate the limit
        return da.where(valid_limits, push(array, None, axis), np.nan)

    # The method parameter makes that the tests for python 3.7 fails.
    return da.reductions.cumreduction(
        func=bottleneck.push,
        binop=_fill_with_last_one,
        ident=np.nan,
        x=array,
        axis=axis,
        dtype=array.dtype,
    )
Exemplo n.º 12
0
def add_ancillary_variable(cube, fx_cube):
    """
    Broadcast fx_cube and add it as an ancillary_variable in
    the cube containing the data.

    Parameters
    ----------
    cube: iris.cube.Cube
        Iris cube with input data.
    fx_cube: iris.cube.Cube
        Iris cube with fx data.

    Returns
    -------
    iris.cube.Cube
        Cube with added ancillary variables

    Raises
    ------
    ValueError
        If fx_cube cannot be broadcast to cube.
    """
    try:
        fx_data = da.broadcast_to(fx_cube.core_data(), cube.shape)
    except ValueError as exc:
        raise ValueError(f"Dimensions of {cube.var_name} and "
                         f"{fx_cube.var_name} cubes do not match. "
                         "Cannot broadcast cubes.") from exc
    ancillary_var = iris.coords.AncillaryVariable(
        fx_data,
        standard_name=fx_cube.standard_name,
        units=fx_cube.units,
        var_name=fx_cube.var_name,
        attributes=fx_cube.attributes)
    cube.add_ancillary_variable(ancillary_var, range(0, ancillary_var.ndim))
    logger.debug('Added %s as ancillary variable in cube of %s.',
                 fx_cube.var_name, cube.var_name)
Exemplo n.º 13
0
def _psf(**kw):
    args = OmegaConf.create(kw)
    from omegaconf import ListConfig
    if not isinstance(args.ms, list) and not isinstance(args.ms, ListConfig):
        args.ms = [args.ms]
    OmegaConf.set_struct(args, True)

    import numpy as np
    from pfb.utils.misc import chan_to_band_mapping
    import dask
    # from dask.distributed import performance_report
    from dask.graph_manipulation import clone
    from daskms import xds_from_storage_ms as xds_from_ms
    from daskms import xds_from_storage_table as xds_from_table
    from daskms import Dataset
    from daskms.experimental.zarr import xds_to_zarr
    import dask.array as da
    from africanus.constants import c as lightspeed
    from africanus.gridding.wgridder.dask import dirty as vis2im
    from ducc0.fft import good_size
    from pfb.utils.misc import stitch_images, plan_row_chunk
    from pfb.utils.fits import set_wcs, save_fits

    # chan <-> band mapping
    ms = args.ms
    nband = args.nband
    freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping(
        ms, nband=nband)

    # gridder memory budget
    max_chan_chunk = 0
    max_freq = 0
    for ims in args.ms:
        for spw in freqs[ims]:
            counts = freq_bin_counts[ims][spw].compute()
            freq = freqs[ims][spw].compute()
            max_chan_chunk = np.maximum(max_chan_chunk, counts.max())
            max_freq = np.maximum(max_freq, freq.max())

    # assumes measurement sets have the same columns,
    # number of correlations etc.
    xds = xds_from_ms(args.ms[0])
    ncorr = xds[0].dims['corr']
    nrow = xds[0].dims['row']
    # we still have to cater for complex valued data because we cast
    # the weights to complex but we not longer need to factor the
    # weight column into our memory budget
    data_bytes = getattr(xds[0], args.data_column).data.itemsize
    bytes_per_row = max_chan_chunk * ncorr * data_bytes
    memory_per_row = bytes_per_row

    # flags (uint8 or bool)
    memory_per_row += bytes_per_row / 8

    # UVW
    memory_per_row += xds[0].UVW.data.itemsize * 3

    # ANTENNA1/2
    memory_per_row += xds[0].ANTENNA1.data.itemsize * 2

    # TIME
    memory_per_row += xds[0].TIME.data.itemsize

    # data column is not actually read into memory just used to infer
    # dtype and chunking
    columns = (args.data_column, args.weight_column, args.flag_column, 'UVW',
               'ANTENNA1', 'ANTENNA2', 'TIME')

    # flag row
    if 'FLAG_ROW' in xds[0]:
        columns += ('FLAG_ROW', )
        memory_per_row += xds[0].FLAG_ROW.data.itemsize

    # imaging weights
    if args.imaging_weight_column is not None:
        columns += (args.imaging_weight_column, )
        memory_per_row += bytes_per_row / 2

    # Mueller term (complex valued)
    if args.mueller_column is not None:
        columns += (args.mueller_column, )
        memory_per_row += bytes_per_row

    # get max uv coords over all fields
    uvw = []
    u_max = 0.0
    v_max = 0.0
    for ims in args.ms:
        xds = xds_from_ms(ims, columns=('UVW'), chunks={'row': -1})

        for ds in xds:
            uvw = ds.UVW.data
            u_max = da.maximum(u_max, abs(uvw[:, 0]).max())
            v_max = da.maximum(v_max, abs(uvw[:, 1]).max())
            uv_max = da.maximum(u_max, v_max)

    uv_max = uv_max.compute()
    del uvw

    # image size
    cell_N = 1.0 / (2 * uv_max * max_freq / lightspeed)

    if args.cell_size is not None:
        cell_size = args.cell_size
        cell_rad = cell_size * np.pi / 60 / 60 / 180
        if cell_N / cell_rad < 1:
            raise ValueError(
                "Requested cell size too small. "
                "Super resolution factor = ", cell_N / cell_rad)
        print("Super resolution factor = %f" % (cell_N / cell_rad), file=log)
    else:
        cell_rad = cell_N / args.super_resolution_factor
        cell_size = cell_rad * 60 * 60 * 180 / np.pi
        print("Cell size set to %5.5e arcseconds" % cell_size, file=log)

    if args.nx is None:
        fov = args.field_of_view * 3600
        npix = int(args.psf_oversize * fov / cell_size)
        if npix % 2:
            npix += 1
        nx = npix
        ny = npix
    else:
        nx = args.nx
        ny = args.ny if args.ny is not None else nx

    print("PSF size set to (%i, %i, %i)" % (nband, nx, ny), file=log)

    # get approx image size
    # this is not a conservative estimate when multiple SPW's map to a single
    # imaging band
    pixel_bytes = np.dtype(args.output_type).itemsize
    band_size = nx * ny * pixel_bytes

    if args.host_address is None:
        # full image on single node
        row_chunk = plan_row_chunk(args.mem_limit / args.nworkers, band_size,
                                   nrow, memory_per_row,
                                   args.nthreads_per_worker)

    else:
        # single band per node
        row_chunk = plan_row_chunk(args.mem_limit, band_size, nrow,
                                   memory_per_row, args.nthreads_per_worker)

    if args.row_chunks is not None:
        row_chunk = int(args.row_chunks)
        if row_chunk == -1:
            row_chunk = nrow

    print(
        "nrows = %i, row chunks set to %i for a total of %i chunks per node" %
        (nrow, row_chunk, int(np.ceil(nrow / row_chunk))),
        file=log)

    chunks = {}
    for ims in args.ms:
        chunks[ims] = []  # xds_from_ms expects a list per ds
        for spw in freqs[ims]:
            chunks[ims].append({
                'row': row_chunk,
                'chan': chan_chunks[ims][spw]['chan']
            })

    psfs = []
    radec = None  # assumes we are only imaging field 0 of first MS
    out_datasets = []
    for ims in args.ms:
        xds = xds_from_ms(ims, chunks=chunks[ims], columns=columns)

        # subtables
        ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
        fields = xds_from_table(ims + "::FIELD")
        spws = xds_from_table(ims + "::SPECTRAL_WINDOW")
        pols = xds_from_table(ims + "::POLARIZATION")

        # subtable data
        ddids = dask.compute(ddids)[0]
        fields = dask.compute(fields)[0]
        spws = dask.compute(spws)[0]
        pols = dask.compute(pols)[0]

        for ds in xds:
            field = fields[ds.FIELD_ID]

            # check fields match
            if radec is None:
                radec = field.PHASE_DIR.data.squeeze()

            if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()):
                continue

            # this is not correct, need to use spw
            spw = ds.DATA_DESC_ID

            uvw = clone(ds.UVW.data)

            data_type = getattr(ds, args.data_column).data.dtype
            data_shape = getattr(ds, args.data_column).data.shape
            data_chunks = getattr(ds, args.data_column).data.chunks

            weights = getattr(ds, args.weight_column).data
            if len(weights.shape) < 3:
                weights = da.broadcast_to(weights[:, None, :],
                                          data_shape,
                                          chunks=data_chunks)

            if args.imaging_weight_column is not None:
                imaging_weights = getattr(ds, args.imaging_weight_column).data
                if len(imaging_weights.shape) < 3:
                    imaging_weights = da.broadcast_to(imaging_weights[:,
                                                                      None, :],
                                                      data_shape,
                                                      chunks=data_chunks)

                weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0]
                weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1]
            else:
                weightsxx = weights[:, :, 0]
                weightsyy = weights[:, :, -1]

            # apply mueller term
            if args.mueller_column is not None:
                mueller = getattr(ds, args.mueller_column).data
                weightsxx *= da.absolute(mueller[:, :, 0])**2
                weightsyy *= da.absolute(mueller[:, :, -1])**2

            # weighted sum corr to Stokes I
            weights = weightsxx + weightsyy

            # MS may contain auto-correlations
            if 'FLAG_ROW' in xds[0]:
                frow = ds.FLAG_ROW.data | (ds.ANTENNA1.data
                                           == ds.ANTENNA2.data)
            else:
                frow = (ds.ANTENNA1.data == ds.ANTENNA2.data)

            # only keep data where both corrs are unflagged
            flag = getattr(ds, args.flag_column).data
            flagxx = flag[:, :, 0]
            flagyy = flag[:, :, -1]
            # ducc0 uses uint8 mask not flag
            mask = ~da.logical_or((flagxx | flagyy), frow[:, None])

            psf = vis2im(uvw,
                         freqs[ims][spw],
                         weights.astype(data_type),
                         freq_bin_idx[ims][spw],
                         freq_bin_counts[ims][spw],
                         nx,
                         ny,
                         cell_rad,
                         flag=mask.astype(np.uint8),
                         nthreads=args.nvthreads,
                         epsilon=args.epsilon,
                         do_wstacking=args.wstack,
                         double_accum=args.double_accum)

            psfs.append(psf)

            data_vars = {
                'FIELD_ID': (('row', ),
                             da.full_like(ds.TIME.data,
                                          ds.FIELD_ID,
                                          chunks=args.row_out_chunk)),
                'DATA_DESC_ID': (('row', ),
                                 da.full_like(ds.TIME.data,
                                              ds.DATA_DESC_ID,
                                              chunks=args.row_out_chunk)),
                'WEIGHT':
                (('row', 'chan'), weights.rechunk({0: args.row_out_chunk
                                                   })),  # why no 'f4'?
                'UVW': (('row', 'uvw'), uvw.rechunk({0: args.row_out_chunk}))
            }

            coords = {'chan': (('chan', ), freqs[ims][spw])}

            out_ds = Dataset(data_vars, coords)

            out_datasets.append(out_ds)

    writes = xds_to_zarr(out_datasets,
                         args.output_filename + '.zarr',
                         columns='ALL')

    # dask.visualize(writes, filename=args.output_filename + '_psf_writes_graph.pdf', optimize_graph=False)
    # dask.visualize(psfs, filename=args.output_filename + '_psf_graph.pdf', optimize_graph=False)

    if not args.mock:
        # psfs = dask.compute(psfs, writes, optimize_graph=False)[0]
        # with performance_report(filename=args.output_filename + '_psf_per.html'):
        psfs = dask.compute(psfs, writes, optimize_graph=False)[0]

        psf = stitch_images(psfs, nband, band_mapping)

        hdr = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec,
                      freq_out)
        save_fits(args.output_filename + '_psf.fits',
                  psf,
                  hdr,
                  dtype=args.output_type)

        psf_mfs = np.sum(psf, axis=0)
        wsum = psf_mfs.max()
        psf_mfs /= wsum

        hdr_mfs = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec,
                          np.mean(freq_out))
        save_fits(args.output_filename + '_psf_mfs.fits',
                  psf_mfs,
                  hdr_mfs,
                  dtype=args.output_type)

    print("All done here.", file=log)
Exemplo n.º 14
0
def get_plot_data(msinfo, group_cols, mytaql, chan_freqs,
                  chanslice, subset,
                  noflags, noconj,
                  iter_field, iter_spw, iter_scan,
                  join_corrs=False,
                  row_chunk_size=100000):

    ms_cols = {'ANTENNA1', 'ANTENNA2'}
    if not noflags:
        ms_cols.update({'FLAG', 'FLAG_ROW'})
    # get visibility columns
    for axis in DataAxis.all_axes.values():
        ms_cols.update(axis.columns)

    # get MS data
    msdata = daskms.xds_from_ms(msinfo.msname, columns=list(ms_cols), group_cols=group_cols, taql_where=mytaql,
                                chunks=dict(row=row_chunk_size))

    log.info(f': Indexing MS and building dataframes (chunk size is {row_chunk_size})')

    np = 0  # number of points to plot

    # output dataframes, indexed by (field, spw, scan, antenna, correlation)
    # If any of these axes is not being iterated over, then the index is None
    output_dataframes = OrderedDict()

    # # make prototype dataframe
    # import pandas
    #
    #

    # iterate over groups
    for group in msdata:
        ddid     =  group.DATA_DESC_ID  # always present
        fld      =  group.FIELD_ID # always present
        if fld not in subset.field or ddid not in subset.spw:
            log.debug(f"field {fld} ddid {ddid} not in selection, skipping")
            continue
        scan    = getattr(group, 'SCAN_NUMBER', None)  # will be present if iterating over scans

        # TODO: antenna iteration. None forces no iteration, for now
        antenna = None

        # always read flags -- easier that way
        flag = group.FLAG if not noflags else None
        flag_row = group.FLAG_ROW if not noflags else None


        baselines = group.ANTENNA1*len(msinfo.antenna) + group.ANTENNA2

        freqs = chan_freqs[ddid]
        chans = xarray.DataArray(range(len(freqs)), dims=("chan",))
        wavel = freq_to_wavel(freqs)
        extras = dict(chans=chans, freqs=freqs, wavel=wavel, rows=group.row, baselines=baselines)

        nchan = len(group.chan)
        if flag is not None:
            flag = flag[dict(chan=chanslice)]
            nchan = flag.shape[1]
        shape = (len(group.row), nchan)

        datums = OrderedDict()

        for corr in subset.corr.numbers:
            # make dictionary of extra values for DataMappers
            extras['corr'] = corr
            # loop over datums to be computed
            for axis in DataAxis.all_axes.values():
                value = datums[axis.label][-1] if axis.label in datums else None
                # a datum was already computed?
                if value is not None:
                    # if not joining correlations, then that's the only one we'll need, so continue
                    if not join_corrs:
                        continue
                    # joining correlations, and datum has a correlation dependence: compute another one
                    if axis.corr is None:
                        value = None
                if value is None:
                    value = axis.get_value(group, corr, extras, flag=flag, flag_row=flag_row, chanslice=chanslice)
                    # reshape values of shape NTIME to (NTIME,1) and NFREQ to (1,NFREQ), and scalar to (NTIME,1)
                    if value.ndim == 1:
                        timefreq_axis = axis.mapper.axis or 0
                        assert value.shape[0] == shape[timefreq_axis], \
                               f"{axis.mapper.fullname}: size {value.shape[0]}, expected {shape[timefreq_axis]}"
                        shape1 = [1,1]
                        shape1[timefreq_axis] = value.shape[0]
                        value = value.reshape(shape1)
                        if timefreq_axis > 0:
                            value = da.broadcast_to(value, shape)
                        log.debug(f"axis {axis.mapper.fullname} has shape {value.shape}")
                    # else 2D value better match expected shape
                    else:
                        assert value.shape == shape, f"{axis.mapper.fullname}: shape {value.shape}, expected {shape}"
                datums.setdefault(axis.label, []).append(value)

        # if joining correlations, stick all elements together. Otherwise, we'd better have one per label
        if join_corrs:
            datums = OrderedDict({label: da.concatenate(arrs) for label, arrs in datums.items()})
        else:
            assert all([len(arrs) == 1 for arrs in datums.values()])
            datums = OrderedDict({label: arrs[0] for label, arrs in datums.items()})

        # broadcast to same shape, and unravel all datums
        datums = OrderedDict({ key: arr.ravel() for key, arr in zip(datums.keys(),
                                                                    da.broadcast_arrays(*datums.values()))})

        # if any axis needs to be conjugated, double up all of them
        if not noconj and any([axis.conjugate for axis in DataAxis.all_axes.values()]):
            for axis in DataAxis.all_axes.values():
                if axis.conjugate:
                    datums[axis.label] = da.concatenate([datums[axis.label], -datums[axis.label]])
                else:
                    datums[axis.label] = da.concatenate([datums[axis.label], datums[axis.label]])

        labels, values = list(datums.keys()), list(datums.values())
        np += values[0].size

        # now stack them all into a big dataframe
        rectype = [(axis.label, numpy.int32 if axis.nlevels else numpy.float32) for axis in DataAxis.all_axes.values()]
        recarr = da.empty_like(values[0], dtype=rectype)
        ddf = dask_df.from_array(recarr)
        for label, value in zip(labels, values):
            ddf[label] = value

        # now, are we iterating or concatenating? Make frame key accordingly
        dataframe_key = (fld if iter_field else None,
                         ddid if iter_spw else None,
                         scan if iter_scan else None,
                         antenna)

        # do we already have a frame for this key
        ddf0 = output_dataframes.get(dataframe_key)

        if ddf0 is None:
            log.debug(f"first frame for {dataframe_key}")
            output_dataframes[dataframe_key] = ddf
        else:
            log.debug(f"appending to frame for {dataframe_key}")
            output_dataframes[dataframe_key] = ddf0.append(ddf)

    # convert discrete axes into categoricals
    if data_mappers.USE_COUNT_CAT:
        categorical_axes = [axis.label for axis in DataAxis.all_axes.values() if axis.nlevels]
        if categorical_axes:
            log.info(": counting colours")
            for key, ddf in list(output_dataframes.items()):
                output_dataframes[key] = ddf.categorize(categorical_axes)

    log.info(": complete")
    return output_dataframes, np
Exemplo n.º 15
0
def extract_shape(cube,
                  shapefile,
                  method='contains',
                  crop=True,
                  decomposed=False):
    """
    Extract a region defined by a shapefile.

    Note that this function does not work for shapes crossing the
    prime meridian or poles.

    Parameters
    ----------
    cube: iris.cube.Cube
       input cube.
    shapefile: str
        A shapefile defining the region(s) to extract.
    method: str, optional
        Select all points contained by the shape or select a single
        representative point. Choose either 'contains' or 'representative'.
        If 'contains' is used, but not a single grid point is contained by the
        shape, a representative point will selected.
    crop: bool, optional
        Crop the resulting cube using `extract_region()`. Note that data on
        irregular grids will not be cropped.
    decomposed: bool, optional
        Whether or not to retain the sub shapes of the shapefile in the output.
        If this is set to True, the output cube has a dimension for the sub
        shapes.

    Returns
    -------
    iris.cube.Cube
        Cube containing the extracted region.

    See Also
    --------
    extract_region : Extract a region from a cube.
    """
    with fiona.open(shapefile) as geometries:

        # get parameters specific to the shapefile (NE used case
        # eg longitudes [-180, 180] or latitude missing
        # or overflowing edges)
        cmor_coords = True
        pad_north_pole = False
        pad_hawaii = False
        if geometries.bounds[0] < 0:
            cmor_coords = False
        if geometries.bounds[1] > -90. and geometries.bounds[1] < -85.:
            pad_north_pole = True
        if geometries.bounds[0] > -180. and geometries.bounds[0] < 179.:
            pad_hawaii = True

        if crop:
            cube = _crop_cube(cube,
                              *geometries.bounds,
                              cmor_coords=cmor_coords)

        lon, lat = _correct_coords_from_shapefile(cube, cmor_coords,
                                                  pad_north_pole, pad_hawaii)

        selections = _get_masks_from_geometries(geometries,
                                                lon,
                                                lat,
                                                method=method,
                                                decomposed=decomposed)

    cubelist = iris.cube.CubeList()

    for id_, select in selections.items():
        _cube = cube.copy()
        _cube.add_aux_coord(
            iris.coords.AuxCoord(id_, units='no_unit', long_name="shape_id"))

        select = da.broadcast_to(select, _cube.shape)
        _cube.data = da.ma.masked_where(~select, _cube.core_data())
        cubelist.append(_cube)

    cube = cubelist.merge_cube()

    return fix_coordinate_ordering(cube)
Exemplo n.º 16
0
def extract_region(cube, start_longitude, end_longitude, start_latitude,
                   end_latitude):
    """
    Extract a region from a cube.

    Function that subsets a cube on a box (start_longitude, end_longitude,
    start_latitude, end_latitude)
    This function is a restriction of masked_cube_lonlat().

    Parameters
    ----------
    cube: iris.cube.Cube
        input data cube.
    start_longitude: float
        Western boundary longitude.
    end_longitude: float
        Eastern boundary longitude.
    start_latitude: float
        Southern Boundary latitude.
    end_latitude: float
        Northern Boundary Latitude.

    Returns
    -------
    iris.cube.Cube
        smaller cube.
    """
    if abs(start_latitude) > 90.:
        raise ValueError(f"Invalid start_latitude: {start_latitude}")
    if abs(end_latitude) > 90.:
        raise ValueError(f"Invalid end_latitude: {end_latitude}")
    if cube.coord('latitude').ndim == 1:
        # Iris check if any point of the cell is inside the region
        # To check only the center, ignore_bounds must be set to
        # True (default) is False
        region_subset = cube.intersection(
            longitude=(start_longitude, end_longitude),
            latitude=(start_latitude, end_latitude),
            ignore_bounds=True,
        )
        region_subset = region_subset.intersection(longitude=(0., 360.))
        return region_subset
    # Irregular grids
    lats = cube.coord('latitude').points
    lons = cube.coord('longitude').points
    # Convert longitudes to valid range
    if start_longitude != 360.:
        start_longitude %= 360.
    if end_longitude != 360.:
        end_longitude %= 360.

    if start_longitude <= end_longitude:
        select_lons = (lons >= start_longitude) & (lons <= end_longitude)
    else:
        select_lons = (lons >= start_longitude) | (lons <= end_longitude)

    if start_latitude <= end_latitude:
        select_lats = (lats >= start_latitude) & (lats <= end_latitude)
    else:
        select_lats = (lats >= start_latitude) | (lats <= end_latitude)

    selection = select_lats & select_lons
    selection = da.broadcast_to(selection, cube.shape)
    cube.data = da.ma.masked_where(~selection, cube.core_data())
    return cube
Exemplo n.º 17
0
def broadcast_signals(*args, ignore_axis=None):
    """Broadcasts all passed signals according to the HyperSpy broadcasting
    rules: signal and navigation spaces are each separately broadcasted
    according to the numpy broadcasting rules. One axis can be ignored and
    left untouched (or set to be size 1) across all signals.

    Parameters
    ----------
    *args : BaseSignal
        Signals to broadcast together
    ignore_axis : {None, str, int, Axis}
        The axis to be ignored when broadcasting

    Returns
    -------
    list of signals
    """
    if len(args) < 2:
        raise ValueError(
            "This function requires at least two signal instances")
    args = list(args)
    if not are_signals_aligned(*args, ignore_axis=ignore_axis):
        raise ValueError("The signals cannot be broadcasted")
    else:
        if ignore_axis is not None:
            for s in args:
                try:
                    ignore_axis = s.axes_manager[ignore_axis]
                    break
                except ValueError:
                    pass
        new_nav_axes = []
        new_nav_shapes = []
        for axes in zip_longest(*[s.axes_manager.navigation_axes
                                  for s in args], fillvalue=None):
            only_left = filter(lambda x: x is not None, axes)
            longest = sorted(only_left, key=lambda x: x.size, reverse=True)[0]
            new_nav_axes.append(longest)
            new_nav_shapes.append(longest.size if (ignore_axis is None or
                                                   ignore_axis not in
                                                   axes)
                                  else None)
        new_sig_axes = []
        new_sig_shapes = []
        for axes in zip_longest(*[s.axes_manager.signal_axes
                                  for s in args], fillvalue=None):
            only_left = filter(lambda x: x is not None, axes)
            longest = sorted(only_left, key=lambda x: x.size, reverse=True)[0]
            new_sig_axes.append(longest)
            new_sig_shapes.append(longest.size if (ignore_axis is None or
                                                   ignore_axis not in
                                                   axes)
                                  else None)

        results = []
        new_axes = new_nav_axes[::-1] + new_sig_axes[::-1]
        new_data_shape = new_nav_shapes[::-1] + new_sig_shapes[::-1]
        for s in args:
            data = s._data_aligned_with_axes
            sam = s.axes_manager
            sdim_diff = len(new_sig_axes) - sam.signal_dimension
            while sdim_diff > 0:
                slices = (slice(None),) * sam.navigation_dimension
                slices += (None, Ellipsis)
                data = data[slices]
                sdim_diff -= 1
            thisshape = new_data_shape.copy()
            if ignore_axis is not None:
                _id = new_data_shape.index(None)
                newlen = data.shape[_id] if len(data.shape) > _id else 1
                thisshape[_id] = newlen
            thisshape = tuple(thisshape)
            if data.shape != thisshape:
                if isinstance(data, np.ndarray):
                    data = np.broadcast_to(data, thisshape)
                else:
                    data = da.broadcast_to(data, thisshape)

            ns = s._deepcopy_with_new_data(data)
            ns.axes_manager._axes = [ax.copy() for ax in new_axes]
            ns.get_dimensions_from_data()
            results.append(ns.transpose(signal_axes=len(new_sig_axes)))
        return results
Exemplo n.º 18
0
    def make_psf(self):
        print("Making PSF")
        psfs = []
        for ims in self.ms:
            xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                              chunks=self.chunks[ims],
                              columns=self.columns)

            # subtables
            ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
            fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
            spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__")
            pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

            # subtable data
            ddids = dask.compute(ddids)[0]
            fields = dask.compute(fields)[0]
            spws = dask.compute(spws)[0]
            pols = dask.compute(pols)[0]

            for ds in xds:
                field = fields[ds.FIELD_ID]
                radec = field.PHASE_DIR.data.squeeze()
                if not np.array_equal(radec, self.radec):
                    continue

                spw = ds.DATA_DESC_ID  # this is not correct, need to use spw
                
                freq_bin_idx = self.freq_bin_idx[ims][spw]
                freq_bin_counts = self.freq_bin_counts[ims][spw]
                freq = self.freq[ims][spw]

                uvw = ds.UVW.data

                flag = getattr(ds, self.flag_column).data

                weights = getattr(ds, self.weight_column).data
                if len(weights.shape) < 3:
                    weights = da.broadcast_to(weights[:, None, :], flag.shape, chunks=flag.chunks)
                
                if self.imaging_weight_column is not None:
                    imaging_weights = getattr(ds, self.imaging_weight_column).data
                    if len(imaging_weights.shape) < 3:
                        imaging_weights = da.broadcast_to(imaging_weights[:, None, :], data.shape, chunks=data.chunks)

                    weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0]
                    weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1]
                else:
                    weightsxx = weights[:, :, 0]
                    weightsyy = weights[:, :, -1]

                # weighted sum corr to Stokes I
                weights = weightsxx + weightsyy
                data = weights.astype(np.complex64)
                
                # only keep data where both corrs are unflagged
                flagxx = flag[:, :, 0]
                flagyy = flag[:, :, -1]
                flag = ~ (flagxx | flagyy)  # ducc0 convention

                psf = vis2im(uvw, freq, data, freq_bin_idx, freq_bin_counts,
                             2*self.nx, 2*self.ny, self.cell, flag=flag.astype(np.uint8),
                             nthreads=self.nthreads, epsilon=self.epsilon, do_wstacking=self.do_wstacking)

                psfs.append(psf)

        psfs = dask.compute(psfs)[0]
                
        return accumulate_dirty(psfs, self.nband, self.band_mapping).astype(np.float64)
Exemplo n.º 19
0
    def __get_seasonal_means_with_ttest_stats_dask_lazy(
            self,
            data,
            season_to_monthperiod=None,
            start_year=-np.Inf,
            end_year=np.Inf,
            convert_monthly_accumulators_to_daily=False):

        # mask the resulting fields
        epsilon = 1.0e-5
        mask = np.less_equal(np.abs(data[0, :, :] - self.missing_value),
                             epsilon)

        print("data.shape = ", data.shape)

        data_sel, times_sel = data, self.time

        # select the interval of interest

        if convert_monthly_accumulators_to_daily:
            ndays = da.from_array(
                np.array([
                    calendar.monthrange(d.year, d.month)[1] for d in times_sel
                ]), (100, ))
            ndays = da.transpose(da.broadcast_to(
                da.from_array(ndays, ndays.shape),
                data_sel.shape[1:] + ndays.shape),
                                 axes=(2, 0, 1))

            data_sel = data_sel / ndays

        year_month_to_index_arr = defaultdict(list)
        for i, t in enumerate(times_sel):
            year_month_to_index_arr[t.year, t.month].append(i)

        # calculate monthly means
        monthly_data = {}
        for y in range(start_year, end_year + 1):
            for m in range(1, 13):
                aslice = slice(year_month_to_index_arr[y, m][0],
                               year_month_to_index_arr[y, m][-1] + 1)
                print(aslice, data_sel.shape)
                monthly_data[y, m] = data_sel[aslice, :, :].mean(axis=0)

        result = OrderedDict()
        for season, month_period in season_to_monthperiod.items():
            assert isinstance(month_period, MonthPeriod)

            seasonal_means = []
            ndays_per_season = []

            for p in month_period.get_season_periods(start_year=start_year,
                                                     end_year=end_year):
                lmos = da.stack([
                    monthly_data[start.year, start.month]
                    for start in p.range("months")
                ])
                ndays_per_month = np.array([
                    calendar.monthrange(start.year, start.month)[1]
                    for start in p.range("months")
                ])
                ndays_per_month = da.from_array(ndays_per_month,
                                                ndays_per_month.shape)

                print(p)
                print(lmos.shape, ndays_per_month.shape, ndays_per_month.sum())
                seasonal_mean = da.tensordot(
                    lmos, ndays_per_month, axes=([
                        0,
                    ], [
                        0,
                    ])) / ndays_per_month.sum()

                seasonal_means.append(seasonal_mean)
                ndays_per_season.append(ndays_per_month.sum())

            seasonal_means = da.stack(seasonal_means)
            ndays_per_season = np.array(ndays_per_season)
            ndays_per_season = da.from_array(ndays_per_season,
                                             ndays_per_season.shape)

            print(seasonal_means.shape, ndays_per_season.shape)

            assert seasonal_means.shape[0] == ndays_per_season.shape[0]

            clim_mean = da.tensordot(
                seasonal_means, ndays_per_season, axes=([
                    0,
                ], [
                    0,
                ])) / ndays_per_season.sum()

            clim_std = ((seasonal_means -
                         da.broadcast_to(clim_mean, seasonal_means.shape))**2 *
                        ndays_per_season[:, np.newaxis, np.newaxis]).sum(
                            axis=0) / ndays_per_season.sum()

            clim_std = clim_std**0.5

            result[season] = [clim_mean, clim_std, ndays_per_season.shape[0]]

        return result, mask
Exemplo n.º 20
0
def broadcast_signals(*args, ignore_axis=None):
    """Broadcasts all passed signals according to the HyperSpy broadcasting
    rules: signal and navigation spaces are each separately broadcasted
    according to the numpy broadcasting rules. One axis can be ignored and
    left untouched (or set to be size 1) across all signals.

    Parameters
    ----------
    *args : BaseSignal
        Signals to broadcast together
    ignore_axis : {None, str, int, Axis}
        The axis to be ignored when broadcasting

    Returns
    -------
    list of signals
    """
    if len(args) < 2:
        raise ValueError(
            "This function requires at least two signal instances")
    args = list(args)
    if not are_signals_aligned(*args):
        raise ValueError("The signals cannot be broadcasted")
    else:
        if ignore_axis is not None:
            for s in args:
                try:
                    ignore_axis = s.axes_manager[ignore_axis]
                    break
                except ValueError:
                    pass
        new_nav_axes = []
        new_nav_shapes = []
        for axes in zip_longest(*[
                s.axes_manager.navigation_axes for s in args
        ],
                                fillvalue=None):
            only_left = filter(lambda x: x is not None, axes)
            longest = sorted(only_left, key=lambda x: x.size, reverse=True)[0]
            new_nav_axes.append(longest)
            new_nav_shapes.append(longest.size if (
                ignore_axis is None or ignore_axis not in axes) else None)
        new_sig_axes = []
        new_sig_shapes = []
        for axes in zip_longest(*[s.axes_manager.signal_axes for s in args],
                                fillvalue=None):
            only_left = filter(lambda x: x is not None, axes)
            longest = sorted(only_left, key=lambda x: x.size, reverse=True)[0]
            new_sig_axes.append(longest)
            new_sig_shapes.append(longest.size if (
                ignore_axis is None or ignore_axis not in axes) else None)

        results = []
        new_axes = new_nav_axes[::-1] + new_sig_axes[::-1]
        new_data_shape = new_nav_shapes[::-1] + new_sig_shapes[::-1]
        for s in args:
            data = s._data_aligned_with_axes
            sam = s.axes_manager
            sdim_diff = len(new_sig_axes) - sam.signal_dimension
            while sdim_diff > 0:
                slices = (slice(None), ) * sam.navigation_dimension
                slices += (None, Ellipsis)
                data = data[slices]
                sdim_diff -= 1
            thisshape = new_data_shape.copy()
            if ignore_axis is not None:
                _id = new_data_shape.index(None)
                newlen = data.shape[_id] if len(data.shape) > _id else 1
                thisshape[_id] = newlen
            thisshape = tuple(thisshape)
            if data.shape != thisshape:
                if isinstance(data, np.ndarray):
                    data = np.broadcast_to(data, thisshape)
                else:
                    data = da.broadcast_to(data, thisshape)

            ns = s._deepcopy_with_new_data(data)
            ns.axes_manager._axes = [ax.copy() for ax in new_axes]
            ns.get_dimensions_from_data()
            results.append(ns.transpose(signal_axes=len(new_sig_axes)))
        return results
Exemplo n.º 21
0
    def make_dirty(self):
        print("Making dirty", file=log)
        dirty = da.zeros((self.nband, self.nx, self.ny),
                         dtype=np.float32,
                         chunks=(1, self.nx, self.ny),
                         name=False)
        dirties = []
        for ims in self.ms:
            xds = xds_from_ms(ims,
                              group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                              chunks=self.chunks[ims],
                              columns=self.columns)

            # subtables
            ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
            fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
            spws = xds_from_table(ims + "::SPECTRAL_WINDOW",
                                  group_cols="__row__")
            pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

            # subtable data
            ddids = dask.compute(ddids)[0]
            fields = dask.compute(fields)[0]
            spws = dask.compute(spws)[0]
            pols = dask.compute(pols)[0]

            for ds in xds:
                field = fields[ds.FIELD_ID]
                radec = field.PHASE_DIR.data.squeeze()
                if not np.array_equal(radec, self.radec):
                    continue

                spw = ds.DATA_DESC_ID  # this is not correct, need to use spw

                freq_bin_idx = self.freq_bin_idx[ims][spw]
                freq_bin_counts = self.freq_bin_counts[ims][spw]
                freq = self.freq[ims][spw]
                freq_chunk = freq_bin_counts[0].compute()

                uvw = ds.UVW.data

                data = getattr(ds, self.data_column).data
                dataxx = data[:, :, 0]
                datayy = data[:, :, -1]

                weights = getattr(ds, self.weight_column).data
                if len(weights.shape) < 3:
                    weights = da.broadcast_to(weights[:, None, :],
                                              data.shape,
                                              chunks=data.chunks)

                if self.imaging_weight_column is not None:
                    imaging_weights = getattr(ds,
                                              self.imaging_weight_column).data
                    if len(imaging_weights.shape) < 3:
                        imaging_weights = da.broadcast_to(
                            imaging_weights[:, None, :],
                            data.shape,
                            chunks=data.chunks)

                    weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0]
                    weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1]
                else:
                    weightsxx = weights[:, :, 0]
                    weightsyy = weights[:, :, -1]

                # apply adjoint of mueller term.
                # Phases modify data amplitudes modify weights.
                if self.mueller_column is not None:
                    mueller = getattr(ds, self.mueller_column).data
                    dataxx *= da.exp(-1j * da.angle(mueller[:, :, 0]))
                    datayy *= da.exp(-1j * da.angle(mueller[:, :, -1]))
                    weightsxx *= da.absolute(mueller[:, :, 0])
                    weightsyy *= da.absolute(mueller[:, :, -1])

                # weighted sum corr to Stokes I
                weights = weightsxx + weightsyy
                data = (weightsxx * dataxx + weightsyy * datayy)
                # TODO - turn off this stupid warning
                data = da.where(weights, data / weights, 0.0j)

                # only keep data where both corrs are unflagged
                flag = getattr(ds, self.flag_column).data
                flagxx = flag[:, :, 0]
                flagyy = flag[:, :, -1]
                # ducc0 convention uses uint8 mask not flag
                flag = ~(flagxx | flagyy)

                dirty = vis2im(uvw,
                               freq,
                               data,
                               freq_bin_idx,
                               freq_bin_counts,
                               self.nx,
                               self.ny,
                               self.cell,
                               weights=weights,
                               flag=flag.astype(np.uint8),
                               nthreads=self.nthreads,
                               epsilon=self.epsilon,
                               do_wstacking=self.do_wstacking,
                               double_accum=True)

                dirties.append(dirty)

        dirties = dask.compute(dirties, scheduler='single-threaded')[0]

        return accumulate_dirty(dirties, self.nband,
                                self.band_mapping).astype(self.real_type)
Exemplo n.º 22
0
    def compute_weights(self, robust):
        from pfb.utils.weighting import compute_counts, counts_to_weights
        # compute counts
        counts = []
        for ims in self.ms:
            xds = xds_from_ms(ims,
                              group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                              chunks=self.chunks[ims],
                              columns=('UVW'))

            # subtables
            ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
            fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
            spws = xds_from_table(ims + "::SPECTRAL_WINDOW",
                                  group_cols="__row__")
            pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

            # subtable data
            ddids = dask.compute(ddids)[0]
            fields = dask.compute(fields)[0]
            spws = dask.compute(spws)[0]
            pols = dask.compute(pols)[0]

            for ds in xds:
                field = fields[ds.FIELD_ID]
                radec = field.PHASE_DIR.data.squeeze()
                if not np.array_equal(radec, self.radec):
                    continue

                spw = ds.DATA_DESC_ID  # not optimal, need to use spw

                freq_bin_idx = self.freq_bin_idx[ims][spw]
                freq_bin_counts = self.freq_bin_counts[ims][spw]
                freq = self.freq[ims][spw]

                uvw = ds.UVW.data

                count = compute_counts(uvw, freq, freq_bin_idx,
                                       freq_bin_counts, self.nx, self.ny,
                                       self.cell, self.cell, np.float32)

                counts.append(count)

        counts = dask.compute(counts)[0]

        counts = accumulate_dirty(counts, self.nband, self.band_mapping)

        counts = da.from_array(counts, chunks=(1, -1, -1))

        # convert counts to weights
        writes = []
        for ims in self.ms:
            xds = xds_from_ms(ims,
                              group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                              chunks=self.chunks[ims],
                              columns=self.columns)

            # subtables
            ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
            fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
            spws = xds_from_table(ims + "::SPECTRAL_WINDOW",
                                  group_cols="__row__")
            pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

            # subtable data
            ddids = dask.compute(ddids)[0]
            fields = dask.compute(fields)[0]
            spws = dask.compute(spws)[0]
            pols = dask.compute(pols)[0]

            out_data = []
            for ds in xds:
                field = fields[ds.FIELD_ID]
                radec = field.PHASE_DIR.data.squeeze()
                if not np.array_equal(radec, self.radec):
                    continue

                spw = ds.DATA_DESC_ID  # this is not correct, need to use spw

                freq_bin_idx = self.freq_bin_idx[ims][spw]
                freq_bin_counts = self.freq_bin_counts[ims][spw]
                freq = self.freq[ims][spw]

                uvw = ds.UVW.data

                weights = counts_to_weights(counts, uvw, freq, freq_bin_idx,
                                            freq_bin_counts, self.nx, self.ny,
                                            self.cell, self.cell, np.float32,
                                            robust)

                # hack to get shape and chunking info
                data = getattr(ds, self.data_column).data

                weights = da.broadcast_to(weights[:, :, None],
                                          data.shape,
                                          chunks=data.chunks)
                out_ds = ds.assign(**{
                    self.imaging_weight_column: (("row", "chan", "corr"),
                                                 weights)
                })
                out_data.append(out_ds)
            writes.append(
                xds_to_table(out_data,
                             ims,
                             columns=[self.imaging_weight_column]))
        dask.compute(writes)
Exemplo n.º 23
0
def main():
    args = parse_args()
    dask.config.set(num_workers=args.workers)

    # Lightweight open with no data - just to create telstate and identify the CBID
    ds = TelstateDataSource.from_url(args.source,
                                     upgrade_flags=False,
                                     chunk_store=None)
    # View the CBID, but not any specific stream
    cbid = ds.capture_block_id
    telstate = ds.telstate.root().view(cbid)
    streams = get_streams(telstate, args.streams)

    # Find all arrays in the selected streams, and also ensure we're not
    # trying to write things back on top of an existing dataset.
    arrays = {}
    for stream_name in streams:
        sts = view_capture_stream(telstate, cbid, stream_name)
        try:
            chunk_info = sts['chunk_info']
        except KeyError as exc:
            raise RuntimeError('Could not get chunk info for {!r}: {}'.format(
                stream_name, exc))
        for array_name, array_info in chunk_info.items():
            if args.new_prefix is not None:
                array_info[
                    'prefix'] = args.new_prefix + '-' + stream_name.replace(
                        '_', '-')
            prefix = array_info['prefix']
            path = os.path.join(args.dest, prefix)
            if os.path.exists(path):
                raise RuntimeError(
                    'Directory {!r} already exists'.format(path))
            store = get_chunk_store(args.source, sts, array_name)
            # Older files have dtype as an object that can't be encoded in msgpack
            dtype = np.dtype(array_info['dtype'])
            array_info['dtype'] = np.lib.format.dtype_to_descr(dtype)
            arrays[(stream_name, array_name)] = Array(stream_name, array_name,
                                                      store, array_info)

    # Apply DATA_LOST bits to the flags arrays. This is a less efficient approach than
    # datasources.py, but much simpler.
    for stream_name in streams:
        flags_array = arrays.get((stream_name, 'flags'))
        if not flags_array:
            continue
        sources = [stream_name]
        sts = view_capture_stream(telstate, cbid, stream_name)
        sources += sts['src_streams']
        for src_stream in sources:
            if src_stream not in streams:
                continue
            src_ts = view_capture_stream(telstate, cbid, src_stream)
            for array_name in src_ts['chunk_info']:
                if array_name == 'flags' and src_stream != stream_name:
                    # Upgraded flags completely replace the source stream's
                    # flags, rather than augmenting them. Thus, data lost in
                    # the source stream has no effect.
                    continue
                lost_flags = arrays[(src_stream, array_name)].lost_flags
                lost_flags = lost_flags.rechunk(
                    flags_array.data.chunks[:lost_flags.ndim])
                # weights_channel doesn't have a baseline axis
                while lost_flags.ndim < flags_array.data.ndim:
                    lost_flags = lost_flags[..., np.newaxis]
                lost_flags = da.broadcast_to(lost_flags,
                                             flags_array.data.shape,
                                             chunks=flags_array.data.chunks)
                flags_array.data |= lost_flags

    # Apply the rechunking specs
    for spec in args.spec:
        key = (spec.stream, spec.array)
        if key not in arrays:
            raise RuntimeError('{}/{} is not a known array'.format(
                spec.stream, spec.array))
        arrays[key].data = arrays[key].data.rechunk({
            0: spec.time,
            1: spec.freq
        })

    # Write out the new data
    dest_store = NpyFileChunkStore(args.dest)
    stores = []
    for array in arrays.values():
        full_name = dest_store.join(array.chunk_info['prefix'],
                                    array.array_name)
        dest_store.create_array(full_name)
        stores.append(dest_store.put_dask_array(full_name, array.data))
        array.chunk_info['chunks'] = array.data.chunks
    stores = da.compute(*stores)
    # put_dask_array returns an array with an exception object per chunk
    for result_set in stores:
        for result in result_set.flat:
            if result is not None:
                raise result

    # Fix up chunk_info for new chunking
    for stream_name in streams:
        sts = view_capture_stream(telstate, cbid, stream_name)
        chunk_info = sts['chunk_info']
        for array_name in chunk_info.keys():
            chunk_info[array_name] = arrays[(stream_name,
                                             array_name)].chunk_info
        sts.wrapped.delete('chunk_info')
        sts.wrapped['chunk_info'] = chunk_info
        # s3_endpoint_url is for the old version of the data
        sts.wrapped.delete('s3_endpoint_url')
        if args.s3_endpoint_url is not None:
            sts.wrapped['s3_endpoint_url'] = args.s3_endpoint_url

    # Write updated RDB file
    url_parts = urllib.parse.urlparse(args.source, scheme='file')
    dest_file = os.path.join(args.dest, args.new_prefix or cbid,
                             os.path.basename(url_parts.path))
    os.makedirs(os.path.dirname(dest_file), exist_ok=True)
    with RDBWriter(dest_file) as writer:
        writer.save(telstate.backend)
Exemplo n.º 24
0
def calculate_corners(center_lat, center_lon):
    """Calculate corner coordinates by averaging neighbor cells
    """

    # get rank
    rank = len(center_lat.dims)

    if rank == 1:
        # get dimensions
        nlon = center_lon.size
        nlat = center_lat.size

        # convert center points from 1d to 2d
        center_lat2d = da.broadcast_to(center_lat.values[None, :],
                                       (nlon, nlat))
        center_lon2d = da.broadcast_to(center_lon.values[:, None],
                                       (nlon, nlat))
    elif rank == 2:
        # get dimensions
        dims = center_lon.shape
        nlon = dims[0]
        nlat = dims[1]

        # just rename and convert to dask array
        center_lat2d = da.from_array(center_lat)
        center_lon2d = da.from_array(center_lon)
    else:
        print(
            'Unrecognized grid! The rank of coordinate variables can be 1 or 2 but it is {}.'
            .format(rank))
        sys.exit(2)

    # calculate corner coordinates for latitude, counterclockwise order, imposing Fortran ordering
    center_lat2d_ext = da.from_array(
        np.pad(center_lat2d.compute(), (1, 1),
               mode='reflect',
               reflect_type='odd'))

    ur = (center_lat2d_ext[1:-1, 1:-1] + center_lat2d_ext[0:-2, 1:-1] +
          center_lat2d_ext[1:-1, 2:] + center_lat2d_ext[0:-2, 2:]) / 4.0
    ul = (center_lat2d_ext[1:-1, 1:-1] + center_lat2d_ext[0:-2, 1:-1] +
          center_lat2d_ext[1:-1, 0:-2] + center_lat2d_ext[0:-2, 0:-2]) / 4.0
    ll = (center_lat2d_ext[1:-1, 1:-1] + center_lat2d_ext[1:-1, 0:-2] +
          center_lat2d_ext[2:, 1:-1] + center_lat2d_ext[2:, 0:-2]) / 4.0
    lr = (center_lat2d_ext[1:-1, 1:-1] + center_lat2d_ext[1:-1, 2:] +
          center_lat2d_ext[2:, 1:-1] + center_lat2d_ext[2:, 2:]) / 4.0

    # this looks clockwise ordering but it is transposed and becomes counterclockwise, bit-to-bit with NCL
    corner_lat = da.stack([
        ul.T.reshape((-1, )).T,
        ll.T.reshape((-1, )).T,
        lr.T.reshape((-1, )).T,
        ur.T.reshape((-1, )).T
    ],
                          axis=1)

    # calculate corner coordinates for longitude, counterclockwise order, imposing Fortran ordering
    center_lon2d_ext = da.from_array(
        np.pad(center_lon2d.compute(), (1, 1),
               mode='reflect',
               reflect_type='odd'))

    ur = (center_lon2d_ext[1:-1, 1:-1] + center_lon2d_ext[0:-2, 1:-1] +
          center_lon2d_ext[1:-1, 2:] + center_lon2d_ext[0:-2, 2:]) / 4.0
    ul = (center_lon2d_ext[1:-1, 1:-1] + center_lon2d_ext[0:-2, 1:-1] +
          center_lon2d_ext[1:-1, 0:-2] + center_lon2d_ext[0:-2, 0:-2]) / 4.0
    ll = (center_lon2d_ext[1:-1, 1:-1] + center_lon2d_ext[1:-1, 0:-2] +
          center_lon2d_ext[2:, 1:-1] + center_lon2d_ext[2:, 0:-2]) / 4.0
    lr = (center_lon2d_ext[1:-1, 1:-1] + center_lon2d_ext[1:-1, 2:] +
          center_lon2d_ext[2:, 1:-1] + center_lon2d_ext[2:, 2:]) / 4.0

    # this looks clockwise ordering but it is transposed and becomes counterclockwise, bit-to-bit with NCL
    corner_lon = da.stack([
        ul.T.reshape((-1, )).T,
        ll.T.reshape((-1, )).T,
        lr.T.reshape((-1, )).T,
        ur.T.reshape((-1, )).T
    ],
                          axis=1)

    return center_lat2d, center_lon2d, corner_lat, corner_lon
Exemplo n.º 25
0
    def make_residual(self, x):
        # Note deprecated (does not support Jones terms)
        print("Making residual", file=log)
        x = da.from_array(x.astype(self.real_type),
                          chunks=(1, self.nx, self.ny),
                          name=False)
        residuals = []
        for ims in self.ms:
            xds = xds_from_ms(ims,
                              group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                              chunks=self.chunks[ims],
                              columns=self.columns)

            # subtables
            ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
            fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
            spws = xds_from_table(ims + "::SPECTRAL_WINDOW",
                                  group_cols="__row__")
            pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

            # subtable data
            ddids = dask.compute(ddids)[0]
            fields = dask.compute(fields)[0]
            spws = dask.compute(spws)[0]
            pols = dask.compute(pols)[0]

            for ds in xds:
                field = fields[ds.FIELD_ID]
                radec = field.PHASE_DIR.data.squeeze()
                if not np.array_equal(radec, self.radec):
                    continue

                spw = ds.DATA_DESC_ID  # this is not correct, need to use spw

                freq_bin_idx = self.freq_bin_idx[ims][spw]
                freq_bin_counts = self.freq_bin_counts[ims][spw]
                freq = self.freq[ims][spw]

                uvw = ds.UVW.data

                data = getattr(ds, self.data_column).data
                dataxx = data[:, :, 0]
                datayy = data[:, :, -1]

                weights = getattr(ds, self.weight_column).data
                if len(weights.shape) < 3:
                    weights = da.broadcast_to(weights[:, None, :],
                                              data.shape,
                                              chunks=data.chunks)

                if self.imaging_weight_column is not None:
                    imaging_weights = getattr(ds,
                                              self.imaging_weight_column).data
                    if len(imaging_weights.shape) < 3:
                        imaging_weights = da.broadcast_to(
                            imaging_weights[:, None, :],
                            data.shape,
                            chunks=data.chunks)

                    weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0]
                    weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1]
                else:
                    weightsxx = weights[:, :, 0]
                    weightsyy = weights[:, :, -1]

                # weighted sum corr to Stokes I
                weights = weightsxx + weightsyy
                data = (weightsxx * dataxx + weightsyy * datayy)
                data = da.where(weights, data / weights, 0.0j)

                # only keep data where both corrs are unflagged
                flag = getattr(ds, self.flag_column).data
                flagxx = flag[:, :, 0]
                flagyy = flag[:, :, -1]
                flag = ~(flagxx | flagyy)  # ducc0 convention

                bands = self.band_mapping[ims][spw]
                model = x[list(bands), :, :]
                residual = im2residim(uvw,
                                      freq,
                                      model,
                                      data,
                                      freq_bin_idx,
                                      freq_bin_counts,
                                      self.cell,
                                      weights=weights,
                                      flag=flag.astype(np.uint8),
                                      nthreads=self.nthreads,
                                      epsilon=self.epsilon,
                                      do_wstacking=self.do_wstacking,
                                      double_accum=True)

                residuals.append(residual)

        residuals = dask.compute(residuals)[0]

        return accumulate_dirty(residuals, self.nband,
                                self.band_mapping).astype(self.real_type)
Exemplo n.º 26
0
dz["st_ocean"] = depth["st_ocean"]

for var in line_tracer_vars:
    ds_out[f"{var}_segment_001"] = ds_out[f"{var}_segment_001"].expand_dims(
        "ny_segment_001", axis=1)

for var in surface_vars:
    # add the y dimension
    ds_out[f"{var}_segment_001"] = ds_out[f"{var}_segment_001"].expand_dims(
        "ny_segment_001", axis=2)

    ds_out[f"vc_{var}_segment_001"] = (
        ["time", f"nz_segment_001_{var}", "ny_segment_001", "nx_segment_001"],
        da.broadcast_to(
            depth.data[None, :, None, None],
            ds_out[f"{var}_segment_001"].shape,
            chunks=(1, None, None, None),
        ),
    )
    ds_out[f"dz_{var}_segment_001"] = (
        ["time", f"nz_segment_001_{var}", "ny_segment_001", "nx_segment_001"],
        da.broadcast_to(
            dz.data[None, :, None, None],
            ds_out[f"{var}_segment_001"].shape,
            chunks=(1, None, None, None),
        ),
    )

with ProgressBar():
    ds_out.to_netcdf("forcing_obc.nc")
Exemplo n.º 27
0
    def make_psf(self):
        print("Making PSF", file=log)
        psfs = []
        self.stokes_weights = {}
        self.uvws = {}
        for ims in self.ms:
            xds = xds_from_ms(ims,
                              group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                              chunks=self.chunks[ims],
                              columns=self.columns)

            # subtables
            ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
            fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
            spws = xds_from_table(ims + "::SPECTRAL_WINDOW",
                                  group_cols="__row__")
            pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

            # subtable data
            ddids = dask.compute(ddids)[0]
            fields = dask.compute(fields)[0]
            spws = dask.compute(spws)[0]
            pols = dask.compute(pols)[0]
            self.stokes_weights[ims] = {}
            self.uvws[ims] = {}

            for ds in xds:
                field = fields[ds.FIELD_ID]
                radec = field.PHASE_DIR.data.squeeze()
                if not np.array_equal(radec, self.radec):
                    continue

                # this is not correct, need to use spw
                spw = ds.DATA_DESC_ID

                freq_bin_idx = self.freq_bin_idx[ims][spw]
                freq_bin_counts = self.freq_bin_counts[ims][spw]
                freq = self.freq[ims][spw]

                uvw = ds.UVW.data

                flag = getattr(ds, self.flag_column).data

                weights = getattr(ds, self.weight_column).data
                if len(weights.shape) < 3:
                    weights = da.broadcast_to(weights[:, None, :],
                                              flag.shape,
                                              chunks=flag.chunks)

                if self.imaging_weight_column is not None:
                    imaging_weights = getattr(ds,
                                              self.imaging_weight_column).data
                    if len(imaging_weights.shape) < 3:
                        imaging_weights = da.broadcast_to(
                            imaging_weights[:, None, :],
                            flag.shape,
                            chunks=flag.chunks)

                    weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0]
                    weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1]
                else:
                    weightsxx = weights[:, :, 0]
                    weightsyy = weights[:, :, -1]

                # for the PSF we need to scale the weights by the
                # Mueller amplitudes squared
                if self.mueller_column is not None:
                    mueller = getattr(ds, self.mueller_column).data
                    weightsxx *= da.absolute(mueller[:, :, 0])**2
                    weightsyy *= da.absolute(mueller[:, :, -1])**2

                # weighted sum corr to Stokes I
                weights = weightsxx + weightsyy

                # only keep data where both corrs are unflagged
                flagxx = flag[:, :, 0]
                flagyy = flag[:, :, -1]
                flag = ~(flagxx | flagyy)  # ducc0 convention

                weights *= flag

                data = weights.astype(np.complex64)

                psf = vis2im(uvw,
                             freq,
                             data,
                             freq_bin_idx,
                             freq_bin_counts,
                             self.nx_psf,
                             self.ny_psf,
                             self.cell,
                             flag=flag.astype(np.uint8),
                             nthreads=self.nthreads,
                             epsilon=self.epsilon,
                             do_wstacking=self.do_wstacking,
                             double_accum=True)

                psfs.append(psf)

                # assumes that stokes weights and uvw fit into memory
                # self.stokes_weights[ims][spw] = dask.persist(weights.rechunk({0:-1}))[0]
                # self.uvws[ims][spw] = dask.persist(uvw.rechunk({0:-1}))[0]

                # for comparison with numpy implementation
                # self.stokes_weights[ims][spw] = dask.compute(weights)[0]
                # self.uvws[ims][spw] = dask.compute(uvw)[0]

        # import pdb
        # pdb.set_trace()

        psfs = dask.compute(psfs, scheduler='single-threaded')[0]
        return accumulate_dirty(psfs, self.nband,
                                self.band_mapping).astype(self.real_type)
Exemplo n.º 28
0
def main(args):
    """
    Flags outliers in data given a model and rescale weights so that whitened residuals have a
    mean amplitude of sqrt(2). 
    
    Flags and weights are computed per chunk of data
    """
    radec_ref = None
    writes = []
    for ims in args.ms:
        xds = xds_from_ms(ims,
                          group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                          chunks={
                              "row": args.row_chunks,
                              "chan": args.chan_chunks
                          },
                          columns=('UVW', args.data_column, args.weight_column,
                                   args.model_column, args.flag_column,
                                   'FLAG_ROW'))

        # subtables
        ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
        fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
        spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__")
        pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

        # subtable data
        ddids = dask.compute(ddids)[0]
        fields = dask.compute(fields)[0]
        spws = dask.compute(spws)[0]
        pols = dask.compute(pols)[0]

        out_data = []
        for ds in xds:
            field = fields[ds.FIELD_ID]
            radec = field.PHASE_DIR.data.squeeze()

            # check fields match
            if radec_ref is None:
                radec_ref = radec

            if not np.array_equal(radec, radec_ref):
                continue

            # load in data and compute whitened residuals
            data = getattr(ds, args.data_column).data
            model = getattr(ds, args.model_column).data
            flag = getattr(ds, args.flag_column).data
            flag = da.logical_or(flag, ds.FLAG_ROW.data[:, None, None])
            weights = getattr(ds, args.weight_column).data
            if len(weights.shape) < 3:
                weights = da.broadcast_to(weights[:, None, :],
                                          data.shape,
                                          chunks=data.chunks)

            if args.trim_channels:
                flag = trim_chans(flag, args.trim_channels)

            # Stokes I vis
            weights = (~flag) * weights
            resid_vis = (data - model) * weights
            wsums = (weights[:, :, 0] + weights[:, :, -1])
            resid_vis_I = da.where(
                wsums, (resid_vis[:, :, 0] + resid_vis[:, :, -1]) / wsums,
                0.0j)

            # whiten and take abs
            white_resid = resid_vis_I * da.sqrt(wsums)
            abs_resid_vis_I = (white_resid).__abs__()

            # mean amp
            sum_amp = da.sum(abs_resid_vis_I)
            count = da.sum(wsums > 0)
            mean_amp = sum_amp / count

            flag_legacy = flag[:, :, 0] | flag[:, :, -1]
            flag_I = da.logical_or(abs_resid_vis_I > args.sigma_cut * mean_amp,
                                   flag_legacy)

            # new flags
            updated_flag = da.broadcast_to(flag_I[:, :, None],
                                           flag.shape,
                                           chunks=flag.chunks)

            # scale weights (whitened residuals should have mean amplitude of 1/sqrt(2))
            if args.scale_weights:
                # recompute mean amp with new flags
                weights = (~updated_flag) * weights
                resid_vis = (data - model) * weights
                wsums = (weights[:, :, 0] + weights[:, :, -1])
                resid_vis_I = da.where(
                    wsums, (resid_vis[:, :, 0] + resid_vis[:, :, -1]) / wsums,
                    0.0j)
                white_resid = resid_vis_I * da.sqrt(wsums)
                abs_resid_vis_I = (white_resid).__abs__()
                sum_amp = da.sum(abs_resid_vis_I)
                count = da.sum(wsums > 0)
                mean_amp = sum_amp / count
                updated_weight = 2**0.5 * weights / mean_amp**2
            else:
                updated_weight = weights

            ds = ds.assign(**{
                args.weight_out_column: (("row", "chan", "corr"),
                                         updated_weight)
            })
            ds = ds.assign(**{
                args.flag_out_column: (("row", "chan", "corr"), updated_flag)
            })

            out_data.append(ds)
        writes.append(
            xds_to_table(
                out_data,
                ims,
                columns=[args.flag_out_column, args.weight_out_column]))

    with ProgressBar():
        dask.compute(writes)

    # report new mean amp
    if args.report_means:
        radec_ref = None
        mean_amps = []
        for ims in args.ms:
            xds = xds_from_ms(
                ims,
                group_cols=('FIELD_ID', 'DATA_DESC_ID'),
                chunks={
                    "row": args.row_chunks,
                    "chan": args.chan_chunks
                },
                columns=('UVW', args.data_column, args.weight_out_column,
                         args.model_column, args.flag_out_column, 'FLAG_ROW'))

            # subtables
            ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
            fields = xds_from_table(ims + "::FIELD", group_cols="__row__")
            spws = xds_from_table(ims + "::SPECTRAL_WINDOW",
                                  group_cols="__row__")
            pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__")

            # subtable data
            ddids = dask.compute(ddids)[0]
            fields = dask.compute(fields)[0]
            spws = dask.compute(spws)[0]
            pols = dask.compute(pols)[0]

            for ds in xds:
                field = fields[ds.FIELD_ID]
                radec = field.PHASE_DIR.data.squeeze()

                # check fields match
                if radec_ref is None:
                    radec_ref = radec

                if not np.array_equal(radec, radec_ref):
                    continue

                # load in data and compute whitened residuals
                data = getattr(ds, args.data_column).data
                model = getattr(ds, args.model_column).data
                flag = getattr(ds, args.flag_out_column).data
                flag = da.logical_or(flag, ds.FLAG_ROW.data[:, None, None])
                weights = getattr(ds, args.weight_out_column).data
                if len(weights.shape) < 3:
                    weights = da.broadcast_to(weights[:, None, :],
                                              data.shape,
                                              chunks=data.chunks)

                # Stokes I vis
                weights = (~flag) * weights
                resid_vis = (data - model) * weights
                wsums = (weights[:, :, 0] + weights[:, :, -1])
                resid_vis_I = da.where(
                    wsums, (resid_vis[:, :, 0] + resid_vis[:, :, -1]) / wsums,
                    0.0j)

                # whiten and take abs
                white_resid = resid_vis_I * da.sqrt(wsums)
                abs_resid_vis_I = (white_resid).__abs__()

                # mean amp
                sum_amp = da.sum(abs_resid_vis_I)
                count = da.sum(wsums > 0)
                mean_amps.append(sum_amp / count)

        mean_amps = dask.compute(mean_amps)[0]

        print(mean_amps)
Exemplo n.º 29
0
def _main(args):
    tic = time.time()

    log.info(banner())

    if args.disable_post_mortem:
        log.warn("Disabling crash debugging with the "
                 "Interactive Python Debugger, as per user request")
        post_mortem_handler.disable_pdb_on_error()

    log.info("Flagging on the {0:s} column".format(args.data_column))
    data_column = args.data_column
    masked_channels = [
        load_mask(fn, dilate=args.dilate_masks) for fn in collect_masks()
    ]
    GD = args.config

    log_configuration(args)

    # Group datasets by these columns
    group_cols = ["FIELD_ID", "DATA_DESC_ID", "SCAN_NUMBER"]
    # Index datasets by these columns
    index_cols = ['TIME']

    # Reopen the datasets using the aggregated row ordering
    columns = [data_column, "FLAG", "TIME", "ANTENNA1", "ANTENNA2"]

    if args.subtract_model_column is not None:
        columns.append(args.subtract_model_column)

    xds = list(
        xds_from_ms(args.ms,
                    columns=tuple(columns),
                    group_cols=group_cols,
                    index_cols=index_cols,
                    chunks={"row": args.row_chunks}))

    # Get support tables
    st = support_tables(args.ms)
    ddid_ds = st["DATA_DESCRIPTION"]
    field_ds = st["FIELD"]
    pol_ds = st["POLARIZATION"]
    spw_ds = st["SPECTRAL_WINDOW"]
    ant_ds = st["ANTENNA"]

    assert len(ant_ds) == 1
    assert len(ddid_ds) == 1

    antspos = ant_ds[0].POSITION.data
    antsnames = ant_ds[0].NAME.data
    fieldnames = [fds.NAME.data[0] for fds in field_ds]

    avail_scans = [ds.SCAN_NUMBER for ds in xds]
    args.scan_numbers = list(
        set(avail_scans).intersection(args.scan_numbers if args.scan_numbers
                                      is not None else avail_scans))

    if args.scan_numbers != []:
        log.info("Only considering scans '{0:s}' as "
                 "per user selection criterion".format(", ".join(
                     map(str, map(int, args.scan_numbers)))))

    if args.field_names != []:
        flatten_field_names = []
        for f in args.field_names:
            # accept comma lists per specification
            flatten_field_names += [x.strip() for x in f.split(",")]
        for f in flatten_field_names:
            if re.match(r"^\d+$", f) and int(f) < len(fieldnames):
                flatten_field_names.append(fieldnames[int(f)])
        flatten_field_names = list(
            set(
                filter(lambda x: not re.match(r"^\d+$", x),
                       flatten_field_names)))
        log.info("Only considering fields '{0:s}' for flagging per "
                 "user "
                 "selection criterion.".format(", ".join(flatten_field_names)))
        if not set(flatten_field_names) <= set(fieldnames):
            raise ValueError("One or more fields cannot be "
                             "found in dataset '{0:s}' "
                             "You specified {1:s}, but "
                             "only {2:s} are available".format(
                                 args.ms, ",".join(flatten_field_names),
                                 ",".join(fieldnames)))

        field_dict = {fieldnames.index(fn): fn for fn in flatten_field_names}
    else:
        field_dict = {i: fn for i, fn in enumerate(fieldnames)}

    # List which hold our dask compute graphs for each dataset
    write_computes = []
    original_stats = []
    final_stats = []

    # Iterate through each dataset
    for ds in xds:
        if ds.FIELD_ID not in field_dict:
            continue

        if (args.scan_numbers is not None
                and ds.SCAN_NUMBER not in args.scan_numbers):
            continue

        log.info("Adding field '{0:s}' scan {1:d} to "
                 "compute graph for processing".format(field_dict[ds.FIELD_ID],
                                                       ds.SCAN_NUMBER))

        ddid = ddid_ds[ds.attrs['DATA_DESC_ID']]
        spw_info = spw_ds[ddid.SPECTRAL_WINDOW_ID.data[0]]
        pol_info = pol_ds[ddid.POLARIZATION_ID.data[0]]

        nrow, nchan, ncorr = getattr(ds, data_column).data.shape

        # Visibilities from the dataset
        vis = getattr(ds, data_column).data
        if args.subtract_model_column is not None:
            log.info("Forming residual data between '{0:s}' and "
                     "'{1:s}' for flagging.".format(
                         data_column, args.subtract_model_column))
            vismod = getattr(ds, args.subtract_model_column).data
            vis = vis - vismod

        antenna1 = ds.ANTENNA1.data
        antenna2 = ds.ANTENNA2.data
        chan_freq = spw_info.CHAN_FREQ.data[0]
        chan_width = spw_info.CHAN_WIDTH.data[0]

        # Generate unflagged defaults if we should ignore existing flags
        # otherwise take flags from the dataset
        if args.ignore_flags is True:
            flags = da.full_like(vis, False, dtype=np.bool)
            log.critical("Completely ignoring measurement set "
                         "flags as per '-if' request. "
                         "Strategy WILL NOT or with original flags, even if "
                         "specified!")
        else:
            flags = ds.FLAG.data

        # If we're flagging on polarised intensity,
        # we convert visibilities to polarised intensity
        # and any flagged correlation will flag the entire visibility
        if args.flagging_strategy == "polarisation":
            corr_type = pol_info.CORR_TYPE.data[0].tolist()
            stokes_map = stokes_corr_map(corr_type)
            stokes_pol = tuple(v for k, v in stokes_map.items() if k != "I")
            vis = polarised_intensity(vis, stokes_pol)
            flags = da.any(flags, axis=2, keepdims=True)
        elif args.flagging_strategy == "total_power":
            if args.subtract_model_column is None:
                log.critical("You requested to flag total quadrature "
                             "power, but not on residuals. "
                             "This is not advisable and the flagger "
                             "may mistake fringes of "
                             "off-axis sources for broadband RFI.")
            corr_type = pol_info.CORR_TYPE.data[0].tolist()
            stokes_map = stokes_corr_map(corr_type)
            stokes_pol = tuple(v for k, v in stokes_map.items())
            vis = polarised_intensity(vis, stokes_pol)
            flags = da.any(flags, axis=2, keepdims=True)
        elif args.flagging_strategy == "standard":
            if args.subtract_model_column is None:
                log.critical("You requested to flag per correlation, "
                             "but not on residuals. "
                             "This is not advisable and the flagger "
                             "may mistake fringes of off-axis sources "
                             "for broadband RFI.")
        else:
            raise ValueError("Invalid flagging strategy '%s'" %
                             args.flagging_strategy)

        ubl = unique_baselines(antenna1, antenna2)
        utime, time_inv = da.unique(ds.TIME.data, return_inverse=True)
        utime, ubl = dask.compute(utime, ubl)
        ubl = ubl.view(np.int32).reshape(-1, 2)
        # Stack the baseline index with the unique baselines
        bl_range = np.arange(ubl.shape[0], dtype=ubl.dtype)[:, None]
        ubl = np.concatenate([bl_range, ubl], axis=1)
        ubl = da.from_array(ubl, chunks=(args.baseline_chunks, 3))

        vis_windows, flag_windows = pack_data(time_inv,
                                              ubl,
                                              antenna1,
                                              antenna2,
                                              vis,
                                              flags,
                                              utime.shape[0],
                                              backend=args.window_backend,
                                              path=args.temporary_directory)

        original_stats.append(
            window_stats(flag_windows, ubl, chan_freq, antsnames,
                         ds.SCAN_NUMBER, field_dict[ds.FIELD_ID],
                         ds.attrs['DATA_DESC_ID']))

        with StrategyExecutor(antspos, ubl, chan_freq, chan_width,
                              masked_channels, GD['strategies']) as se:

            flag_windows = se.apply_strategies(flag_windows, vis_windows)

        final_stats.append(
            window_stats(flag_windows, ubl, chan_freq, antsnames,
                         ds.SCAN_NUMBER, field_dict[ds.FIELD_ID],
                         ds.attrs['DATA_DESC_ID']))

        # Unpack window data for writing back to the MS
        unpacked_flags = unpack_data(antenna1, antenna2, time_inv, ubl,
                                     flag_windows)

        # Flag entire visibility if any correlations are flagged
        equalized_flags = da.sum(unpacked_flags, axis=2, keepdims=True) > 0
        corr_flags = da.broadcast_to(equalized_flags, (nrow, nchan, ncorr))

        if corr_flags.chunks != ds.FLAG.data.chunks:
            raise ValueError("Output flag chunking does not "
                             "match input flag chunking")

        # Create new dataset containing new flags
        new_ds = ds.assign(FLAG=(("row", "chan", "corr"), corr_flags))

        # Write back to original dataset
        writes = xds_to_table(new_ds, args.ms, "FLAG")
        # original should also have .compute called because we need stats
        write_computes.append(writes)

    if len(write_computes) > 0:
        # Combine stats from all datasets
        original_stats = combine_window_stats(original_stats)
        final_stats = combine_window_stats(final_stats)

        with contextlib.ExitStack() as stack:
            # Create dask profiling contexts
            profilers = []

            if can_profile:
                profilers.append(stack.enter_context(Profiler()))
                profilers.append(stack.enter_context(CacheProfiler()))
                profilers.append(stack.enter_context(ResourceProfiler()))

            if sys.stdout.isatty():
                # Interactive terminal, default ProgressBar
                stack.enter_context(ProgressBar())
            else:
                # Non-interactive, emit a bar every 5 minutes so
                # as not to spam the log
                stack.enter_context(ProgressBar(minimum=1, dt=5 * 60))

            _, original_stats, final_stats = dask.compute(
                write_computes, original_stats, final_stats)
        if can_profile:
            visualize(profilers)

        toc = time.time()

        # Log each summary line
        for line in summarise_stats(final_stats, original_stats):
            log.info(line)

        elapsed = toc - tic
        log.info("Data flagged successfully in "
                 "{0:02.0f}h{1:02.0f}m{2:02.0f}s".format((elapsed // 60) // 60,
                                                         (elapsed // 60) % 60,
                                                         elapsed % 60))
    else:
        log.info("User data selection criteria resulted in empty dataset. "
                 "Nothing to be done. Bye!")
Exemplo n.º 30
0
    def __call__(self, signal, out=None, axes=None):
        """Slice the signal according to the ROI, and return it.

        Arguments
        ---------
        signal : Signal
            The signal to slice with the ROI.
        out : Signal, default = None
            If the 'out' argument is supplied, the sliced output will be put
            into this instead of returning a Signal. See Signal.__getitem__()
            for more details on 'out'.
        axes : specification of axes to use, default = None
            The axes argument specifies which axes the ROI will be applied on.
            The items in the collection can be either of the following:
                * a tuple of:
                    - DataAxis. These will not be checked with
                      signal.axes_manager.
                    - anything that will index signal.axes_manager
                * For any other value, it will check whether the navigation
                  space can fit the right number of axis, and use that if it
                  fits. If not, it will try the signal space.
        """

        if axes is None and signal in self.signal_map:
            axes = self.signal_map[signal][1]
        else:
            axes = self._parse_axes(axes, signal.axes_manager)

        natax = signal.axes_manager._get_axes_in_natural_order()
        # Slice original data with a circumscribed rectangle
        cx = self.cx + 0.5001 * axes[0].scale
        cy = self.cy + 0.5001 * axes[1].scale
        ranges = [[cx - self.r, cx + self.r],
                  [cy - self.r, cy + self.r]]
        slices = self._make_slices(natax, axes, ranges)
        ir = [slices[natax.index(axes[0])],
              slices[natax.index(axes[1])]]
        vx = axes[0].axis[ir[0]] - cx
        vy = axes[1].axis[ir[1]] - cy
        gx, gy = np.meshgrid(vx, vy)
        gr = gx**2 + gy**2
        mask = gr > self.r**2
        if self.r_inner != t.Undefined:
            mask |= gr < self.r_inner**2
        tiles = []
        shape = []
        chunks = []
        for i in range(len(slices)):
            if signal._lazy:
                chunks.append(signal.data.chunks[i][0])
            if i == natax.index(axes[0]):
                thisshape = mask.shape[0]
                tiles.append(thisshape)
                shape.append(thisshape)
            elif i == natax.index(axes[1]):
                thisshape = mask.shape[1]
                tiles.append(thisshape)
                shape.append(thisshape)
            else:
                tiles.append(signal.axes_manager._axes[i].size)
                shape.append(1)
        mask = mask.reshape(shape)

        nav_axes = [ax.navigate for ax in axes]
        nav_dim = signal.axes_manager.navigation_dimension
        if True in nav_axes:
            if False in nav_axes:

                slicer = signal.inav[slices[:nav_dim]].isig.__getitem__
                slices = slices[nav_dim:]
            else:
                slicer = signal.inav.__getitem__
                slices = slices[0:nav_dim]
        else:
            slicer = signal.isig.__getitem__
            slices = slices[nav_dim:]

        roi = slicer(slices, out=out)
        roi = out or roi
        if roi._lazy:
            import dask.array as da
            mask = da.from_array(mask, chunks=chunks)
            mask = da.broadcast_to(mask, tiles)
            # By default promotes dtype to float if required
            roi.data = da.where(mask, np.nan, roi.data)
        else:
            mask = np.broadcast_to(mask, tiles)
            roi.data = np.ma.masked_array(roi.data, mask, hard_mask=True)
        if out is None:
            return roi
        else:
            out.events.data_changed.trigger(out)
Exemplo n.º 31
0
    def __call__(self, signal, out=None, axes=None):
        """Slice the signal according to the ROI, and return it.

        Arguments
        ---------
        signal : Signal
            The signal to slice with the ROI.
        out : Signal, default = None
            If the 'out' argument is supplied, the sliced output will be put
            into this instead of returning a Signal. See Signal.__getitem__()
            for more details on 'out'.
        axes : specification of axes to use, default = None
            The axes argument specifies which axes the ROI will be applied on.
            The items in the collection can be either of the following:
                * a tuple of:
                    - DataAxis. These will not be checked with
                      signal.axes_manager.
                    - anything that will index signal.axes_manager
                * For any other value, it will check whether the navigation
                  space can fit the right number of axis, and use that if it
                  fits. If not, it will try the signal space.
        """

        if axes is None and signal in self.signal_map:
            axes = self.signal_map[signal][1]
        else:
            axes = self._parse_axes(axes, signal.axes_manager)

        natax = signal.axes_manager._get_axes_in_natural_order()
        # Slice original data with a circumscribed rectangle
        cx = self.cx + 0.5001 * axes[0].scale
        cy = self.cy + 0.5001 * axes[1].scale
        ranges = [[cx - self.r, cx + self.r], [cy - self.r, cy + self.r]]
        slices = self._make_slices(natax, axes, ranges)
        ir = [slices[natax.index(axes[0])], slices[natax.index(axes[1])]]
        vx = axes[0].axis[ir[0]] - cx
        vy = axes[1].axis[ir[1]] - cy
        gx, gy = np.meshgrid(vx, vy)
        gr = gx**2 + gy**2
        mask = gr > self.r**2
        if self.r_inner != t.Undefined:
            mask |= gr < self.r_inner**2
        tiles = []
        shape = []
        chunks = []
        for i in range(len(slices)):
            if signal._lazy:
                chunks.append(signal.data.chunks[i][0])
            if i == natax.index(axes[0]):
                thisshape = mask.shape[0]
                tiles.append(thisshape)
                shape.append(thisshape)
            elif i == natax.index(axes[1]):
                thisshape = mask.shape[1]
                tiles.append(thisshape)
                shape.append(thisshape)
            else:
                tiles.append(signal.axes_manager._axes[i].size)
                shape.append(1)
        mask = mask.reshape(shape)

        nav_axes = [ax.navigate for ax in axes]
        nav_dim = signal.axes_manager.navigation_dimension
        if True in nav_axes:
            if False in nav_axes:

                slicer = signal.inav[slices[:nav_dim]].isig.__getitem__
                slices = slices[nav_dim:]
            else:
                slicer = signal.inav.__getitem__
                slices = slices[0:nav_dim]
        else:
            slicer = signal.isig.__getitem__
            slices = slices[nav_dim:]

        roi = slicer(slices, out=out)
        roi = out or roi
        if roi._lazy:
            import dask.array as da
            mask = da.from_array(mask, chunks=chunks)
            mask = da.broadcast_to(mask, tiles)
            # By default promotes dtype to float if required
            roi.data = da.where(mask, np.nan, roi.data)
        else:
            mask = np.broadcast_to(mask, tiles)
            roi.data = np.ma.masked_array(roi.data, mask, hard_mask=True)
        if out is None:
            return roi
        else:
            out.events.data_changed.trigger(out)
Exemplo n.º 32
0
def extract_rot_cube(cube, min_lat, min_lon, max_lat, max_lon):
    """
    Function etracts the specific region from the cube.
    args
    ----
    cube: cube on rotated coord system, used as reference grid for transformation.
    Returns
    -------
    min_lat: The minimum latitude point of the desired extracted cube.
    min_lon: The minimum longitude point of the desired extracted cube.
    max_lat: The maximum latitude point of the desired extracted cube.
    max_lon: The maximum longitude point of the desired extracted cube.
    An example:
    >>> file = os.path.join(conf.DATA_DIR, 'rcm_monthly.pp')
    >>> cube = iris.load_cube(file, 'air_temperature')
    >>> min_lat = 50
    >>> min_lon = -10
    >>> max_lat = 60
    >>> max_lon = 0
    >>> extracted_cube = extract_rot_cube(cube, min_lat, min_lon, max_lat, max_lon)
    >>> max_lat_cube =  np.max(extracted_cube.coord('latitude').points)
    >>> print(f'{max_lat_cube:.3f}')
    61.365
    >>> min_lat_cube = np.min(extracted_cube.coord('latitude').points)
    >>> print(f'{min_lat_cube:.3f}')
    48.213
    >>> max_lon_cube = np.max(extracted_cube.coord('longitude').points)
    >>> print(f'{max_lon_cube:.3f}')
    3.643
    >>> min_lon_cube = np.min(extracted_cube.coord('longitude').points)
    >>> print(f'{min_lon_cube:.3f}')
    -16.292
    """

    # adding unrotated coords to the cube
    cube = add_aux_unrotated_coords(cube)

    # mask the cube using the true lat and lon
    lats = cube.coord("latitude").points
    lons = cube.coord("longitude").points
    select_lons = (lons >= min_lon) & (lons <= max_lon)
    select_lats = (lats >= min_lat) & (lats <= max_lat)
    selection = select_lats & select_lons
    selection = da.broadcast_to(selection, cube.shape)
    cube.data = da.ma.masked_where(~selection, cube.core_data())

    # grab a single 2D slice of X and Y and take the mask
    lon_coord = cube.coord(axis="X", dim_coords=True)
    lat_coord = cube.coord(axis="Y", dim_coords=True)
    for yx_slice in cube.slices(["grid_latitude", "grid_longitude"]):
        cmask = yx_slice.data.mask
        break

    # now cut the cube down along X and Y coords
    x1, x2, y1, y2 = _get_xy_noborder(cmask)
    idx = len(cube.shape) * [slice(None)]

    idx[cube.coord_dims(cube.coord(axis="x",
                                   dim_coords=True))[0]] = slice(x1, x2, 1)
    idx[cube.coord_dims(cube.coord(axis="y",
                                   dim_coords=True))[0]] = slice(y1, y2, 1)

    extracted_cube = cube[tuple(idx)]

    return extracted_cube