Exemple #1
0
def _get_ts_dask(pt_1, pt_2, pt_3, pt_4, out_x, out_y):
    """Calculate vertical and horizontal fractional distances t and s"""

    # General case, ie. where the the corners form an irregular rectangle
    t__, s__ = _get_ts_irregular_dask(pt_1, pt_2, pt_3, pt_4, out_y, out_x)

    # Cases where verticals are parallel
    idxs = da.isnan(t__) | da.isnan(s__)
    # Remove extra dimensions
    idxs = da.ravel(idxs)

    if da.any(idxs):
        t_new, s_new = _get_ts_uprights_parallel_dask(pt_1, pt_2,
                                                      pt_3, pt_4,
                                                      out_y, out_x)

        t__ = da.where(idxs, t_new, t__)
        s__ = da.where(idxs, s_new, s__)

    # Cases where both verticals and horizontals are parallel
    idxs = da.isnan(t__) | da.isnan(s__)
    # Remove extra dimensions
    idxs = da.ravel(idxs)
    if da.any(idxs):
        t_new, s_new = _get_ts_parallellogram_dask(pt_1, pt_2, pt_3,
                                                   out_y, out_x)
        t__ = da.where(idxs, t_new, t__)
        s__ = da.where(idxs, s_new, s__)

    idxs = (t__ < 0) | (t__ > 1) | (s__ < 0) | (s__ > 1)
    t__ = da.where(idxs, np.nan, t__)
    s__ = da.where(idxs, np.nan, s__)

    return t__, s__
Exemple #2
0
def _get_ts_dask(pt_1, pt_2, pt_3, pt_4, out_x, out_y):
    """Calculate vertical and horizontal fractional distances t and s"""

    # General case, ie. where the the corners form an irregular rectangle
    t__, s__ = _get_ts_irregular_dask(pt_1, pt_2, pt_3, pt_4, out_y, out_x)

    # Cases where verticals are parallel
    idxs = da.isnan(t__) | da.isnan(s__)
    # Remove extra dimensions
    idxs = da.ravel(idxs)

    if da.any(idxs):
        t_new, s_new = _get_ts_uprights_parallel_dask(pt_1, pt_2,
                                                      pt_3, pt_4,
                                                      out_y, out_x)

        t__ = da.where(idxs, t_new, t__)
        s__ = da.where(idxs, s_new, s__)

    # Cases where both verticals and horizontals are parallel
    idxs = da.isnan(t__) | da.isnan(s__)
    # Remove extra dimensions
    idxs = da.ravel(idxs)
    if da.any(idxs):
        t_new, s_new = _get_ts_parallellogram_dask(pt_1, pt_2, pt_3,
                                                   out_y, out_x)
        t__ = da.where(idxs, t_new, t__)
        s__ = da.where(idxs, s_new, s__)

    idxs = (t__ < 0) | (t__ > 1) | (s__ < 0) | (s__ > 1)
    t__ = da.where(idxs, np.nan, t__)
    s__ = da.where(idxs, np.nan, s__)

    return t__, s__
Exemple #3
0
def var_equal(*args):
    for iarg in args[1:]:
        if iarg is args[0]:
            continue
        if iarg.shape != args[0].shape:
            return False
        if iarg.dtype.kind == 'f':
            if da.any(~da.isclose(iarg._values, args[0]._values)):
                return False
        else:
            if da.any(iarg._values != args[0]._values):
                return False
    return True
Exemple #4
0
def cov(*args, axis=None, **kwargs):
    """
    covariance
    """
    if axis is None:
        args = [x.flatten() for x in args]
        axis = 0

    X = da.stack(args, axis=-1).rechunk(com.CHUNKSIZE)
    cond = da.any(da.isnan(X), axis=-1)
    X = da.where(cond[..., None], np.nan, X)

    X -= da.nanmean(X, axis=axis, keepdims=True)
    X = da.where(da.isnan(X), 0, X)
    return X.swapaxes(axis, -1) @ X.swapaxes(axis,
                                             -2).conj() / (X.shape[axis] - 1)
