예제 #1
0
 def _ReplaceSlicesWithTuples(self, idx):
   """Helper method to replace slices with tuples for dynamic indexing args."""
   if isinstance(idx, slice):
     triple = idx.start, idx.stop, idx.step
     isnone = [i for i, elt in enumerate(triple) if elt is None]
     zeros = itertools.repeat(0)
     nones = itertools.repeat(None)
     out = lax.subvals(triple, zip(isnone, zeros))
     return out, lambda out: slice(*lax.subvals(out, zip(isnone, nones)))
   elif isinstance(idx, (tuple, list)) and idx:
     t = type(idx)
     elts, packs = zip(*map(self._ReplaceSlicesWithTuples, idx))
     return elts, lambda elts: t((pack(i) for pack, i in zip(packs, elts)))
   else:
     return idx, lambda x: x
예제 #2
0
def split(ary, indices_or_sections, axis=0):
  dummy_val = onp.broadcast_to(0, ary.shape)  # zero strides
  subarrays = onp.split(dummy_val, indices_or_sections, axis)  # shapes
  split_indices = onp.cumsum([0] + [onp.shape(sub)[axis] for sub in subarrays])
  starts, ends = [0] * ndim(ary), shape(ary)
  _subval = lambda x, i, v: lax.subvals(x, [(i, v)])
  return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end))
          for start, end in zip(split_indices[:-1], split_indices[1:])]
예제 #3
0
  def reduction(a, axis=None, dtype=None, out=None, keepdims=False):
    if out is not None:
      raise ValueError("reduction does not support `out` argument.")

    a = a if isinstance(a, ndarray) else asarray(a)
    dims = _reduction_dims(a, axis)
    result_dtype = _dtype(np_fun(onp.ones((), dtype=_dtype(a))))
    if _dtype(a) != result_dtype:
      a = lax.convert_element_type(a, result_dtype)
    result = lax.reduce(a, _reduction_init_val(a, init_val), op, dims)
    if keepdims:
      shape_with_singletons = lax.subvals(shape(a), zip(dims, (1,) * len(dims)))
      result = lax.reshape(result, shape_with_singletons)
    if dtype and onp.dtype(dtype) != onp.dtype(result_dtype):
      result = lax.convert_element_type(result, dtype)
    return result
예제 #4
0
 def fun(x, indexer_with_dummies):
   idx = type(indexer)(lax.subvals(indexer_with_dummies, substitutes))
   return x[idx]