def get_value(self, group, corr, extras, flag, flag_row, chanslice): coldata = self.get_column_data(group) # correlation may be pre-set by plot type, or may be passed to us corr = self.corr if self.corr is not None else corr # apply correlation reduction if coldata is not None and coldata.ndim == 3: assert corr is not None # the mapper can't have a specific axis set if self.mapper.axis is not None: raise TypeError(f"{self.name}: unexpected column with ndim=3") coldata = self.ms.corr_data_mappers[corr](coldata) # apply mapping function coldata = self.mapper.mapper( coldata, **{name: extras[name] for name in self.mapper.extras}) # scalar expanded to row vector if numpy.isscalar(coldata): coldata = da.full_like(flag_row, fill_value=coldata, dtype=type(coldata)) flag = flag_row else: # apply channel slicing, if there's a channel axis in the array (and the array is a DataArray) if type(coldata) is xarray.DataArray and 'chan' in coldata.dims: coldata = coldata[dict(chan=chanslice)] # determine flags -- start with original flags if flag is not None: if coldata.ndim == 2: flag = self.ms.corr_flag_mappers[corr](flag) elif coldata.ndim == 1: if not self.mapper.axis: flag = flag_row elif self.mapper.axis == 1: flag = None # shapes must now match if flag is not None and coldata.shape != flag.shape: raise TypeError(f"{self.name}: unexpected column shape") # discretize if self.nlevels: # minmax set? discretize over that if self.discretized_delta is not None: coldata = da.floor( (coldata - self.minmax[0]) / self.discretized_delta) coldata = da.minimum(da.maximum(coldata, 0), self.nlevels - 1).astype(COUNT_DTYPE) else: if coldata.dtype is bool: if not numpy.issubdtype(coldata.dtype, numpy.integer): raise TypeError( f"{self.name}: min/max must be set to colour by non-integer values" ) coldata = da.remainder(coldata, self.nlevels).astype(COUNT_DTYPE) if flag is not None: flag |= ~da.isfinite(coldata) return dama.masked_array(coldata, flag) else: return dama.masked_array(coldata, ~da.isfinite(coldata))
def test_init_score( task, output, client): if task == 'ranking' and output == 'scipy_csr_matrix': pytest.skip('LGBMRanker is not currently tested on sparse matrices') if task == 'ranking': _, _, _, _, dX, dy, dw, dg = _create_ranking_data( output=output, group=None ) model_factory = lgb.DaskLGBMRanker else: _, _, _, dX, dy, dw = _create_data( objective=task, output=output, ) dg = None if task == 'classification': model_factory = lgb.DaskLGBMClassifier elif task == 'regression': model_factory = lgb.DaskLGBMRegressor params = { 'n_estimators': 1, 'num_leaves': 2, 'time_out': 5 } init_score = random.random() if output.startswith('dataframe'): init_scores = dy.map_partitions(lambda x: pd.Series([init_score] * x.size)) else: init_scores = da.full_like(dy, fill_value=init_score, dtype=np.float64) model = model_factory(client=client, **params) model.fit(dX, dy, sample_weight=dw, init_score=init_scores, group=dg) # value of the root node is 0 when init_score is set assert model.booster_.trees_to_dataframe()['value'][0] == 0 client.close(timeout=CLIENT_CLOSE_TIMEOUT)
print(my_ones_arr.compute()) my_custom_arr = da.random.randint(10, size=(4, 4), chunks=(1, 4)) print(my_custom_arr.compute()) print(my_custom_arr.mean(axis=0).compute()) print(my_custom_arr.mean(axis=1).compute()) # supports slicing print(my_custom_arr[1:3, 2:4].compute()) # Supports broadcasting my_small_arr = da.ones(4, chunks=2) brd_example1 = da.add(my_custom_arr, my_small_arr) print(brd_example1.compute()) ten_arr = da.full_like(my_small_arr, 10) brd_example2 = da.add(my_custom_arr, ten_arr) print(brd_example2.compute()) # supports reshaping print(my_custom_arr.shape) custom_arr_1d = my_custom_arr.reshape(16) print(custom_arr_1d.compute()) # supprts stacking stacked_arr = da.stack([brd_example1, brd_example2]) print(stacked_arr.compute()) another_stacked = da.stack([brd_example1, brd_example2], axis=1) print(another_stacked.compute()) # supports concatination
def _vcfzarr_to_dataset( vcfzarr: zarr.Array, contig: Optional[str] = None, variant_contig_names: Optional[List[str]] = None, fix_strings: bool = True, field_defs: Optional[Dict[str, Dict[str, Any]]] = None, ) -> xr.Dataset: variant_position = da.from_zarr(vcfzarr["variants/POS"]) if contig is None: # Get the contigs from variants/CHROM variants_chrom = da.from_zarr(vcfzarr["variants/CHROM"]).astype(str) variant_contig, variant_contig_names = encode_array( variants_chrom.compute()) variant_contig = variant_contig.astype("i1") variant_contig_names = list(variant_contig_names) else: # Single contig: contig names were passed in assert variant_contig_names is not None contig_index = variant_contig_names.index(contig) variant_contig = da.full_like(variant_position, contig_index) # For variant alleles, combine REF and ALT into a single array variants_ref = da.from_zarr(vcfzarr["variants/REF"]) variants_alt = da.from_zarr(vcfzarr["variants/ALT"]) variant_allele = da.concatenate( [_ensure_2d(variants_ref), _ensure_2d(variants_alt)], axis=1) # rechunk so there's a single chunk in alleles axis variant_allele = variant_allele.rechunk((None, variant_allele.shape[1])) if "variants/ID" in vcfzarr: variants_id = da.from_zarr(vcfzarr["variants/ID"]).astype(str) else: variants_id = None ds = create_genotype_call_dataset( variant_contig_names=variant_contig_names, variant_contig=variant_contig, variant_position=variant_position, variant_allele=variant_allele, sample_id=da.from_zarr(vcfzarr["samples"]).astype(str), call_genotype=da.from_zarr(vcfzarr["calldata/GT"]), variant_id=variants_id, ) # Add a mask for variant ID if variants_id is not None: ds["variant_id_mask"] = ( [DIM_VARIANT], variants_id == ".", ) # Add any other fields field_defs = field_defs or {} default_info_fields = [ "ALT", "CHROM", "ID", "POS", "REF", "QUAL", "FILTER_PASS" ] default_format_fields = ["GT"] for key in set( vcfzarr["variants"].array_keys()) - set(default_info_fields): category = "INFO" vcfzarr_key = f"variants/{key}" variable_name = f"variant_{key}" dims = [DIM_VARIANT] field = f"{category}/{key}" field_def = field_defs.get(field, {}) _add_field_to_dataset(category, key, vcfzarr_key, variable_name, dims, field_def, vcfzarr, ds) for key in set( vcfzarr["calldata"].array_keys()) - set(default_format_fields): category = "FORMAT" vcfzarr_key = f"calldata/{key}" variable_name = f"call_{key}" dims = [DIM_VARIANT, DIM_SAMPLE] field = f"{category}/{key}" field_def = field_defs.get(field, {}) _add_field_to_dataset(category, key, vcfzarr_key, variable_name, dims, field_def, vcfzarr, ds) # Fix string types to include length if fix_strings: for (var, arr) in ds.data_vars.items(): kind = arr.dtype.kind if kind in ["O", "U", "S"]: # Compute fixed-length string dtype for array if kind == "O" or var in ("variant_id", "variant_allele"): kind = "S" max_len = max_str_len(arr).values # type: ignore[union-attr] dt = f"{kind}{max_len}" ds[var] = arr.astype(dt) if var in {"variant_id", "variant_allele"}: ds.attrs[f"max_length_{var}"] = max_len return ds
def test_full_like_error_nonscalar_fill_value(): x = np.full((3, 3), 1, dtype="i8") with pytest.raises(ValueError, match="fill_value must be scalar"): da.full_like(x, [100, 100], chunks=(2, 2), dtype="i8")
lambda x: da.expm1(x), lambda x: 2 * x, lambda x: x / 2, lambda x: x ** 2, lambda x: x + x, lambda x: x * x, lambda x: x[0], lambda x: x[:, 1], lambda x: x[:1, None, 1:3], lambda x: x.T, lambda x: da.transpose(x, (1, 2, 0)), lambda x: x.sum(), lambda x: da.empty_like(x), lambda x: da.ones_like(x), lambda x: da.zeros_like(x), lambda x: da.full_like(x, 5), pytest.param( lambda x: x.mean(), marks=pytest.mark.skipif( not IS_NEP18_ACTIVE or cupy.__version__ < LooseVersion("6.4.0"), reason="NEP-18 support is not available in NumPy or CuPy older than " "6.4.0 (requires https://github.com/cupy/cupy/pull/2418)", ), ), pytest.param( lambda x: x.moment(order=0), ), lambda x: x.moment(order=2), pytest.param( lambda x: x.std(), marks=pytest.mark.skipif(
def _psf(**kw): args = OmegaConf.create(kw) from omegaconf import ListConfig if not isinstance(args.ms, list) and not isinstance(args.ms, ListConfig): args.ms = [args.ms] OmegaConf.set_struct(args, True) import numpy as np from pfb.utils.misc import chan_to_band_mapping import dask # from dask.distributed import performance_report from dask.graph_manipulation import clone from daskms import xds_from_storage_ms as xds_from_ms from daskms import xds_from_storage_table as xds_from_table from daskms import Dataset from daskms.experimental.zarr import xds_to_zarr import dask.array as da from africanus.constants import c as lightspeed from africanus.gridding.wgridder.dask import dirty as vis2im from ducc0.fft import good_size from pfb.utils.misc import stitch_images, plan_row_chunk from pfb.utils.fits import set_wcs, save_fits # chan <-> band mapping ms = args.ms nband = args.nband freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping( ms, nband=nband) # gridder memory budget max_chan_chunk = 0 max_freq = 0 for ims in args.ms: for spw in freqs[ims]: counts = freq_bin_counts[ims][spw].compute() freq = freqs[ims][spw].compute() max_chan_chunk = np.maximum(max_chan_chunk, counts.max()) max_freq = np.maximum(max_freq, freq.max()) # assumes measurement sets have the same columns, # number of correlations etc. xds = xds_from_ms(args.ms[0]) ncorr = xds[0].dims['corr'] nrow = xds[0].dims['row'] # we still have to cater for complex valued data because we cast # the weights to complex but we not longer need to factor the # weight column into our memory budget data_bytes = getattr(xds[0], args.data_column).data.itemsize bytes_per_row = max_chan_chunk * ncorr * data_bytes memory_per_row = bytes_per_row # flags (uint8 or bool) memory_per_row += bytes_per_row / 8 # UVW memory_per_row += xds[0].UVW.data.itemsize * 3 # ANTENNA1/2 memory_per_row += xds[0].ANTENNA1.data.itemsize * 2 # TIME memory_per_row += xds[0].TIME.data.itemsize # data column is not actually read into memory just used to infer # dtype and chunking columns = (args.data_column, args.weight_column, args.flag_column, 'UVW', 'ANTENNA1', 'ANTENNA2', 'TIME') # flag row if 'FLAG_ROW' in xds[0]: columns += ('FLAG_ROW', ) memory_per_row += xds[0].FLAG_ROW.data.itemsize # imaging weights if args.imaging_weight_column is not None: columns += (args.imaging_weight_column, ) memory_per_row += bytes_per_row / 2 # Mueller term (complex valued) if args.mueller_column is not None: columns += (args.mueller_column, ) memory_per_row += bytes_per_row # get max uv coords over all fields uvw = [] u_max = 0.0 v_max = 0.0 for ims in args.ms: xds = xds_from_ms(ims, columns=('UVW'), chunks={'row': -1}) for ds in xds: uvw = ds.UVW.data u_max = da.maximum(u_max, abs(uvw[:, 0]).max()) v_max = da.maximum(v_max, abs(uvw[:, 1]).max()) uv_max = da.maximum(u_max, v_max) uv_max = uv_max.compute() del uvw # image size cell_N = 1.0 / (2 * uv_max * max_freq / lightspeed) if args.cell_size is not None: cell_size = args.cell_size cell_rad = cell_size * np.pi / 60 / 60 / 180 if cell_N / cell_rad < 1: raise ValueError( "Requested cell size too small. " "Super resolution factor = ", cell_N / cell_rad) print("Super resolution factor = %f" % (cell_N / cell_rad), file=log) else: cell_rad = cell_N / args.super_resolution_factor cell_size = cell_rad * 60 * 60 * 180 / np.pi print("Cell size set to %5.5e arcseconds" % cell_size, file=log) if args.nx is None: fov = args.field_of_view * 3600 npix = int(args.psf_oversize * fov / cell_size) if npix % 2: npix += 1 nx = npix ny = npix else: nx = args.nx ny = args.ny if args.ny is not None else nx print("PSF size set to (%i, %i, %i)" % (nband, nx, ny), file=log) # get approx image size # this is not a conservative estimate when multiple SPW's map to a single # imaging band pixel_bytes = np.dtype(args.output_type).itemsize band_size = nx * ny * pixel_bytes if args.host_address is None: # full image on single node row_chunk = plan_row_chunk(args.mem_limit / args.nworkers, band_size, nrow, memory_per_row, args.nthreads_per_worker) else: # single band per node row_chunk = plan_row_chunk(args.mem_limit, band_size, nrow, memory_per_row, args.nthreads_per_worker) if args.row_chunks is not None: row_chunk = int(args.row_chunks) if row_chunk == -1: row_chunk = nrow print( "nrows = %i, row chunks set to %i for a total of %i chunks per node" % (nrow, row_chunk, int(np.ceil(nrow / row_chunk))), file=log) chunks = {} for ims in args.ms: chunks[ims] = [] # xds_from_ms expects a list per ds for spw in freqs[ims]: chunks[ims].append({ 'row': row_chunk, 'chan': chan_chunks[ims][spw]['chan'] }) psfs = [] radec = None # assumes we are only imaging field 0 of first MS out_datasets = [] for ims in args.ms: xds = xds_from_ms(ims, chunks=chunks[ims], columns=columns) # subtables ddids = xds_from_table(ims + "::DATA_DESCRIPTION") fields = xds_from_table(ims + "::FIELD") spws = xds_from_table(ims + "::SPECTRAL_WINDOW") pols = xds_from_table(ims + "::POLARIZATION") # subtable data ddids = dask.compute(ddids)[0] fields = dask.compute(fields)[0] spws = dask.compute(spws)[0] pols = dask.compute(pols)[0] for ds in xds: field = fields[ds.FIELD_ID] # check fields match if radec is None: radec = field.PHASE_DIR.data.squeeze() if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()): continue # this is not correct, need to use spw spw = ds.DATA_DESC_ID uvw = clone(ds.UVW.data) data_type = getattr(ds, args.data_column).data.dtype data_shape = getattr(ds, args.data_column).data.shape data_chunks = getattr(ds, args.data_column).data.chunks weights = getattr(ds, args.weight_column).data if len(weights.shape) < 3: weights = da.broadcast_to(weights[:, None, :], data_shape, chunks=data_chunks) if args.imaging_weight_column is not None: imaging_weights = getattr(ds, args.imaging_weight_column).data if len(imaging_weights.shape) < 3: imaging_weights = da.broadcast_to(imaging_weights[:, None, :], data_shape, chunks=data_chunks) weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0] weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1] else: weightsxx = weights[:, :, 0] weightsyy = weights[:, :, -1] # apply mueller term if args.mueller_column is not None: mueller = getattr(ds, args.mueller_column).data weightsxx *= da.absolute(mueller[:, :, 0])**2 weightsyy *= da.absolute(mueller[:, :, -1])**2 # weighted sum corr to Stokes I weights = weightsxx + weightsyy # MS may contain auto-correlations if 'FLAG_ROW' in xds[0]: frow = ds.FLAG_ROW.data | (ds.ANTENNA1.data == ds.ANTENNA2.data) else: frow = (ds.ANTENNA1.data == ds.ANTENNA2.data) # only keep data where both corrs are unflagged flag = getattr(ds, args.flag_column).data flagxx = flag[:, :, 0] flagyy = flag[:, :, -1] # ducc0 uses uint8 mask not flag mask = ~da.logical_or((flagxx | flagyy), frow[:, None]) psf = vis2im(uvw, freqs[ims][spw], weights.astype(data_type), freq_bin_idx[ims][spw], freq_bin_counts[ims][spw], nx, ny, cell_rad, flag=mask.astype(np.uint8), nthreads=args.nvthreads, epsilon=args.epsilon, do_wstacking=args.wstack, double_accum=args.double_accum) psfs.append(psf) data_vars = { 'FIELD_ID': (('row', ), da.full_like(ds.TIME.data, ds.FIELD_ID, chunks=args.row_out_chunk)), 'DATA_DESC_ID': (('row', ), da.full_like(ds.TIME.data, ds.DATA_DESC_ID, chunks=args.row_out_chunk)), 'WEIGHT': (('row', 'chan'), weights.rechunk({0: args.row_out_chunk })), # why no 'f4'? 'UVW': (('row', 'uvw'), uvw.rechunk({0: args.row_out_chunk})) } coords = {'chan': (('chan', ), freqs[ims][spw])} out_ds = Dataset(data_vars, coords) out_datasets.append(out_ds) writes = xds_to_zarr(out_datasets, args.output_filename + '.zarr', columns='ALL') # dask.visualize(writes, filename=args.output_filename + '_psf_writes_graph.pdf', optimize_graph=False) # dask.visualize(psfs, filename=args.output_filename + '_psf_graph.pdf', optimize_graph=False) if not args.mock: # psfs = dask.compute(psfs, writes, optimize_graph=False)[0] # with performance_report(filename=args.output_filename + '_psf_per.html'): psfs = dask.compute(psfs, writes, optimize_graph=False)[0] psf = stitch_images(psfs, nband, band_mapping) hdr = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec, freq_out) save_fits(args.output_filename + '_psf.fits', psf, hdr, dtype=args.output_type) psf_mfs = np.sum(psf, axis=0) wsum = psf_mfs.max() psf_mfs /= wsum hdr_mfs = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec, np.mean(freq_out)) save_fits(args.output_filename + '_psf_mfs.fits', psf_mfs, hdr_mfs, dtype=args.output_type) print("All done here.", file=log)
def _main(args): tic = time.time() log.info(banner()) if args.disable_post_mortem: log.warn("Disabling crash debugging with the " "Interactive Python Debugger, as per user request") post_mortem_handler.disable_pdb_on_error() log.info("Flagging on the {0:s} column".format(args.data_column)) data_column = args.data_column masked_channels = [ load_mask(fn, dilate=args.dilate_masks) for fn in collect_masks() ] GD = args.config log_configuration(args) # Group datasets by these columns group_cols = ["FIELD_ID", "DATA_DESC_ID", "SCAN_NUMBER"] # Index datasets by these columns index_cols = ['TIME'] # Reopen the datasets using the aggregated row ordering columns = [data_column, "FLAG", "TIME", "ANTENNA1", "ANTENNA2"] if args.subtract_model_column is not None: columns.append(args.subtract_model_column) xds = list( xds_from_ms(args.ms, columns=tuple(columns), group_cols=group_cols, index_cols=index_cols, chunks={"row": args.row_chunks})) # Get support tables st = support_tables(args.ms) ddid_ds = st["DATA_DESCRIPTION"] field_ds = st["FIELD"] pol_ds = st["POLARIZATION"] spw_ds = st["SPECTRAL_WINDOW"] ant_ds = st["ANTENNA"] assert len(ant_ds) == 1 assert len(ddid_ds) == 1 antspos = ant_ds[0].POSITION.data antsnames = ant_ds[0].NAME.data fieldnames = [fds.NAME.data[0] for fds in field_ds] avail_scans = [ds.SCAN_NUMBER for ds in xds] args.scan_numbers = list( set(avail_scans).intersection(args.scan_numbers if args.scan_numbers is not None else avail_scans)) if args.scan_numbers != []: log.info("Only considering scans '{0:s}' as " "per user selection criterion".format(", ".join( map(str, map(int, args.scan_numbers))))) if args.field_names != []: flatten_field_names = [] for f in args.field_names: # accept comma lists per specification flatten_field_names += [x.strip() for x in f.split(",")] for f in flatten_field_names: if re.match(r"^\d+$", f) and int(f) < len(fieldnames): flatten_field_names.append(fieldnames[int(f)]) flatten_field_names = list( set( filter(lambda x: not re.match(r"^\d+$", x), flatten_field_names))) log.info("Only considering fields '{0:s}' for flagging per " "user " "selection criterion.".format(", ".join(flatten_field_names))) if not set(flatten_field_names) <= set(fieldnames): raise ValueError("One or more fields cannot be " "found in dataset '{0:s}' " "You specified {1:s}, but " "only {2:s} are available".format( args.ms, ",".join(flatten_field_names), ",".join(fieldnames))) field_dict = {fieldnames.index(fn): fn for fn in flatten_field_names} else: field_dict = {i: fn for i, fn in enumerate(fieldnames)} # List which hold our dask compute graphs for each dataset write_computes = [] original_stats = [] final_stats = [] # Iterate through each dataset for ds in xds: if ds.FIELD_ID not in field_dict: continue if (args.scan_numbers is not None and ds.SCAN_NUMBER not in args.scan_numbers): continue log.info("Adding field '{0:s}' scan {1:d} to " "compute graph for processing".format(field_dict[ds.FIELD_ID], ds.SCAN_NUMBER)) ddid = ddid_ds[ds.attrs['DATA_DESC_ID']] spw_info = spw_ds[ddid.SPECTRAL_WINDOW_ID.data[0]] pol_info = pol_ds[ddid.POLARIZATION_ID.data[0]] nrow, nchan, ncorr = getattr(ds, data_column).data.shape # Visibilities from the dataset vis = getattr(ds, data_column).data if args.subtract_model_column is not None: log.info("Forming residual data between '{0:s}' and " "'{1:s}' for flagging.".format( data_column, args.subtract_model_column)) vismod = getattr(ds, args.subtract_model_column).data vis = vis - vismod antenna1 = ds.ANTENNA1.data antenna2 = ds.ANTENNA2.data chan_freq = spw_info.CHAN_FREQ.data[0] chan_width = spw_info.CHAN_WIDTH.data[0] # Generate unflagged defaults if we should ignore existing flags # otherwise take flags from the dataset if args.ignore_flags is True: flags = da.full_like(vis, False, dtype=np.bool) log.critical("Completely ignoring measurement set " "flags as per '-if' request. " "Strategy WILL NOT or with original flags, even if " "specified!") else: flags = ds.FLAG.data # If we're flagging on polarised intensity, # we convert visibilities to polarised intensity # and any flagged correlation will flag the entire visibility if args.flagging_strategy == "polarisation": corr_type = pol_info.CORR_TYPE.data[0].tolist() stokes_map = stokes_corr_map(corr_type) stokes_pol = tuple(v for k, v in stokes_map.items() if k != "I") vis = polarised_intensity(vis, stokes_pol) flags = da.any(flags, axis=2, keepdims=True) elif args.flagging_strategy == "total_power": if args.subtract_model_column is None: log.critical("You requested to flag total quadrature " "power, but not on residuals. " "This is not advisable and the flagger " "may mistake fringes of " "off-axis sources for broadband RFI.") corr_type = pol_info.CORR_TYPE.data[0].tolist() stokes_map = stokes_corr_map(corr_type) stokes_pol = tuple(v for k, v in stokes_map.items()) vis = polarised_intensity(vis, stokes_pol) flags = da.any(flags, axis=2, keepdims=True) elif args.flagging_strategy == "standard": if args.subtract_model_column is None: log.critical("You requested to flag per correlation, " "but not on residuals. " "This is not advisable and the flagger " "may mistake fringes of off-axis sources " "for broadband RFI.") else: raise ValueError("Invalid flagging strategy '%s'" % args.flagging_strategy) ubl = unique_baselines(antenna1, antenna2) utime, time_inv = da.unique(ds.TIME.data, return_inverse=True) utime, ubl = dask.compute(utime, ubl) ubl = ubl.view(np.int32).reshape(-1, 2) # Stack the baseline index with the unique baselines bl_range = np.arange(ubl.shape[0], dtype=ubl.dtype)[:, None] ubl = np.concatenate([bl_range, ubl], axis=1) ubl = da.from_array(ubl, chunks=(args.baseline_chunks, 3)) vis_windows, flag_windows = pack_data(time_inv, ubl, antenna1, antenna2, vis, flags, utime.shape[0], backend=args.window_backend, path=args.temporary_directory) original_stats.append( window_stats(flag_windows, ubl, chan_freq, antsnames, ds.SCAN_NUMBER, field_dict[ds.FIELD_ID], ds.attrs['DATA_DESC_ID'])) with StrategyExecutor(antspos, ubl, chan_freq, chan_width, masked_channels, GD['strategies']) as se: flag_windows = se.apply_strategies(flag_windows, vis_windows) final_stats.append( window_stats(flag_windows, ubl, chan_freq, antsnames, ds.SCAN_NUMBER, field_dict[ds.FIELD_ID], ds.attrs['DATA_DESC_ID'])) # Unpack window data for writing back to the MS unpacked_flags = unpack_data(antenna1, antenna2, time_inv, ubl, flag_windows) # Flag entire visibility if any correlations are flagged equalized_flags = da.sum(unpacked_flags, axis=2, keepdims=True) > 0 corr_flags = da.broadcast_to(equalized_flags, (nrow, nchan, ncorr)) if corr_flags.chunks != ds.FLAG.data.chunks: raise ValueError("Output flag chunking does not " "match input flag chunking") # Create new dataset containing new flags new_ds = ds.assign(FLAG=(("row", "chan", "corr"), corr_flags)) # Write back to original dataset writes = xds_to_table(new_ds, args.ms, "FLAG") # original should also have .compute called because we need stats write_computes.append(writes) if len(write_computes) > 0: # Combine stats from all datasets original_stats = combine_window_stats(original_stats) final_stats = combine_window_stats(final_stats) with contextlib.ExitStack() as stack: # Create dask profiling contexts profilers = [] if can_profile: profilers.append(stack.enter_context(Profiler())) profilers.append(stack.enter_context(CacheProfiler())) profilers.append(stack.enter_context(ResourceProfiler())) if sys.stdout.isatty(): # Interactive terminal, default ProgressBar stack.enter_context(ProgressBar()) else: # Non-interactive, emit a bar every 5 minutes so # as not to spam the log stack.enter_context(ProgressBar(minimum=1, dt=5 * 60)) _, original_stats, final_stats = dask.compute( write_computes, original_stats, final_stats) if can_profile: visualize(profilers) toc = time.time() # Log each summary line for line in summarise_stats(final_stats, original_stats): log.info(line) elapsed = toc - tic log.info("Data flagged successfully in " "{0:02.0f}h{1:02.0f}m{2:02.0f}s".format((elapsed // 60) // 60, (elapsed // 60) % 60, elapsed % 60)) else: log.info("User data selection criteria resulted in empty dataset. " "Nothing to be done. Bye!")
def _vcfzarr_to_dataset( vcfzarr: zarr.Array, contig: Optional[str] = None, variant_contig_names: Optional[List[str]] = None, fix_strings: bool = True, ) -> xr.Dataset: variant_position = da.from_zarr(vcfzarr["variants/POS"]) if contig is None: # Get the contigs from variants/CHROM variants_chrom = da.from_zarr(vcfzarr["variants/CHROM"]).astype(str) variant_contig, variant_contig_names = encode_array( variants_chrom.compute()) variant_contig = variant_contig.astype("int16") variant_contig_names = list(variant_contig_names) else: # Single contig: contig names were passed in assert variant_contig_names is not None contig_index = variant_contig_names.index(contig) variant_contig = da.full_like(variant_position, contig_index) # For variant alleles, combine REF and ALT into a single array variants_ref = da.from_zarr(vcfzarr["variants/REF"]) variants_alt = da.from_zarr(vcfzarr["variants/ALT"]) variant_allele = da.concatenate( [_ensure_2d(variants_ref), _ensure_2d(variants_alt)], axis=1) # rechunk so there's a single chunk in alleles axis variant_allele = variant_allele.rechunk((None, variant_allele.shape[1])) if "variants/ID" in vcfzarr: variants_id = da.from_zarr(vcfzarr["variants/ID"]).astype(str) else: variants_id = None ds = create_genotype_call_dataset( variant_contig_names=variant_contig_names, variant_contig=variant_contig, variant_position=variant_position, variant_allele=variant_allele, sample_id=da.from_zarr(vcfzarr["samples"]).astype(str), call_genotype=da.from_zarr(vcfzarr["calldata/GT"]), variant_id=variants_id, ) # Add a mask for variant ID if variants_id is not None: ds["variant_id_mask"] = ( [DIM_VARIANT], variants_id == ".", ) # Fix string types to include length if fix_strings: for (var, arr) in ds.data_vars.items(): kind = arr.dtype.kind if kind in ["O", "U", "S"]: # Compute fixed-length string dtype for array if kind == "O" or var in ("variant_id", "variant_allele"): kind = "S" max_len = max_str_len(arr).values dt = f"{kind}{max_len}" ds[var] = arr.astype(dt) # type: ignore[no-untyped-call] if var in {"variant_id", "variant_allele"}: ds.attrs[f"max_{var}_length"] = max_len return ds