Exemple #5
0
def test_reductions():
    x = np.arange(5).astype('f4')
    a = da.from_array(x, chunks=(2,))

    assert eq(da.all(a), np.all(x))
    assert eq(da.any(a), np.any(x))
    assert eq(da.argmax(a, axis=0), np.argmax(x, axis=0))
    assert eq(da.argmin(a, axis=0), np.argmin(x, axis=0))
    assert eq(da.max(a), np.max(x))
    assert eq(da.mean(a), np.mean(x))
    assert eq(da.min(a), np.min(x))
    assert eq(da.nanargmax(a, axis=0), np.nanargmax(x, axis=0))
    assert eq(da.nanargmin(a, axis=0), np.nanargmin(x, axis=0))
    assert eq(da.nanmax(a), np.nanmax(x))
    assert eq(da.nanmin(a), np.nanmin(x))
    assert eq(da.nansum(a), np.nansum(x))
    assert eq(da.nanvar(a), np.nanvar(x))
    assert eq(da.nanstd(a), np.nanstd(x))
Exemple #6
0
def test_reductions():
    x = np.arange(5).astype('f4')
    a = da.from_array(x, blockshape=(2, ))

    assert eq(da.all(a), np.all(x))
    assert eq(da.any(a), np.any(x))
    assert eq(da.argmax(a, axis=0), np.argmax(x, axis=0))
    assert eq(da.argmin(a, axis=0), np.argmin(x, axis=0))
    assert eq(da.max(a), np.max(x))
    assert eq(da.mean(a), np.mean(x))
    assert eq(da.min(a), np.min(x))
    assert eq(da.nanargmax(a, axis=0), np.nanargmax(x, axis=0))
    assert eq(da.nanargmin(a, axis=0), np.nanargmin(x, axis=0))
    assert eq(da.nanmax(a), np.nanmax(x))
    assert eq(da.nanmin(a), np.nanmin(x))
    assert eq(da.nansum(a), np.nansum(x))
    assert eq(da.nanvar(a), np.nanvar(x))
    assert eq(da.nanstd(a), np.nanstd(x))
Exemple #7
0
def _multimodel_mask_products(products, shape):
    """Apply common mask to all cubes of products in-place."""
    # Create mask and get products used for mask
    mask = da.full(shape, False, dtype=bool)
    used_products = set()
    for product in products:
        for cube in product.cubes:
            new_mask = da.ma.getmaskarray(cube.core_data())
            mask |= new_mask
            if da.any(new_mask):
                used_products.add(product)

    # Apply common mask and update provenance information
    for product in products:
        for cube in product.cubes:
            cube.data = da.ma.masked_array(cube.core_data(), mask=mask)
        for other_product in used_products:
            if other_product.filename != product.filename:
                product.wasderivedfrom(other_product)

    return products
