예제 #1
0
    def get_value(self, group, corr, extras, flag, flag_row, chanslice):
        coldata = self.get_column_data(group)
        # correlation may be pre-set by plot type, or may be passed to us
        corr = self.corr if self.corr is not None else corr
        # apply correlation reduction
        if coldata is not None and coldata.ndim == 3:
            assert corr is not None
            # the mapper can't have a specific axis set
            if self.mapper.axis is not None:
                raise TypeError(f"{self.name}: unexpected column with ndim=3")
            coldata = self.ms.corr_data_mappers[corr](coldata)
        # apply mapping function
        coldata = self.mapper.mapper(
            coldata, **{name: extras[name]
                        for name in self.mapper.extras})
        # scalar expanded to row vector
        if numpy.isscalar(coldata):
            coldata = da.full_like(flag_row,
                                   fill_value=coldata,
                                   dtype=type(coldata))
            flag = flag_row
        else:
            # apply channel slicing, if there's a channel axis in the array (and the array is a DataArray)
            if type(coldata) is xarray.DataArray and 'chan' in coldata.dims:
                coldata = coldata[dict(chan=chanslice)]
            # determine flags -- start with original flags
            if flag is not None:
                if coldata.ndim == 2:
                    flag = self.ms.corr_flag_mappers[corr](flag)
                elif coldata.ndim == 1:
                    if not self.mapper.axis:
                        flag = flag_row
                    elif self.mapper.axis == 1:
                        flag = None
                # shapes must now match
                if flag is not None and coldata.shape != flag.shape:
                    raise TypeError(f"{self.name}: unexpected column shape")
        # discretize
        if self.nlevels:
            # minmax set? discretize over that
            if self.discretized_delta is not None:
                coldata = da.floor(
                    (coldata - self.minmax[0]) / self.discretized_delta)
                coldata = da.minimum(da.maximum(coldata, 0),
                                     self.nlevels - 1).astype(COUNT_DTYPE)
            else:
                if coldata.dtype is bool:
                    if not numpy.issubdtype(coldata.dtype, numpy.integer):
                        raise TypeError(
                            f"{self.name}: min/max must be set to colour by non-integer values"
                        )
                    coldata = da.remainder(coldata,
                                           self.nlevels).astype(COUNT_DTYPE)

        if flag is not None:
            flag |= ~da.isfinite(coldata)
            return dama.masked_array(coldata, flag)
        else:
            return dama.masked_array(coldata, ~da.isfinite(coldata))
예제 #2
0
def test_init_score(
        task,
        output,
        client):
    if task == 'ranking' and output == 'scipy_csr_matrix':
        pytest.skip('LGBMRanker is not currently tested on sparse matrices')

    if task == 'ranking':
        _, _, _, _, dX, dy, dw, dg = _create_ranking_data(
            output=output,
            group=None
        )
        model_factory = lgb.DaskLGBMRanker
    else:
        _, _, _, dX, dy, dw = _create_data(
            objective=task,
            output=output,
        )
        dg = None
        if task == 'classification':
            model_factory = lgb.DaskLGBMClassifier
        elif task == 'regression':
            model_factory = lgb.DaskLGBMRegressor

    params = {
        'n_estimators': 1,
        'num_leaves': 2,
        'time_out': 5
    }
    init_score = random.random()
    if output.startswith('dataframe'):
        init_scores = dy.map_partitions(lambda x: pd.Series([init_score] * x.size))
    else:
        init_scores = da.full_like(dy, fill_value=init_score, dtype=np.float64)
    model = model_factory(client=client, **params)
    model.fit(dX, dy, sample_weight=dw, init_score=init_scores, group=dg)
    # value of the root node is 0 when init_score is set
    assert model.booster_.trees_to_dataframe()['value'][0] == 0

    client.close(timeout=CLIENT_CLOSE_TIMEOUT)
예제 #3
0
print(my_ones_arr.compute())

my_custom_arr = da.random.randint(10, size=(4, 4), chunks=(1, 4))
print(my_custom_arr.compute())
print(my_custom_arr.mean(axis=0).compute())
print(my_custom_arr.mean(axis=1).compute())

# supports slicing
print(my_custom_arr[1:3, 2:4].compute())

# Supports broadcasting
my_small_arr = da.ones(4, chunks=2)
brd_example1 = da.add(my_custom_arr, my_small_arr)
print(brd_example1.compute())

