Esempio n. 1
0
    def _run_fornav_single(self, data, out_chunks, target_geo_def, fill_value,
                           **kwargs):
        ll2cr_result = self.cache['ll2cr_result']
        ll2cr_blocks = self.cache['ll2cr_blocks'].items()
        ll2cr_numblocks = ll2cr_result.shape if isinstance(
            ll2cr_result, np.ndarray) else ll2cr_result.numblocks
        fornav_task_name = f"fornav-{data.name}-{ll2cr_result.name}"
        maximum_weight_mode = kwargs.setdefault('maximum_weight_mode', False)
        weight_sum_min = kwargs.setdefault('weight_sum_min', -1.0)
        output_stack = self._generate_fornav_dask_tasks(
            out_chunks, ll2cr_blocks, fornav_task_name, data.name,
            target_geo_def, fill_value, kwargs)

        dsk_graph = HighLevelGraph.from_collections(
            fornav_task_name, output_stack, dependencies=[data, ll2cr_result])
        stack_chunks = (
            (1, ) * (ll2cr_numblocks[0] * ll2cr_numblocks[1]), ) + out_chunks
        out_stack = da.Array(dsk_graph, fornav_task_name, stack_chunks,
                             data.dtype)
        combine_fornav_with_kwargs = partial(
            _combine_fornav, maximum_weight_mode=maximum_weight_mode)
        average_fornav_with_kwargs = partial(
            _average_fornav,
            maximum_weight_mode=maximum_weight_mode,
            weight_sum_min=weight_sum_min,
            dtype=data.dtype,
            fill_value=fill_value)
        out = da.reduction(out_stack,
                           _chunk_callable,
                           average_fornav_with_kwargs,
                           combine=combine_fornav_with_kwargs,
                           axis=(0, ),
                           dtype=data.dtype,
                           concatenate=False)
        return out
Esempio n. 2
0
def test_general_reduction_names():
    dtype = int
    a = da.reduction(
        da.ones(10, dtype, chunks=2), np.sum, np.sum, dtype=dtype, name="foo"
    )
    names, tokens = list(zip_longest(*[key[0].rsplit("-", 1) for key in a.dask]))
    assert set(names) == {"ones_like", "foo", "foo-partial", "foo-aggregate"}
    assert all(tokens)
Esempio n. 3
0
def test_weighted_reduction():
    # Weighted reduction
    def w_sum(x, weights=None, dtype=None, computing_meta=False, **kwargs):
        """`chunk` callable for (weighted) sum"""
        if computing_meta:
            return x
        if weights is not None:
            x = x * weights
        return np.sum(x, dtype=dtype, **kwargs)

    # Arrays
    a = 1 + np.ma.arange(60).reshape(6, 10)
    a[2, 2] = np.ma.masked
    dx = da.from_array(a, chunks=(4, 5))
    # Weights
    w = np.linspace(1, 2, 6).reshape(6, 1)

    # No weights (i.e. normal sum)
    x = da.reduction(dx, w_sum, np.sum, dtype=dx.dtype)
    assert_eq(x, np.sum(a), check_shape=True)

    # Weighted sum
    x = da.reduction(dx, w_sum, np.sum, dtype="f8", weights=w)
    assert_eq(x, np.sum(a * w), check_shape=True)

    # Non-broadcastable weights (short axis)
    with pytest.raises(ValueError):
        da.reduction(dx, w_sum, np.sum, weights=[1, 2, 3])

    # Non-broadcastable weights (too many dims)
    with pytest.raises(ValueError):
        da.reduction(dx, w_sum, np.sum, weights=[[[2]]])
Esempio n. 4
0
def reduce_loglikelihood(log_weighted_likelihoods):
    if isinstance(log_weighted_likelihoods, np.ndarray):
        log_likelihood = logaddexp_reduce(log_weighted_likelihoods)
    else:
        # Sum along gaussians axis (using logAddExp to prevent underflow)
        log_likelihood = da.reduction(
            x=log_weighted_likelihoods,
            chunk=logaddexp_reduce,
            aggregate=logaddexp_reduce,
            axis=0,
            dtype=float,
            keepdims=False,
        )
    return log_likelihood