Exemple #8
0
def new_grid_mapping_from_coords(
    x_coords: xr.DataArray,
    y_coords: xr.DataArray,
    crs: Union[str, pyproj.crs.CRS],
    *,
    tile_size: Union[int, Tuple[int, int]] = None,
    tolerance: float = DEFAULT_TOLERANCE,
) -> GridMapping:
    crs = _normalize_crs(crs)
    assert_instance(x_coords, xr.DataArray, name='x_coords')
    assert_instance(y_coords, xr.DataArray, name='y_coords')
    assert_true(x_coords.ndim in (1, 2),
                'x_coords and y_coords must be either 1D or 2D arrays')
    assert_instance(tolerance, float, name='tolerance')
    assert_true(tolerance > 0.0, 'tolerance must be greater zero')

    if x_coords.name and y_coords.name:
        xy_var_names = str(x_coords.name), str(y_coords.name)
    else:
        xy_var_names = _default_xy_var_names(crs)

    tile_size = _normalize_int_pair(tile_size, default=None)
    is_lon_360 = None  # None means "not yet known"
    if crs.is_geographic:
        is_lon_360 = bool(np.any(x_coords > 180))

    x_res = 0
    y_res = 0

    if x_coords.ndim == 1:
        # We have 1D x,y coordinates
        cls = Coords1DGridMapping

        assert_true(x_coords.size >= 2 and y_coords.size >= 2,
                    'sizes of x_coords and y_coords 1D arrays must be >= 2')

        size = x_coords.size, y_coords.size

        x_dim, y_dim = x_coords.dims[0], y_coords.dims[0]

        x_diff = _abs_no_zero(x_coords.diff(dim=x_dim).values)
        y_diff = _abs_no_zero(y_coords.diff(dim=y_dim).values)

        if not is_lon_360 and crs.is_geographic:
            is_anti_meridian_crossed = np.any(np.nanmax(x_diff) > 180)
            if is_anti_meridian_crossed:
                x_coords = to_lon_360(x_coords)
                x_diff = _abs_no_zero(x_coords.diff(dim=x_dim))
                is_lon_360 = True

        x_res, y_res = x_diff[0], y_diff[0]
        x_diff_equal = np.allclose(x_diff, x_res, atol=tolerance)
        y_diff_equal = np.allclose(y_diff, y_res, atol=tolerance)
        is_regular = x_diff_equal and y_diff_equal
        if is_regular:
            x_res = round_to_fraction(x_res, 5, 0.25)
            y_res = round_to_fraction(y_res, 5, 0.25)
        else:
            x_res = round_to_fraction(float(np.nanmedian(x_diff)), 2, 0.5)
            y_res = round_to_fraction(float(np.nanmedian(y_diff)), 2, 0.5)

        if tile_size is None \
                and x_coords.chunks is not None \
                and y_coords.chunks is not None:
            tile_size = (max(0,
                             *x_coords.chunks[0]), max(0, *y_coords.chunks[0]))

        # Guess j axis direction
        is_j_axis_up = bool(y_coords[0] < y_coords[-1])

    else:
        # We have 2D x,y coordinates
        cls = Coords2DGridMapping

        assert_true(
            x_coords.shape == y_coords.shape, 'shapes of x_coords and y_coords'
            ' 2D arrays must be equal')
        assert_true(
            x_coords.dims == y_coords.dims,
            'dimensions of x_coords and y_coords'
            ' 2D arrays must be equal')

        y_dim, x_dim = x_coords.dims

        height, width = x_coords.shape
        size = width, height

        x = da.asarray(x_coords)
        y = da.asarray(y_coords)

        x_x_diff = _abs_no_nan(da.diff(x, axis=1))
        x_y_diff = _abs_no_nan(da.diff(x, axis=0))
        y_x_diff = _abs_no_nan(da.diff(y, axis=1))
        y_y_diff = _abs_no_nan(da.diff(y, axis=0))

        if not is_lon_360 and crs.is_geographic:
            is_anti_meridian_crossed = da.any(da.max(x_x_diff) > 180) \
                                       or da.any(da.max(x_y_diff) > 180)
            if is_anti_meridian_crossed:
                x_coords = to_lon_360(x_coords)
                x = da.asarray(x_coords)
                x_x_diff = _abs_no_nan(da.diff(x, axis=1))
                x_y_diff = _abs_no_nan(da.diff(x, axis=0))
                is_lon_360 = True

        is_regular = False

        if da.all(x_y_diff == 0) and da.all(y_x_diff == 0):
            x_res = x_x_diff[0, 0]
            y_res = y_y_diff[0, 0]
            is_regular = \
                da.allclose(x_x_diff[0, :], x_res, atol=tolerance) \
                and da.allclose(x_x_diff[-1, :], x_res, atol=tolerance) \
                and da.allclose(y_y_diff[:, 0], y_res, atol=tolerance) \
                and da.allclose(y_y_diff[:, -1], y_res, atol=tolerance)

        if not is_regular:
            # Let diff arrays have same shape as original by
            # doubling last rows and columns.
            x_x_diff_c = da.concatenate([x_x_diff, x_x_diff[:, -1:]], axis=1)
            y_x_diff_c = da.concatenate([y_x_diff, y_x_diff[:, -1:]], axis=1)
            x_y_diff_c = da.concatenate([x_y_diff, x_y_diff[-1:, :]], axis=0)
            y_y_diff_c = da.concatenate([y_y_diff, y_y_diff[-1:, :]], axis=0)
            # Find resolution via area
            x_abs_diff = da.sqrt(da.square(x_x_diff_c) + da.square(x_y_diff_c))
            y_abs_diff = da.sqrt(da.square(y_x_diff_c) + da.square(y_y_diff_c))
            if crs.is_geographic:
                # Convert degrees into meters
                x_abs_diff_r = da.radians(x_abs_diff)
                y_abs_diff_r = da.radians(y_abs_diff)
                x_abs_diff = _ER * da.cos(x_abs_diff_r) * y_abs_diff_r
                y_abs_diff = _ER * y_abs_diff_r
            xy_areas = (x_abs_diff * y_abs_diff).flatten()
            xy_areas = da.where(xy_areas > 0, xy_areas, np.nan)
            # Get indices of min and max area
            xy_area_index_min = da.nanargmin(xy_areas)
            xy_area_index_max = da.nanargmax(xy_areas)
            # Convert area to edge length
            xy_res_min = math.sqrt(xy_areas[xy_area_index_min])
            xy_res_max = math.sqrt(xy_areas[xy_area_index_max])
            # Empirically weight min more than max
            xy_res = 0.7 * xy_res_min + 0.3 * xy_res_max
            if crs.is_geographic:
                # Convert meters back into degrees
                # print(f'xy_res in meters: {xy_res}')
                xy_res = math.degrees(xy_res / _ER)
                # print(f'xy_res in degrees: {xy_res}')
            # Because this is an estimation, we can round to a nice number
            xy_res = round_to_fraction(xy_res, digits=1, resolution=0.5)
            x_res, y_res = float(xy_res), float(xy_res)

        if tile_size is None and x_coords.chunks is not None:
            j_chunks, i_chunks = x_coords.chunks
            tile_size = max(0, *i_chunks), max(0, *j_chunks)

        if tile_size is not None:
            tile_width, tile_height = tile_size
            x_coords = x_coords.chunk((tile_height, tile_width))
            y_coords = y_coords.chunk((tile_height, tile_width))

        # Guess j axis direction
        is_j_axis_up = np.all(y_coords[0, :] < y_coords[-1, :]) or None

    assert_true(x_res > 0 and y_res > 0,
                'internal error: x_res and y_res could not be determined',
                exception_type=RuntimeError)

    x_res, y_res = _to_int_or_float(x_res), _to_int_or_float(y_res)
    x_res_05, y_res_05 = x_res / 2, y_res / 2
    x_min = _to_int_or_float(x_coords.min() - x_res_05)
    y_min = _to_int_or_float(y_coords.min() - y_res_05)
    x_max = _to_int_or_float(x_coords.max() + x_res_05)
    y_max = _to_int_or_float(y_coords.max() + y_res_05)

    return cls(x_coords=x_coords,
               y_coords=y_coords,
               crs=crs,
               size=size,
               tile_size=tile_size,
               xy_bbox=(x_min, y_min, x_max, y_max),
               xy_res=(x_res, y_res),
               xy_var_names=xy_var_names,
               xy_dim_names=(str(x_dim), str(y_dim)),
               is_regular=is_regular,
               is_lon_360=is_lon_360,
               is_j_axis_up=is_j_axis_up)
