def body_fun(i, vals): a, state = vals # select i-th element from each b b = [lax.dynamic_index_in_dim(b, i, keepdims=False) for b in bs] a_out = core.eval_jaxpr(jaxpr, consts, (), a, core.pack(b)) # select fields from a_out and update state state_out = [ lax.dynamic_update_index_in_dim(s, a[None, ...], i, axis=0) for a, s in zip([tuple(a_out)[j] for j in fields], state) ] return a_out, state_out
def body_fun(i, state): arr, total, _ = state arr_i = lax.dynamic_index_in_dim(arr, i, 0, False) return (arr, lax.add(total, arr_i), ())
def body_fun(i, state): arr, total = state['arr'], state['total'] arr_i = lax.dynamic_index_in_dim(arr, i, 0, False) return {'arr': arr, 'total': lax.add(total, arr_i)}
def body_fun(state): arr, num, i, total = state arr_i = lax.dynamic_index_in_dim(arr, i, 0, False) return (arr, num, lax.add(i, 1), lax.add(total, arr_i))
def body_fun(state): i, j, arr, out = state arr_i = lax.dynamic_index_in_dim(arr, i, 0, False) arr_i_j = lax.dynamic_index_in_dim(arr_i, j, 0, False) out = update_entry(out, arr_i_j, i, j) return (i, lax.add(j, 1), arr, out)
def _rewriting_take(arr, idx, axis=0): """A function like numpy.take that handles boxes and rewrites to LAX.""" # Handle special indexers: (), Ellipsis, slice(None), and None. # TODO(mattjj): don't compare empty tuple identity (though works for CPython) if idx is () or idx is Ellipsis or _is_slice_none(idx): # pylint: disable=literal-comparison return arr elif idx is None: return expand_dims(arr, 0) # Handle int index _int = lambda aval: not aval.shape and onp.issubdtype( aval.dtype, onp.integer) try: abstract_idx = core.get_aval(idx) except TypeError: abstract_idx = None if isinstance(abstract_idx, ConcreteArray) and _int(abstract_idx): return lax.index_in_dim(arr, idx, axis, False) elif isinstance(abstract_idx, ShapedArray) and _int(abstract_idx): idx = mod(idx, arr.shape[axis]) return lax.dynamic_index_in_dim(arr, idx, axis, False) # Handle slice index (only static, otherwise an error is raised) elif isinstance(idx, slice): if not _all( elt is None or isinstance(core.get_aval(elt), ConcreteArray) for elt in (idx.start, idx.stop, idx.step)): msg = ( "Array slice indices must have static start/stop/step to be used " "with Numpy indexing syntax. Try lax.dynamic_slice instead.") raise IndexError(msg) else: start, limit, stride, needs_rev = _static_idx(idx, arr.shape[axis]) result = lax.slice_in_dim(arr, start, limit, stride, axis=axis) return lax.rev(result, [axis]) if needs_rev else result # Handle non-advanced tuple indices by recursing once elif isinstance(idx, tuple) and _all(onp.ndim(elt) == 0 for elt in idx): canonical_idx = _canonicalize_tuple_index(arr, idx) result, axis = arr, 0 for elt in (elt for elt in canonical_idx if elt is not None): result = _rewriting_take(result, elt, axis=axis) axis += isinstance(elt, slice) # advance axis index if not eliminated unexpanded_shape_itr = iter(result.shape) result_shape = tuple(1 if elt is None else next(unexpanded_shape_itr) for elt in canonical_idx if not isinstance(elt, int)) return lax.reshape(result, result_shape) # Handle advanced indexing (non-tuple sequence, ndarray of dtype int or bool, # or a tuple with at least one sequence object). # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing # https://gist.github.com/seberg/976373b6a2b7c4188591 # Handle integer array indexing *without* ellipsis/slices/nones # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#integer-array-indexing if _is_advanced_int_indexer_without_slices(idx): if isinstance(idx, list): if _any(_shape(e) for e in idx): # At least one sequence element in the index list means broadcasting. idx = broadcast_arrays(*idx) else: # The index list is a flat list of integers. idx = [ lax.concatenate([lax.reshape(e, (1, )) for e in idx], 0) ] else: # The indexer is just a single integer array. idx = [idx] flat_idx = tuple( mod(ravel(x), arr.shape[i]) for i, x in enumerate(idx)) out = lax.index_take(arr, flat_idx, tuple(range(len(idx)))) return lax.reshape(out, idx[0].shape + _shape(arr)[len(idx):]) # Handle integer array indexing *with* ellipsis/slices/nones by recursing once # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#combining-advanced-and-basic-indexing elif _is_advanced_int_indexer(idx): canonical_idx = _canonicalize_tuple_index(arr, tuple(idx)) idx_noadvanced = [ slice(None) if _is_int(e) else e for e in canonical_idx ] arr_sliced = _rewriting_take(arr, tuple(idx_noadvanced)) advanced_pairs = ((e, i) for i, e in enumerate(canonical_idx) if _is_int(e)) idx_advanced, axes = zip(*advanced_pairs) idx_advanced = broadcast_arrays(*idx_advanced) flat_idx = tuple( mod(ravel(x), arr_sliced.shape[i]) for i, x in zip(axes, idx_advanced)) out = lax.index_take(arr_sliced, flat_idx, axes) shape_suffix = tuple(onp.delete(_shape(arr_sliced), axes)) out = lax.reshape(out, idx_advanced[0].shape + shape_suffix) axes_are_contiguous = onp.all(onp.diff(axes) == 1) if axes_are_contiguous: start = axes[0] naxes = idx_advanced[0].ndim out = moveaxis(out, list(range(naxes)), list(range(start, start + naxes))) return out msg = "Indexing mode not yet supported. Open a feature request!\n{}" raise IndexError(msg.format(idx))
def pop(self) -> Tuple[Any, Stack]: """Pops from the stack, returning an (elem, updated stack) pair.""" elem = jax.tree_util.tree_map( lambda x: lax.dynamic_index_in_dim( x, self._size - 1, 0, keepdims=False), self._data) return elem, Stack(self._size - 1, self._data)