Esempio n. 5
0
def compute_xy_bbox(xy_coords: Union[xr.DataArray, np.ndarray, da.Array]) \
        -> Tuple[float, float, float, float]:
    xy_coords = da.asarray(xy_coords)
    result = da.reduction(
        xy_coords,
        compute_xy_bbox_chunk,
        compute_xy_bbox_aggregate,
        keepdims=True,
        # concatenate=False,
        dtype=xy_coords.dtype,
        axis=(1, 2),
        meta=np.array([[0, 0], [0, 0]],
                      dtype=xy_coords.dtype)
    )
    x_min, x_max, y_min, y_max = map(float, result.compute().flatten())
    return x_min, y_min, x_max, y_max
Esempio n. 6
0
def test_chunk_structure_independence(axes, split_every, chunks):
    # Reducing an array should not depend on its chunk-structure!!!
    # See Issue #8541: https://github.com/dask/dask/issues/8541
    shape = tuple(np.sum(s) for s in chunks)
    np_array = np.arange(np.prod(shape)).reshape(*shape)
    x = da.from_array(np_array, chunks=chunks)
    reduced_x = da.reduction(
        x,
        lambda x, axis, keepdims: x,
        lambda x, axis, keepdims: x,
        keepdims=True,
        axis=axes,
        split_every=split_every,
        dtype=x.dtype,
        meta=x._meta,
    )
    _assert_eq(reduced_x, np_array, check_chunks=False, check_shape=False)
Esempio n. 7
0
def dataset_chunks(datasets, time_bin_secs, max_row_chunks):
    """
    Given ``max_row_chunks`` determine a chunking strategy
    for each dataset that prevents binning unique times in
    separate chunks.
    """
    # Calculate (utime, idx, counts) tuple for each dataset
    # then tranpose to get lists for each tuple entry
    if len(datasets) == 0:
        return (), ()

    utimes = []
    interval_avg = []
    counts = []
    monotonicity_checks = []

    for ds in datasets:
        # Compute unique times, their counts and interval sum
        # for each row chunk
        block_values = da.blockwise(_time_interval_sum,
                                    "r",
                                    ds.TIME.data,
                                    "r",
                                    ds.INTERVAL.data,
                                    "r",
                                    meta=np.empty((0, ), dtype=np.object),
                                    dtype=np.object)

        # Reduce each row chunk's values
        reduction = da.reduction(block_values,
                                 chunk=_chunk,
                                 combine=_time_int_combine,
                                 aggregate=_time_int_agg,
                                 concatenate=False,
                                 split_every=16,
                                 meta=np.empty((0, ), dtype=np.object),
                                 dtype=np.object)

        # Pull out the final unique times, counts and interval average
        utime = reduction.map_blocks(getitem, 0, dtype=ds.TIME.dtype)
        count = reduction.map_blocks(getitem, 1, dtype=np.int32)
        int_avg = reduction.map_blocks(getitem, 2, dtype=ds.INTERVAL.dtype)

        # Check monotonicity of TIME while we're at it
        is_monotonic = da.all(da.diff(ds.TIME.data) >= 0.0)

        utimes.append(utime)
        counts.append(count)
        interval_avg.append(int_avg)
        monotonicity_checks.append(is_monotonic)

    # Work out the unique times, average intervals for those times
    # and the frequency of those times
    (ds_utime, ds_avg_intervals, ds_counts,
     ds_monotonicity_checks) = dask.compute(utimes, interval_avg, counts,
                                            monotonicity_checks)

    if not all(ds_monotonicity_checks):
        raise ValueError("TIME is not monotonically increasing. "
                         "This is required.")

    # Produce row and time chunking strategies for each dataset
    ds_row_chunks = []
    ds_time_chunks = []
    ds_interval_secs = []

    it = zip(ds_utime, ds_avg_intervals, ds_counts)
    for di, (utime, avg_interval, counts) in enumerate(it):
        # Maintain row and time chunks for this dataset
        row_chunks = []
        time_chunks = []
        interval_secs = []

        # Start out with first entries
        bin_rows = counts[0]
        bin_times = 1
        bin_secs = avg_interval[0]

        dsit = enumerate(zip(utime[1:], avg_interval[1:], counts[1:]))
        for ti, (ut, avg_int, count) in dsit:
            if count > max_row_chunks:
                logger.warning(
                    "Unique time {:3f} occurred {:d} times "
                    "in dataset {:d} but this exceeds the "
                    "requested row chunks {:d}. "
                    "Consider increasing --row-chunks", ut, count, di,
                    max_row_chunks)

            if avg_int > time_bin_secs:
                logger.warning(
                    "The average INTERVAL associated with "
                    "unique time {:3f} in dataset {:d} "
                    "is {:3f} but this exceeds the requested "
                    "number of seconds in a time bin {:3f}s. "
                    "Consider increasing --time-bin-secs", ut, di, avg_int,
                    time_bin_secs)

            next_rows = bin_rows + count

            # If we're still within the number of rows for this bin
            # keep going
            if next_rows < max_row_chunks:
                bin_rows = next_rows
                bin_times += 1
                bin_secs += avg_int
            # Otherwise finalize this bin and
            # start a new one with the counts
            # we were trying to add
            else:
                row_chunks.append(bin_rows)
                time_chunks.append(bin_times)
                interval_secs.append(bin_secs)
                bin_rows = count
                bin_times = 1
                bin_secs = avg_int

        # Finish any remaining bins
        if bin_rows > 0:
            assert bin_times > 0
            row_chunks.append(bin_rows)
            time_chunks.append(bin_times)
            interval_secs.append(bin_secs)

        row_chunks = tuple(row_chunks)
        time_chunks = tuple(time_chunks)
        interval_secs = tuple(interval_secs)
        ds_row_chunks.append(row_chunks)
        ds_time_chunks.append(time_chunks)
        ds_interval_secs.append(interval_secs)

    logger.info("Dataset Chunking: (r)ow - (t)imes - (s)econds")

    it = zip(datasets, ds_row_chunks, ds_time_chunks, ds_interval_secs)
    for di, (ds, ds_rcs, ds_tcs, ds_int_secs) in enumerate(it):
        ds_rows = ds.dims['row']
        ds_crows = sum(ds_rcs)

        if not ds_rows == ds_crows:
            raise ValueError("Number of dataset rows %d "
                             "does not match the sum %d "
                             "of the row chunks %s" %
                             (ds_rows, ds_crows, ds_rcs))

        log_str = ", ".join("(%dr,%dt,%.1fs)" % (rc, tc, its)
                            for rc, tc, its in zip(*(ds_rcs, ds_tcs,
                                                     ds_int_secs)))

        logger.info("Dataset {d}: {s}", d=di, s=log_str)

    return ds_row_chunks, ds_time_chunks
