Пример #1
0
 def _ReplaceSlicesWithTuples(self, idx):
   """Helper method to replace slices with tuples for dynamic indexing args."""
   if isinstance(idx, slice):
     triple = idx.start, idx.stop, idx.step
     isnone = [i for i, elt in enumerate(triple) if elt is None]
     zeros = itertools.repeat(0)
     nones = itertools.repeat(None)
     out = util.subvals(triple, zip(isnone, zeros))
     return out, lambda out: slice(*util.subvals(out, zip(isnone, nones)))
   elif isinstance(idx, (tuple, list)) and idx:
     t = type(idx)
     elts, packs = zip(*map(self._ReplaceSlicesWithTuples, idx))
     return elts, lambda elts: t((pack(i) for pack, i in zip(packs, elts)))
   else:
     return idx, lambda x: x
Пример #2
0
 def f_aug(*args):
     outs_and_residuals = core.jaxpr_as_fun(jaxpr)(*args)
     outs, residuals = split_list(outs_and_residuals,
                                  [num_non_res_outputs])
     aug_residuals = _map(ad_util.zeros_like_aval, all_res_avals)
     aug_residuals = util.subvals(aug_residuals,
                                  zip(res_indices, residuals))
     return outs + list(aug_residuals)
Пример #3
0
    def augment_jaxpr(jaxpr, res_indices):
        num_res = len(res_indices)
        res_vars = jaxpr.jaxpr.invars[:num_res]
        non_res_vars = jaxpr.jaxpr.invars[num_res:]

        aug_res_vars = list(
            util.subvals(all_res_vars, zip(res_indices, res_vars)))
        aug_invars = aug_res_vars + non_res_vars
        jaxpr_aug = core.Jaxpr(jaxpr.jaxpr.constvars, aug_invars,
                               jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns,
                               jaxpr.jaxpr.effects)
        jaxpr_aug = core.ClosedJaxpr(jaxpr_aug, jaxpr.consts)
        return jaxpr_aug
Пример #4
0
 def np_fun(x, indexer_with_dummies):
   idx = type(indexer)(util.subvals(indexer_with_dummies, substitutes))
   return np.asarray(x)[idx]