def _ReplaceSlicesWithTuples(self, idx): """Helper method to replace slices with tuples for dynamic indexing args.""" if isinstance(idx, slice): triple = idx.start, idx.stop, idx.step isnone = [i for i, elt in enumerate(triple) if elt is None] zeros = itertools.repeat(0) nones = itertools.repeat(None) out = util.subvals(triple, zip(isnone, zeros)) return out, lambda out: slice(*util.subvals(out, zip(isnone, nones))) elif isinstance(idx, (tuple, list)) and idx: t = type(idx) elts, packs = zip(*map(self._ReplaceSlicesWithTuples, idx)) return elts, lambda elts: t((pack(i) for pack, i in zip(packs, elts))) else: return idx, lambda x: x
def f_aug(*args): outs_and_residuals = core.jaxpr_as_fun(jaxpr)(*args) outs, residuals = split_list(outs_and_residuals, [num_non_res_outputs]) aug_residuals = _map(ad_util.zeros_like_aval, all_res_avals) aug_residuals = util.subvals(aug_residuals, zip(res_indices, residuals)) return outs + list(aug_residuals)
def augment_jaxpr(jaxpr, res_indices): num_res = len(res_indices) res_vars = jaxpr.jaxpr.invars[:num_res] non_res_vars = jaxpr.jaxpr.invars[num_res:] aug_res_vars = list( util.subvals(all_res_vars, zip(res_indices, res_vars))) aug_invars = aug_res_vars + non_res_vars jaxpr_aug = core.Jaxpr(jaxpr.jaxpr.constvars, aug_invars, jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns, jaxpr.jaxpr.effects) jaxpr_aug = core.ClosedJaxpr(jaxpr_aug, jaxpr.consts) return jaxpr_aug
def np_fun(x, indexer_with_dummies): idx = type(indexer)(util.subvals(indexer_with_dummies, substitutes)) return np.asarray(x)[idx]