Esempio n. 8
0
def dataset_chunks(datasets, time_bin_secs, max_row_chunks):
    """
    Given ``max_row_chunks`` determine a chunking strategy
    for each dataset that prevents binning unique times in
    separate chunks.
    """
    # Calculate (utime, idx, counts) tuple for each dataset
    # then tranpose to get lists for each tuple entry
    if len(datasets) == 0:
        return (), ()

    utimes = []
    interval_avg = []
    counts = []
    monotonicity_checks = []

    for ds in datasets:
        # Compute unique times, their counts and interval sum
        # for each row chunk
        block_values = da.blockwise(_time_interval_sum,
                                    "r",
                                    ds.TIME.data,
                                    "r",
                                    ds.INTERVAL.data,
                                    "r",
                                    meta=np.empty((0, ), dtype=np.object),
                                    dtype=np.object)

        # Reduce each row chunk's values
        reduction = da.reduction(block_values,
                                 chunk=_chunk,
                                 combine=_time_int_combine,
                                 aggregate=_time_int_agg,
                                 concatenate=False,
                                 split_every=16,
                                 meta=np.empty((0, ), dtype=np.object),
                                 dtype=np.object)

        # Pull out the final unique times, counts and interval average
        utime = reduction.map_blocks(getitem, 0, dtype=ds.TIME.dtype)
        count = reduction.map_blocks(getitem, 1, dtype=np.int32)
        int_avg = reduction.map_blocks(getitem, 2, dtype=ds.INTERVAL.dtype)

        # Check monotonicity of TIME while we're at it
        is_monotonic = da.all(da.diff(ds.TIME.data) >= 0.0)

        utimes.append(utime)
        counts.append(count)
        interval_avg.append(int_avg)
        monotonicity_checks.append(is_monotonic)

    # Work out the unique times, average intervals for those times
    # and the frequency of those times
    (ds_utime, ds_avg_intervals, ds_counts,
     ds_monotonicity_checks) = dask.compute(utimes, interval_avg, counts,
                                            monotonicity_checks)

    if not all(ds_monotonicity_checks):
        raise ValueError("TIME is not monotonically increasing. "
                         "This is required.")

    grouper = DatasetGrouper(time_bin_secs, max_row_chunks)
    res = grouper.group(ds_utime, ds_avg_intervals, ds_counts)
    ds_row_chunks, ds_time_chunks, ds_interval_secs = res

    logger.info("Dataset Chunking: (r)ow - (t)imes - (s)econds")

    it = zip(datasets, ds_row_chunks, ds_time_chunks, ds_interval_secs)
    for di, (ds, ds_rcs, ds_tcs, ds_int_secs) in enumerate(it):
        ds_rows = ds.dims['row']
        ds_crows = sum(ds_rcs)

        if not ds_rows == ds_crows:
            raise ValueError("Number of dataset rows %d "
                             "does not match the sum %d "
                             "of the row chunks %s" %
                             (ds_rows, ds_crows, ds_rcs))

        log_str = ", ".join("(%dr,%dt,%.1fs)" % (rc, tc, its)
                            for rc, tc, its in zip(*(ds_rcs, ds_tcs,
                                                     ds_int_secs)))

        logger.info("Dataset {d}: {s}", d=di, s=log_str)

    return ds_row_chunks, ds_time_chunks
