Exemple #1
0
 def body_fun(i, vals):
     a, state = vals
     # select i-th element from each b
     b = [lax.dynamic_index_in_dim(b, i, keepdims=False) for b in bs]
     a_out = core.eval_jaxpr(jaxpr, consts, (), a, core.pack(b))
     # select fields from a_out and update state
     state_out = [
         lax.dynamic_update_index_in_dim(s, a[None, ...], i, axis=0)
         for a, s in zip([tuple(a_out)[j] for j in fields], state)
     ]
     return a_out, state_out
Exemple #2
0
 def body_fun(i, state):
     arr, total, _ = state
     arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
     return (arr, lax.add(total, arr_i), ())
Exemple #3
0
 def body_fun(i, state):
     arr, total = state['arr'], state['total']
     arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
     return {'arr': arr, 'total': lax.add(total, arr_i)}
Exemple #4
0
 def body_fun(state):
     arr, num, i, total = state
     arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
     return (arr, num, lax.add(i, 1), lax.add(total, arr_i))
Exemple #5
0
 def body_fun(state):
     i, j, arr, out = state
     arr_i = lax.dynamic_index_in_dim(arr, i, 0, False)
     arr_i_j = lax.dynamic_index_in_dim(arr_i, j, 0, False)
     out = update_entry(out, arr_i_j, i, j)
     return (i, lax.add(j, 1), arr, out)
Exemple #6
0
def _rewriting_take(arr, idx, axis=0):
    """A function like numpy.take that handles boxes and rewrites to LAX."""

    # Handle special indexers: (), Ellipsis, slice(None), and None.
    # TODO(mattjj): don't compare empty tuple identity (though works for CPython)
    if idx is () or idx is Ellipsis or _is_slice_none(idx):  # pylint: disable=literal-comparison
        return arr
    elif idx is None:
        return expand_dims(arr, 0)

    # Handle int index
    _int = lambda aval: not aval.shape and onp.issubdtype(
        aval.dtype, onp.integer)
    try:
        abstract_idx = core.get_aval(idx)
    except TypeError:
        abstract_idx = None

    if isinstance(abstract_idx, ConcreteArray) and _int(abstract_idx):
        return lax.index_in_dim(arr, idx, axis, False)
    elif isinstance(abstract_idx, ShapedArray) and _int(abstract_idx):
        idx = mod(idx, arr.shape[axis])
        return lax.dynamic_index_in_dim(arr, idx, axis, False)

    # Handle slice index (only static, otherwise an error is raised)
    elif isinstance(idx, slice):
        if not _all(
                elt is None or isinstance(core.get_aval(elt), ConcreteArray)
                for elt in (idx.start, idx.stop, idx.step)):
            msg = (
                "Array slice indices must have static start/stop/step to be used "
                "with Numpy indexing syntax. Try lax.dynamic_slice instead.")
            raise IndexError(msg)
        else:
            start, limit, stride, needs_rev = _static_idx(idx, arr.shape[axis])
            result = lax.slice_in_dim(arr, start, limit, stride, axis=axis)
            return lax.rev(result, [axis]) if needs_rev else result

    # Handle non-advanced tuple indices by recursing once
    elif isinstance(idx, tuple) and _all(onp.ndim(elt) == 0 for elt in idx):
        canonical_idx = _canonicalize_tuple_index(arr, idx)
        result, axis = arr, 0
        for elt in (elt for elt in canonical_idx if elt is not None):
            result = _rewriting_take(result, elt, axis=axis)
            axis += isinstance(elt,
                               slice)  # advance axis index if not eliminated
        unexpanded_shape_itr = iter(result.shape)
        result_shape = tuple(1 if elt is None else next(unexpanded_shape_itr)
                             for elt in canonical_idx
                             if not isinstance(elt, int))
        return lax.reshape(result, result_shape)

    # Handle advanced indexing (non-tuple sequence, ndarray of dtype int or bool,
    # or a tuple with at least one sequence object).
    # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
    # https://gist.github.com/seberg/976373b6a2b7c4188591

    # Handle integer array indexing *without* ellipsis/slices/nones
    # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#integer-array-indexing
    if _is_advanced_int_indexer_without_slices(idx):
        if isinstance(idx, list):
            if _any(_shape(e) for e in idx):
                # At least one sequence element in the index list means broadcasting.
                idx = broadcast_arrays(*idx)
            else:
                # The index list is a flat list of integers.
                idx = [
                    lax.concatenate([lax.reshape(e, (1, )) for e in idx], 0)
                ]
        else:
            # The indexer is just a single integer array.
            idx = [idx]

        flat_idx = tuple(
            mod(ravel(x), arr.shape[i]) for i, x in enumerate(idx))
        out = lax.index_take(arr, flat_idx, tuple(range(len(idx))))
        return lax.reshape(out, idx[0].shape + _shape(arr)[len(idx):])

    # Handle integer array indexing *with* ellipsis/slices/nones by recursing once
    # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#combining-advanced-and-basic-indexing
    elif _is_advanced_int_indexer(idx):
        canonical_idx = _canonicalize_tuple_index(arr, tuple(idx))
        idx_noadvanced = [
            slice(None) if _is_int(e) else e for e in canonical_idx
        ]
        arr_sliced = _rewriting_take(arr, tuple(idx_noadvanced))

        advanced_pairs = ((e, i) for i, e in enumerate(canonical_idx)
                          if _is_int(e))
        idx_advanced, axes = zip(*advanced_pairs)
        idx_advanced = broadcast_arrays(*idx_advanced)

        flat_idx = tuple(
            mod(ravel(x), arr_sliced.shape[i])
            for i, x in zip(axes, idx_advanced))
        out = lax.index_take(arr_sliced, flat_idx, axes)
        shape_suffix = tuple(onp.delete(_shape(arr_sliced), axes))
        out = lax.reshape(out, idx_advanced[0].shape + shape_suffix)

        axes_are_contiguous = onp.all(onp.diff(axes) == 1)
        if axes_are_contiguous:
            start = axes[0]
            naxes = idx_advanced[0].ndim
            out = moveaxis(out, list(range(naxes)),
                           list(range(start, start + naxes)))
        return out

    msg = "Indexing mode not yet supported. Open a feature request!\n{}"
    raise IndexError(msg.format(idx))
Exemple #7
0
 def pop(self) -> Tuple[Any, Stack]:
     """Pops from the stack, returning an (elem, updated stack) pair."""
     elem = jax.tree_util.tree_map(
         lambda x: lax.dynamic_index_in_dim(
             x, self._size - 1, 0, keepdims=False), self._data)
     return elem, Stack(self._size - 1, self._data)