コード例 #1
0
def take_along_axis(
    a: np.ndarray | da.Array, ind: np.ndarray | da.Array
) -> np.ndarray | da.Array:
    """Easily use the outputs of argsort on ND arrays to pick the results."""
    if isinstance(a, np.ndarray) and isinstance(ind, np.ndarray):
        a = a.reshape((1,) * (ind.ndim - a.ndim) + a.shape)
        ind = ind.reshape((1,) * (a.ndim - ind.ndim) + ind.shape)
        return np.take_along_axis(a, ind, axis=-1)

    # a and/or ind are dask arrays. This is not yet implemented upstream.
    # Upstream tracker: https://github.com/dask/dask/issues/3663

    # This is going to be an ugly and slow mess, as dask does not support
    # fancy indexing.

    # Normalize a and ind. The end result is that a can have more axes than
    # ind on the left, but not vice versa, and that all axes except the
    # extra ones on the left and the rightmost one (the axis to take
    # along) are the same shape.
    if ind.ndim > a.ndim:
        a = a.reshape((1,) * (ind.ndim - a.ndim) + a.shape)
    common_shape = tuple(np.maximum(a.shape[-ind.ndim : -1], ind.shape[:-1]))
    a_extra_shape = a.shape[: -ind.ndim]
    a = broadcast_to(a, a_extra_shape + common_shape + a.shape[-1:])
    ind = broadcast_to(ind, common_shape + ind.shape[-1:])

    # Flatten all common axes onto axis -2
    final_shape = a.shape[: -ind.ndim] + ind.shape
    ind = ind.reshape(ind.size // ind.shape[-1], ind.shape[-1])
    a = a.reshape(*a_extra_shape, ind.shape[0], a.shape[-1])

    # Now we have a[..., i, j] and ind[i, j], where i are the flattened
    # common axes and j is the axis to take along.
    res = []

    # Cycle a and ind along i, perform 1D slices, and then stack them back
    # together
    for i in range(ind.shape[0]):
        a_i = a[..., i, :]
        ind_i = ind[i, :]

        if not isinstance(a_i, da.Array):
            a_i = da.from_array(a_i, chunks=a_i.shape)

        if isinstance(ind_i, da.Array):
            res_i = slice_with_int_dask_array_on_axis(a_i, ind_i, axis=a_i.ndim - 1)
        else:
            res_i = a_i[..., ind_i]
        res.append(res_i)

    res_arr = da.stack(res, axis=-2)
    # Un-flatten axis i
    res_arr = res_arr.reshape(*final_shape)
    return res_arr