Esempio n. 9
0
def default(glyph, df, schema, canvas, summary, cuda=False):
    shape, bounds, st, axis = shape_bounds_st_and_axis(df, canvas, glyph)

    # Compile functions
    create, info, append, combine, finalize = \
        compile_components(summary, schema, glyph, cuda=cuda)
    x_mapper = canvas.x_axis.mapper
    y_mapper = canvas.y_axis.mapper
    extend = glyph._build_extend(x_mapper, y_mapper, info, append)

    # Here be dragons
    # Get the dataframe graph
    graph = df.__dask_graph__()

    # Guess a reasonable output dtype from combination of dataframe dtypes
    dtypes = []

    for dt in df.dtypes:
        if isinstance(dt, pd.CategoricalDtype):
            continue
        elif isinstance(dt, pd.api.extensions.ExtensionDtype):
            # RaggedArray implementation and
            # https://github.com/pandas-dev/pandas/issues/22224
            try:
                subdtype = dt.subtype
            except AttributeError:
                continue
            else:
                dtypes.append(subdtype)
        else:
            dtypes.append(dt)

    dtype = np.result_type(*dtypes)
    # Create a meta object so that dask.array doesn't try to look
    # too closely at the type of the chunks it's wrapping
    # they're actually dataframes, tell dask they're ndarrays
    meta = np.empty((0, ), dtype=dtype)
    # Create a chunks tuple, a singleton for each dataframe chunk
    # The number of chunks + structure needs to match that of
    # the dataframe, so that we can use the dataframe graph keys,
    # but we don't have to be precise with the chunk size.
    # We could use np.nan instead of 1 to indicate that we actually
    # don't know how large the chunk is
    chunks = (tuple(1 for _ in range(df.npartitions)), )

    # Now create a dask array from the dataframe graph layer
    # It's a dask array of dataframes, which is dodgy but useful
    # for the following reasons:
    #
    # (1) The dataframes get converted to a single array by
    #     the datashader reduction functions anyway
    # (2) dask.array.reduction is handy for coding a tree
    #     reduction of arrays
    df_array = da.Array(graph, df._name, chunks, meta=meta)
    # A sufficient condition for ensuring the chimera holds together
    assert list(df_array.__dask_keys__()) == list(df.__dask_keys__())

    def chunk(df, axis, keepdims):
        """ used in the dask.array.reduction chunk step """
        aggs = create(shape)
        extend(aggs, df, st, bounds)
        return aggs

    def wrapped_combine(x, axis, keepdims):
        """ wrap datashader combine in dask.array.reduction combine """
        if isinstance(x, list):
            # list of tuples of ndarrays
            # assert all(isinstance(item, tuple) and
            #            len(item) == 1 and
            #            isinstance(item[0], np.ndarray)
            #            for item in x)
            return combine(x)
        elif isinstance(x, tuple):
            # tuple with single ndarray
            # assert len(x) == 1 and isinstance(x[0], np.ndarray)
            return x
        else:
            raise TypeError("Unknown type %s in wrapped_combine" % type(x))

    local_axis = axis

    def aggregate(x, axis, keepdims):
        """ Wrap datashader finalize in dask.array.reduction aggregate """
        return finalize(wrapped_combine(x, axis, keepdims),
                        cuda=cuda,
                        coords=local_axis,
                        dims=[glyph.y_label, glyph.x_label])

    R = da.reduction(
        df_array,
        aggregate=aggregate,
        chunk=chunk,
        combine=wrapped_combine,
        # Control granularity of tree branching
        # less is more
        split_every=2,
        # We don't want np.concatenate called
        # during combine and aggregate. It'll
        # fail because we're handling tuples of ndarrays
        # and lists of tuples of ndarrays
        concatenate=False,
        # Prevent dask from internally inspecting
        # chunk, combine and aggrregate
        meta=meta,
        # Provide some sort of dtype for the
        # resultant dask array
        dtype=meta.dtype)

    return R, R.name