Exemple #9
0
def split(X=None,
          y=None,
          instance_indexes=None,
          test_ratio=0.3,
          initial_label_rate=0.05,
          split_count=10,
          all_class=True):
    """Split given data.
    Provide one of X, y or instance_indexes to execute the split.
    Parameters
    ----------
    X: array-like, optional
        Data matrix with [n_samples, n_features]
    y: array-like, optional
        labels of given data [n_samples, n_labels] or [n_samples]
    instance_indexes: list, optional (default=None)
        List contains instances' names, used for image datasets,
        or provide index list instead of data matrix.
        Must provide one of [instance_names, X, y]
    test_ratio: float, optional (default=0.3)
        Ratio of test set
    initial_label_rate: float, optional (default=0.05)
        Ratio of initial label set
        e.g. Initial_labelset*(1-test_ratio)*n_samples
    split_count: int, optional (default=10)
        Random split data _split_count times
    all_class: bool, optional (default=True)
        Whether each split will contain at least one instance for each class.
        If False, a totally random split will be performed.
        Giving None to disable saving.

    Returns
    -------
    train_idx: list
        index of training set, shape like [n_split_count, n_training_indexes]
    test_idx: list
        index of testing set, shape like [n_split_count, n_testing_indexes]
    label_idx: list
        index of labeling set, shape like [n_split_count, n_labeling_indexes]
    unlabel_idx: list
        index of unlabeling set, shape like [n_split_count, n_unlabeling_indexes]
    """

    # check input parameters
    if X is None and y is None and instance_indexes is None:
        raise Exception("Must provide one of X, y or instance_indexes.")

    len_of_parameters = [
        len(X) if X is not None else None,
        len(y) if y is not None else None,
        len(instance_indexes) if instance_indexes is not None else None
    ]
    number_of_instance = np.unique(
        [i for i in len_of_parameters if i is not None])
    if len(number_of_instance) > 1:
        raise ValueError("Different length of instances and _labels found.")
    else:
        number_of_instance = number_of_instance[0]

    if instance_indexes is not None:
        instance_indexes = da.array(instance_indexes)
    else:
        instance_indexes = da.arange(number_of_instance)

    # split
    train_idx = []
    test_idx = []
    label_idx = []
    unlabel_idx = []

    for i in range(split_count):
        if (not all_class) or y is None:
            rp = randperm(number_of_instance)
            cutpoint = int(round((1 - test_ratio) * len(rp)))
            tp_train = instance_indexes[rp[0:cutpoint]]
            train_idx.append(tp_train)
            test_idx.append(instance_indexes[rp[cutpoint:]])
            cutpoint = int(round(initial_label_rate * len(tp_train)))
            if cutpoint <= 1:
                cutpoint = 1
            label_idx.append(tp_train[0:cutpoint])
            unlabel_idx.append(tp_train[cutpoint:])
        else:
            if y is None:
                raise Exception(
                    "y must be provided when all_class flag is True.")
            if isinstance(y, da.core.Array):
                check_array(y, ensure_2d=False, dtype=None, distributed=False)
            else:
                y = check_array(y,
                                ensure_2d=False,
                                dtype=None,
                                distributed=True)

            if y.ndim == 1:
                label_num = len(da.unique(y).compute())
            else:
                label_num = y.shape[1]
            if round((1 - test_ratio) * initial_label_rate *
                     number_of_instance) < label_num:
                raise ValueError(
                    "The initial rate is too small to guarantee that each "
                    "split will contain at least one instance for each class.")

            # check validaty
            while 1:
                rp = randperm(number_of_instance)
                cutpoint = int(round((1 - test_ratio) * len(rp)))
                tp_train = instance_indexes[rp[0:cutpoint]]
                cutpointlabel = int(round(initial_label_rate * len(tp_train)))
                if cutpointlabel <= 1:
                    cutpointlabel = 1
                label_id = tp_train[0:cutpointlabel]
                if y.ndim == 1:
                    if len(da.unique(y[label_id]).compute()) == label_num:
                        break
                else:
                    temp = da.sum(y[label_id], axis=0)
                    if not da.any(temp == 0):
                        break

            train_idx.append(tp_train)
            test_idx.append(instance_indexes[rp[cutpoint:]])
            label_idx.append(tp_train[0:cutpointlabel])
            unlabel_idx.append(tp_train[cutpointlabel:])

    return compute(train_idx, test_idx, label_idx, unlabel_idx)
