def _run_fornav_single(self, data, out_chunks, target_geo_def, fill_value, **kwargs): ll2cr_result = self.cache['ll2cr_result'] ll2cr_blocks = self.cache['ll2cr_blocks'].items() ll2cr_numblocks = ll2cr_result.shape if isinstance( ll2cr_result, np.ndarray) else ll2cr_result.numblocks fornav_task_name = f"fornav-{data.name}-{ll2cr_result.name}" maximum_weight_mode = kwargs.setdefault('maximum_weight_mode', False) weight_sum_min = kwargs.setdefault('weight_sum_min', -1.0) output_stack = self._generate_fornav_dask_tasks( out_chunks, ll2cr_blocks, fornav_task_name, data.name, target_geo_def, fill_value, kwargs) dsk_graph = HighLevelGraph.from_collections( fornav_task_name, output_stack, dependencies=[data, ll2cr_result]) stack_chunks = ( (1, ) * (ll2cr_numblocks[0] * ll2cr_numblocks[1]), ) + out_chunks out_stack = da.Array(dsk_graph, fornav_task_name, stack_chunks, data.dtype) combine_fornav_with_kwargs = partial( _combine_fornav, maximum_weight_mode=maximum_weight_mode) average_fornav_with_kwargs = partial( _average_fornav, maximum_weight_mode=maximum_weight_mode, weight_sum_min=weight_sum_min, dtype=data.dtype, fill_value=fill_value) out = da.reduction(out_stack, _chunk_callable, average_fornav_with_kwargs, combine=combine_fornav_with_kwargs, axis=(0, ), dtype=data.dtype, concatenate=False) return out
def test_general_reduction_names(): dtype = int a = da.reduction( da.ones(10, dtype, chunks=2), np.sum, np.sum, dtype=dtype, name="foo" ) names, tokens = list(zip_longest(*[key[0].rsplit("-", 1) for key in a.dask])) assert set(names) == {"ones_like", "foo", "foo-partial", "foo-aggregate"} assert all(tokens)
def test_weighted_reduction(): # Weighted reduction def w_sum(x, weights=None, dtype=None, computing_meta=False, **kwargs): """`chunk` callable for (weighted) sum""" if computing_meta: return x if weights is not None: x = x * weights return np.sum(x, dtype=dtype, **kwargs) # Arrays a = 1 + np.ma.arange(60).reshape(6, 10) a[2, 2] = np.ma.masked dx = da.from_array(a, chunks=(4, 5)) # Weights w = np.linspace(1, 2, 6).reshape(6, 1) # No weights (i.e. normal sum) x = da.reduction(dx, w_sum, np.sum, dtype=dx.dtype) assert_eq(x, np.sum(a), check_shape=True) # Weighted sum x = da.reduction(dx, w_sum, np.sum, dtype="f8", weights=w) assert_eq(x, np.sum(a * w), check_shape=True) # Non-broadcastable weights (short axis) with pytest.raises(ValueError): da.reduction(dx, w_sum, np.sum, weights=[1, 2, 3]) # Non-broadcastable weights (too many dims) with pytest.raises(ValueError): da.reduction(dx, w_sum, np.sum, weights=[[[2]]])
def reduce_loglikelihood(log_weighted_likelihoods): if isinstance(log_weighted_likelihoods, np.ndarray): log_likelihood = logaddexp_reduce(log_weighted_likelihoods) else: # Sum along gaussians axis (using logAddExp to prevent underflow) log_likelihood = da.reduction( x=log_weighted_likelihoods, chunk=logaddexp_reduce, aggregate=logaddexp_reduce, axis=0, dtype=float, keepdims=False, ) return log_likelihood
def compute_xy_bbox(xy_coords: Union[xr.DataArray, np.ndarray, da.Array]) \ -> Tuple[float, float, float, float]: xy_coords = da.asarray(xy_coords) result = da.reduction( xy_coords, compute_xy_bbox_chunk, compute_xy_bbox_aggregate, keepdims=True, # concatenate=False, dtype=xy_coords.dtype, axis=(1, 2), meta=np.array([[0, 0], [0, 0]], dtype=xy_coords.dtype) ) x_min, x_max, y_min, y_max = map(float, result.compute().flatten()) return x_min, y_min, x_max, y_max
def test_chunk_structure_independence(axes, split_every, chunks): # Reducing an array should not depend on its chunk-structure!!! # See Issue #8541: https://github.com/dask/dask/issues/8541 shape = tuple(np.sum(s) for s in chunks) np_array = np.arange(np.prod(shape)).reshape(*shape) x = da.from_array(np_array, chunks=chunks) reduced_x = da.reduction( x, lambda x, axis, keepdims: x, lambda x, axis, keepdims: x, keepdims=True, axis=axes, split_every=split_every, dtype=x.dtype, meta=x._meta, ) _assert_eq(reduced_x, np_array, check_chunks=False, check_shape=False)
def dataset_chunks(datasets, time_bin_secs, max_row_chunks): """ Given ``max_row_chunks`` determine a chunking strategy for each dataset that prevents binning unique times in separate chunks. """ # Calculate (utime, idx, counts) tuple for each dataset # then tranpose to get lists for each tuple entry if len(datasets) == 0: return (), () utimes = [] interval_avg = [] counts = [] monotonicity_checks = [] for ds in datasets: # Compute unique times, their counts and interval sum # for each row chunk block_values = da.blockwise(_time_interval_sum, "r", ds.TIME.data, "r", ds.INTERVAL.data, "r", meta=np.empty((0, ), dtype=np.object), dtype=np.object) # Reduce each row chunk's values reduction = da.reduction(block_values, chunk=_chunk, combine=_time_int_combine, aggregate=_time_int_agg, concatenate=False, split_every=16, meta=np.empty((0, ), dtype=np.object), dtype=np.object) # Pull out the final unique times, counts and interval average utime = reduction.map_blocks(getitem, 0, dtype=ds.TIME.dtype) count = reduction.map_blocks(getitem, 1, dtype=np.int32) int_avg = reduction.map_blocks(getitem, 2, dtype=ds.INTERVAL.dtype) # Check monotonicity of TIME while we're at it is_monotonic = da.all(da.diff(ds.TIME.data) >= 0.0) utimes.append(utime) counts.append(count) interval_avg.append(int_avg) monotonicity_checks.append(is_monotonic) # Work out the unique times, average intervals for those times # and the frequency of those times (ds_utime, ds_avg_intervals, ds_counts, ds_monotonicity_checks) = dask.compute(utimes, interval_avg, counts, monotonicity_checks) if not all(ds_monotonicity_checks): raise ValueError("TIME is not monotonically increasing. " "This is required.") # Produce row and time chunking strategies for each dataset ds_row_chunks = [] ds_time_chunks = [] ds_interval_secs = [] it = zip(ds_utime, ds_avg_intervals, ds_counts) for di, (utime, avg_interval, counts) in enumerate(it): # Maintain row and time chunks for this dataset row_chunks = [] time_chunks = [] interval_secs = [] # Start out with first entries bin_rows = counts[0] bin_times = 1 bin_secs = avg_interval[0] dsit = enumerate(zip(utime[1:], avg_interval[1:], counts[1:])) for ti, (ut, avg_int, count) in dsit: if count > max_row_chunks: logger.warning( "Unique time {:3f} occurred {:d} times " "in dataset {:d} but this exceeds the " "requested row chunks {:d}. " "Consider increasing --row-chunks", ut, count, di, max_row_chunks) if avg_int > time_bin_secs: logger.warning( "The average INTERVAL associated with " "unique time {:3f} in dataset {:d} " "is {:3f} but this exceeds the requested " "number of seconds in a time bin {:3f}s. " "Consider increasing --time-bin-secs", ut, di, avg_int, time_bin_secs) next_rows = bin_rows + count # If we're still within the number of rows for this bin # keep going if next_rows < max_row_chunks: bin_rows = next_rows bin_times += 1 bin_secs += avg_int # Otherwise finalize this bin and # start a new one with the counts # we were trying to add else: row_chunks.append(bin_rows) time_chunks.append(bin_times) interval_secs.append(bin_secs) bin_rows = count bin_times = 1 bin_secs = avg_int # Finish any remaining bins if bin_rows > 0: assert bin_times > 0 row_chunks.append(bin_rows) time_chunks.append(bin_times) interval_secs.append(bin_secs) row_chunks = tuple(row_chunks) time_chunks = tuple(time_chunks) interval_secs = tuple(interval_secs) ds_row_chunks.append(row_chunks) ds_time_chunks.append(time_chunks) ds_interval_secs.append(interval_secs) logger.info("Dataset Chunking: (r)ow - (t)imes - (s)econds") it = zip(datasets, ds_row_chunks, ds_time_chunks, ds_interval_secs) for di, (ds, ds_rcs, ds_tcs, ds_int_secs) in enumerate(it): ds_rows = ds.dims['row'] ds_crows = sum(ds_rcs) if not ds_rows == ds_crows: raise ValueError("Number of dataset rows %d " "does not match the sum %d " "of the row chunks %s" % (ds_rows, ds_crows, ds_rcs)) log_str = ", ".join("(%dr,%dt,%.1fs)" % (rc, tc, its) for rc, tc, its in zip(*(ds_rcs, ds_tcs, ds_int_secs))) logger.info("Dataset {d}: {s}", d=di, s=log_str) return ds_row_chunks, ds_time_chunks
def dataset_chunks(datasets, time_bin_secs, max_row_chunks): """ Given ``max_row_chunks`` determine a chunking strategy for each dataset that prevents binning unique times in separate chunks. """ # Calculate (utime, idx, counts) tuple for each dataset # then tranpose to get lists for each tuple entry if len(datasets) == 0: return (), () utimes = [] interval_avg = [] counts = [] monotonicity_checks = [] for ds in datasets: # Compute unique times, their counts and interval sum # for each row chunk block_values = da.blockwise(_time_interval_sum, "r", ds.TIME.data, "r", ds.INTERVAL.data, "r", meta=np.empty((0, ), dtype=np.object), dtype=np.object) # Reduce each row chunk's values reduction = da.reduction(block_values, chunk=_chunk, combine=_time_int_combine, aggregate=_time_int_agg, concatenate=False, split_every=16, meta=np.empty((0, ), dtype=np.object), dtype=np.object) # Pull out the final unique times, counts and interval average utime = reduction.map_blocks(getitem, 0, dtype=ds.TIME.dtype) count = reduction.map_blocks(getitem, 1, dtype=np.int32) int_avg = reduction.map_blocks(getitem, 2, dtype=ds.INTERVAL.dtype) # Check monotonicity of TIME while we're at it is_monotonic = da.all(da.diff(ds.TIME.data) >= 0.0) utimes.append(utime) counts.append(count) interval_avg.append(int_avg) monotonicity_checks.append(is_monotonic) # Work out the unique times, average intervals for those times # and the frequency of those times (ds_utime, ds_avg_intervals, ds_counts, ds_monotonicity_checks) = dask.compute(utimes, interval_avg, counts, monotonicity_checks) if not all(ds_monotonicity_checks): raise ValueError("TIME is not monotonically increasing. " "This is required.") grouper = DatasetGrouper(time_bin_secs, max_row_chunks) res = grouper.group(ds_utime, ds_avg_intervals, ds_counts) ds_row_chunks, ds_time_chunks, ds_interval_secs = res logger.info("Dataset Chunking: (r)ow - (t)imes - (s)econds") it = zip(datasets, ds_row_chunks, ds_time_chunks, ds_interval_secs) for di, (ds, ds_rcs, ds_tcs, ds_int_secs) in enumerate(it): ds_rows = ds.dims['row'] ds_crows = sum(ds_rcs) if not ds_rows == ds_crows: raise ValueError("Number of dataset rows %d " "does not match the sum %d " "of the row chunks %s" % (ds_rows, ds_crows, ds_rcs)) log_str = ", ".join("(%dr,%dt,%.1fs)" % (rc, tc, its) for rc, tc, its in zip(*(ds_rcs, ds_tcs, ds_int_secs))) logger.info("Dataset {d}: {s}", d=di, s=log_str) return ds_row_chunks, ds_time_chunks
def default(glyph, df, schema, canvas, summary, cuda=False): shape, bounds, st, axis = shape_bounds_st_and_axis(df, canvas, glyph) # Compile functions create, info, append, combine, finalize = \ compile_components(summary, schema, glyph, cuda=cuda) x_mapper = canvas.x_axis.mapper y_mapper = canvas.y_axis.mapper extend = glyph._build_extend(x_mapper, y_mapper, info, append) # Here be dragons # Get the dataframe graph graph = df.__dask_graph__() # Guess a reasonable output dtype from combination of dataframe dtypes dtypes = [] for dt in df.dtypes: if isinstance(dt, pd.CategoricalDtype): continue elif isinstance(dt, pd.api.extensions.ExtensionDtype): # RaggedArray implementation and # https://github.com/pandas-dev/pandas/issues/22224 try: subdtype = dt.subtype except AttributeError: continue else: dtypes.append(subdtype) else: dtypes.append(dt) dtype = np.result_type(*dtypes) # Create a meta object so that dask.array doesn't try to look # too closely at the type of the chunks it's wrapping # they're actually dataframes, tell dask they're ndarrays meta = np.empty((0, ), dtype=dtype) # Create a chunks tuple, a singleton for each dataframe chunk # The number of chunks + structure needs to match that of # the dataframe, so that we can use the dataframe graph keys, # but we don't have to be precise with the chunk size. # We could use np.nan instead of 1 to indicate that we actually # don't know how large the chunk is chunks = (tuple(1 for _ in range(df.npartitions)), ) # Now create a dask array from the dataframe graph layer # It's a dask array of dataframes, which is dodgy but useful # for the following reasons: # # (1) The dataframes get converted to a single array by # the datashader reduction functions anyway # (2) dask.array.reduction is handy for coding a tree # reduction of arrays df_array = da.Array(graph, df._name, chunks, meta=meta) # A sufficient condition for ensuring the chimera holds together assert list(df_array.__dask_keys__()) == list(df.__dask_keys__()) def chunk(df, axis, keepdims): """ used in the dask.array.reduction chunk step """ aggs = create(shape) extend(aggs, df, st, bounds) return aggs def wrapped_combine(x, axis, keepdims): """ wrap datashader combine in dask.array.reduction combine """ if isinstance(x, list): # list of tuples of ndarrays # assert all(isinstance(item, tuple) and # len(item) == 1 and # isinstance(item[0], np.ndarray) # for item in x) return combine(x) elif isinstance(x, tuple): # tuple with single ndarray # assert len(x) == 1 and isinstance(x[0], np.ndarray) return x else: raise TypeError("Unknown type %s in wrapped_combine" % type(x)) local_axis = axis def aggregate(x, axis, keepdims): """ Wrap datashader finalize in dask.array.reduction aggregate """ return finalize(wrapped_combine(x, axis, keepdims), cuda=cuda, coords=local_axis, dims=[glyph.y_label, glyph.x_label]) R = da.reduction( df_array, aggregate=aggregate, chunk=chunk, combine=wrapped_combine, # Control granularity of tree branching # less is more split_every=2, # We don't want np.concatenate called # during combine and aggregate. It'll # fail because we're handling tuples of ndarrays # and lists of tuples of ndarrays concatenate=False, # Prevent dask from internally inspecting # chunk, combine and aggrregate meta=meta, # Provide some sort of dtype for the # resultant dask array dtype=meta.dtype) return R, R.name
def antenna_flags_field(msname, fields=None, antennas=None): ds_ant = xds_from_table(msname+"::ANTENNA")[0] ds_field = xds_from_table(msname+"::FIELD")[0] ds_obs = xds_from_table(msname+"::OBSERVATION")[0] ant_names = ds_ant.NAME.data.compute() field_names = ds_field.NAME.data.compute() ant_positions = ds_ant.POSITION.data.compute() try: # Get observatory name and centre of array obs_name = ds_obs.TELESCOPE_NAME.data.compute()[0] me = casacore.measures.measures() obs_cofa = me.observatory(obs_name) lon, lat, alt = (obs_cofa['m0']['value'], obs_cofa['m1']['value'], obs_cofa['m2']['value']) cofa = wgs84_to_ecef(lon, lat, alt) except: # Otherwise use the first id antenna cofa = ant_positions[0] if fields: if isinstance(fields[0], str): field_ids = list(map(fields.index, fields)) else: field_ids = fields else: field_ids = list(range(len(field_names))) if antennas: if isinstance(antennas[0], str): ant_ids = list(map(antennas.index, antennas)) else: ant_ids = antennas else: ant_ids = list(range(len(ant_names))) nant = len(ant_ids) nfield = len(field_ids) fields_str = ", ".join(map(str, field_ids)) ds_mss = xds_from_ms(msname, group_cols=["FIELD_ID", "DATA_DESC_ID"], chunks={'row': 100000}, taql_where="FIELD_ID IN [%s]" % fields_str) flag_sum_computes = [] for ds in ds_mss: flag_sums = da.blockwise(_get_flags, ("row",), ant_ids, ("ant",), ds.ANTENNA1.data, ("row",), ds.ANTENNA2.data, ("row",), ds.FLAG.data, ("row","chan", "corr"), adjust_chunks={"row": nant }, dtype=numpy.ndarray) flags_redux = da.reduction(flag_sums, chunk=_chunk, combine=_combine, aggregate=_aggregate, concatenate=False, dtype=numpy.float64) flag_sum_computes.append(flags_redux) #flag_sum_computes[0].visualize("graph.pdf") sum_per_field_spw = dask.compute(flag_sum_computes)[0] sum_all = sum(sum_per_field_spw) fractions = sum_all[:,0]/sum_all[:,1] stats = {} for i,aid in enumerate(ant_ids): ant_stats = {} ant_pos = list(ant_positions[i]) ant_stats["name"] = ant_names[aid] ant_stats["position"] = ant_pos ant_stats["array_centre_dist"] = _distance(cofa, ant_pos) ant_stats["frac"] = fractions[i] ant_stats["sum"] = sum_all[i][0] ant_stats["counts"] = sum_all[i][1] stats[aid] = ant_stats return stats
def bda_average_spw(out_datasets, ddid_ds, spw_ds): """ Parameters ---------- out_datasets : list of Datasets list of Datasets ddid_ds : Dataset DATA_DESCRIPTION dataset spw_ds : list of Datasets list of Datasets, each describing a single Spectral Window Returns ------- output_ds : list of Datasets list of Datasets spw_ds : list of Datasets list of Datasets, each describing an averaged Spectral Window """ channelisations = [] # Over the entire set of datasets, determine the complete # set of channelisations, per input DDID and # reduce down to a single object for out_ds in out_datasets: transform = da.blockwise(_channelisations, ("row",), out_ds.DATA_DESC_ID.data, ("row",), out_ds.NUM_CHAN.data, ("row",), ddid_ds.SPECTRAL_WINDOW_ID.data, ("ddid",), ddid_ds.POLARIZATION_ID.data, ("ddid",), meta=np.empty((0,), dtype=np.object)) result = da.reduction(transform, chunk=_noop, combine=combine, aggregate=combine, concatenate=False, keepdims=True, meta=np.empty((0,), dtype=np.object), dtype=np.object) channelisations.append(result) # Final reduction object, note the aggregate method # which generates the mapping ddid_chan_map = da.reduction(da.concatenate(channelisations), chunk=_noop, combine=combine, aggregate=aggregate, concatenate=False, keepdims=False, meta=np.empty((), dtype=np.object), dtype=np.object) def _squeeze_tuplify(*args): return tuple(a.squeeze() for a in args) chan_freqs = da.blockwise(_squeeze_tuplify, ("row", "chan"), *(a for spw in spw_ds for a in (spw.CHAN_FREQ.data, ("row", "chan"))), concatenate=False, align_arrays=False, adjust_chunks={"chan": lambda c: np.nan}, meta=np.empty((0, 0), dtype=np.object)) chan_widths = da.blockwise(_squeeze_tuplify, ("row", "chan"), *(a for spw in spw_ds for a in (spw.CHAN_WIDTH.data, ("row", "chan"))), concatenate=False, align_arrays=False, adjust_chunks={"chan": lambda c: np.nan}, meta=np.empty((0, 0), dtype=np.object)) ref_freqs = da.blockwise(_squeeze_tuplify, ("row",), *(a for spw in spw_ds for a in (spw.REF_FREQUENCY.data, ("row",))), concatenate=False, align_arrays=False, meta=np.empty((0,), dtype=np.object)) meas_freq_refs = da.blockwise(_squeeze_tuplify, ("row",), *(a for spw in spw_ds for a in (spw.REF_FREQUENCY.data, ("row",))), concatenate=False, align_arrays=False, meta=np.empty((0,), dtype=np.object)) result = da.blockwise(ddid_and_spw_factory, ("row", "chan"), chan_freqs, ("row", "chan"), chan_widths, ("row", "chan"), ref_freqs, ("row",), meas_freq_refs, ("row",), ddid_chan_map, (), meta=np.empty((0, 0), dtype=np.object)) # There should only be one chunk assert result.npartitions == 1 chan_freq = da.blockwise(getitem, ("row", "chan"), result, ("row", "chan"), 0, None, dtype=np.float64) chan_width = da.blockwise(getitem, ("row", "chan"), result, ("row", "chan"), 1, None, dtype=np.float64) num_chan = da.blockwise(lambda d, i: d[0][i], ("row",), result, ("row", "chan"), 2, None, dtype=np.int32) ref_freq = da.blockwise(lambda d, i: d[0][i], ("row",), result, ("row", "chan"), 3, None, dtype=np.float64) meas_freq_refs = da.blockwise(lambda d, i: d[0][i], ("row",), result, ("row", "chan"), 4, None, dtype=np.float64) total_bw = da.blockwise(lambda d, i: d[0][i], ("row",), result, ("row", "chan"), 5, None, dtype=np.float64) spectral_window_id = da.blockwise(lambda d, i: d[0][i], ("row",), result, ("row", "chan"), 6, None, dtype=np.int32) polarization_id = da.blockwise(lambda d, i: d[0][i], ("row",), result, ("row", "chan"), 7, None, dtype=np.int32) ddid_map = da.blockwise(lambda d, i: d[0][i], ("row",), result, ("row", "chan"), 8, None, dtype=np.int32) for o, out_ds in enumerate(out_datasets): data_desc_id = da.blockwise(_new_ddids, ("row",), out_ds.DATA_DESC_ID.data, ("row",), out_ds.NUM_CHAN.data, ("row",), ddid_map, ("ddid",), dtype=out_ds.DATA_DESC_ID.dtype) dv = dict(out_ds.data_vars) dv["DATA_DESC_ID"] = (("row",), data_desc_id) del dv["NUM_CHAN"] del dv["DECORR_CHAN_WIDTH"] out_datasets[o] = Dataset(dv, out_ds.coords, out_ds.attrs) out_spw_ds = Dataset({ "CHAN_FREQ": (("row", "chan"), chan_freq), "CHAN_WIDTH": (("row", "chan"), chan_width), "EFFECTIVE_BW": (("row", "chan"), chan_width), "RESOLUTION": (("row", "chan"), chan_width), "NUM_CHAN": (("row",), num_chan), "REF_FREQUENCY": (("row",), ref_freq), "TOTAL_BANDWIDTH": (("row",), total_bw) }) out_ddid_ds = Dataset({ "SPECTRAL_WINDOW_ID": (("row",), spectral_window_id), "POLARIZATION_ID": (("row",), polarization_id), }) return out_datasets, [out_spw_ds], out_ddid_ds
def pairwise_distance( x: ArrayLike, metric: MetricTypes = "euclidean", split_every: typing.Optional[int] = None, ) -> da.array: """Calculates the pairwise distance between all pairs of row vectors in the given two dimensional array x. To illustrate the algorithm consider the following (4, 5) two dimensional array: [e.00, e.01, e.02, e.03, e.04] [e.10, e.11, e.12, e.13, e.14] [e.20, e.21, e.22, e.23, e.24] [e.30, e.31, e.32, e.33, e.34] The rows of the above matrix are the set of vectors. Now let's label all the vectors as v0, v1, v2, v3. The result will be a two dimensional symmetric matrix which will contain the distance between all pairs. Since there are 4 vectors, calculating the distance between each vector and every other vector, will result in 16 distances and the resultant array will be of size (4, 4) as follows: [v0.v0, v0.v1, v0.v2, v0.v3] [v1.v0, v1.v1, v1.v2, v1.v3] [v2.v0, v2.v1, v2.v2, v2.v3] [v3.v0, v3.v1, v3.v2, v3.v3] The (i, j) position in the resulting array (matrix) denotes the distance between vi and vj vectors. Negative and nan values are considered as missing values. They are ignored for all distance metric calculations. Parameters ---------- x [array-like, shape: (M, N)] An array like two dimensional matrix. The rows are the vectors used for comparison, i.e. for pairwise distance. metric The distance metric to use. The distance function can be 'euclidean' or 'correlation'. split_every Determines the depth of the recursive aggregation in the reduction step. This argument is directly passed to the call to``dask.reduction`` function in the reduce step of the map reduce. Omit to let dask heuristically decide a good default. A default can also be set globally with the split_every key in dask.config. Returns ------- [array-like, shape: (M, M)] A two dimensional distance matrix, which will be symmetric. The dimension will be (M, M). The (i, j) position in the resulting array (matrix) denotes the distance between ith and jth row vectors in the input array. Examples -------- >>> from sgkit.distance.api import pairwise_distance >>> import dask.array as da >>> x = da.array([[6, 4, 1,], [4, 5, 2], [9, 7, 3]]).rechunk(2, 2) >>> pairwise_distance(x, metric='euclidean').compute() array([[0. , 2.44948974, 4.69041576], [2.44948974, 0. , 5.47722558], [4.69041576, 5.47722558, 0. ]]) >>> import numpy as np >>> x = np.array([[6, 4, 1,], [4, 5, 2], [9, 7, 3]]) >>> pairwise_distance(x, metric='euclidean').compute() array([[0. , 2.44948974, 4.69041576], [2.44948974, 0. , 5.47722558], [4.69041576, 5.47722558, 0. ]]) >>> x = np.array([[6, 4, 1,], [4, 5, 2], [9, 7, 3]]) >>> pairwise_distance(x, metric='correlation').compute() array([[-4.44089210e-16, 2.62956526e-01, 2.82353505e-03], [ 2.62956526e-01, 0.00000000e+00, 2.14285714e-01], [ 2.82353505e-03, 2.14285714e-01, 0.00000000e+00]]) """ try: metric_map_func = getattr(metrics, f"{metric}_map") metric_reduce_func = getattr(metrics, f"{metric}_reduce") n_map_param = metrics.N_MAP_PARAM[metric] except AttributeError: raise NotImplementedError( f"Given metric: {metric} is not implemented.") x = da.asarray(x) if x.ndim != 2: raise ValueError(f"2-dimensional array expected, got '{x.ndim}'") # setting this variable outside of _pairwise to avoid it's recreation # in every iteration, which eventually leads to increase in dask # graph serialisation/deserialisation time significantly metric_param = np.empty(n_map_param, dtype=x.dtype) def _pairwise(f: ArrayLike, g: ArrayLike) -> ArrayLike: result: ArrayLike = metric_map_func(f[:, None, :], g, metric_param) # Adding a new axis to help combine chunks along this axis in the # reduction step (see the _aggregate and _combine functions below). return result[..., np.newaxis] # concatenate in blockwise leads to high memory footprints, so instead # we perform blockwise without contraction followed by reduction. # More about this issue: https://github.com/dask/dask/issues/6874 out = da.blockwise( _pairwise, "ijk", x, "ik", x, "jk", dtype=x.dtype, concatenate=False, ) def _aggregate(x_chunk: ArrayLike, **_: typing.Any) -> ArrayLike: """Last function to be executed when resolving the dask graph, producing the final output. It is always invoked, even when the reduced Array counts a single chunk along the reduced axes.""" x_chunk = x_chunk.reshape(x_chunk.shape[:-2] + (-1, n_map_param)) result: ArrayLike = metric_reduce_func(x_chunk) return result def _chunk(x_chunk: ArrayLike, **_: typing.Any) -> ArrayLike: return x_chunk def _combine(x_chunk: ArrayLike, **_: typing.Any) -> ArrayLike: """Function used for intermediate recursive aggregation (see split_every argument to ``da.reduction below``). If the reduction can be performed in less than 3 steps, it will not be invoked at all.""" # reduce chunks by summing along the -2 axis x_chunk_reshaped = x_chunk.reshape(x_chunk.shape[:-2] + (-1, n_map_param)) return x_chunk_reshaped.sum(axis=-2)[..., np.newaxis] r = da.reduction( out, chunk=_chunk, combine=_combine, aggregate=_aggregate, axis=-1, dtype=x.dtype, meta=np.ndarray((0, 0), dtype=x.dtype), split_every=split_every, name="pairwise", ) t = da.triu(r) return t + t.T