예제 #1
0
 def _ReplaceSlicesWithTuples(self, idx):
   """Helper method to replace slices with tuples for dynamic indexing args."""
   if isinstance(idx, slice):
     triple = idx.start, idx.stop, idx.step
     isnone = [i for i, elt in enumerate(triple) if elt is None]
     zeros = itertools.repeat(0)
     nones = itertools.repeat(None)
     out = util.subvals(triple, zip(isnone, zeros))
     return out, lambda out: slice(*util.subvals(out, zip(isnone, nones)))
   elif isinstance(idx, (tuple, list)) and idx:
     t = type(idx)
     elts, packs = zip(*map(self._ReplaceSlicesWithTuples, idx))
     return elts, lambda elts: t((pack(i) for pack, i in zip(packs, elts)))
   else:
     return idx, lambda x: x
예제 #2
0
 def fun(x, indexer_with_dummies):
   idx = type(indexer)(util.subvals(indexer_with_dummies, substitutes))
   return jnp.asarray(x)[idx]