Esempio n. 10
0
def antenna_flags_field(msname, fields=None, antennas=None):
    ds_ant = xds_from_table(msname+"::ANTENNA")[0]
    ds_field = xds_from_table(msname+"::FIELD")[0]
    ds_obs = xds_from_table(msname+"::OBSERVATION")[0]

    ant_names = ds_ant.NAME.data.compute()
    field_names = ds_field.NAME.data.compute()
    ant_positions = ds_ant.POSITION.data.compute()

    try:
        # Get observatory name and centre of array
        obs_name = ds_obs.TELESCOPE_NAME.data.compute()[0]
        me = casacore.measures.measures()
        obs_cofa = me.observatory(obs_name)
        lon, lat, alt = (obs_cofa['m0']['value'],
                         obs_cofa['m1']['value'],
                         obs_cofa['m2']['value'])
        cofa = wgs84_to_ecef(lon, lat, alt)
    except:
        # Otherwise use the first id antenna
        cofa = ant_positions[0]

    if fields:
        if isinstance(fields[0], str):
            field_ids = list(map(fields.index, fields))
        else:
            field_ids = fields
    else:
        field_ids = list(range(len(field_names)))

    if antennas:
        if isinstance(antennas[0], str):
            ant_ids = list(map(antennas.index, antennas))
        else:
            ant_ids = antennas
    else:
        ant_ids = list(range(len(ant_names)))

    nant = len(ant_ids)
    nfield = len(field_ids)
    
    fields_str = ", ".join(map(str, field_ids))
    ds_mss = xds_from_ms(msname, group_cols=["FIELD_ID", "DATA_DESC_ID"], 
            chunks={'row': 100000}, taql_where="FIELD_ID IN [%s]" % fields_str)
    flag_sum_computes = []
    for ds in ds_mss:
        flag_sums = da.blockwise(_get_flags, ("row",),
                                    ant_ids, ("ant",),
                                    ds.ANTENNA1.data, ("row",),
                                    ds.ANTENNA2.data, ("row",),
                                    ds.FLAG.data, ("row","chan", "corr"),
                                    adjust_chunks={"row": nant },
                                    dtype=numpy.ndarray)
    
        flags_redux = da.reduction(flag_sums,
                                 chunk=_chunk,
                                 combine=_combine,
                                 aggregate=_aggregate,
                                 concatenate=False,
                                 dtype=numpy.float64)
        flag_sum_computes.append(flags_redux)

    #flag_sum_computes[0].visualize("graph.pdf")
    sum_per_field_spw = dask.compute(flag_sum_computes)[0]
    sum_all = sum(sum_per_field_spw)
    fractions = sum_all[:,0]/sum_all[:,1]
    stats = {}
    for i,aid in enumerate(ant_ids):
        ant_stats = {}
        ant_pos = list(ant_positions[i])
        ant_stats["name"] = ant_names[aid]
        ant_stats["position"] = ant_pos
        ant_stats["array_centre_dist"] = _distance(cofa, ant_pos)
        ant_stats["frac"] = fractions[i]
        ant_stats["sum"] = sum_all[i][0]
        ant_stats["counts"] = sum_all[i][1]
        stats[aid] = ant_stats

    return stats