Exemple #10
0
def _main(args):
    tic = time.time()

    log.info(banner())

    if args.disable_post_mortem:
        log.warn("Disabling crash debugging with the "
                 "Interactive Python Debugger, as per user request")
        post_mortem_handler.disable_pdb_on_error()

    log.info("Flagging on the {0:s} column".format(args.data_column))
    data_column = args.data_column
    masked_channels = [
        load_mask(fn, dilate=args.dilate_masks) for fn in collect_masks()
    ]
    GD = args.config

    log_configuration(args)

    # Group datasets by these columns
    group_cols = ["FIELD_ID", "DATA_DESC_ID", "SCAN_NUMBER"]
    # Index datasets by these columns
    index_cols = ['TIME']

    # Reopen the datasets using the aggregated row ordering
    columns = [data_column, "FLAG", "TIME", "ANTENNA1", "ANTENNA2"]

    if args.subtract_model_column is not None:
        columns.append(args.subtract_model_column)

    xds = list(
        xds_from_ms(args.ms,
                    columns=tuple(columns),
                    group_cols=group_cols,
                    index_cols=index_cols,
                    chunks={"row": args.row_chunks}))

    # Get support tables
    st = support_tables(args.ms)
    ddid_ds = st["DATA_DESCRIPTION"]
    field_ds = st["FIELD"]
    pol_ds = st["POLARIZATION"]
    spw_ds = st["SPECTRAL_WINDOW"]
    ant_ds = st["ANTENNA"]

    assert len(ant_ds) == 1
    assert len(ddid_ds) == 1

    antspos = ant_ds[0].POSITION.data
    antsnames = ant_ds[0].NAME.data
    fieldnames = [fds.NAME.data[0] for fds in field_ds]

    avail_scans = [ds.SCAN_NUMBER for ds in xds]
    args.scan_numbers = list(
        set(avail_scans).intersection(args.scan_numbers if args.scan_numbers
                                      is not None else avail_scans))

    if args.scan_numbers != []:
        log.info("Only considering scans '{0:s}' as "
                 "per user selection criterion".format(", ".join(
                     map(str, map(int, args.scan_numbers)))))

    if args.field_names != []:
        flatten_field_names = []
        for f in args.field_names:
            # accept comma lists per specification
            flatten_field_names += [x.strip() for x in f.split(",")]
        for f in flatten_field_names:
            if re.match(r"^\d+$", f) and int(f) < len(fieldnames):
                flatten_field_names.append(fieldnames[int(f)])
        flatten_field_names = list(
            set(
                filter(lambda x: not re.match(r"^\d+$", x),
                       flatten_field_names)))
        log.info("Only considering fields '{0:s}' for flagging per "
                 "user "
                 "selection criterion.".format(", ".join(flatten_field_names)))
        if not set(flatten_field_names) <= set(fieldnames):
            raise ValueError("One or more fields cannot be "
                             "found in dataset '{0:s}' "
                             "You specified {1:s}, but "
                             "only {2:s} are available".format(
                                 args.ms, ",".join(flatten_field_names),
                                 ",".join(fieldnames)))

        field_dict = {fieldnames.index(fn): fn for fn in flatten_field_names}
    else:
        field_dict = {i: fn for i, fn in enumerate(fieldnames)}

    # List which hold our dask compute graphs for each dataset
    write_computes = []
    original_stats = []
    final_stats = []

    # Iterate through each dataset
    for ds in xds:
        if ds.FIELD_ID not in field_dict:
            continue

        if (args.scan_numbers is not None
                and ds.SCAN_NUMBER not in args.scan_numbers):
            continue

        log.info("Adding field '{0:s}' scan {1:d} to "
                 "compute graph for processing".format(field_dict[ds.FIELD_ID],
                                                       ds.SCAN_NUMBER))

        ddid = ddid_ds[ds.attrs['DATA_DESC_ID']]
        spw_info = spw_ds[ddid.SPECTRAL_WINDOW_ID.data[0]]
        pol_info = pol_ds[ddid.POLARIZATION_ID.data[0]]

        nrow, nchan, ncorr = getattr(ds, data_column).data.shape

        # Visibilities from the dataset
        vis = getattr(ds, data_column).data
        if args.subtract_model_column is not None:
            log.info("Forming residual data between '{0:s}' and "
                     "'{1:s}' for flagging.".format(
                         data_column, args.subtract_model_column))
            vismod = getattr(ds, args.subtract_model_column).data
            vis = vis - vismod

        antenna1 = ds.ANTENNA1.data
        antenna2 = ds.ANTENNA2.data
        chan_freq = spw_info.CHAN_FREQ.data[0]
        chan_width = spw_info.CHAN_WIDTH.data[0]

        # Generate unflagged defaults if we should ignore existing flags
        # otherwise take flags from the dataset
        if args.ignore_flags is True:
            flags = da.full_like(vis, False, dtype=np.bool)
            log.critical("Completely ignoring measurement set "
                         "flags as per '-if' request. "
                         "Strategy WILL NOT or with original flags, even if "
                         "specified!")
        else:
            flags = ds.FLAG.data

        # If we're flagging on polarised intensity,
        # we convert visibilities to polarised intensity
        # and any flagged correlation will flag the entire visibility
        if args.flagging_strategy == "polarisation":
            corr_type = pol_info.CORR_TYPE.data[0].tolist()
            stokes_map = stokes_corr_map(corr_type)
            stokes_pol = tuple(v for k, v in stokes_map.items() if k != "I")
            vis = polarised_intensity(vis, stokes_pol)
            flags = da.any(flags, axis=2, keepdims=True)
        elif args.flagging_strategy == "total_power":
            if args.subtract_model_column is None:
                log.critical("You requested to flag total quadrature "
                             "power, but not on residuals. "
                             "This is not advisable and the flagger "
                             "may mistake fringes of "
                             "off-axis sources for broadband RFI.")
            corr_type = pol_info.CORR_TYPE.data[0].tolist()
            stokes_map = stokes_corr_map(corr_type)
            stokes_pol = tuple(v for k, v in stokes_map.items())
            vis = polarised_intensity(vis, stokes_pol)
            flags = da.any(flags, axis=2, keepdims=True)
        elif args.flagging_strategy == "standard":
            if args.subtract_model_column is None:
                log.critical("You requested to flag per correlation, "
                             "but not on residuals. "
                             "This is not advisable and the flagger "
                             "may mistake fringes of off-axis sources "
                             "for broadband RFI.")
        else:
            raise ValueError("Invalid flagging strategy '%s'" %
                             args.flagging_strategy)

        ubl = unique_baselines(antenna1, antenna2)
        utime, time_inv = da.unique(ds.TIME.data, return_inverse=True)
        utime, ubl = dask.compute(utime, ubl)
        ubl = ubl.view(np.int32).reshape(-1, 2)
        # Stack the baseline index with the unique baselines
        bl_range = np.arange(ubl.shape[0], dtype=ubl.dtype)[:, None]
        ubl = np.concatenate([bl_range, ubl], axis=1)
        ubl = da.from_array(ubl, chunks=(args.baseline_chunks, 3))

        vis_windows, flag_windows = pack_data(time_inv,
                                              ubl,
                                              antenna1,
                                              antenna2,
                                              vis,
                                              flags,
                                              utime.shape[0],
                                              backend=args.window_backend,
                                              path=args.temporary_directory)

        original_stats.append(
            window_stats(flag_windows, ubl, chan_freq, antsnames,
                         ds.SCAN_NUMBER, field_dict[ds.FIELD_ID],
                         ds.attrs['DATA_DESC_ID']))

        with StrategyExecutor(antspos, ubl, chan_freq, chan_width,
                              masked_channels, GD['strategies']) as se:

            flag_windows = se.apply_strategies(flag_windows, vis_windows)

        final_stats.append(
            window_stats(flag_windows, ubl, chan_freq, antsnames,
                         ds.SCAN_NUMBER, field_dict[ds.FIELD_ID],
                         ds.attrs['DATA_DESC_ID']))

        # Unpack window data for writing back to the MS
        unpacked_flags = unpack_data(antenna1, antenna2, time_inv, ubl,
                                     flag_windows)

        # Flag entire visibility if any correlations are flagged
        equalized_flags = da.sum(unpacked_flags, axis=2, keepdims=True) > 0
        corr_flags = da.broadcast_to(equalized_flags, (nrow, nchan, ncorr))

        if corr_flags.chunks != ds.FLAG.data.chunks:
            raise ValueError("Output flag chunking does not "
                             "match input flag chunking")

        # Create new dataset containing new flags
        new_ds = ds.assign(FLAG=(("row", "chan", "corr"), corr_flags))

        # Write back to original dataset
        writes = xds_to_table(new_ds, args.ms, "FLAG")
        # original should also have .compute called because we need stats
        write_computes.append(writes)

    if len(write_computes) > 0:
        # Combine stats from all datasets
        original_stats = combine_window_stats(original_stats)
        final_stats = combine_window_stats(final_stats)

        with contextlib.ExitStack() as stack:
            # Create dask profiling contexts
            profilers = []

            if can_profile:
                profilers.append(stack.enter_context(Profiler()))
                profilers.append(stack.enter_context(CacheProfiler()))
                profilers.append(stack.enter_context(ResourceProfiler()))

            if sys.stdout.isatty():
                # Interactive terminal, default ProgressBar
                stack.enter_context(ProgressBar())
            else:
                # Non-interactive, emit a bar every 5 minutes so
                # as not to spam the log
                stack.enter_context(ProgressBar(minimum=1, dt=5 * 60))

            _, original_stats, final_stats = dask.compute(
                write_computes, original_stats, final_stats)
        if can_profile:
            visualize(profilers)

        toc = time.time()

        # Log each summary line
        for line in summarise_stats(final_stats, original_stats):
            log.info(line)

        elapsed = toc - tic
        log.info("Data flagged successfully in "
                 "{0:02.0f}h{1:02.0f}m{2:02.0f}s".format((elapsed // 60) // 60,
                                                         (elapsed // 60) % 60,
                                                         elapsed % 60))
    else:
        log.info("User data selection criteria resulted in empty dataset. "
                 "Nothing to be done. Bye!")
             natural_weights,
             wavelength,
             conv_filter,
             cell_size,
             ny=args.npix,
             nx=args.npix)

ncorr = dirty.shape[2]

# FFT each polarisation and then restack
fft_shifts = [da.fft.ifftshift(dirty[:, :, p]) for p in range(ncorr)]
ffts = [da.fft.ifft2(shift) for shift in fft_shifts]
dirty_fft = [da.fft.fftshift(fft) for fft in ffts]

# Flag PSF visibility if any correlations are flagged
psf_flags = da.any(xds.FLAG.data, axis=2, keepdims=True)

# Construct PSF from unity visibilities and natural weights
psf = grid(da.ones_like(psf_flags, dtype=xds.DATA.data.dtype),
           xds.UVW.data,
           psf_flags,
           da.ones_like(psf_flags, dtype=natural_weights.dtype),
           wavelength,
           conv_filter,
           cell_size,
           ny=2 * args.npix,
           nx=2 * args.npix)

# Should only be one correlation
assert psf.shape[2] == 1, psf.shape
Exemple #12
0
    def average_beams(self, threshold, mask='compute', warn=False):
        """
        Average the beams.  Note that this operation only makes sense in
        limited contexts!  Generally one would want to convolve all the beams
        to a common shape, but this method is meant to handle the "simple" case
        when all your beams are the same to within some small factor and can
        therefore be arithmetically averaged.

        Parameters
        ----------
        threshold : float
            The fractional difference between beam major, minor, and pa to
            permit
        mask : 'compute', None, or boolean array
            The mask to apply to the beams.  Useful for excluding bad channels
            and edge beams.
        warn : bool
            Warn if successful?

        Returns
        -------
        new_beam : radio_beam.Beam
            A new radio beam object that is the average of the unmasked beams
        """

        use_dask = isinstance(self._data, da.Array)

        if mask == 'compute':
            if use_dask:
                # If we are dealing with dask arrays, we compute the beam
                # mask once and for all since it is used multiple times in its
                # entirety in the remainder of this method.
                beam_mask = da.any(da.logical_and(
                    self._mask_include, self.goodbeams_mask[:, None, None]),
                                   axis=(1, 2))
                # da.any appears to return an object dtype instead of a bool
                beam_mask = self._compute(beam_mask).astype('bool')
            elif self.mask is not None:
                beam_mask = np.any(np.logical_and(
                    self.mask.include(), self.goodbeams_mask[:, None, None]),
                                   axis=(1, 2))
            else:
                beam_mask = self.goodbeams_mask
        else:
            if mask.ndim > 1:
                beam_mask = np.logical_and(mask, self.goodbeams_mask[:, None,
                                                                     None])
            else:
                beam_mask = np.logical_and(mask, self.goodbeams_mask)

        # use private _beams here because the public one excludes the bad beams
        # by default
        new_beam = self._beams.average_beam(includemask=beam_mask)

        if np.isnan(new_beam):
            raise ValueError(
                "Beam was not finite after averaging.  "
                "This either indicates that there was a problem "
                "with the include mask, one of the beam's values, "
                "or a bug.")

        self._check_beam_areas(threshold, mean_beam=new_beam, mask=beam_mask)
        if warn:
            warnings.warn(
                "Arithmetic beam averaging is being performed.  This is "
                "not a mathematically robust operation, but is being "
                "permitted because the beams differ by "
                "<{0}".format(threshold), BeamAverageWarning)
        return new_beam
Exemple #13
0
def _transform_array(image: da.Array,
                     scale: Tuple[float, ...],
                     offset: Tuple[float, ...],
                     shape: Tuple[int, ...],
                     chunks: Optional[Tuple[int, ...]],
                     spline_order: int,
                     recover_nan: bool) -> da.Array:
    """
    Apply affine transformation to ND-image.

    :param image: ND-image with shape (..., size_y, size_x)
    :param scale: Scaling factors (1, ..., 1, sy, sx)
    :param offset: Offset values (0, ..., 0, oy, ox)
    :param shape: (..., size_y, size_x)
    :param chunks: (..., chunk_size_y, chunk_size_x)
    :param spline_order: 0 ... 5
    :param recover_nan: True/False
    :return: Transformed ND-image.
    """
    assert_true(len(scale) == image.ndim, 'invalid scale')
    assert_true(len(offset) == image.ndim, 'invalid offset')
    assert_true(len(shape) == image.ndim, 'invalid shape')
    assert_true(chunks is None or len(chunks) == image.ndim,
                'invalid chunks')
    if _is_no_op(image, scale, offset, shape):
        return image
    # As of scipy 0.18, matrix = scale is no longer supported.
    # Therefore we use the diagonal matrix form here,
    # where scale is the diagonal.
    matrix = np.diag(scale)
    at_kwargs = dict(
        offset=offset,
        order=spline_order,
        output_shape=shape,
        output_chunks=chunks,
        mode='constant',
    )
    if recover_nan and spline_order > 0:
        # We can "recover" values that are neighbours to NaN values
        # that would otherwise become NaN too.
        mask = da.isnan(image)
        # First check if there are NaN values ar all
        if da.any(mask):
            # Yes, then
            # 1. replace NaN by zero
            filled_im = da.where(mask, 0.0, image)
            # 2. transform the zeo-filled image
            scaled_im = ndinterp.affine_transform(filled_im,
                                                  matrix,
                                                  **at_kwargs,
                                                  cval=0.0)
            # 3. transform the inverted mask
            scaled_norm = ndinterp.affine_transform(1.0 - mask,
                                                    matrix,
                                                    **at_kwargs,
                                                    cval=0.0)
            # 4. put back NaN where there was zero,
            #    otherwise decode using scaled mask
            return da.where(da.isclose(scaled_norm, 0.0),
                            np.nan, scaled_im / scaled_norm)

    # No dealing with NaN required
    return ndinterp.affine_transform(image, matrix, **at_kwargs, cval=np.nan)