def _ppermute_batcher(vals_in, dims_in, axis_size, axis_name, perm): assert len(perm) == axis_size, "Permutation doesn't match the axis size!" perm_indices = np.full((axis_size,), -1, dtype=np.int32) for s, d in perm: perm_indices[s] = d vals_out = [lax_numpy.take(v, perm_indices, d) if d is not batching.not_mapped else v for v, d in zip(vals_in, dims_in)] return vals_out, dims_in
def _ppermute_batcher(frame, vals_in, dims_in, axis_name, perm): assert len(perm) == frame.size, "Permutation doesn't match the axis size!" assert axis_name == frame.name, "ppermute batcher called with wrong axis name" (v, ), (d, ) = vals_in, dims_in assert d is not batching.not_mapped perm_indices = [None] * frame.size for src, dst in perm: perm_indices[src] = dst return lax_numpy.take(v, perm_indices, d), d