Esempio n. 11
0
def bda_average_spw(out_datasets, ddid_ds, spw_ds):
    """
    Parameters
    ----------
    out_datasets : list of Datasets
        list of Datasets
    ddid_ds : Dataset
        DATA_DESCRIPTION dataset
    spw_ds : list of Datasets
        list of Datasets, each describing a single Spectral Window

    Returns
    -------
    output_ds : list of Datasets
        list of Datasets
    spw_ds : list of Datasets
        list of Datasets, each describing an averaged Spectral Window
    """

    channelisations = []

    # Over the entire set of datasets, determine the complete
    # set of channelisations, per input DDID and
    # reduce down to a single object
    for out_ds in out_datasets:
        transform = da.blockwise(_channelisations, ("row",),
                                 out_ds.DATA_DESC_ID.data, ("row",),
                                 out_ds.NUM_CHAN.data, ("row",),
                                 ddid_ds.SPECTRAL_WINDOW_ID.data, ("ddid",),
                                 ddid_ds.POLARIZATION_ID.data, ("ddid",),
                                 meta=np.empty((0,), dtype=np.object))

        result = da.reduction(transform,
                              chunk=_noop,
                              combine=combine,
                              aggregate=combine,
                              concatenate=False,
                              keepdims=True,
                              meta=np.empty((0,), dtype=np.object),
                              dtype=np.object)

        channelisations.append(result)

    # Final reduction object, note the aggregate method
    # which generates the mapping
    ddid_chan_map = da.reduction(da.concatenate(channelisations),
                                 chunk=_noop,
                                 combine=combine,
                                 aggregate=aggregate,
                                 concatenate=False,
                                 keepdims=False,
                                 meta=np.empty((), dtype=np.object),
                                 dtype=np.object)

    def _squeeze_tuplify(*args):
        return tuple(a.squeeze() for a in args)

    chan_freqs = da.blockwise(_squeeze_tuplify, ("row", "chan"),
                              *(a for spw in spw_ds for a
                                in (spw.CHAN_FREQ.data, ("row", "chan"))),
                              concatenate=False,
                              align_arrays=False,
                              adjust_chunks={"chan": lambda c: np.nan},
                              meta=np.empty((0, 0), dtype=np.object))

    chan_widths = da.blockwise(_squeeze_tuplify, ("row", "chan"),
                               *(a for spw in spw_ds for a
                                 in (spw.CHAN_WIDTH.data, ("row", "chan"))),
                               concatenate=False,
                               align_arrays=False,
                               adjust_chunks={"chan": lambda c: np.nan},
                               meta=np.empty((0, 0), dtype=np.object))

    ref_freqs = da.blockwise(_squeeze_tuplify, ("row",),
                             *(a for spw in spw_ds for a
                               in (spw.REF_FREQUENCY.data, ("row",))),
                             concatenate=False,
                             align_arrays=False,
                             meta=np.empty((0,), dtype=np.object))

    meas_freq_refs = da.blockwise(_squeeze_tuplify, ("row",),
                                  *(a for spw in spw_ds for a
                                    in (spw.REF_FREQUENCY.data, ("row",))),
                                  concatenate=False,
                                  align_arrays=False,
                                  meta=np.empty((0,), dtype=np.object))

    result = da.blockwise(ddid_and_spw_factory, ("row", "chan"),
                          chan_freqs, ("row", "chan"),
                          chan_widths, ("row", "chan"),
                          ref_freqs, ("row",),
                          meas_freq_refs, ("row",),
                          ddid_chan_map, (),
                          meta=np.empty((0, 0), dtype=np.object))

    # There should only be one chunk
    assert result.npartitions == 1

    chan_freq = da.blockwise(getitem, ("row", "chan"),
                             result, ("row", "chan"),
                             0, None,
                             dtype=np.float64)

    chan_width = da.blockwise(getitem, ("row", "chan"),
                              result, ("row", "chan"),
                              1, None,
                              dtype=np.float64)

    num_chan = da.blockwise(lambda d, i: d[0][i], ("row",),
                            result, ("row", "chan"),
                            2, None,
                            dtype=np.int32)

    ref_freq = da.blockwise(lambda d, i: d[0][i], ("row",),
                            result, ("row", "chan"),
                            3, None,
                            dtype=np.float64)

    meas_freq_refs = da.blockwise(lambda d, i: d[0][i], ("row",),
                                  result, ("row", "chan"),
                                  4, None,
                                  dtype=np.float64)

    total_bw = da.blockwise(lambda d, i: d[0][i], ("row",),
                            result, ("row", "chan"),
                            5, None,
                            dtype=np.float64)

    spectral_window_id = da.blockwise(lambda d, i: d[0][i], ("row",),
                                      result, ("row", "chan"),
                                      6, None,
                                      dtype=np.int32)

    polarization_id = da.blockwise(lambda d, i: d[0][i], ("row",),
                                   result, ("row", "chan"),
                                   7, None,
                                   dtype=np.int32)

    ddid_map = da.blockwise(lambda d, i: d[0][i], ("row",),
                            result, ("row", "chan"),
                            8, None,
                            dtype=np.int32)

    for o, out_ds in enumerate(out_datasets):
        data_desc_id = da.blockwise(_new_ddids, ("row",),
                                    out_ds.DATA_DESC_ID.data, ("row",),
                                    out_ds.NUM_CHAN.data, ("row",),
                                    ddid_map, ("ddid",),
                                    dtype=out_ds.DATA_DESC_ID.dtype)

        dv = dict(out_ds.data_vars)
        dv["DATA_DESC_ID"] = (("row",), data_desc_id)
        del dv["NUM_CHAN"]
        del dv["DECORR_CHAN_WIDTH"]

        out_datasets[o] = Dataset(dv, out_ds.coords, out_ds.attrs)

    out_spw_ds = Dataset({
        "CHAN_FREQ": (("row", "chan"), chan_freq),
        "CHAN_WIDTH": (("row", "chan"), chan_width),
        "EFFECTIVE_BW": (("row", "chan"), chan_width),
        "RESOLUTION": (("row", "chan"), chan_width),
        "NUM_CHAN": (("row",), num_chan),
        "REF_FREQUENCY": (("row",), ref_freq),
        "TOTAL_BANDWIDTH": (("row",), total_bw)
    })

    out_ddid_ds = Dataset({
        "SPECTRAL_WINDOW_ID": (("row",), spectral_window_id),
        "POLARIZATION_ID": (("row",), polarization_id),
    })

    return out_datasets, [out_spw_ds], out_ddid_ds