ten_arr = da.full_like(my_small_arr, 10)
brd_example2 = da.add(my_custom_arr, ten_arr)
print(brd_example2.compute())

# supports reshaping
print(my_custom_arr.shape)
custom_arr_1d = my_custom_arr.reshape(16)
print(custom_arr_1d.compute())

# supprts stacking
stacked_arr = da.stack([brd_example1, brd_example2])
print(stacked_arr.compute())
another_stacked = da.stack([brd_example1, brd_example2], axis=1)
print(another_stacked.compute())

# supports concatination
예제 #4
0
def _vcfzarr_to_dataset(
    vcfzarr: zarr.Array,
    contig: Optional[str] = None,
    variant_contig_names: Optional[List[str]] = None,
    fix_strings: bool = True,
    field_defs: Optional[Dict[str, Dict[str, Any]]] = None,
) -> xr.Dataset:

    variant_position = da.from_zarr(vcfzarr["variants/POS"])

    if contig is None:
        # Get the contigs from variants/CHROM
        variants_chrom = da.from_zarr(vcfzarr["variants/CHROM"]).astype(str)
        variant_contig, variant_contig_names = encode_array(
            variants_chrom.compute())
        variant_contig = variant_contig.astype("i1")
        variant_contig_names = list(variant_contig_names)
    else:
        # Single contig: contig names were passed in
        assert variant_contig_names is not None
        contig_index = variant_contig_names.index(contig)
        variant_contig = da.full_like(variant_position, contig_index)

    # For variant alleles, combine REF and ALT into a single array
    variants_ref = da.from_zarr(vcfzarr["variants/REF"])
    variants_alt = da.from_zarr(vcfzarr["variants/ALT"])
    variant_allele = da.concatenate(
        [_ensure_2d(variants_ref),
         _ensure_2d(variants_alt)], axis=1)
    # rechunk so there's a single chunk in alleles axis
    variant_allele = variant_allele.rechunk((None, variant_allele.shape[1]))

    if "variants/ID" in vcfzarr:
        variants_id = da.from_zarr(vcfzarr["variants/ID"]).astype(str)
    else:
        variants_id = None

    ds = create_genotype_call_dataset(
        variant_contig_names=variant_contig_names,
        variant_contig=variant_contig,
        variant_position=variant_position,
        variant_allele=variant_allele,
        sample_id=da.from_zarr(vcfzarr["samples"]).astype(str),
        call_genotype=da.from_zarr(vcfzarr["calldata/GT"]),
        variant_id=variants_id,
    )

    # Add a mask for variant ID
    if variants_id is not None:
        ds["variant_id_mask"] = (
            [DIM_VARIANT],
            variants_id == ".",
        )

    # Add any other fields
    field_defs = field_defs or {}
    default_info_fields = [
        "ALT", "CHROM", "ID", "POS", "REF", "QUAL", "FILTER_PASS"
    ]
    default_format_fields = ["GT"]
    for key in set(
            vcfzarr["variants"].array_keys()) - set(default_info_fields):
        category = "INFO"
        vcfzarr_key = f"variants/{key}"
        variable_name = f"variant_{key}"
        dims = [DIM_VARIANT]
        field = f"{category}/{key}"
        field_def = field_defs.get(field, {})
        _add_field_to_dataset(category, key, vcfzarr_key, variable_name, dims,
                              field_def, vcfzarr, ds)
    for key in set(
            vcfzarr["calldata"].array_keys()) - set(default_format_fields):
        category = "FORMAT"
        vcfzarr_key = f"calldata/{key}"
        variable_name = f"call_{key}"
        dims = [DIM_VARIANT, DIM_SAMPLE]
        field = f"{category}/{key}"
        field_def = field_defs.get(field, {})
        _add_field_to_dataset(category, key, vcfzarr_key, variable_name, dims,
                              field_def, vcfzarr, ds)

    # Fix string types to include length
    if fix_strings:
        for (var, arr) in ds.data_vars.items():
            kind = arr.dtype.kind
            if kind in ["O", "U", "S"]:
                # Compute fixed-length string dtype for array
                if kind == "O" or var in ("variant_id", "variant_allele"):
                    kind = "S"
                max_len = max_str_len(arr).values  # type: ignore[union-attr]
                dt = f"{kind}{max_len}"
                ds[var] = arr.astype(dt)

                if var in {"variant_id", "variant_allele"}:
                    ds.attrs[f"max_length_{var}"] = max_len

    return ds
