def broadcast_and_rechunk(array_tuple, chunks=None): dest_array_id = max_array_id(array_tuple) dest_shape = array_tuple[dest_array_id].shape if chunks is None: chunks = array_tuple[dest_array_id].chunks broadcast_list = [None] * len(array_tuple) for array_id, array in enumerate(array_tuple): try: arr_chunks = array.chunks if arr_chunks != chunks: broadcast_list[array_id] = (da.broadcast_to( add_dims_to_broadcast(array, dest_shape), dest_shape).rechunk(chunks)) elif dest_shape != array.shape: broadcast_list[array_id] = (da.broadcast_to( add_dims_to_broadcast(array, dest_shape), dest_shape)) else: broadcast_list[array_id] = array except AttributeError: broadcast_list[array_id] = da.from_array( array, chunks=remove_chunks_from_shape(array, dest_shape, chunks)) if dest_shape != array.shape: broadcast_list[array_id] = (da.broadcast_to( add_dims_to_broadcast(broadcast_list[array_id], dest_shape), dest_shape).rechunk(chunks)) return broadcast_list
def _get_zhai_data_frame(datasets, lat_constraint): """Get :class:`pandas.DataFrame` including the data for ``zhai``.""" cl_cube = _get_cube(datasets, 'cl') wap_cube = _get_cube(datasets, 'wap') tos_cube = _get_cube(datasets, 'tos') # Add air_pressure coordinate if necessary if not cl_cube.coords('air_pressure'): if cl_cube.coords('altitude'): add_plev_from_altitude(cl_cube) elif cl_cube.coords('atmosphere_sigma_coordinate'): add_sigma_factory(cl_cube) else: raise ValueError(f"No 'air_pressure' coord available in cube " f"{cl_cube.summary(shorten=True)}") # Apply common mask (only ocean) mask_2d = da.ma.getmaskarray(tos_cube.core_data()) mask_3d = mask_2d[:, np.newaxis, ...] mask_3d = da.broadcast_to(mask_3d, cl_cube.shape) wap_cube.data = da.ma.masked_array(wap_cube.core_data(), mask=mask_2d) cl_cube.data = da.ma.masked_array(cl_cube.core_data(), mask=mask_3d) # Calculate SST mean and MBLC fraction tos_cube = _get_mean_over_subsidence(tos_cube, wap_cube, lat_constraint) mblc_cube = _get_seasonal_mblc_fraction(cl_cube, wap_cube, lat_constraint) return pd.DataFrame( { 'tos': tos_cube.data, 'mblc_fraction': mblc_cube.data }, index=pd.Index(np.arange(12) + 1, name='month'), )
def compute(fieldset): # Calculating vertical weighted average for f in [fieldset.U, fieldset.V]: for tind in f.loaded_time_indices: data = da.sum(f.data[tind, :] * DZ, axis=0) / sum(dz) data = da.broadcast_to(data, (1, f.grid.zdim, f.grid.ydim, f.grid.xdim)) f.data = f.data_concatenate(f.data, data, tind)
def tsLab(series, seg, label): segnp = sitk.GetArrayFromImage(seg) # The mask indicates which entries to ignore. m3 = da.broadcast_to(segnp != label, series.shape) smk = da.ma.masked_array(series, m3) p = smk.mean(axis=[2, 3]) l = p.compute() return l
def _mask_cube(cube, selections): cubelist = iris.cube.CubeList() for id_, select in selections.items(): _cube = cube.copy() _cube.add_aux_coord( iris.coords.AuxCoord(id_, units='no_unit', long_name="shape_id")) select = da.broadcast_to(select, _cube.shape) _cube.data = da.ma.masked_where(~select, _cube.core_data()) cubelist.append(_cube) return fix_coordinate_ordering(cubelist.merge_cube())
def add_cell_measure(cube, fx_cube, measure): """ Broadcast fx_cube and add it as a cell_measure in the cube containing the data. Parameters ---------- cube: iris.cube.Cube Iris cube with input data. fx_cube: iris.cube.Cube Iris cube with fx data. measure: str Name of the measure, can be 'area' or 'volume'. Returns ------- iris.cube.Cube Cube with added ancillary variables Raises ------ ValueError If measure name is not 'area' or 'volume'. ValueError If fx_cube cannot be broadcast to cube. """ if measure not in ['area', 'volume']: raise ValueError(f"measure name must be 'area' or 'volume', " f"got {measure} instead") try: fx_data = da.broadcast_to(fx_cube.core_data(), cube.shape) except ValueError as exc: raise ValueError(f"Dimensions of {cube.var_name} and " f"{fx_cube.var_name} cubes do not match. " "Cannot broadcast cubes.") from exc measure = iris.coords.CellMeasure(fx_data, standard_name=fx_cube.standard_name, units=fx_cube.units, measure=measure, var_name=fx_cube.var_name, attributes=fx_cube.attributes) cube.add_cell_measure(measure, range(0, measure.ndim)) logger.debug('Added %s as cell measure in cube of %s.', fx_cube.var_name, cube.var_name)
def diag_dot(diag_array, x, return_diag=False): """Computes dot product between diag_array and x Parameters ---------- diag_array : array_like, shape (K, ) or (K, 1) Diagonal entries of a diagonal maitrx [d1, d2, d3, ..., dk] -> [d1*e1, d2*e2, d2*e3, ..., dk*ek], where ei is the ith unit column vector x : array_like, shape (K, ...) return_diag : boolean If return_diag is True, return broadcasted array prepped for operation Returns ------- out : array_like, shape of x """ if len(x.shape) not in [1, 2]: raise ValueError( "x must have (M, K) or (K, ). Current Shape = {}".format(x.shape)) if diag_array.shape[0] != x.shape[0]: raise ValueError( 'shapes {} and {} not aligned: {} (dim 0 and 1) != {} (dim 0)'. format(diag_array.shape, x.shape, diag_array.shape[0], x.shape[0])) if len(diag_array.shape) not in [1, 2]: raise ValueError( 'diag_array must have dimension (K, ) or (K, 1). Current shape = {}' .format(diag_array.shape)) if len(x.shape) == 1: if len(diag_array.shape) == 2: d = np.squeeze(diag_array) else: d = diag_array else: if len(diag_array.shape) == 1: d = diag_array[:, np.newaxis] else: d = diag_array d = broadcast_to(d, x.shape) if return_diag: return d else: return np.multiply(d, x)
def coriolis(self, lats, ndim): """Compute Coriolis force. Parameters ---------- lats: iris.coord.Coord Latitude coordinate. ndim: int Number of dimension. Returns ------- fcor: da.array Array containing Coriolis force. """ fcor = 2.0 * self.omega * np.sin(np.radians(lats.points)) fcor = fcor[np.newaxis, np.newaxis, :, np.newaxis] fcor = da.broadcast_to(fcor, ndim) return fcor
def get_time_weights(cube): """Compute the weighting of the time axis. Parameters ---------- cube: iris.cube.Cube input cube. Returns ------- numpy.array Array of time weights for averaging. """ time = cube.coord('time') time_weights = time.bounds[..., 1] - time.bounds[..., 0] time_weights = time_weights.squeeze() if time_weights.shape == (): time_weights = da.broadcast_to(time_weights, cube.shape) else: time_weights = iris.util.broadcast_to_shape(time_weights, cube.shape, cube.coord_dims('time')) return time_weights
def apply(X: Array, YP: Array, BX: Array, BYP: Array) -> Array: # Collapse selected variant blocks and alphas into single # new covariate dimension assert YP.shape[2] == BYP.shape[2] n_group_covar = n_covar + BYP.shape[2] * n_alpha_1 BYP = BYP.reshape((n_outcome, n_sample_block, -1)) BG = da.concatenate((BX, BYP), axis=-1) BG = BG.rechunk((-1, None, -1)) assert_block_shape(BG, 1, n_sample_block, 1) assert_chunk_shape(BG, n_outcome, 1, n_group_covar) assert_array_shape(BG, n_outcome, n_sample_block, n_group_covar) YP = YP.reshape((n_outcome, n_sample, -1)) XYP = da.broadcast_to(X, (n_outcome, n_sample, n_covar)) XG = da.concatenate((XYP, YP), axis=-1) XG = XG.rechunk((-1, None, -1)) assert_block_shape(XG, 1, n_sample_block, 1) assert_chunk_shape(XG, n_outcome, sample_chunks[0], n_group_covar) assert_array_shape(XG, n_outcome, n_sample, n_group_covar) YG = da.map_blocks( # Block chunks: # (n_outcome, sample_chunks[0], n_group_covar) @ # (n_outcome, n_group_covar, 1) [after transpose] lambda x, b: x @ b.transpose((0, 2, 1)), XG, BG, chunks=(n_outcome, sample_chunks, 1), ) assert_block_shape(YG, 1, n_sample_block, 1) assert_chunk_shape(YG, n_outcome, sample_chunks[0], 1) assert_array_shape(YG, n_outcome, n_sample, 1) YG = da.squeeze(YG, axis=-1).T assert_block_shape(YG, n_sample_block, 1) assert_chunk_shape(YG, sample_chunks[0], n_outcome) assert_array_shape(YG, n_sample, n_outcome) return YG
def push(array, n, axis): """ Dask-aware bottleneck.push """ import bottleneck import dask.array as da import numpy as np def _fill_with_last_one(a, b): # cumreduction apply the push func over all the blocks first so, the only missing part is filling # the missing values using the last data of the previous chunk return np.where(~np.isnan(b), b, a) if n is not None and 0 < n < array.shape[axis] - 1: arange = da.broadcast_to( da.arange(array.shape[axis], chunks=array.chunks[axis], dtype=array.dtype).reshape( tuple(size if i == axis else 1 for i, size in enumerate(array.shape))), array.shape, array.chunks, ) valid_arange = da.where(da.notnull(array), arange, np.nan) valid_limits = (arange - push(valid_arange, None, axis)) <= n # omit the forward fill that violate the limit return da.where(valid_limits, push(array, None, axis), np.nan) # The method parameter makes that the tests for python 3.7 fails. return da.reductions.cumreduction( func=bottleneck.push, binop=_fill_with_last_one, ident=np.nan, x=array, axis=axis, dtype=array.dtype, )
def add_ancillary_variable(cube, fx_cube): """ Broadcast fx_cube and add it as an ancillary_variable in the cube containing the data. Parameters ---------- cube: iris.cube.Cube Iris cube with input data. fx_cube: iris.cube.Cube Iris cube with fx data. Returns ------- iris.cube.Cube Cube with added ancillary variables Raises ------ ValueError If fx_cube cannot be broadcast to cube. """ try: fx_data = da.broadcast_to(fx_cube.core_data(), cube.shape) except ValueError as exc: raise ValueError(f"Dimensions of {cube.var_name} and " f"{fx_cube.var_name} cubes do not match. " "Cannot broadcast cubes.") from exc ancillary_var = iris.coords.AncillaryVariable( fx_data, standard_name=fx_cube.standard_name, units=fx_cube.units, var_name=fx_cube.var_name, attributes=fx_cube.attributes) cube.add_ancillary_variable(ancillary_var, range(0, ancillary_var.ndim)) logger.debug('Added %s as ancillary variable in cube of %s.', fx_cube.var_name, cube.var_name)
def _psf(**kw): args = OmegaConf.create(kw) from omegaconf import ListConfig if not isinstance(args.ms, list) and not isinstance(args.ms, ListConfig): args.ms = [args.ms] OmegaConf.set_struct(args, True) import numpy as np from pfb.utils.misc import chan_to_band_mapping import dask # from dask.distributed import performance_report from dask.graph_manipulation import clone from daskms import xds_from_storage_ms as xds_from_ms from daskms import xds_from_storage_table as xds_from_table from daskms import Dataset from daskms.experimental.zarr import xds_to_zarr import dask.array as da from africanus.constants import c as lightspeed from africanus.gridding.wgridder.dask import dirty as vis2im from ducc0.fft import good_size from pfb.utils.misc import stitch_images, plan_row_chunk from pfb.utils.fits import set_wcs, save_fits # chan <-> band mapping ms = args.ms nband = args.nband freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping( ms, nband=nband) # gridder memory budget max_chan_chunk = 0 max_freq = 0 for ims in args.ms: for spw in freqs[ims]: counts = freq_bin_counts[ims][spw].compute() freq = freqs[ims][spw].compute() max_chan_chunk = np.maximum(max_chan_chunk, counts.max()) max_freq = np.maximum(max_freq, freq.max()) # assumes measurement sets have the same columns, # number of correlations etc. xds = xds_from_ms(args.ms[0]) ncorr = xds[0].dims['corr'] nrow = xds[0].dims['row'] # we still have to cater for complex valued data because we cast # the weights to complex but we not longer need to factor the # weight column into our memory budget data_bytes = getattr(xds[0], args.data_column).data.itemsize bytes_per_row = max_chan_chunk * ncorr * data_bytes memory_per_row = bytes_per_row # flags (uint8 or bool) memory_per_row += bytes_per_row / 8 # UVW memory_per_row += xds[0].UVW.data.itemsize * 3 # ANTENNA1/2 memory_per_row += xds[0].ANTENNA1.data.itemsize * 2 # TIME memory_per_row += xds[0].TIME.data.itemsize # data column is not actually read into memory just used to infer # dtype and chunking columns = (args.data_column, args.weight_column, args.flag_column, 'UVW', 'ANTENNA1', 'ANTENNA2', 'TIME') # flag row if 'FLAG_ROW' in xds[0]: columns += ('FLAG_ROW', ) memory_per_row += xds[0].FLAG_ROW.data.itemsize # imaging weights if args.imaging_weight_column is not None: columns += (args.imaging_weight_column, ) memory_per_row += bytes_per_row / 2 # Mueller term (complex valued) if args.mueller_column is not None: columns += (args.mueller_column, ) memory_per_row += bytes_per_row # get max uv coords over all fields uvw = [] u_max = 0.0 v_max = 0.0 for ims in args.ms: xds = xds_from_ms(ims, columns=('UVW'), chunks={'row': -1}) for ds in xds: uvw = ds.UVW.data u_max = da.maximum(u_max, abs(uvw[:, 0]).max()) v_max = da.maximum(v_max, abs(uvw[:, 1]).max()) uv_max = da.maximum(u_max, v_max) uv_max = uv_max.compute() del uvw # image size cell_N = 1.0 / (2 * uv_max * max_freq / lightspeed) if args.cell_size is not None: cell_size = args.cell_size cell_rad = cell_size * np.pi / 60 / 60 / 180 if cell_N / cell_rad < 1: raise ValueError( "Requested cell size too small. " "Super resolution factor = ", cell_N / cell_rad) print("Super resolution factor = %f" % (cell_N / cell_rad), file=log) else: cell_rad = cell_N / args.super_resolution_factor cell_size = cell_rad * 60 * 60 * 180 / np.pi print("Cell size set to %5.5e arcseconds" % cell_size, file=log) if args.nx is None: fov = args.field_of_view * 3600 npix = int(args.psf_oversize * fov / cell_size) if npix % 2: npix += 1 nx = npix ny = npix else: nx = args.nx ny = args.ny if args.ny is not None else nx print("PSF size set to (%i, %i, %i)" % (nband, nx, ny), file=log) # get approx image size # this is not a conservative estimate when multiple SPW's map to a single # imaging band pixel_bytes = np.dtype(args.output_type).itemsize band_size = nx * ny * pixel_bytes if args.host_address is None: # full image on single node row_chunk = plan_row_chunk(args.mem_limit / args.nworkers, band_size, nrow, memory_per_row, args.nthreads_per_worker) else: # single band per node row_chunk = plan_row_chunk(args.mem_limit, band_size, nrow, memory_per_row, args.nthreads_per_worker) if args.row_chunks is not None: row_chunk = int(args.row_chunks) if row_chunk == -1: row_chunk = nrow print( "nrows = %i, row chunks set to %i for a total of %i chunks per node" % (nrow, row_chunk, int(np.ceil(nrow / row_chunk))), file=log) chunks = {} for ims in args.ms: chunks[ims] = [] # xds_from_ms expects a list per ds for spw in freqs[ims]: chunks[ims].append({ 'row': row_chunk, 'chan': chan_chunks[ims][spw]['chan'] }) psfs = [] radec = None # assumes we are only imaging field 0 of first MS out_datasets = [] for ims in args.ms: xds = xds_from_ms(ims, chunks=chunks[ims], columns=columns) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD") spws = xds_from_table(ims + "::SPECTRAL_WINDOW") pols = xds_from_table(ims + "::POLARIZATION") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] for ds in xds: field = fields[ds.FIELD_ID] # check fields match if radec is None: radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()): continue # this is not correct, need to use spw spw = ds.DATA_DESC_ID uvw = clone(ds.UVW.data) data_type = getattr(ds, args.data_column).data.dtype data_shape = getattr(ds, args.data_column).data.shape data_chunks = getattr(ds, args.data_column).data.chunks weights = getattr(ds, args.weight_column).data if len(weights.shape) < 3: weights = da.broadcast_to(weights[:, None, :], data_shape, chunks=data_chunks) if args.imaging_weight_column is not None: imaging_weights = getattr(ds, args.imaging_weight_column).data if len(imaging_weights.shape) < 3: imaging_weights = da.broadcast_to(imaging_weights[:, None, :], data_shape, chunks=data_chunks) weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0] weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1] else: weightsxx = weights[:, :, 0] weightsyy = weights[:, :, -1] # apply mueller term if args.mueller_column is not None: mueller = getattr(ds, args.mueller_column).data weightsxx *= da.absolute(mueller[:, :, 0])**2 weightsyy *= da.absolute(mueller[:, :, -1])**2 # weighted sum corr to Stokes I weights = weightsxx + weightsyy # MS may contain auto-correlations if 'FLAG_ROW' in xds[0]: frow = ds.FLAG_ROW.data | (ds.ANTENNA1.data == ds.ANTENNA2.data) else: frow = (ds.ANTENNA1.data == ds.ANTENNA2.data) # only keep data where both corrs are unflagged flag = getattr(ds, args.flag_column).data flagxx = flag[:, :, 0] flagyy = flag[:, :, -1] # ducc0 uses uint8 mask not flag mask = ~da.logical_or((flagxx | flagyy), frow[:, None]) psf = vis2im(uvw, freqs[ims][spw], weights.astype(data_type), freq_bin_idx[ims][spw], freq_bin_counts[ims][spw], nx, ny, cell_rad, flag=mask.astype(np.uint8), nthreads=args.nvthreads, epsilon=args.epsilon, do_wstacking=args.wstack, double_accum=args.double_accum) psfs.append(psf) data_vars = { 'FIELD_ID': (('row', ), da.full_like(ds.TIME.data, ds.FIELD_ID, chunks=args.row_out_chunk)), 'DATA_DESC_ID': (('row', ), da.full_like(ds.TIME.data, ds.DATA_DESC_ID, chunks=args.row_out_chunk)), 'WEIGHT': (('row', 'chan'), weights.rechunk({0: args.row_out_chunk })), # why no 'f4'? 'UVW': (('row', 'uvw'), uvw.rechunk({0: args.row_out_chunk})) } coords = {'chan': (('chan', ), freqs[ims][spw])} out_ds = Dataset(data_vars, coords) out_datasets.append(out_ds) writes = xds_to_zarr(out_datasets, args.output_filename + '.zarr', columns='ALL') # dask.visualize(writes, filename=args.output_filename + '_psf_writes_graph.pdf', optimize_graph=False) # dask.visualize(psfs, filename=args.output_filename + '_psf_graph.pdf', optimize_graph=False) if not args.mock: # psfs = dask.compute(psfs, writes, optimize_graph=False)[0] # with performance_report(filename=args.output_filename + '_psf_per.html'): psfs = dask.compute(psfs, writes, optimize_graph=False)[0] psf = stitch_images(psfs, nband, band_mapping) hdr = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec, freq_out) save_fits(args.output_filename + '_psf.fits', psf, hdr, dtype=args.output_type) psf_mfs = np.sum(psf, axis=0) wsum = psf_mfs.max() psf_mfs /= wsum hdr_mfs = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec, np.mean(freq_out)) save_fits(args.output_filename + '_psf_mfs.fits', psf_mfs, hdr_mfs, dtype=args.output_type) print("All done here.", file=log)
def get_plot_data(msinfo, group_cols, mytaql, chan_freqs, chanslice, subset, noflags, noconj, iter_field, iter_spw, iter_scan, join_corrs=False, row_chunk_size=100000): ms_cols = {'ANTENNA1', 'ANTENNA2'} if not noflags: ms_cols.update({'FLAG', 'FLAG_ROW'}) # get visibility columns for axis in DataAxis.all_axes.values(): ms_cols.update(axis.columns) # get MS data msdata = daskms.xds_from_ms(msinfo.msname, columns=list(ms_cols), group_cols=group_cols, taql_where=mytaql, chunks=dict(row=row_chunk_size)) log.info(f': Indexing MS and building dataframes (chunk size is {row_chunk_size})') np = 0 # number of points to plot # output dataframes, indexed by (field, spw, scan, antenna, correlation) # If any of these axes is not being iterated over, then the index is None output_dataframes = OrderedDict() # # make prototype dataframe # import pandas # # # iterate over groups for group in msdata: ddid = group.DATA_DESC_ID # always present fld = group.FIELD_ID # always present if fld not in subset.field or ddid not in subset.spw: log.debug(f"field {fld} ddid {ddid} not in selection, skipping") continue scan = getattr(group, 'SCAN_NUMBER', None) # will be present if iterating over scans # TODO: antenna iteration. None forces no iteration, for now antenna = None # always read flags -- easier that way flag = group.FLAG if not noflags else None flag_row = group.FLAG_ROW if not noflags else None baselines = group.ANTENNA1*len(msinfo.antenna) + group.ANTENNA2 freqs = chan_freqs[ddid] chans = xarray.DataArray(range(len(freqs)), dims=("chan",)) wavel = freq_to_wavel(freqs) extras = dict(chans=chans, freqs=freqs, wavel=wavel, rows=group.row, baselines=baselines) nchan = len(group.chan) if flag is not None: flag = flag[dict(chan=chanslice)] nchan = flag.shape[1] shape = (len(group.row), nchan) datums = OrderedDict() for corr in subset.corr.numbers: # make dictionary of extra values for DataMappers extras['corr'] = corr # loop over datums to be computed for axis in DataAxis.all_axes.values(): value = datums[axis.label][-1] if axis.label in datums else None # a datum was already computed? if value is not None: # if not joining correlations, then that's the only one we'll need, so continue if not join_corrs: continue # joining correlations, and datum has a correlation dependence: compute another one if axis.corr is None: value = None if value is None: value = axis.get_value(group, corr, extras, flag=flag, flag_row=flag_row, chanslice=chanslice) # reshape values of shape NTIME to (NTIME,1) and NFREQ to (1,NFREQ), and scalar to (NTIME,1) if value.ndim == 1: timefreq_axis = axis.mapper.axis or 0 assert value.shape[0] == shape[timefreq_axis], \ f"{axis.mapper.fullname}: size {value.shape[0]}, expected {shape[timefreq_axis]}" shape1 = [1,1] shape1[timefreq_axis] = value.shape[0] value = value.reshape(shape1) if timefreq_axis > 0: value = da.broadcast_to(value, shape) log.debug(f"axis {axis.mapper.fullname} has shape {value.shape}") # else 2D value better match expected shape else: assert value.shape == shape, f"{axis.mapper.fullname}: shape {value.shape}, expected {shape}" datums.setdefault(axis.label, []).append(value) # if joining correlations, stick all elements together. Otherwise, we'd better have one per label if join_corrs: datums = OrderedDict({label: da.concatenate(arrs) for label, arrs in datums.items()}) else: assert all([len(arrs) == 1 for arrs in datums.values()]) datums = OrderedDict({label: arrs[0] for label, arrs in datums.items()}) # broadcast to same shape, and unravel all datums datums = OrderedDict({ key: arr.ravel() for key, arr in zip(datums.keys(), da.broadcast_arrays(*datums.values()))}) # if any axis needs to be conjugated, double up all of them if not noconj and any([axis.conjugate for axis in DataAxis.all_axes.values()]): for axis in DataAxis.all_axes.values(): if axis.conjugate: datums[axis.label] = da.concatenate([datums[axis.label], -datums[axis.label]]) else: datums[axis.label] = da.concatenate([datums[axis.label], datums[axis.label]]) labels, values = list(datums.keys()), list(datums.values()) np += values[0].size # now stack them all into a big dataframe rectype = [(axis.label, numpy.int32 if axis.nlevels else numpy.float32) for axis in DataAxis.all_axes.values()] recarr = da.empty_like(values[0], dtype=rectype) ddf = dask_df.from_array(recarr) for label, value in zip(labels, values): ddf[label] = value # now, are we iterating or concatenating? Make frame key accordingly dataframe_key = (fld if iter_field else None, ddid if iter_spw else None, scan if iter_scan else None, antenna) # do we already have a frame for this key ddf0 = output_dataframes.get(dataframe_key) if ddf0 is None: log.debug(f"first frame for {dataframe_key}") output_dataframes[dataframe_key] = ddf else: log.debug(f"appending to frame for {dataframe_key}") output_dataframes[dataframe_key] = ddf0.append(ddf) # convert discrete axes into categoricals if data_mappers.USE_COUNT_CAT: categorical_axes = [axis.label for axis in DataAxis.all_axes.values() if axis.nlevels] if categorical_axes: log.info(": counting colours") for key, ddf in list(output_dataframes.items()): output_dataframes[key] = ddf.categorize(categorical_axes) log.info(": complete") return output_dataframes, np
def extract_shape(cube, shapefile, method='contains', crop=True, decomposed=False): """ Extract a region defined by a shapefile. Note that this function does not work for shapes crossing the prime meridian or poles. Parameters ---------- cube: iris.cube.Cube input cube. shapefile: str A shapefile defining the region(s) to extract. method: str, optional Select all points contained by the shape or select a single representative point. Choose either 'contains' or 'representative'. If 'contains' is used, but not a single grid point is contained by the shape, a representative point will selected. crop: bool, optional Crop the resulting cube using `extract_region()`. Note that data on irregular grids will not be cropped. decomposed: bool, optional Whether or not to retain the sub shapes of the shapefile in the output. If this is set to True, the output cube has a dimension for the sub shapes. Returns ------- iris.cube.Cube Cube containing the extracted region. See Also -------- extract_region : Extract a region from a cube. """ with fiona.open(shapefile) as geometries: # get parameters specific to the shapefile (NE used case # eg longitudes [-180, 180] or latitude missing # or overflowing edges) cmor_coords = True pad_north_pole = False pad_hawaii = False if geometries.bounds[0] < 0: cmor_coords = False if geometries.bounds[1] > -90. and geometries.bounds[1] < -85.: pad_north_pole = True if geometries.bounds[0] > -180. and geometries.bounds[0] < 179.: pad_hawaii = True if crop: cube = _crop_cube(cube, *geometries.bounds, cmor_coords=cmor_coords) lon, lat = _correct_coords_from_shapefile(cube, cmor_coords, pad_north_pole, pad_hawaii) selections = _get_masks_from_geometries(geometries, lon, lat, method=method, decomposed=decomposed) cubelist = iris.cube.CubeList() for id_, select in selections.items(): _cube = cube.copy() _cube.add_aux_coord( iris.coords.AuxCoord(id_, units='no_unit', long_name="shape_id")) select = da.broadcast_to(select, _cube.shape) _cube.data = da.ma.masked_where(~select, _cube.core_data()) cubelist.append(_cube) cube = cubelist.merge_cube() return fix_coordinate_ordering(cube)
def extract_region(cube, start_longitude, end_longitude, start_latitude, end_latitude): """ Extract a region from a cube. Function that subsets a cube on a box (start_longitude, end_longitude, start_latitude, end_latitude) This function is a restriction of masked_cube_lonlat(). Parameters ---------- cube: iris.cube.Cube input data cube. start_longitude: float Western boundary longitude. end_longitude: float Eastern boundary longitude. start_latitude: float Southern Boundary latitude. end_latitude: float Northern Boundary Latitude. Returns ------- iris.cube.Cube smaller cube. """ if abs(start_latitude) > 90.: raise ValueError(f"Invalid start_latitude: {start_latitude}") if abs(end_latitude) > 90.: raise ValueError(f"Invalid end_latitude: {end_latitude}") if cube.coord('latitude').ndim == 1: # Iris check if any point of the cell is inside the region # To check only the center, ignore_bounds must be set to # True (default) is False region_subset = cube.intersection( longitude=(start_longitude, end_longitude), latitude=(start_latitude, end_latitude), ignore_bounds=True, ) region_subset = region_subset.intersection(longitude=(0., 360.)) return region_subset # Irregular grids lats = cube.coord('latitude').points lons = cube.coord('longitude').points # Convert longitudes to valid range if start_longitude != 360.: start_longitude %= 360. if end_longitude != 360.: end_longitude %= 360. if start_longitude <= end_longitude: select_lons = (lons >= start_longitude) & (lons <= end_longitude) else: select_lons = (lons >= start_longitude) | (lons <= end_longitude) if start_latitude <= end_latitude: select_lats = (lats >= start_latitude) & (lats <= end_latitude) else: select_lats = (lats >= start_latitude) | (lats <= end_latitude) selection = select_lats & select_lons selection = da.broadcast_to(selection, cube.shape) cube.data = da.ma.masked_where(~selection, cube.core_data()) return cube
def broadcast_signals(*args, ignore_axis=None): """Broadcasts all passed signals according to the HyperSpy broadcasting rules: signal and navigation spaces are each separately broadcasted according to the numpy broadcasting rules. One axis can be ignored and left untouched (or set to be size 1) across all signals. Parameters ---------- *args : BaseSignal Signals to broadcast together ignore_axis : {None, str, int, Axis} The axis to be ignored when broadcasting Returns ------- list of signals """ if len(args) < 2: raise ValueError( "This function requires at least two signal instances") args = list(args) if not are_signals_aligned(*args, ignore_axis=ignore_axis): raise ValueError("The signals cannot be broadcasted") else: if ignore_axis is not None: for s in args: try: ignore_axis = s.axes_manager[ignore_axis] break except ValueError: pass new_nav_axes = [] new_nav_shapes = [] for axes in zip_longest(*[s.axes_manager.navigation_axes for s in args], fillvalue=None): only_left = filter(lambda x: x is not None, axes) longest = sorted(only_left, key=lambda x: x.size, reverse=True)[0] new_nav_axes.append(longest) new_nav_shapes.append(longest.size if (ignore_axis is None or ignore_axis not in axes) else None) new_sig_axes = [] new_sig_shapes = [] for axes in zip_longest(*[s.axes_manager.signal_axes for s in args], fillvalue=None): only_left = filter(lambda x: x is not None, axes) longest = sorted(only_left, key=lambda x: x.size, reverse=True)[0] new_sig_axes.append(longest) new_sig_shapes.append(longest.size if (ignore_axis is None or ignore_axis not in axes) else None) results = [] new_axes = new_nav_axes[::-1] + new_sig_axes[::-1] new_data_shape = new_nav_shapes[::-1] + new_sig_shapes[::-1] for s in args: data = s._data_aligned_with_axes sam = s.axes_manager sdim_diff = len(new_sig_axes) - sam.signal_dimension while sdim_diff > 0: slices = (slice(None),) * sam.navigation_dimension slices += (None, Ellipsis) data = data[slices] sdim_diff -= 1 thisshape = new_data_shape.copy() if ignore_axis is not None: _id = new_data_shape.index(None) newlen = data.shape[_id] if len(data.shape) > _id else 1 thisshape[_id] = newlen thisshape = tuple(thisshape) if data.shape != thisshape: if isinstance(data, np.ndarray): data = np.broadcast_to(data, thisshape) else: data = da.broadcast_to(data, thisshape) ns = s._deepcopy_with_new_data(data) ns.axes_manager._axes = [ax.copy() for ax in new_axes] ns.get_dimensions_from_data() results.append(ns.transpose(signal_axes=len(new_sig_axes))) return results
def make_psf(self): print("Making PSF") psfs = [] for ims in self.ms: xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), chunks=self.chunks[ims], columns=self.columns) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD", group_cols="__row__") spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, self.radec): continue spw = ds.DATA_DESC_ID # this is not correct, need to use spw freq_bin_idx = self.freq_bin_idx[ims][spw] freq_bin_counts = self.freq_bin_counts[ims][spw] freq = self.freq[ims][spw] uvw = ds.UVW.data flag = getattr(ds, self.flag_column).data weights = getattr(ds, self.weight_column).data if len(weights.shape) < 3: weights = da.broadcast_to(weights[:, None, :], flag.shape, chunks=flag.chunks) if self.imaging_weight_column is not None: imaging_weights = getattr(ds, self.imaging_weight_column).data if len(imaging_weights.shape) < 3: imaging_weights = da.broadcast_to(imaging_weights[:, None, :], data.shape, chunks=data.chunks) weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0] weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1] else: weightsxx = weights[:, :, 0] weightsyy = weights[:, :, -1] # weighted sum corr to Stokes I weights = weightsxx + weightsyy data = weights.astype(np.complex64) # only keep data where both corrs are unflagged flagxx = flag[:, :, 0] flagyy = flag[:, :, -1] flag = ~ (flagxx | flagyy) # ducc0 convention psf = vis2im(uvw, freq, data, freq_bin_idx, freq_bin_counts, 2*self.nx, 2*self.ny, self.cell, flag=flag.astype(np.uint8), nthreads=self.nthreads, epsilon=self.epsilon, do_wstacking=self.do_wstacking) psfs.append(psf) psfs = dask.compute(psfs)[0] return accumulate_dirty(psfs, self.nband, self.band_mapping).astype(np.float64)
def __get_seasonal_means_with_ttest_stats_dask_lazy( self, data, season_to_monthperiod=None, start_year=-np.Inf, end_year=np.Inf, convert_monthly_accumulators_to_daily=False): # mask the resulting fields epsilon = 1.0e-5 mask = np.less_equal(np.abs(data[0, :, :] - self.missing_value), epsilon) print("data.shape = ", data.shape) data_sel, times_sel = data, self.time # select the interval of interest if convert_monthly_accumulators_to_daily: ndays = da.from_array( np.array([ calendar.monthrange(d.year, d.month)[1] for d in times_sel ]), (100, )) ndays = da.transpose(da.broadcast_to( da.from_array(ndays, ndays.shape), data_sel.shape[1:] + ndays.shape), axes=(2, 0, 1)) data_sel = data_sel / ndays year_month_to_index_arr = defaultdict(list) for i, t in enumerate(times_sel): year_month_to_index_arr[t.year, t.month].append(i) # calculate monthly means monthly_data = {} for y in range(start_year, end_year + 1): for m in range(1, 13): aslice = slice(year_month_to_index_arr[y, m][0], year_month_to_index_arr[y, m][-1] + 1) print(aslice, data_sel.shape) monthly_data[y, m] = data_sel[aslice, :, :].mean(axis=0) result = OrderedDict() for season, month_period in season_to_monthperiod.items(): assert isinstance(month_period, MonthPeriod) seasonal_means = [] ndays_per_season = [] for p in month_period.get_season_periods(start_year=start_year, end_year=end_year): lmos = da.stack([ monthly_data[start.year, start.month] for start in p.range("months") ]) ndays_per_month = np.array([ calendar.monthrange(start.year, start.month)[1] for start in p.range("months") ]) ndays_per_month = da.from_array(ndays_per_month, ndays_per_month.shape) print(p) print(lmos.shape, ndays_per_month.shape, ndays_per_month.sum()) seasonal_mean = da.tensordot( lmos, ndays_per_month, axes=([ 0, ], [ 0, ])) / ndays_per_month.sum() seasonal_means.append(seasonal_mean) ndays_per_season.append(ndays_per_month.sum()) seasonal_means = da.stack(seasonal_means) ndays_per_season = np.array(ndays_per_season) ndays_per_season = da.from_array(ndays_per_season, ndays_per_season.shape) print(seasonal_means.shape, ndays_per_season.shape) assert seasonal_means.shape[0] == ndays_per_season.shape[0] clim_mean = da.tensordot( seasonal_means, ndays_per_season, axes=([ 0, ], [ 0, ])) / ndays_per_season.sum() clim_std = ((seasonal_means - da.broadcast_to(clim_mean, seasonal_means.shape))**2 * ndays_per_season[:, np.newaxis, np.newaxis]).sum( axis=0) / ndays_per_season.sum() clim_std = clim_std**0.5 result[season] = [clim_mean, clim_std, ndays_per_season.shape[0]] return result, mask
def broadcast_signals(*args, ignore_axis=None): """Broadcasts all passed signals according to the HyperSpy broadcasting rules: signal and navigation spaces are each separately broadcasted according to the numpy broadcasting rules. One axis can be ignored and left untouched (or set to be size 1) across all signals. Parameters ---------- *args : BaseSignal Signals to broadcast together ignore_axis : {None, str, int, Axis} The axis to be ignored when broadcasting Returns ------- list of signals """ if len(args) < 2: raise ValueError( "This function requires at least two signal instances") args = list(args) if not are_signals_aligned(*args): raise ValueError("The signals cannot be broadcasted") else: if ignore_axis is not None: for s in args: try: ignore_axis = s.axes_manager[ignore_axis] break except ValueError: pass new_nav_axes = [] new_nav_shapes = [] for axes in zip_longest(*[ s.axes_manager.navigation_axes for s in args ], fillvalue=None): only_left = filter(lambda x: x is not None, axes) longest = sorted(only_left, key=lambda x: x.size, reverse=True)[0] new_nav_axes.append(longest) new_nav_shapes.append(longest.size if ( ignore_axis is None or ignore_axis not in axes) else None) new_sig_axes = [] new_sig_shapes = [] for axes in zip_longest(*[s.axes_manager.signal_axes for s in args], fillvalue=None): only_left = filter(lambda x: x is not None, axes) longest = sorted(only_left, key=lambda x: x.size, reverse=True)[0] new_sig_axes.append(longest) new_sig_shapes.append(longest.size if ( ignore_axis is None or ignore_axis not in axes) else None) results = [] new_axes = new_nav_axes[::-1] + new_sig_axes[::-1] new_data_shape = new_nav_shapes[::-1] + new_sig_shapes[::-1] for s in args: data = s._data_aligned_with_axes sam = s.axes_manager sdim_diff = len(new_sig_axes) - sam.signal_dimension while sdim_diff > 0: slices = (slice(None), ) * sam.navigation_dimension slices += (None, Ellipsis) data = data[slices] sdim_diff -= 1 thisshape = new_data_shape.copy() if ignore_axis is not None: _id = new_data_shape.index(None) newlen = data.shape[_id] if len(data.shape) > _id else 1 thisshape[_id] = newlen thisshape = tuple(thisshape) if data.shape != thisshape: if isinstance(data, np.ndarray): data = np.broadcast_to(data, thisshape) else: data = da.broadcast_to(data, thisshape) ns = s._deepcopy_with_new_data(data) ns.axes_manager._axes = [ax.copy() for ax in new_axes] ns.get_dimensions_from_data() results.append(ns.transpose(signal_axes=len(new_sig_axes))) return results
def make_dirty(self): print("Making dirty", file=log) dirty = da.zeros((self.nband, self.nx, self.ny), dtype=np.float32, chunks=(1, self.nx, self.ny), name=False) dirties = [] for ims in self.ms: xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), chunks=self.chunks[ims], columns=self.columns) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD", group_cols="__row__") spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, self.radec): continue spw = ds.DATA_DESC_ID # this is not correct, need to use spw freq_bin_idx = self.freq_bin_idx[ims][spw] freq_bin_counts = self.freq_bin_counts[ims][spw] freq = self.freq[ims][spw] freq_chunk = freq_bin_counts[0].compute() uvw = ds.UVW.data data = getattr(ds, self.data_column).data dataxx = data[:, :, 0] datayy = data[:, :, -1] weights = getattr(ds, self.weight_column).data if len(weights.shape) < 3: weights = da.broadcast_to(weights[:, None, :], data.shape, chunks=data.chunks) if self.imaging_weight_column is not None: imaging_weights = getattr(ds, self.imaging_weight_column).data if len(imaging_weights.shape) < 3: imaging_weights = da.broadcast_to( imaging_weights[:, None, :], data.shape, chunks=data.chunks) weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0] weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1] else: weightsxx = weights[:, :, 0] weightsyy = weights[:, :, -1] # apply adjoint of mueller term. # Phases modify data amplitudes modify weights. if self.mueller_column is not None: mueller = getattr(ds, self.mueller_column).data dataxx *= da.exp(-1j * da.angle(mueller[:, :, 0])) datayy *= da.exp(-1j * da.angle(mueller[:, :, -1])) weightsxx *= da.absolute(mueller[:, :, 0]) weightsyy *= da.absolute(mueller[:, :, -1]) # weighted sum corr to Stokes I weights = weightsxx + weightsyy data = (weightsxx * dataxx + weightsyy * datayy) # TODO - turn off this stupid warning data = da.where(weights, data / weights, 0.0j) # only keep data where both corrs are unflagged flag = getattr(ds, self.flag_column).data flagxx = flag[:, :, 0] flagyy = flag[:, :, -1] # ducc0 convention uses uint8 mask not flag flag = ~(flagxx | flagyy) dirty = vis2im(uvw, freq, data, freq_bin_idx, freq_bin_counts, self.nx, self.ny, self.cell, weights=weights, flag=flag.astype(np.uint8), nthreads=self.nthreads, epsilon=self.epsilon, do_wstacking=self.do_wstacking, double_accum=True) dirties.append(dirty) dirties = dask.compute(dirties, scheduler='single-threaded')[0] return accumulate_dirty(dirties, self.nband, self.band_mapping).astype(self.real_type)
def compute_weights(self, robust): from pfb.utils.weighting import compute_counts, counts_to_weights # compute counts counts = [] for ims in self.ms: xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), chunks=self.chunks[ims], columns=('UVW')) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD", group_cols="__row__") spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, self.radec): continue spw = ds.DATA_DESC_ID # not optimal, need to use spw freq_bin_idx = self.freq_bin_idx[ims][spw] freq_bin_counts = self.freq_bin_counts[ims][spw] freq = self.freq[ims][spw] uvw = ds.UVW.data count = compute_counts(uvw, freq, freq_bin_idx, freq_bin_counts, self.nx, self.ny, self.cell, self.cell, np.float32) counts.append(count) counts = dask.compute(counts)[0] counts = accumulate_dirty(counts, self.nband, self.band_mapping) counts = da.from_array(counts, chunks=(1, -1, -1)) # convert counts to weights writes = [] for ims in self.ms: xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), chunks=self.chunks[ims], columns=self.columns) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD", group_cols="__row__") spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] out_data = [] for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, self.radec): continue spw = ds.DATA_DESC_ID # this is not correct, need to use spw freq_bin_idx = self.freq_bin_idx[ims][spw] freq_bin_counts = self.freq_bin_counts[ims][spw] freq = self.freq[ims][spw] uvw = ds.UVW.data weights = counts_to_weights(counts, uvw, freq, freq_bin_idx, freq_bin_counts, self.nx, self.ny, self.cell, self.cell, np.float32, robust) # hack to get shape and chunking info data = getattr(ds, self.data_column).data weights = da.broadcast_to(weights[:, :, None], data.shape, chunks=data.chunks) out_ds = ds.assign(**{ self.imaging_weight_column: (("row", "chan", "corr"), weights) }) out_data.append(out_ds) writes.append( xds_to_table(out_data, ims, columns=[self.imaging_weight_column])) dask.compute(writes)
def main(): args = parse_args() dask.config.set(num_workers=args.workers) # Lightweight open with no data - just to create telstate and identify the CBID ds = TelstateDataSource.from_url(args.source, upgrade_flags=False, chunk_store=None) # View the CBID, but not any specific stream cbid = ds.capture_block_id telstate = ds.telstate.root().view(cbid) streams = get_streams(telstate, args.streams) # Find all arrays in the selected streams, and also ensure we're not # trying to write things back on top of an existing dataset. arrays = {} for stream_name in streams: sts = view_capture_stream(telstate, cbid, stream_name) try: chunk_info = sts['chunk_info'] except KeyError as exc: raise RuntimeError('Could not get chunk info for {!r}: {}'.format( stream_name, exc)) for array_name, array_info in chunk_info.items(): if args.new_prefix is not None: array_info[ 'prefix'] = args.new_prefix + '-' + stream_name.replace( '_', '-') prefix = array_info['prefix'] path = os.path.join(args.dest, prefix) if os.path.exists(path): raise RuntimeError( 'Directory {!r} already exists'.format(path)) store = get_chunk_store(args.source, sts, array_name) # Older files have dtype as an object that can't be encoded in msgpack dtype = np.dtype(array_info['dtype']) array_info['dtype'] = np.lib.format.dtype_to_descr(dtype) arrays[(stream_name, array_name)] = Array(stream_name, array_name, store, array_info) # Apply DATA_LOST bits to the flags arrays. This is a less efficient approach than # datasources.py, but much simpler. for stream_name in streams: flags_array = arrays.get((stream_name, 'flags')) if not flags_array: continue sources = [stream_name] sts = view_capture_stream(telstate, cbid, stream_name) sources += sts['src_streams'] for src_stream in sources: if src_stream not in streams: continue src_ts = view_capture_stream(telstate, cbid, src_stream) for array_name in src_ts['chunk_info']: if array_name == 'flags' and src_stream != stream_name: # Upgraded flags completely replace the source stream's # flags, rather than augmenting them. Thus, data lost in # the source stream has no effect. continue lost_flags = arrays[(src_stream, array_name)].lost_flags lost_flags = lost_flags.rechunk( flags_array.data.chunks[:lost_flags.ndim]) # weights_channel doesn't have a baseline axis while lost_flags.ndim < flags_array.data.ndim: lost_flags = lost_flags[..., np.newaxis] lost_flags = da.broadcast_to(lost_flags, flags_array.data.shape, chunks=flags_array.data.chunks) flags_array.data |= lost_flags # Apply the rechunking specs for spec in args.spec: key = (spec.stream, spec.array) if key not in arrays: raise RuntimeError('{}/{} is not a known array'.format( spec.stream, spec.array)) arrays[key].data = arrays[key].data.rechunk({ 0: spec.time, 1: spec.freq }) # Write out the new data dest_store = NpyFileChunkStore(args.dest) stores = [] for array in arrays.values(): full_name = dest_store.join(array.chunk_info['prefix'], array.array_name) dest_store.create_array(full_name) stores.append(dest_store.put_dask_array(full_name, array.data)) array.chunk_info['chunks'] = array.data.chunks stores = da.compute(*stores) # put_dask_array returns an array with an exception object per chunk for result_set in stores: for result in result_set.flat: if result is not None: raise result # Fix up chunk_info for new chunking for stream_name in streams: sts = view_capture_stream(telstate, cbid, stream_name) chunk_info = sts['chunk_info'] for array_name in chunk_info.keys(): chunk_info[array_name] = arrays[(stream_name, array_name)].chunk_info sts.wrapped.delete('chunk_info') sts.wrapped['chunk_info'] = chunk_info # s3_endpoint_url is for the old version of the data sts.wrapped.delete('s3_endpoint_url') if args.s3_endpoint_url is not None: sts.wrapped['s3_endpoint_url'] = args.s3_endpoint_url # Write updated RDB file url_parts = urllib.parse.urlparse(args.source, scheme='file') dest_file = os.path.join(args.dest, args.new_prefix or cbid, os.path.basename(url_parts.path)) os.makedirs(os.path.dirname(dest_file), exist_ok=True) with RDBWriter(dest_file) as writer: writer.save(telstate.backend)
def calculate_corners(center_lat, center_lon): """Calculate corner coordinates by averaging neighbor cells """ # get rank rank = len(center_lat.dims) if rank == 1: # get dimensions nlon = center_lon.size nlat = center_lat.size # convert center points from 1d to 2d center_lat2d = da.broadcast_to(center_lat.values[None, :], (nlon, nlat)) center_lon2d = da.broadcast_to(center_lon.values[:, None], (nlon, nlat)) elif rank == 2: # get dimensions dims = center_lon.shape nlon = dims[0] nlat = dims[1] # just rename and convert to dask array center_lat2d = da.from_array(center_lat) center_lon2d = da.from_array(center_lon) else: print( 'Unrecognized grid! The rank of coordinate variables can be 1 or 2 but it is {}.' .format(rank)) sys.exit(2) # calculate corner coordinates for latitude, counterclockwise order, imposing Fortran ordering center_lat2d_ext = da.from_array( np.pad(center_lat2d.compute(), (1, 1), mode='reflect', reflect_type='odd')) ur = (center_lat2d_ext[1:-1, 1:-1] + center_lat2d_ext[0:-2, 1:-1] + center_lat2d_ext[1:-1, 2:] + center_lat2d_ext[0:-2, 2:]) / 4.0 ul = (center_lat2d_ext[1:-1, 1:-1] + center_lat2d_ext[0:-2, 1:-1] + center_lat2d_ext[1:-1, 0:-2] + center_lat2d_ext[0:-2, 0:-2]) / 4.0 ll = (center_lat2d_ext[1:-1, 1:-1] + center_lat2d_ext[1:-1, 0:-2] + center_lat2d_ext[2:, 1:-1] + center_lat2d_ext[2:, 0:-2]) / 4.0 lr = (center_lat2d_ext[1:-1, 1:-1] + center_lat2d_ext[1:-1, 2:] + center_lat2d_ext[2:, 1:-1] + center_lat2d_ext[2:, 2:]) / 4.0 # this looks clockwise ordering but it is transposed and becomes counterclockwise, bit-to-bit with NCL corner_lat = da.stack([ ul.T.reshape((-1, )).T, ll.T.reshape((-1, )).T, lr.T.reshape((-1, )).T, ur.T.reshape((-1, )).T ], axis=1) # calculate corner coordinates for longitude, counterclockwise order, imposing Fortran ordering center_lon2d_ext = da.from_array( np.pad(center_lon2d.compute(), (1, 1), mode='reflect', reflect_type='odd')) ur = (center_lon2d_ext[1:-1, 1:-1] + center_lon2d_ext[0:-2, 1:-1] + center_lon2d_ext[1:-1, 2:] + center_lon2d_ext[0:-2, 2:]) / 4.0 ul = (center_lon2d_ext[1:-1, 1:-1] + center_lon2d_ext[0:-2, 1:-1] + center_lon2d_ext[1:-1, 0:-2] + center_lon2d_ext[0:-2, 0:-2]) / 4.0 ll = (center_lon2d_ext[1:-1, 1:-1] + center_lon2d_ext[1:-1, 0:-2] + center_lon2d_ext[2:, 1:-1] + center_lon2d_ext[2:, 0:-2]) / 4.0 lr = (center_lon2d_ext[1:-1, 1:-1] + center_lon2d_ext[1:-1, 2:] + center_lon2d_ext[2:, 1:-1] + center_lon2d_ext[2:, 2:]) / 4.0 # this looks clockwise ordering but it is transposed and becomes counterclockwise, bit-to-bit with NCL corner_lon = da.stack([ ul.T.reshape((-1, )).T, ll.T.reshape((-1, )).T, lr.T.reshape((-1, )).T, ur.T.reshape((-1, )).T ], axis=1) return center_lat2d, center_lon2d, corner_lat, corner_lon
def make_residual(self, x): # Note deprecated (does not support Jones terms) print("Making residual", file=log) x = da.from_array(x.astype(self.real_type), chunks=(1, self.nx, self.ny), name=False) residuals = [] for ims in self.ms: xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), chunks=self.chunks[ims], columns=self.columns) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD", group_cols="__row__") spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, self.radec): continue spw = ds.DATA_DESC_ID # this is not correct, need to use spw freq_bin_idx = self.freq_bin_idx[ims][spw] freq_bin_counts = self.freq_bin_counts[ims][spw] freq = self.freq[ims][spw] uvw = ds.UVW.data data = getattr(ds, self.data_column).data dataxx = data[:, :, 0] datayy = data[:, :, -1] weights = getattr(ds, self.weight_column).data if len(weights.shape) < 3: weights = da.broadcast_to(weights[:, None, :], data.shape, chunks=data.chunks) if self.imaging_weight_column is not None: imaging_weights = getattr(ds, self.imaging_weight_column).data if len(imaging_weights.shape) < 3: imaging_weights = da.broadcast_to( imaging_weights[:, None, :], data.shape, chunks=data.chunks) weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0] weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1] else: weightsxx = weights[:, :, 0] weightsyy = weights[:, :, -1] # weighted sum corr to Stokes I weights = weightsxx + weightsyy data = (weightsxx * dataxx + weightsyy * datayy) data = da.where(weights, data / weights, 0.0j) # only keep data where both corrs are unflagged flag = getattr(ds, self.flag_column).data flagxx = flag[:, :, 0] flagyy = flag[:, :, -1] flag = ~(flagxx | flagyy) # ducc0 convention bands = self.band_mapping[ims][spw] model = x[list(bands), :, :] residual = im2residim(uvw, freq, model, data, freq_bin_idx, freq_bin_counts, self.cell, weights=weights, flag=flag.astype(np.uint8), nthreads=self.nthreads, epsilon=self.epsilon, do_wstacking=self.do_wstacking, double_accum=True) residuals.append(residual) residuals = dask.compute(residuals)[0] return accumulate_dirty(residuals, self.nband, self.band_mapping).astype(self.real_type)
dz["st_ocean"] = depth["st_ocean"] for var in line_tracer_vars: ds_out[f"{var}_segment_001"] = ds_out[f"{var}_segment_001"].expand_dims( "ny_segment_001", axis=1) for var in surface_vars: # add the y dimension ds_out[f"{var}_segment_001"] = ds_out[f"{var}_segment_001"].expand_dims( "ny_segment_001", axis=2) ds_out[f"vc_{var}_segment_001"] = ( ["time", f"nz_segment_001_{var}", "ny_segment_001", "nx_segment_001"], da.broadcast_to( depth.data[None, :, None, None], ds_out[f"{var}_segment_001"].shape, chunks=(1, None, None, None), ), ) ds_out[f"dz_{var}_segment_001"] = ( ["time", f"nz_segment_001_{var}", "ny_segment_001", "nx_segment_001"], da.broadcast_to( dz.data[None, :, None, None], ds_out[f"{var}_segment_001"].shape, chunks=(1, None, None, None), ), ) with ProgressBar(): ds_out.to_netcdf("forcing_obc.nc")
def make_psf(self): print("Making PSF", file=log) psfs = [] self.stokes_weights = {} self.uvws = {} for ims in self.ms: xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), chunks=self.chunks[ims], columns=self.columns) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD", group_cols="__row__") spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] self.stokes_weights[ims] = {} self.uvws[ims] = {} for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, self.radec): continue # this is not correct, need to use spw spw = ds.DATA_DESC_ID freq_bin_idx = self.freq_bin_idx[ims][spw] freq_bin_counts = self.freq_bin_counts[ims][spw] freq = self.freq[ims][spw] uvw = ds.UVW.data flag = getattr(ds, self.flag_column).data weights = getattr(ds, self.weight_column).data if len(weights.shape) < 3: weights = da.broadcast_to(weights[:, None, :], flag.shape, chunks=flag.chunks) if self.imaging_weight_column is not None: imaging_weights = getattr(ds, self.imaging_weight_column).data if len(imaging_weights.shape) < 3: imaging_weights = da.broadcast_to( imaging_weights[:, None, :], flag.shape, chunks=flag.chunks) weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0] weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1] else: weightsxx = weights[:, :, 0] weightsyy = weights[:, :, -1] # for the PSF we need to scale the weights by the # Mueller amplitudes squared if self.mueller_column is not None: mueller = getattr(ds, self.mueller_column).data weightsxx *= da.absolute(mueller[:, :, 0])**2 weightsyy *= da.absolute(mueller[:, :, -1])**2 # weighted sum corr to Stokes I weights = weightsxx + weightsyy # only keep data where both corrs are unflagged flagxx = flag[:, :, 0] flagyy = flag[:, :, -1] flag = ~(flagxx | flagyy) # ducc0 convention weights *= flag data = weights.astype(np.complex64) psf = vis2im(uvw, freq, data, freq_bin_idx, freq_bin_counts, self.nx_psf, self.ny_psf, self.cell, flag=flag.astype(np.uint8), nthreads=self.nthreads, epsilon=self.epsilon, do_wstacking=self.do_wstacking, double_accum=True) psfs.append(psf) # assumes that stokes weights and uvw fit into memory # self.stokes_weights[ims][spw] = dask.persist(weights.rechunk({0:-1}))[0] # self.uvws[ims][spw] = dask.persist(uvw.rechunk({0:-1}))[0] # for comparison with numpy implementation # self.stokes_weights[ims][spw] = dask.compute(weights)[0] # self.uvws[ims][spw] = dask.compute(uvw)[0] # import pdb # pdb.set_trace() psfs = dask.compute(psfs, scheduler='single-threaded')[0] return accumulate_dirty(psfs, self.nband, self.band_mapping).astype(self.real_type)
def main(args): """ Flags outliers in data given a model and rescale weights so that whitened residuals have a mean amplitude of sqrt(2). Flags and weights are computed per chunk of data """ radec_ref = None writes = [] for ims in args.ms: xds = xds_from_ms(ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), chunks={ "row": args.row_chunks, "chan": args.chan_chunks }, columns=('UVW', args.data_column, args.weight_column, args.model_column, args.flag_column, 'FLAG_ROW')) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD", group_cols="__row__") spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] out_data = [] for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() # check fields match if radec_ref is None: radec_ref = radec if not np.array_equal(radec, radec_ref): continue # load in data and compute whitened residuals data = getattr(ds, args.data_column).data model = getattr(ds, args.model_column).data flag = getattr(ds, args.flag_column).data flag = da.logical_or(flag, ds.FLAG_ROW.data[:, None, None]) weights = getattr(ds, args.weight_column).data if len(weights.shape) < 3: weights = da.broadcast_to(weights[:, None, :], data.shape, chunks=data.chunks) if args.trim_channels: flag = trim_chans(flag, args.trim_channels) # Stokes I vis weights = (~flag) * weights resid_vis = (data - model) * weights wsums = (weights[:, :, 0] + weights[:, :, -1]) resid_vis_I = da.where( wsums, (resid_vis[:, :, 0] + resid_vis[:, :, -1]) / wsums, 0.0j) # whiten and take abs white_resid = resid_vis_I * da.sqrt(wsums) abs_resid_vis_I = (white_resid).__abs__() # mean amp sum_amp = da.sum(abs_resid_vis_I) count = da.sum(wsums > 0) mean_amp = sum_amp / count flag_legacy = flag[:, :, 0] | flag[:, :, -1] flag_I = da.logical_or(abs_resid_vis_I > args.sigma_cut * mean_amp, flag_legacy) # new flags updated_flag = da.broadcast_to(flag_I[:, :, None], flag.shape, chunks=flag.chunks) # scale weights (whitened residuals should have mean amplitude of 1/sqrt(2)) if args.scale_weights: # recompute mean amp with new flags weights = (~updated_flag) * weights resid_vis = (data - model) * weights wsums = (weights[:, :, 0] + weights[:, :, -1]) resid_vis_I = da.where( wsums, (resid_vis[:, :, 0] + resid_vis[:, :, -1]) / wsums, 0.0j) white_resid = resid_vis_I * da.sqrt(wsums) abs_resid_vis_I = (white_resid).__abs__() sum_amp = da.sum(abs_resid_vis_I) count = da.sum(wsums > 0) mean_amp = sum_amp / count updated_weight = 2**0.5 * weights / mean_amp**2 else: updated_weight = weights ds = ds.assign(**{ args.weight_out_column: (("row", "chan", "corr"), updated_weight) }) ds = ds.assign(**{ args.flag_out_column: (("row", "chan", "corr"), updated_flag) }) out_data.append(ds) writes.append( xds_to_table( out_data, ims, columns=[args.flag_out_column, args.weight_out_column])) with ProgressBar(): dask.compute(writes) # report new mean amp if args.report_means: radec_ref = None mean_amps = [] for ims in args.ms: xds = xds_from_ms( ims, group_cols=('FIELD_ID', 'DATA_DESC_ID'), chunks={ "row": args.row_chunks, "chan": args.chan_chunks }, columns=('UVW', args.data_column, args.weight_out_column, args.model_column, args.flag_out_column, 'FLAG_ROW')) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD", group_cols="__row__") spws = xds_from_table(ims + "::SPECTRAL_WINDOW", group_cols="__row__") pols = xds_from_table(ims + "::POLARIZATION", group_cols="__row__") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] for ds in xds: field = fields[ds.FIELD_ID] radec = field.PHASE_DIR.data.squeeze() # check fields match if radec_ref is None: radec_ref = radec if not np.array_equal(radec, radec_ref): continue # load in data and compute whitened residuals data = getattr(ds, args.data_column).data model = getattr(ds, args.model_column).data flag = getattr(ds, args.flag_out_column).data flag = da.logical_or(flag, ds.FLAG_ROW.data[:, None, None]) weights = getattr(ds, args.weight_out_column).data if len(weights.shape) < 3: weights = da.broadcast_to(weights[:, None, :], data.shape, chunks=data.chunks) # Stokes I vis weights = (~flag) * weights resid_vis = (data - model) * weights wsums = (weights[:, :, 0] + weights[:, :, -1]) resid_vis_I = da.where( wsums, (resid_vis[:, :, 0] + resid_vis[:, :, -1]) / wsums, 0.0j) # whiten and take abs white_resid = resid_vis_I * da.sqrt(wsums) abs_resid_vis_I = (white_resid).__abs__() # mean amp sum_amp = da.sum(abs_resid_vis_I) count = da.sum(wsums > 0) mean_amps.append(sum_amp / count) mean_amps = dask.compute(mean_amps)[0] print(mean_amps)
def _main(args): tic = time.time() log.info(banner()) if args.disable_post_mortem: log.warn("Disabling crash debugging with the " "Interactive Python Debugger, as per user request") post_mortem_handler.disable_pdb_on_error() log.info("Flagging on the {0:s} column".format(args.data_column)) data_column = args.data_column masked_channels = [ load_mask(fn, dilate=args.dilate_masks) for fn in collect_masks() ] GD = args.config log_configuration(args) # Group datasets by these columns group_cols = ["FIELD_ID", "DATA_DESC_ID", "SCAN_NUMBER"] # Index datasets by these columns index_cols = ['TIME'] # Reopen the datasets using the aggregated row ordering columns = [data_column, "FLAG", "TIME", "ANTENNA1", "ANTENNA2"] if args.subtract_model_column is not None: columns.append(args.subtract_model_column) xds = list( xds_from_ms(args.ms, columns=tuple(columns), group_cols=group_cols, index_cols=index_cols, chunks={"row": args.row_chunks})) # Get support tables st = support_tables(args.ms) ddid_ds = st["DATA_DESCRIPTION"] field_ds = st["FIELD"] pol_ds = st["POLARIZATION"] spw_ds = st["SPECTRAL_WINDOW"] ant_ds = st["ANTENNA"] assert len(ant_ds) == 1 assert len(ddid_ds) == 1 antspos = ant_ds[0].POSITION.data antsnames = ant_ds[0].NAME.data fieldnames = [fds.NAME.data[0] for fds in field_ds] avail_scans = [ds.SCAN_NUMBER for ds in xds] args.scan_numbers = list( set(avail_scans).intersection(args.scan_numbers if args.scan_numbers is not None else avail_scans)) if args.scan_numbers != []: log.info("Only considering scans '{0:s}' as " "per user selection criterion".format(", ".join( map(str, map(int, args.scan_numbers))))) if args.field_names != []: flatten_field_names = [] for f in args.field_names: # accept comma lists per specification flatten_field_names += [x.strip() for x in f.split(",")] for f in flatten_field_names: if re.match(r"^\d+$", f) and int(f) < len(fieldnames): flatten_field_names.append(fieldnames[int(f)]) flatten_field_names = list( set( filter(lambda x: not re.match(r"^\d+$", x), flatten_field_names))) log.info("Only considering fields '{0:s}' for flagging per " "user " "selection criterion.".format(", ".join(flatten_field_names))) if not set(flatten_field_names) <= set(fieldnames): raise ValueError("One or more fields cannot be " "found in dataset '{0:s}' " "You specified {1:s}, but " "only {2:s} are available".format( args.ms, ",".join(flatten_field_names), ",".join(fieldnames))) field_dict = {fieldnames.index(fn): fn for fn in flatten_field_names} else: field_dict = {i: fn for i, fn in enumerate(fieldnames)} # List which hold our dask compute graphs for each dataset write_computes = [] original_stats = [] final_stats = [] # Iterate through each dataset for ds in xds: if ds.FIELD_ID not in field_dict: continue if (args.scan_numbers is not None and ds.SCAN_NUMBER not in args.scan_numbers): continue log.info("Adding field '{0:s}' scan {1:d} to " "compute graph for processing".format(field_dict[ds.FIELD_ID], ds.SCAN_NUMBER)) ddid = ddid_ds[ds.attrs['DATA_DESC_ID']] spw_info = spw_ds[ddid.SPECTRAL_WINDOW_ID.data[0]] pol_info = pol_ds[ddid.POLARIZATION_ID.data[0]] nrow, nchan, ncorr = getattr(ds, data_column).data.shape # Visibilities from the dataset vis = getattr(ds, data_column).data if args.subtract_model_column is not None: log.info("Forming residual data between '{0:s}' and " "'{1:s}' for flagging.".format( data_column, args.subtract_model_column)) vismod = getattr(ds, args.subtract_model_column).data vis = vis - vismod antenna1 = ds.ANTENNA1.data antenna2 = ds.ANTENNA2.data chan_freq = spw_info.CHAN_FREQ.data[0] chan_width = spw_info.CHAN_WIDTH.data[0] # Generate unflagged defaults if we should ignore existing flags # otherwise take flags from the dataset if args.ignore_flags is True: flags = da.full_like(vis, False, dtype=np.bool) log.critical("Completely ignoring measurement set " "flags as per '-if' request. " "Strategy WILL NOT or with original flags, even if " "specified!") else: flags = ds.FLAG.data # If we're flagging on polarised intensity, # we convert visibilities to polarised intensity # and any flagged correlation will flag the entire visibility if args.flagging_strategy == "polarisation": corr_type = pol_info.CORR_TYPE.data[0].tolist() stokes_map = stokes_corr_map(corr_type) stokes_pol = tuple(v for k, v in stokes_map.items() if k != "I") vis = polarised_intensity(vis, stokes_pol) flags = da.any(flags, axis=2, keepdims=True) elif args.flagging_strategy == "total_power": if args.subtract_model_column is None: log.critical("You requested to flag total quadrature " "power, but not on residuals. " "This is not advisable and the flagger " "may mistake fringes of " "off-axis sources for broadband RFI.") corr_type = pol_info.CORR_TYPE.data[0].tolist() stokes_map = stokes_corr_map(corr_type) stokes_pol = tuple(v for k, v in stokes_map.items()) vis = polarised_intensity(vis, stokes_pol) flags = da.any(flags, axis=2, keepdims=True) elif args.flagging_strategy == "standard": if args.subtract_model_column is None: log.critical("You requested to flag per correlation, " "but not on residuals. " "This is not advisable and the flagger " "may mistake fringes of off-axis sources " "for broadband RFI.") else: raise ValueError("Invalid flagging strategy '%s'" % args.flagging_strategy) ubl = unique_baselines(antenna1, antenna2) utime, time_inv = da.unique(ds.TIME.data, return_inverse=True) utime, ubl = dask.compute(utime, ubl) ubl = ubl.view(np.int32).reshape(-1, 2) # Stack the baseline index with the unique baselines bl_range = np.arange(ubl.shape[0], dtype=ubl.dtype)[:, None] ubl = np.concatenate([bl_range, ubl], axis=1) ubl = da.from_array(ubl, chunks=(args.baseline_chunks, 3)) vis_windows, flag_windows = pack_data(time_inv, ubl, antenna1, antenna2, vis, flags, utime.shape[0], backend=args.window_backend, path=args.temporary_directory) original_stats.append( window_stats(flag_windows, ubl, chan_freq, antsnames, ds.SCAN_NUMBER, field_dict[ds.FIELD_ID], ds.attrs['DATA_DESC_ID'])) with StrategyExecutor(antspos, ubl, chan_freq, chan_width, masked_channels, GD['strategies']) as se: flag_windows = se.apply_strategies(flag_windows, vis_windows) final_stats.append( window_stats(flag_windows, ubl, chan_freq, antsnames, ds.SCAN_NUMBER, field_dict[ds.FIELD_ID], ds.attrs['DATA_DESC_ID'])) # Unpack window data for writing back to the MS unpacked_flags = unpack_data(antenna1, antenna2, time_inv, ubl, flag_windows) # Flag entire visibility if any correlations are flagged equalized_flags = da.sum(unpacked_flags, axis=2, keepdims=True) > 0 corr_flags = da.broadcast_to(equalized_flags, (nrow, nchan, ncorr)) if corr_flags.chunks != ds.FLAG.data.chunks: raise ValueError("Output flag chunking does not " "match input flag chunking") # Create new dataset containing new flags new_ds = ds.assign(FLAG=(("row", "chan", "corr"), corr_flags)) # Write back to original dataset writes = xds_to_table(new_ds, args.ms, "FLAG") # original should also have .compute called because we need stats write_computes.append(writes) if len(write_computes) > 0: # Combine stats from all datasets original_stats = combine_window_stats(original_stats) final_stats = combine_window_stats(final_stats) with contextlib.ExitStack() as stack: # Create dask profiling contexts profilers = [] if can_profile: profilers.append(stack.enter_context(Profiler())) profilers.append(stack.enter_context(CacheProfiler())) profilers.append(stack.enter_context(ResourceProfiler())) if sys.stdout.isatty(): # Interactive terminal, default ProgressBar stack.enter_context(ProgressBar()) else: # Non-interactive, emit a bar every 5 minutes so # as not to spam the log stack.enter_context(ProgressBar(minimum=1, dt=5 * 60)) _, original_stats, final_stats = dask.compute( write_computes, original_stats, final_stats) if can_profile: visualize(profilers) toc = time.time() # Log each summary line for line in summarise_stats(final_stats, original_stats): log.info(line) elapsed = toc - tic log.info("Data flagged successfully in " "{0:02.0f}h{1:02.0f}m{2:02.0f}s".format((elapsed // 60) // 60, (elapsed // 60) % 60, elapsed % 60)) else: log.info("User data selection criteria resulted in empty dataset. " "Nothing to be done. Bye!")
def __call__(self, signal, out=None, axes=None): """Slice the signal according to the ROI, and return it. Arguments --------- signal : Signal The signal to slice with the ROI. out : Signal, default = None If the 'out' argument is supplied, the sliced output will be put into this instead of returning a Signal. See Signal.__getitem__() for more details on 'out'. axes : specification of axes to use, default = None The axes argument specifies which axes the ROI will be applied on. The items in the collection can be either of the following: * a tuple of: - DataAxis. These will not be checked with signal.axes_manager. - anything that will index signal.axes_manager * For any other value, it will check whether the navigation space can fit the right number of axis, and use that if it fits. If not, it will try the signal space. """ if axes is None and signal in self.signal_map: axes = self.signal_map[signal][1] else: axes = self._parse_axes(axes, signal.axes_manager) natax = signal.axes_manager._get_axes_in_natural_order() # Slice original data with a circumscribed rectangle cx = self.cx + 0.5001 * axes[0].scale cy = self.cy + 0.5001 * axes[1].scale ranges = [[cx - self.r, cx + self.r], [cy - self.r, cy + self.r]] slices = self._make_slices(natax, axes, ranges) ir = [slices[natax.index(axes[0])], slices[natax.index(axes[1])]] vx = axes[0].axis[ir[0]] - cx vy = axes[1].axis[ir[1]] - cy gx, gy = np.meshgrid(vx, vy) gr = gx**2 + gy**2 mask = gr > self.r**2 if self.r_inner != t.Undefined: mask |= gr < self.r_inner**2 tiles = [] shape = [] chunks = [] for i in range(len(slices)): if signal._lazy: chunks.append(signal.data.chunks[i][0]) if i == natax.index(axes[0]): thisshape = mask.shape[0] tiles.append(thisshape) shape.append(thisshape) elif i == natax.index(axes[1]): thisshape = mask.shape[1] tiles.append(thisshape) shape.append(thisshape) else: tiles.append(signal.axes_manager._axes[i].size) shape.append(1) mask = mask.reshape(shape) nav_axes = [ax.navigate for ax in axes] nav_dim = signal.axes_manager.navigation_dimension if True in nav_axes: if False in nav_axes: slicer = signal.inav[slices[:nav_dim]].isig.__getitem__ slices = slices[nav_dim:] else: slicer = signal.inav.__getitem__ slices = slices[0:nav_dim] else: slicer = signal.isig.__getitem__ slices = slices[nav_dim:] roi = slicer(slices, out=out) roi = out or roi if roi._lazy: import dask.array as da mask = da.from_array(mask, chunks=chunks) mask = da.broadcast_to(mask, tiles) # By default promotes dtype to float if required roi.data = da.where(mask, np.nan, roi.data) else: mask = np.broadcast_to(mask, tiles) roi.data = np.ma.masked_array(roi.data, mask, hard_mask=True) if out is None: return roi else: out.events.data_changed.trigger(out)
def extract_rot_cube(cube, min_lat, min_lon, max_lat, max_lon): """ Function etracts the specific region from the cube. args ---- cube: cube on rotated coord system, used as reference grid for transformation. Returns ------- min_lat: The minimum latitude point of the desired extracted cube. min_lon: The minimum longitude point of the desired extracted cube. max_lat: The maximum latitude point of the desired extracted cube. max_lon: The maximum longitude point of the desired extracted cube. An example: >>> file = os.path.join(conf.DATA_DIR, 'rcm_monthly.pp') >>> cube = iris.load_cube(file, 'air_temperature') >>> min_lat = 50 >>> min_lon = -10 >>> max_lat = 60 >>> max_lon = 0 >>> extracted_cube = extract_rot_cube(cube, min_lat, min_lon, max_lat, max_lon) >>> max_lat_cube = np.max(extracted_cube.coord('latitude').points) >>> print(f'{max_lat_cube:.3f}') 61.365 >>> min_lat_cube = np.min(extracted_cube.coord('latitude').points) >>> print(f'{min_lat_cube:.3f}') 48.213 >>> max_lon_cube = np.max(extracted_cube.coord('longitude').points) >>> print(f'{max_lon_cube:.3f}') 3.643 >>> min_lon_cube = np.min(extracted_cube.coord('longitude').points) >>> print(f'{min_lon_cube:.3f}') -16.292 """ # adding unrotated coords to the cube cube = add_aux_unrotated_coords(cube) # mask the cube using the true lat and lon lats = cube.coord("latitude").points lons = cube.coord("longitude").points select_lons = (lons >= min_lon) & (lons <= max_lon) select_lats = (lats >= min_lat) & (lats <= max_lat) selection = select_lats & select_lons selection = da.broadcast_to(selection, cube.shape) cube.data = da.ma.masked_where(~selection, cube.core_data()) # grab a single 2D slice of X and Y and take the mask lon_coord = cube.coord(axis="X", dim_coords=True) lat_coord = cube.coord(axis="Y", dim_coords=True) for yx_slice in cube.slices(["grid_latitude", "grid_longitude"]): cmask = yx_slice.data.mask break # now cut the cube down along X and Y coords x1, x2, y1, y2 = _get_xy_noborder(cmask) idx = len(cube.shape) * [slice(None)] idx[cube.coord_dims(cube.coord(axis="x", dim_coords=True))[0]] = slice(x1, x2, 1) idx[cube.coord_dims(cube.coord(axis="y", dim_coords=True))[0]] = slice(y1, y2, 1) extracted_cube = cube[tuple(idx)] return extracted_cube