def _ReplaceSlicesWithTuples(self, idx): """Helper method to replace slices with tuples for dynamic indexing args.""" if isinstance(idx, slice): triple = idx.start, idx.stop, idx.step isnone = [i for i, elt in enumerate(triple) if elt is None] zeros = itertools.repeat(0) nones = itertools.repeat(None) out = lax.subvals(triple, zip(isnone, zeros)) return out, lambda out: slice(*lax.subvals(out, zip(isnone, nones))) elif isinstance(idx, (tuple, list)) and idx: t = type(idx) elts, packs = zip(*map(self._ReplaceSlicesWithTuples, idx)) return elts, lambda elts: t((pack(i) for pack, i in zip(packs, elts))) else: return idx, lambda x: x
def split(ary, indices_or_sections, axis=0): dummy_val = onp.broadcast_to(0, ary.shape) # zero strides subarrays = onp.split(dummy_val, indices_or_sections, axis) # shapes split_indices = onp.cumsum([0] + [onp.shape(sub)[axis] for sub in subarrays]) starts, ends = [0] * ndim(ary), shape(ary) _subval = lambda x, i, v: lax.subvals(x, [(i, v)]) return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end)) for start, end in zip(split_indices[:-1], split_indices[1:])]
def reduction(a, axis=None, dtype=None, out=None, keepdims=False): if out is not None: raise ValueError("reduction does not support `out` argument.") a = a if isinstance(a, ndarray) else asarray(a) dims = _reduction_dims(a, axis) result_dtype = _dtype(np_fun(onp.ones((), dtype=_dtype(a)))) if _dtype(a) != result_dtype: a = lax.convert_element_type(a, result_dtype) result = lax.reduce(a, _reduction_init_val(a, init_val), op, dims) if keepdims: shape_with_singletons = lax.subvals(shape(a), zip(dims, (1,) * len(dims))) result = lax.reshape(result, shape_with_singletons) if dtype and onp.dtype(dtype) != onp.dtype(result_dtype): result = lax.convert_element_type(result, dtype) return result
def fun(x, indexer_with_dummies): idx = type(indexer)(lax.subvals(indexer_with_dummies, substitutes)) return x[idx]