예제 #5
0
def test_full_like_error_nonscalar_fill_value():
    x = np.full((3, 3), 1, dtype="i8")
    with pytest.raises(ValueError, match="fill_value must be scalar"):
        da.full_like(x, [100, 100], chunks=(2, 2), dtype="i8")
예제 #6
0
파일: test_cupy.py 프로젝트: uellue/dask
 lambda x: da.expm1(x),
 lambda x: 2 * x,
 lambda x: x / 2,
 lambda x: x ** 2,
 lambda x: x + x,
 lambda x: x * x,
 lambda x: x[0],
 lambda x: x[:, 1],
 lambda x: x[:1, None, 1:3],
 lambda x: x.T,
 lambda x: da.transpose(x, (1, 2, 0)),
 lambda x: x.sum(),
 lambda x: da.empty_like(x),
 lambda x: da.ones_like(x),
 lambda x: da.zeros_like(x),
 lambda x: da.full_like(x, 5),
 pytest.param(
     lambda x: x.mean(),
     marks=pytest.mark.skipif(
         not IS_NEP18_ACTIVE or cupy.__version__ < LooseVersion("6.4.0"),
         reason="NEP-18 support is not available in NumPy or CuPy older than "
         "6.4.0 (requires https://github.com/cupy/cupy/pull/2418)",
     ),
 ),
 pytest.param(
     lambda x: x.moment(order=0),
 ),
 lambda x: x.moment(order=2),
 pytest.param(
     lambda x: x.std(),
     marks=pytest.mark.skipif(
예제 #7
0
파일: psf.py 프로젝트: ratt-ru/pfb-clean
def _psf(**kw):
    args = OmegaConf.create(kw)
    from omegaconf import ListConfig
    if not isinstance(args.ms, list) and not isinstance(args.ms, ListConfig):
        args.ms = [args.ms]
    OmegaConf.set_struct(args, True)

    import numpy as np
    from pfb.utils.misc import chan_to_band_mapping
    import dask
    # from dask.distributed import performance_report
    from dask.graph_manipulation import clone
    from daskms import xds_from_storage_ms as xds_from_ms
    from daskms import xds_from_storage_table as xds_from_table
    from daskms import Dataset
    from daskms.experimental.zarr import xds_to_zarr
    import dask.array as da
    from africanus.constants import c as lightspeed
    from africanus.gridding.wgridder.dask import dirty as vis2im
    from ducc0.fft import good_size
    from pfb.utils.misc import stitch_images, plan_row_chunk
    from pfb.utils.fits import set_wcs, save_fits

    # chan <-> band mapping
    ms = args.ms
    nband = args.nband
    freqs, freq_bin_idx, freq_bin_counts, freq_out, band_mapping, chan_chunks = chan_to_band_mapping(
        ms, nband=nband)

    # gridder memory budget
    max_chan_chunk = 0
    max_freq = 0
    for ims in args.ms:
        for spw in freqs[ims]:
            counts = freq_bin_counts[ims][spw].compute()
            freq = freqs[ims][spw].compute()
            max_chan_chunk = np.maximum(max_chan_chunk, counts.max())
            max_freq = np.maximum(max_freq, freq.max())

    # assumes measurement sets have the same columns,
    # number of correlations etc.
    xds = xds_from_ms(args.ms[0])
    ncorr = xds[0].dims['corr']
    nrow = xds[0].dims['row']
    # we still have to cater for complex valued data because we cast
    # the weights to complex but we not longer need to factor the
    # weight column into our memory budget
    data_bytes = getattr(xds[0], args.data_column).data.itemsize
    bytes_per_row = max_chan_chunk * ncorr * data_bytes
    memory_per_row = bytes_per_row

    # flags (uint8 or bool)
    memory_per_row += bytes_per_row / 8

    # UVW
    memory_per_row += xds[0].UVW.data.itemsize * 3

    # ANTENNA1/2
    memory_per_row += xds[0].ANTENNA1.data.itemsize * 2

    # TIME
    memory_per_row += xds[0].TIME.data.itemsize

    # data column is not actually read into memory just used to infer
    # dtype and chunking
    columns = (args.data_column, args.weight_column, args.flag_column, 'UVW',
               'ANTENNA1', 'ANTENNA2', 'TIME')

    # flag row
    if 'FLAG_ROW' in xds[0]:
        columns += ('FLAG_ROW', )
        memory_per_row += xds[0].FLAG_ROW.data.itemsize

    # imaging weights
    if args.imaging_weight_column is not None:
        columns += (args.imaging_weight_column, )
        memory_per_row += bytes_per_row / 2

    # Mueller term (complex valued)
    if args.mueller_column is not None:
        columns += (args.mueller_column, )
        memory_per_row += bytes_per_row

    # get max uv coords over all fields
    uvw = []
    u_max = 0.0
    v_max = 0.0
    for ims in args.ms:
        xds = xds_from_ms(ims, columns=('UVW'), chunks={'row': -1})

        for ds in xds:
            uvw = ds.UVW.data
            u_max = da.maximum(u_max, abs(uvw[:, 0]).max())
            v_max = da.maximum(v_max, abs(uvw[:, 1]).max())
            uv_max = da.maximum(u_max, v_max)

    uv_max = uv_max.compute()
    del uvw

    # image size
    cell_N = 1.0 / (2 * uv_max * max_freq / lightspeed)

    if args.cell_size is not None:
        cell_size = args.cell_size
        cell_rad = cell_size * np.pi / 60 / 60 / 180
        if cell_N / cell_rad < 1:
            raise ValueError(
                "Requested cell size too small. "
                "Super resolution factor = ", cell_N / cell_rad)
        print("Super resolution factor = %f" % (cell_N / cell_rad), file=log)
    else:
        cell_rad = cell_N / args.super_resolution_factor
        cell_size = cell_rad * 60 * 60 * 180 / np.pi
        print("Cell size set to %5.5e arcseconds" % cell_size, file=log)

    if args.nx is None:
        fov = args.field_of_view * 3600
        npix = int(args.psf_oversize * fov / cell_size)
        if npix % 2:
            npix += 1
        nx = npix
        ny = npix
    else:
        nx = args.nx
        ny = args.ny if args.ny is not None else nx

    print("PSF size set to (%i, %i, %i)" % (nband, nx, ny), file=log)

    # get approx image size
    # this is not a conservative estimate when multiple SPW's map to a single
    # imaging band
    pixel_bytes = np.dtype(args.output_type).itemsize
    band_size = nx * ny * pixel_bytes

    if args.host_address is None:
        # full image on single node
        row_chunk = plan_row_chunk(args.mem_limit / args.nworkers, band_size,
                                   nrow, memory_per_row,
                                   args.nthreads_per_worker)

    else:
        # single band per node
        row_chunk = plan_row_chunk(args.mem_limit, band_size, nrow,
                                   memory_per_row, args.nthreads_per_worker)

    if args.row_chunks is not None:
        row_chunk = int(args.row_chunks)
        if row_chunk == -1:
            row_chunk = nrow

    print(
        "nrows = %i, row chunks set to %i for a total of %i chunks per node" %
        (nrow, row_chunk, int(np.ceil(nrow / row_chunk))),
        file=log)

    chunks = {}
    for ims in args.ms:
        chunks[ims] = []  # xds_from_ms expects a list per ds
        for spw in freqs[ims]:
            chunks[ims].append({
                'row': row_chunk,
                'chan': chan_chunks[ims][spw]['chan']
            })

    psfs = []
    radec = None  # assumes we are only imaging field 0 of first MS
    out_datasets = []
    for ims in args.ms:
        xds = xds_from_ms(ims, chunks=chunks[ims], columns=columns)

        # subtables
        ddids = xds_from_table(ims + "::DATA_DESCRIPTION")
        fields = xds_from_table(ims + "::FIELD")
        spws = xds_from_table(ims + "::SPECTRAL_WINDOW")
        pols = xds_from_table(ims + "::POLARIZATION")

        # subtable data
        ddids = dask.compute(ddids)[0]
        fields = dask.compute(fields)[0]
        spws = dask.compute(spws)[0]
        pols = dask.compute(pols)[0]

        for ds in xds:
            field = fields[ds.FIELD_ID]

            # check fields match
            if radec is None:
                radec = field.PHASE_DIR.data.squeeze()

            if not np.array_equal(radec, field.PHASE_DIR.data.squeeze()):
                continue

            # this is not correct, need to use spw
            spw = ds.DATA_DESC_ID

            uvw = clone(ds.UVW.data)

            data_type = getattr(ds, args.data_column).data.dtype
            data_shape = getattr(ds, args.data_column).data.shape
            data_chunks = getattr(ds, args.data_column).data.chunks

            weights = getattr(ds, args.weight_column).data
            if len(weights.shape) < 3:
                weights = da.broadcast_to(weights[:, None, :],
                                          data_shape,
                                          chunks=data_chunks)

            if args.imaging_weight_column is not None:
                imaging_weights = getattr(ds, args.imaging_weight_column).data
                if len(imaging_weights.shape) < 3:
                    imaging_weights = da.broadcast_to(imaging_weights[:,
                                                                      None, :],
                                                      data_shape,
                                                      chunks=data_chunks)

                weightsxx = imaging_weights[:, :, 0] * weights[:, :, 0]
                weightsyy = imaging_weights[:, :, -1] * weights[:, :, -1]
            else:
                weightsxx = weights[:, :, 0]
                weightsyy = weights[:, :, -1]

            # apply mueller term
            if args.mueller_column is not None:
                mueller = getattr(ds, args.mueller_column).data
                weightsxx *= da.absolute(mueller[:, :, 0])**2
                weightsyy *= da.absolute(mueller[:, :, -1])**2

            # weighted sum corr to Stokes I
            weights = weightsxx + weightsyy

            # MS may contain auto-correlations
            if 'FLAG_ROW' in xds[0]:
                frow = ds.FLAG_ROW.data | (ds.ANTENNA1.data
                                           == ds.ANTENNA2.data)
            else:
                frow = (ds.ANTENNA1.data == ds.ANTENNA2.data)

            # only keep data where both corrs are unflagged
            flag = getattr(ds, args.flag_column).data
            flagxx = flag[:, :, 0]
            flagyy = flag[:, :, -1]
            # ducc0 uses uint8 mask not flag
            mask = ~da.logical_or((flagxx | flagyy), frow[:, None])

            psf = vis2im(uvw,
                         freqs[ims][spw],
                         weights.astype(data_type),
                         freq_bin_idx[ims][spw],
                         freq_bin_counts[ims][spw],
                         nx,
                         ny,
                         cell_rad,
                         flag=mask.astype(np.uint8),
                         nthreads=args.nvthreads,
                         epsilon=args.epsilon,
                         do_wstacking=args.wstack,
                         double_accum=args.double_accum)

            psfs.append(psf)

            data_vars = {
                'FIELD_ID': (('row', ),
                             da.full_like(ds.TIME.data,
                                          ds.FIELD_ID,
                                          chunks=args.row_out_chunk)),
                'DATA_DESC_ID': (('row', ),
                                 da.full_like(ds.TIME.data,
                                              ds.DATA_DESC_ID,
                                              chunks=args.row_out_chunk)),
                'WEIGHT':
                (('row', 'chan'), weights.rechunk({0: args.row_out_chunk
                                                   })),  # why no 'f4'?
                'UVW': (('row', 'uvw'), uvw.rechunk({0: args.row_out_chunk}))
            }

            coords = {'chan': (('chan', ), freqs[ims][spw])}

            out_ds = Dataset(data_vars, coords)

            out_datasets.append(out_ds)

    writes = xds_to_zarr(out_datasets,
                         args.output_filename + '.zarr',
                         columns='ALL')

    # dask.visualize(writes, filename=args.output_filename + '_psf_writes_graph.pdf', optimize_graph=False)
    # dask.visualize(psfs, filename=args.output_filename + '_psf_graph.pdf', optimize_graph=False)

    if not args.mock:
        # psfs = dask.compute(psfs, writes, optimize_graph=False)[0]
        # with performance_report(filename=args.output_filename + '_psf_per.html'):
        psfs = dask.compute(psfs, writes, optimize_graph=False)[0]

        psf = stitch_images(psfs, nband, band_mapping)

        hdr = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec,
                      freq_out)
        save_fits(args.output_filename + '_psf.fits',
                  psf,
                  hdr,
                  dtype=args.output_type)

        psf_mfs = np.sum(psf, axis=0)
        wsum = psf_mfs.max()
        psf_mfs /= wsum

        hdr_mfs = set_wcs(cell_size / 3600, cell_size / 3600, nx, ny, radec,
                          np.mean(freq_out))
        save_fits(args.output_filename + '_psf_mfs.fits',
                  psf_mfs,
                  hdr_mfs,
                  dtype=args.output_type)

    print("All done here.", file=log)
예제 #8
0
def _main(args):
    tic = time.time()

    log.info(banner())

    if args.disable_post_mortem:
        log.warn("Disabling crash debugging with the "
                 "Interactive Python Debugger, as per user request")
        post_mortem_handler.disable_pdb_on_error()

    log.info("Flagging on the {0:s} column".format(args.data_column))
    data_column = args.data_column
    masked_channels = [
        load_mask(fn, dilate=args.dilate_masks) for fn in collect_masks()
    ]
    GD = args.config

    log_configuration(args)

    # Group datasets by these columns
    group_cols = ["FIELD_ID", "DATA_DESC_ID", "SCAN_NUMBER"]
    # Index datasets by these columns
    index_cols = ['TIME']

    # Reopen the datasets using the aggregated row ordering
    columns = [data_column, "FLAG", "TIME", "ANTENNA1", "ANTENNA2"]

    if args.subtract_model_column is not None:
        columns.append(args.subtract_model_column)

    xds = list(
        xds_from_ms(args.ms,
                    columns=tuple(columns),
                    group_cols=group_cols,
                    index_cols=index_cols,
                    chunks={"row": args.row_chunks}))

    # Get support tables
    st = support_tables(args.ms)
    ddid_ds = st["DATA_DESCRIPTION"]
    field_ds = st["FIELD"]
    pol_ds = st["POLARIZATION"]
    spw_ds = st["SPECTRAL_WINDOW"]
    ant_ds = st["ANTENNA"]

    assert len(ant_ds) == 1
    assert len(ddid_ds) == 1

    antspos = ant_ds[0].POSITION.data
    antsnames = ant_ds[0].NAME.data
    fieldnames = [fds.NAME.data[0] for fds in field_ds]

    avail_scans = [ds.SCAN_NUMBER for ds in xds]
    args.scan_numbers = list(
        set(avail_scans).intersection(args.scan_numbers if args.scan_numbers
                                      is not None else avail_scans))

    if args.scan_numbers != []:
        log.info("Only considering scans '{0:s}' as "
                 "per user selection criterion".format(", ".join(
                     map(str, map(int, args.scan_numbers)))))

    if args.field_names != []:
        flatten_field_names = []
        for f in args.field_names:
            # accept comma lists per specification
            flatten_field_names += [x.strip() for x in f.split(",")]
        for f in flatten_field_names:
            if re.match(r"^\d+$", f) and int(f) < len(fieldnames):
                flatten_field_names.append(fieldnames[int(f)])
        flatten_field_names = list(
            set(
                filter(lambda x: not re.match(r"^\d+$", x),
                       flatten_field_names)))
        log.info("Only considering fields '{0:s}' for flagging per "
                 "user "
                 "selection criterion.".format(", ".join(flatten_field_names)))
        if not set(flatten_field_names) <= set(fieldnames):
            raise ValueError("One or more fields cannot be "
                             "found in dataset '{0:s}' "
                             "You specified {1:s}, but "
                             "only {2:s} are available".format(
                                 args.ms, ",".join(flatten_field_names),
                                 ",".join(fieldnames)))

        field_dict = {fieldnames.index(fn): fn for fn in flatten_field_names}
    else:
        field_dict = {i: fn for i, fn in enumerate(fieldnames)}

    # List which hold our dask compute graphs for each dataset
    write_computes = []
    original_stats = []
    final_stats = []

    # Iterate through each dataset
    for ds in xds:
        if ds.FIELD_ID not in field_dict:
            continue

        if (args.scan_numbers is not None
                and ds.SCAN_NUMBER not in args.scan_numbers):
            continue

        log.info("Adding field '{0:s}' scan {1:d} to "
                 "compute graph for processing".format(field_dict[ds.FIELD_ID],
                                                       ds.SCAN_NUMBER))

        ddid = ddid_ds[ds.attrs['DATA_DESC_ID']]
        spw_info = spw_ds[ddid.SPECTRAL_WINDOW_ID.data[0]]
        pol_info = pol_ds[ddid.POLARIZATION_ID.data[0]]

        nrow, nchan, ncorr = getattr(ds, data_column).data.shape

        # Visibilities from the dataset
        vis = getattr(ds, data_column).data
        if args.subtract_model_column is not None:
            log.info("Forming residual data between '{0:s}' and "
                     "'{1:s}' for flagging.".format(
                         data_column, args.subtract_model_column))
            vismod = getattr(ds, args.subtract_model_column).data
            vis = vis - vismod

        antenna1 = ds.ANTENNA1.data
        antenna2 = ds.ANTENNA2.data
        chan_freq = spw_info.CHAN_FREQ.data[0]
        chan_width = spw_info.CHAN_WIDTH.data[0]

        # Generate unflagged defaults if we should ignore existing flags
        # otherwise take flags from the dataset
        if args.ignore_flags is True:
            flags = da.full_like(vis, False, dtype=np.bool)
            log.critical("Completely ignoring measurement set "
                         "flags as per '-if' request. "
                         "Strategy WILL NOT or with original flags, even if "
                         "specified!")
        else:
            flags = ds.FLAG.data

        # If we're flagging on polarised intensity,
        # we convert visibilities to polarised intensity
        # and any flagged correlation will flag the entire visibility
        if args.flagging_strategy == "polarisation":
            corr_type = pol_info.CORR_TYPE.data[0].tolist()
            stokes_map = stokes_corr_map(corr_type)
            stokes_pol = tuple(v for k, v in stokes_map.items() if k != "I")
            vis = polarised_intensity(vis, stokes_pol)
            flags = da.any(flags, axis=2, keepdims=True)
        elif args.flagging_strategy == "total_power":
            if args.subtract_model_column is None:
                log.critical("You requested to flag total quadrature "
                             "power, but not on residuals. "
                             "This is not advisable and the flagger "
                             "may mistake fringes of "
                             "off-axis sources for broadband RFI.")
            corr_type = pol_info.CORR_TYPE.data[0].tolist()
            stokes_map = stokes_corr_map(corr_type)
            stokes_pol = tuple(v for k, v in stokes_map.items())
            vis = polarised_intensity(vis, stokes_pol)
            flags = da.any(flags, axis=2, keepdims=True)
        elif args.flagging_strategy == "standard":
            if args.subtract_model_column is None:
                log.critical("You requested to flag per correlation, "
                             "but not on residuals. "
                             "This is not advisable and the flagger "
                             "may mistake fringes of off-axis sources "
                             "for broadband RFI.")
        else:
            raise ValueError("Invalid flagging strategy '%s'" %
                             args.flagging_strategy)

        ubl = unique_baselines(antenna1, antenna2)
        utime, time_inv = da.unique(ds.TIME.data, return_inverse=True)
        utime, ubl = dask.compute(utime, ubl)
        ubl = ubl.view(np.int32).reshape(-1, 2)
        # Stack the baseline index with the unique baselines
        bl_range = np.arange(ubl.shape[0], dtype=ubl.dtype)[:, None]
        ubl = np.concatenate([bl_range, ubl], axis=1)
        ubl = da.from_array(ubl, chunks=(args.baseline_chunks, 3))

        vis_windows, flag_windows = pack_data(time_inv,
                                              ubl,
                                              antenna1,
                                              antenna2,
                                              vis,
                                              flags,
                                              utime.shape[0],
                                              backend=args.window_backend,
                                              path=args.temporary_directory)

        original_stats.append(
            window_stats(flag_windows, ubl, chan_freq, antsnames,
                         ds.SCAN_NUMBER, field_dict[ds.FIELD_ID],
                         ds.attrs['DATA_DESC_ID']))

        with StrategyExecutor(antspos, ubl, chan_freq, chan_width,
                              masked_channels, GD['strategies']) as se:

            flag_windows = se.apply_strategies(flag_windows, vis_windows)

        final_stats.append(
            window_stats(flag_windows, ubl, chan_freq, antsnames,
                         ds.SCAN_NUMBER, field_dict[ds.FIELD_ID],
                         ds.attrs['DATA_DESC_ID']))

        # Unpack window data for writing back to the MS
        unpacked_flags = unpack_data(antenna1, antenna2, time_inv, ubl,
                                     flag_windows)

        # Flag entire visibility if any correlations are flagged
        equalized_flags = da.sum(unpacked_flags, axis=2, keepdims=True) > 0
        corr_flags = da.broadcast_to(equalized_flags, (nrow, nchan, ncorr))

        if corr_flags.chunks != ds.FLAG.data.chunks:
            raise ValueError("Output flag chunking does not "
                             "match input flag chunking")

        # Create new dataset containing new flags
        new_ds = ds.assign(FLAG=(("row", "chan", "corr"), corr_flags))

        # Write back to original dataset
        writes = xds_to_table(new_ds, args.ms, "FLAG")
        # original should also have .compute called because we need stats
        write_computes.append(writes)

    if len(write_computes) > 0:
        # Combine stats from all datasets
        original_stats = combine_window_stats(original_stats)
        final_stats = combine_window_stats(final_stats)

        with contextlib.ExitStack() as stack:
            # Create dask profiling contexts
            profilers = []

            if can_profile:
                profilers.append(stack.enter_context(Profiler()))
                profilers.append(stack.enter_context(CacheProfiler()))
                profilers.append(stack.enter_context(ResourceProfiler()))

            if sys.stdout.isatty():
                # Interactive terminal, default ProgressBar
                stack.enter_context(ProgressBar())
            else:
                # Non-interactive, emit a bar every 5 minutes so
                # as not to spam the log
                stack.enter_context(ProgressBar(minimum=1, dt=5 * 60))

            _, original_stats, final_stats = dask.compute(
                write_computes, original_stats, final_stats)
        if can_profile:
            visualize(profilers)

        toc = time.time()

        # Log each summary line
        for line in summarise_stats(final_stats, original_stats):
            log.info(line)

        elapsed = toc - tic
        log.info("Data flagged successfully in "
                 "{0:02.0f}h{1:02.0f}m{2:02.0f}s".format((elapsed // 60) // 60,
                                                         (elapsed // 60) % 60,
                                                         elapsed % 60))
    else:
        log.info("User data selection criteria resulted in empty dataset. "
                 "Nothing to be done. Bye!")
예제 #9
0
def _vcfzarr_to_dataset(
    vcfzarr: zarr.Array,
    contig: Optional[str] = None,
    variant_contig_names: Optional[List[str]] = None,
    fix_strings: bool = True,
) -> xr.Dataset:

    variant_position = da.from_zarr(vcfzarr["variants/POS"])

    if contig is None:
        # Get the contigs from variants/CHROM
        variants_chrom = da.from_zarr(vcfzarr["variants/CHROM"]).astype(str)
        variant_contig, variant_contig_names = encode_array(
            variants_chrom.compute())
        variant_contig = variant_contig.astype("int16")
        variant_contig_names = list(variant_contig_names)
    else:
        # Single contig: contig names were passed in
        assert variant_contig_names is not None
        contig_index = variant_contig_names.index(contig)
        variant_contig = da.full_like(variant_position, contig_index)

    # For variant alleles, combine REF and ALT into a single array
    variants_ref = da.from_zarr(vcfzarr["variants/REF"])
    variants_alt = da.from_zarr(vcfzarr["variants/ALT"])
    variant_allele = da.concatenate(
        [_ensure_2d(variants_ref),
         _ensure_2d(variants_alt)], axis=1)
    # rechunk so there's a single chunk in alleles axis
    variant_allele = variant_allele.rechunk((None, variant_allele.shape[1]))

    if "variants/ID" in vcfzarr:
        variants_id = da.from_zarr(vcfzarr["variants/ID"]).astype(str)
    else:
        variants_id = None

    ds = create_genotype_call_dataset(
        variant_contig_names=variant_contig_names,
        variant_contig=variant_contig,
        variant_position=variant_position,
        variant_allele=variant_allele,
        sample_id=da.from_zarr(vcfzarr["samples"]).astype(str),
        call_genotype=da.from_zarr(vcfzarr["calldata/GT"]),
        variant_id=variants_id,
    )

    # Add a mask for variant ID
    if variants_id is not None:
        ds["variant_id_mask"] = (
            [DIM_VARIANT],
            variants_id == ".",
        )

    # Fix string types to include length
    if fix_strings:
        for (var, arr) in ds.data_vars.items():
            kind = arr.dtype.kind
            if kind in ["O", "U", "S"]:
                # Compute fixed-length string dtype for array
                if kind == "O" or var in ("variant_id", "variant_allele"):
                    kind = "S"
                max_len = max_str_len(arr).values
                dt = f"{kind}{max_len}"
                ds[var] = arr.astype(dt)  # type: ignore[no-untyped-call]

                if var in {"variant_id", "variant_allele"}:
                    ds.attrs[f"max_{var}_length"] = max_len

    return ds