Esempio n. 12
0
def pairwise_distance(
    x: ArrayLike,
    metric: MetricTypes = "euclidean",
    split_every: typing.Optional[int] = None,
) -> da.array:
    """Calculates the pairwise distance between all pairs of row vectors in the
    given two dimensional array x.

    To illustrate the algorithm consider the following (4, 5) two dimensional array:

    [e.00, e.01, e.02, e.03, e.04]
    [e.10, e.11, e.12, e.13, e.14]
    [e.20, e.21, e.22, e.23, e.24]
    [e.30, e.31, e.32, e.33, e.34]

    The rows of the above matrix are the set of vectors. Now let's label all
    the vectors as v0, v1, v2, v3.

    The result will be a two dimensional symmetric matrix which will contain
    the distance between all pairs. Since there are 4 vectors, calculating the
    distance between each vector and every other vector, will result in 16
    distances and the resultant array will be of size (4, 4) as follows:

    [v0.v0, v0.v1, v0.v2, v0.v3]
    [v1.v0, v1.v1, v1.v2, v1.v3]
    [v2.v0, v2.v1, v2.v2, v2.v3]
    [v3.v0, v3.v1, v3.v2, v3.v3]

    The (i, j) position in the resulting array (matrix) denotes the distance
    between vi and vj vectors.

    Negative and nan values are considered as missing values. They are ignored
    for all distance metric calculations.

    Parameters
    ----------
    x
        [array-like, shape: (M, N)]
        An array like two dimensional matrix. The rows are the
        vectors used for comparison, i.e. for pairwise distance.
    metric
        The distance metric to use. The distance function can be
        'euclidean' or 'correlation'.
    split_every
        Determines the depth of the recursive aggregation in the reduction
        step. This argument is directly passed to the call to``dask.reduction``
        function in the reduce step of the map reduce.

        Omit to let dask heuristically decide a good default. A default can
        also be set globally with the split_every key in dask.config.

    Returns
    -------

    [array-like, shape: (M, M)]
    A two dimensional distance matrix, which will be symmetric. The dimension
    will be (M, M). The (i, j) position in the resulting array
    (matrix) denotes the distance between ith and jth row vectors
    in the input array.

    Examples
    --------

    >>> from sgkit.distance.api import pairwise_distance
    >>> import dask.array as da
    >>> x = da.array([[6, 4, 1,], [4, 5, 2], [9, 7, 3]]).rechunk(2, 2)
    >>> pairwise_distance(x, metric='euclidean').compute()
    array([[0.        , 2.44948974, 4.69041576],
           [2.44948974, 0.        , 5.47722558],
           [4.69041576, 5.47722558, 0.        ]])

    >>> import numpy as np
    >>> x = np.array([[6, 4, 1,], [4, 5, 2], [9, 7, 3]])
    >>> pairwise_distance(x, metric='euclidean').compute()
    array([[0.        , 2.44948974, 4.69041576],
           [2.44948974, 0.        , 5.47722558],
           [4.69041576, 5.47722558, 0.        ]])

    >>> x = np.array([[6, 4, 1,], [4, 5, 2], [9, 7, 3]])
    >>> pairwise_distance(x, metric='correlation').compute()
    array([[-4.44089210e-16,  2.62956526e-01,  2.82353505e-03],
           [ 2.62956526e-01,  0.00000000e+00,  2.14285714e-01],
           [ 2.82353505e-03,  2.14285714e-01,  0.00000000e+00]])
    """
    try:
        metric_map_func = getattr(metrics, f"{metric}_map")
        metric_reduce_func = getattr(metrics, f"{metric}_reduce")
        n_map_param = metrics.N_MAP_PARAM[metric]
    except AttributeError:
        raise NotImplementedError(
            f"Given metric: {metric} is not implemented.")

    x = da.asarray(x)
    if x.ndim != 2:
        raise ValueError(f"2-dimensional array expected, got '{x.ndim}'")

    # setting this variable outside of _pairwise to avoid it's recreation
    # in every iteration, which eventually leads to increase in dask
    # graph serialisation/deserialisation time significantly
    metric_param = np.empty(n_map_param, dtype=x.dtype)

    def _pairwise(f: ArrayLike, g: ArrayLike) -> ArrayLike:
        result: ArrayLike = metric_map_func(f[:, None, :], g, metric_param)
        # Adding a new axis to help combine chunks along this axis in the
        # reduction step (see the _aggregate and _combine functions below).
        return result[..., np.newaxis]

    # concatenate in blockwise leads to high memory footprints, so instead
    # we perform blockwise without contraction followed by reduction.
    # More about this issue: https://github.com/dask/dask/issues/6874
    out = da.blockwise(
        _pairwise,
        "ijk",
        x,
        "ik",
        x,
        "jk",
        dtype=x.dtype,
        concatenate=False,
    )

    def _aggregate(x_chunk: ArrayLike, **_: typing.Any) -> ArrayLike:
        """Last function to be executed when resolving the dask graph,
        producing the final output. It is always invoked, even when the reduced
        Array counts a single chunk along the reduced axes."""
        x_chunk = x_chunk.reshape(x_chunk.shape[:-2] + (-1, n_map_param))
        result: ArrayLike = metric_reduce_func(x_chunk)
        return result

    def _chunk(x_chunk: ArrayLike, **_: typing.Any) -> ArrayLike:
        return x_chunk

    def _combine(x_chunk: ArrayLike, **_: typing.Any) -> ArrayLike:
        """Function used for intermediate recursive aggregation (see
        split_every argument to ``da.reduction below``).  If the
        reduction can be performed in less than 3 steps, it will
        not be invoked at all."""
        # reduce chunks by summing along the -2 axis
        x_chunk_reshaped = x_chunk.reshape(x_chunk.shape[:-2] +
                                           (-1, n_map_param))
        return x_chunk_reshaped.sum(axis=-2)[..., np.newaxis]

    r = da.reduction(
        out,
        chunk=_chunk,
        combine=_combine,
        aggregate=_aggregate,
        axis=-1,
        dtype=x.dtype,
        meta=np.ndarray((0, 0), dtype=x.dtype),
        split_every=split_every,
        name="pairwise",
    )

    t = da.triu(r)
    return t + t.T