def test_vmap_after(self): batch = 4 qy_size = 128 db_size = 1024 feature_dim = 32 k = 10 rng = jtu.rand_default(self.rng()) qy = rng([qy_size, feature_dim, batch], np.float32) db = rng([db_size, feature_dim, batch], np.float32) recall = 0.95 # Create ground truth gt_scores = lax.dot_general(qy, db, (([1], [1]), ([2], [2]))) _, gt_args = lax.top_k(gt_scores, k) gt_args = lax.transpose(gt_args, [2, 0, 1]) gt_args = lax.reshape(gt_args, [qy_size * batch, k]) # test target def approx_max_k(qy, db): scores = qy @ db.transpose() return lax.approx_max_k(scores, k) _, ann_args = jax.vmap(approx_max_k, (2, 2))(qy, db) ann_args = lax.transpose(ann_args, [2, 0, 1]) ann_args = lax.reshape(ann_args, [qy_size * batch, k]) ann_recall = compute_recall(np.asarray(ann_args), np.asarray(gt_args)) self.assertGreater(ann_recall, recall)
def threefry_random_bits(key: jnp.ndarray, bit_width, shape): """Sample uniform random bits of given width and shape using PRNG key.""" if not _is_threefry_prng_key(key): raise TypeError("_random_bits got invalid prng key.") if bit_width not in (8, 16, 32, 64): raise TypeError("requires 8-, 16-, 32- or 64-bit field width.") shape = core.as_named_shape(shape) for name, size in shape.named_items: real_size = lax.psum(1, name) if real_size != size: raise ValueError( f"The shape of axis {name} was specified as {size}, " f"but it really is {real_size}") axis_index = lax.axis_index(name) key = threefry_fold_in(key, axis_index) size = prod(shape.positional) # Compute ceil(bit_width * size / 32) in a way that is friendly to shape # polymorphism max_count, r = divmod(bit_width * size, 32) if r > 0: max_count += 1 if core.is_constant_dim(max_count): nblocks, rem = divmod(max_count, jnp.iinfo(np.uint32).max) else: nblocks, rem = 0, max_count if not nblocks: bits = threefry_2x32(key, lax.iota(np.uint32, rem)) else: keys = threefry_split(key, nblocks + 1) subkeys, last_key = keys[:-1], keys[-1] blocks = vmap(threefry_2x32, in_axes=(0, None))(subkeys, lax.iota(np.uint32, jnp.iinfo(np.uint32).max)) last = threefry_2x32(last_key, lax.iota(np.uint32, rem)) bits = lax.concatenate([blocks.ravel(), last], 0) dtype = UINT_DTYPES[bit_width] if bit_width == 64: bits = [lax.convert_element_type(x, dtype) for x in jnp.split(bits, 2)] bits = lax.shift_left(bits[0], dtype(32)) | bits[1] elif bit_width in [8, 16]: # this is essentially bits.view(dtype)[:size] bits = lax.bitwise_and( np.uint32(np.iinfo(dtype).max), lax.shift_right_logical( lax.broadcast(bits, (1, )), lax.mul( np.uint32(bit_width), lax.broadcasted_iota(np.uint32, (32 // bit_width, 1), 0)))) bits = lax.reshape(bits, (np.uint32(max_count * 32 // bit_width), ), (1, 0)) bits = lax.convert_element_type(bits, dtype)[:size] return lax.reshape(bits, shape)
def _lu_python(x): """Default LU decomposition in Python, where no better version exists.""" m, n = x.shape[-2:] batch_dims = x.shape[:-2] if len(batch_dims) > 0: batch_size =, dtype=onp.int64) pivot, lu = api.vmap(_lu_blocked)(lax.reshape(x, (batch_size, m, n))) pivot = lax.reshape(pivot, batch_dims + (min(m, n), )) lu = lax.reshape(lu, batch_dims + (m, n)) else: pivot, lu = _lu_blocked(x) return lu, pivot
def repeat(a, repeats, axis=None): if not isscalar(repeats): raise NotImplementedError( "np.repeat implementation only supports scalar repeats") if axis is None or isscalar(a): a = ravel(a) axis = 0 a_shape = list(shape(a)) num_dims = len(a_shape) if axis < 0: axis = axis + num_dims if axis < 0 or axis >= num_dims: raise ValueError( "axis {} is out of bounds for array of dimension {}".format( axis, num_dims)) # Broadcasts to [..., X, repeats, ...] and reshapes to [..., X * repeats, ...] broadcast_shape = list(a_shape) broadcast_shape.insert(axis + 1, repeats) broadcast_dims = onp.concatenate( (onp.arange(0, axis + 1), onp.arange(axis + 2, num_dims + 1))) a_shape[axis] *= repeats return lax.reshape( lax.broadcast_in_dim(a, broadcast_shape, broadcast_dims), a_shape)
def array(object, dtype=None, copy=True, order="K", ndmin=0): del copy # Unused. if ndmin != 0 or order != "K": raise NotImplementedError("Only implemented for order='K', ndmin=0.") if hasattr(object, '__asarray__'): return object.__asarray__(dtype) elif isinstance(object, ndarray): if dtype and _dtype(object) != dtype: return lax.convert_element_type(object, dtype) else: return object elif isinstance(object, (list, tuple)): if object: subarrays = [ expand_dims(array(elt, dtype=dtype), 0) for elt in object ] return concatenate(subarrays) else: return onp.array([], dtype) elif isscalar(object): out = lax.reshape(object, ()) if dtype and _dtype(out) != dtype: return lax.convert_element_type(out, dtype) else: return out else: raise TypeError("Unexpected input type for array: {}".format( type(object)))
def threefry_2x32(keypair, count): """Apply the Threefry 2x32 hash. Args: keypair: a pair of 32bit unsigned integers used for the key. count: an array of dtype uint32 used for the counts. Returns: An array of dtype uint32 with the same shape as `count`. """ key1, key2 = keypair if not lax.dtype(key1) == lax.dtype(key2) == lax.dtype(count) == np.uint32: msg = "threefry_2x32 requires uint32 arguments, got {}" raise TypeError(msg.format([lax.dtype(x) for x in [key1, key2, count]])) try: odd_size = count.size % 2 except core.InconclusiveDimensionOperation as e: msg = ( "jax.random functions have limited support for shape polymorphism. " "In particular, the product of the known dimensions must be even.") raise core.InconclusiveDimensionOperation(msg) from e if odd_size: x = list(jnp.split(jnp.concatenate([count.ravel(), np.uint32([0])]), 2)) else: x = list(jnp.split(count.ravel(), 2)) x = threefry2x32_p.bind(key1, key2, x[0], x[1]) out = jnp.concatenate(x) assert out.dtype == np.uint32 return lax.reshape(out[:-1] if odd_size else out, count.shape)
def _cumulative_reduction(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, out=None): _check_arraylike(np_reduction.__name__, a) if out is not None: raise NotImplementedError( f"The 'out' argument to jnp.{np_reduction.__name__} " f"is not supported.") lax_internal._check_user_dtype_supported(dtype, np_reduction.__name__) if axis is None or _isscalar(a): a = lax.reshape(a, (np.size(a), )) axis = 0 a_shape = list(np.shape(a)) num_dims = len(a_shape) axis = _canonicalize_axis(axis, num_dims) if fill_nan: a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a) if not dtype and dtypes.dtype(a) == np.bool_: dtype = dtypes.canonicalize_dtype(dtypes.int_) if dtype: a = lax.convert_element_type(a, dtype) return reduction(a, axis)
def _lu_solve_core(lu, pivots, b, trans): m = lu.shape[0] permutation = lu_pivots_to_permutation(pivots, m) x = np.reshape(b, (m, -1)) if trans == 0: x = x[permutation, :] x = triangular_solve(lu, x, left_side=True, lower=True, unit_diagonal=True) x = triangular_solve(lu, x, left_side=True, lower=False) elif trans == 1 or trans == 2: conj = trans == 2 x = triangular_solve(lu, x, left_side=True, lower=False, transpose_a=True, conjugate_a=conj) x = triangular_solve(lu, x, left_side=True, lower=True, unit_diagonal=True, transpose_a=True, conjugate_a=conj) x = x[np.argsort(permutation), :] else: raise ValueError( "'trans' value must be 0, 1, or 2, got {}".format(trans)) return lax.reshape(x, b.shape)
def threefry_seed(seed: int) -> jnp.ndarray: """Create a single raw threefry PRNG key given an integer seed. Args: seed: a 64- or 32-bit integer used as the value of the key. Returns: The PRNG key contents, modeled as an array of shape (2,) and dtype uint32. The key is constructed from a 64-bit seed by effectively bit-casting to a pair of uint32 values (or from a 32-bit seed by first padding out with zeros). """ # Avoid overflowerror in X32 mode by first converting ints to int64. # This breaks JIT invariance for large ints, but supports the common # use-case of instantiating with Python hashes in X32 mode. if isinstance(seed, int): seed_arr = jnp.asarray(np.int64(seed)) else: seed_arr = jnp.asarray(seed) if seed_arr.shape: raise TypeError(f"PRNG key seed must be a scalar; got {seed!r}.") if not np.issubdtype(seed_arr.dtype, np.integer): raise TypeError(f"PRNG key seed must be an integer; got {seed!r}") convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1]) k1 = convert(lax.shift_right_logical(seed_arr, lax._const(seed_arr, 32))) k2 = convert(jnp.bitwise_and(seed_arr, np.uint32(0xFFFFFFFF))) return lax.concatenate([k1, k2], 0)
def scan_reference(f, init, xs): carry = init ys = [] for x in xs: (carry, y) = f(carry, x) ys.append(lax.reshape(y, (1, ) + onp.shape(y))) ys = lax.concatenate(ys, 0) return carry, ys
def promote_shapes(*args, shape=()): # adapted from lax.lax_numpy if len(args) < 2 and not shape: return args else: shapes = [np.shape(arg) for arg in args] num_dims = len(lax.broadcast_shapes(shape, *shapes)) return [lax.reshape(arg, (1,) * (num_dims - len(s)) + s) if len(s) < num_dims else arg for arg, s in zip(args, shapes)]
def _promote_shapes(*args): """Prepend implicit leading singleton dimensions for Numpy broadcasting.""" if len(args) < 2: return args else: shapes = [shape(arg) for arg in args] nd = len(_broadcast_shapes(*shapes)) return [lax.reshape(arg, (1,) * (nd - len(shp)) + shp) if len(shp) != nd else arg for arg, shp in zip(args, shapes)]
def squeeze(a, axis=None): if 1 not in shape(a): return a if axis is None: newshape = [d for d in shape(a) if d != 1] else: axis = frozenset(onp.mod(axis, ndim(a)).reshape(-1)) newshape = [d for i, d in enumerate(shape(a)) if d != 1 or i not in axis] return lax.reshape(a, newshape)
def matmul(a, b): # pylint: disable=missing-docstring _check_arraylike("matmul", a, b) a_is_vec, b_is_vec = (ndim(a) == 1), (ndim(b) == 1) a = lax.reshape(a, (1, ) + shape(a)) if a_is_vec else a b = lax.reshape(b, shape(b) + (1, )) if b_is_vec else b a, b = _promote_dtypes(a, b) batch_shape = _broadcast_shapes(shape(a)[:-2], shape(b)[:-2]) a = broadcast_to(a, batch_shape + shape(a)[-2:]) b = broadcast_to(b, batch_shape + shape(b)[-2:]) batch_dims = tuple(range(len(batch_shape))) result = lax.dot_general(a, b, (((ndim(a) - 1, ), (ndim(b) - 2, )), (batch_dims, batch_dims))) if a_is_vec or b_is_vec: m, n = shape(result)[-2:] new_m = () if a_is_vec else (m, ) new_n = () if b_is_vec else (n, ) return lax.reshape(result, batch_shape + new_m + new_n) else: return result
def testSparseReshapeWithDimensions(self, shape, new_shape, n_batch, n_dense, dimensions): rng = jtu.rand_some_zero(self.rng()) arr = rng(shape, 'int32') arr_sparse = BCOO.fromdense(arr, n_batch=n_batch, n_dense=n_dense) f = self.sparsify( lambda x: lax.reshape(x, new_shape, dimensions=dimensions)) arr2 = f(arr) arr2_sparse = f(arr_sparse) self.assertArraysEqual(arr2, arr2_sparse.todense())
def reshape(a, newshape, order="C"): # pylint: disable=missing-docstring if order == "C" or order is None: dims = None elif order == "F": dims = onp.arange(ndim(a))[::-1] elif order == "A": dims = onp.arange(ndim(a))[::-1] if isfortran(a) else onp.arange(ndim(a)) else: raise ValueError("Unexpected value for 'order' argument: {}.".format(order)) dummy_val = onp.broadcast_to(0, a.shape) # zero strides computed_newshape = onp.reshape(dummy_val, newshape).shape return lax.reshape(a, computed_newshape, dims)
def chooser_taylor_rule(primals_in, series_in, **params): operand, = primals_in gs, = series_in primal_out = chooser_fun(operand, **params) axes = params.pop("axes", None) primal_dtype = gs[0].dtype shape = [1 if i in axes else d for i, d in enumerate(operand.shape)] location_indicators = lax.convert_element_type( lax._eq_meet(operand, lax.reshape(primal_out, shape)), primal_dtype) counts = lax._reduce_sum(location_indicators, axes) def _reduce_chooser_taylor_rule(g): return lax.div(lax._reduce_sum(lax.mul(g, location_indicators), axes), counts) series_out = [_reduce_chooser_taylor_rule(g) for g in gs] return primal_out, series_out
def dot(a, b): # pylint: disable=missing-docstring _check_arraylike("dot", a, b) a, b = _promote_dtypes(a, b) a_ndim, b_ndim = ndim(a), ndim(b) if a_ndim == 0 or b_ndim == 0: return lax.mul(a, b) if _max(a_ndim, b_ndim) <= 2: return, b) a_reshaped = reshape(a, (-1, shape(a)[-1])) if _ndim(b) in {1, 2}: out =, b) else: b_reshaped = reshape(moveaxis(b, -2, 0), (shape(b)[-2], -1)) out =, b_reshaped) return lax.reshape(out, a.shape[:-1] + b.shape[:-2] + b.shape[-2:][1:])
def reduction(a, axis=None, dtype=None, out=None, keepdims=False): if out is not None: raise ValueError("reduction does not support `out` argument.") a = a if isinstance(a, ndarray) else asarray(a) dims = _reduction_dims(a, axis) result_dtype = _dtype(np_fun(onp.ones((), dtype=_dtype(a)))) if _dtype(a) != result_dtype: a = lax.convert_element_type(a, result_dtype) result = lax.reduce(a, _reduction_init_val(a, init_val), op, dims) if keepdims: shape_with_singletons = lax.subvals(shape(a), zip(dims, (1,) * len(dims))) result = lax.reshape(result, shape_with_singletons) if dtype and onp.dtype(dtype) != onp.dtype(result_dtype): result = lax.convert_element_type(result, dtype) return result
def _reshape_papply_rule(name, vals, axes, new_sizes, dimensions, old_sizes): operand, = vals axis, = axes def filter_ones(xs): return filter(lambda x: x != 1, xs) def find_new_axis(old_axis, old_sizes, new_sizes): if len(filter_ones(new_sizes)) != len(filter_ones(old_sizes)): return None num_before = len(filter_ones(old_sizes[:old_axis])) sz = old_sizes[old_axis] for i, new_sz in enumerate(new_sizes): if num_before == 0: if new_sz == sz: return i elif new_sz != 1: return None elif new_sz != 1: num_before -= 1 return None err = NotImplementedError( 'papply of reshape that would change hidden dimension size') if dimensions is None: new_axis = find_new_axis(axis, old_sizes, new_sizes) if new_axis is not None: if ([:axis]) !=[:new_axis]) or[axis + 1:]) != new_sizes[new_axis + 1:])): raise err new_sizes_ = new_sizes[:new_axis] + new_sizes[new_axis + 1:] return lax.reshape(operand, new_sizes_, dimensions=dimensions), new_axis else: raise err else: raise NotImplementedError('papply of reshape with `dimensions`')
def update_entry(arr, val, i, j): val = lax.reshape(val, [1, 1]) return lax.dynamic_update_slice(arr, val, (i, j))
def reshape_op(self, params, inputs): return lax.reshape(inputs, (inputs.shape[0], *self.target_shape))
def testReshapeGrad(self, arg_shape, out_shape, permutation, dtype, rng_factory): rng = rng_factory(self.rng()) operand = rng(arg_shape, dtype) reshape = lambda x: lax.reshape(x, out_shape, permutation) check_grads(reshape, (operand,), 2, ["fwd", "rev"], eps=1.)
def softmax(attn_weights, norm_dims, dtype, softmax_hparams, quant_context): """Normalizes attention.""" a = attn_weights def unquantized_softmax(a): a = lax.exp( a - jax.scipy.special.logsumexp(a, axis=norm_dims, keepdims=True)) return a.astype(dtype) # Quantize intermediate activations with QuantOps. # Currently only supports unscaled floating-point formats. def quantized_softmax(a): # We compute softmax as exp(x-max(x))/sum_i(exp(x_i-max(x))), quantizing # intermediate values. Note this differs from the log-domain # implementation of softmax used above. quant_hparams = softmax_hparams.quant_hparams fp_quant_config = QuantOps.FloatQuant(is_scaled=False, fp_spec=quant_hparams.prec) quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant_config, bounds=None) a = quant_ops.to_quantized(a, dtype=dtype) # Note that the max of a quantized vector is necessarily also quantized to # the same precision since the max of a vector must be an existing element # of the vector, so we don't need to explicitly insert a quantization # operator to the output of the max reduction. a_max = jnp.max(a, axis=norm_dims, keepdims=True) a_minus_max = quant_ops.to_quantized(a - a_max, dtype=dtype) a_exp = quant_ops.to_quantized(jnp.exp(a_minus_max), dtype=dtype) sum_exp_quantized_reduction = quantization.quantized_sum( a_exp, axis=norm_dims, keepdims=True, prec=quant_hparams.reduction_prec) sum_exp = quant_ops.to_quantized(sum_exp_quantized_reduction, dtype=dtype) inv_sum_exp = quant_ops.to_quantized(jnp.reciprocal(sum_exp), dtype=dtype) a_softmax = quant_ops.to_quantized(a_exp * inv_sum_exp, dtype=dtype) return a_softmax.astype(dtype) # If no params, return accurate Softmax. if softmax_hparams == SoftmaxHParams(None, None, None) or softmax_hparams is None: return unquantized_softmax(a) # TODO(shivaniagrawal): Partial sum quantization (if enabled) will happen for # the entire training run, even before the global activation start step. if softmax_hparams.quant_hparams is not None: return lax.cond(quant_context.quantize_acts, quantized_softmax, unquantized_softmax, a) # Approximated Softmax exp_hparams = softmax_hparams.exp_hparams recip_hparams = softmax_hparams.reciprocal_hparams # Substract max value from dimensions to be normalized. shape = jax.util.subvals(onp.shape(a), zip(norm_dims, (1, ) * len(norm_dims))) dimadd = lambda x: lax.reshape(x, shape) # pylint: disable=protected-access amax = lax.reduce(a, lax_numpy._constant_like(a, -onp.inf), lax.max, norm_dims) amax =, amax, lax.full_like(amax, 0)) amax_singletons = dimadd(amax) asubmax = lax.sub(a, amax_singletons) # Calculate approximated exponential approx_exp = exponential(asubmax, dtype, exp_hparams) # If sum_high_bound: Upper clip bound for sum(exp(x-M)). asumexp = dimadd( lax.reduce(approx_exp, lax_numpy._constant_like(a, 0), lax.add, norm_dims)) if exp_hparams.sum_high_bound is not None and exp_hparams.sum_high_bound != 0: sum_low_bound = 1. if (exp_hparams.low_bound != 0) and exp_hparams.clip_and_subtract: sum_low_bound = 1 - onp.exp(exp_hparams.low_bound) asumexp = jnp.clip(asumexp, sum_low_bound, exp_hparams.sum_high_bound) # Approximation of reciprocal. arecip = reciprocal(asumexp, dtype, recip_hparams) return lax.mul(approx_exp, arecip).astype(dtype)
def _threefry_split(key, num) -> jnp.ndarray: counts = lax.iota(np.uint32, num * 2) return lax.reshape(threefry_2x32(key, counts), (num, 2))
def expand_dims(a, axis): shape = _shape(a) axis = axis % (ndim(a) + 1) # pylint: disable=g-no-augmented-assignment return lax.reshape(a, shape[:axis] + (1, ) + shape[axis:])
def testReshape(self, arg_shape, out_shape, dtype, dimensions, bdims, rng_factory): rng = rng_factory(self.rng()) op = lambda x: lax.reshape(x, out_shape, dimensions=dimensions) self._CheckBatching(op, 10, bdims, (arg_shape, ), (dtype, ), rng)
def flatten(x): return lax.reshape(x, (x.shape[0] * x.shape[1], ))
def testReshape(self, arg_shape, out_shape, dtype, dimensions, bdims): rng = jtu.rand_default(self.rng()) op = lambda x: lax.reshape(x, out_shape, dimensions=dimensions) self._CheckBatching(op, 10, bdims, (arg_shape, ), (dtype, ), rng)
def _rewriting_take(arr, idx, axis=0): """A function like numpy.take that handles boxes and rewrites to LAX.""" # Handle special indexers: (), Ellipsis, slice(None), and None. # TODO(mattjj): don't compare empty tuple identity (though works for CPython) if idx is () or idx is Ellipsis or _is_slice_none(idx): # pylint: disable=literal-comparison return arr elif idx is None: return expand_dims(arr, 0) # Handle int index _int = lambda aval: not aval.shape and onp.issubdtype( aval.dtype, onp.integer) try: abstract_idx = core.get_aval(idx) except TypeError: abstract_idx = None if isinstance(abstract_idx, ConcreteArray) and _int(abstract_idx): return lax.index_in_dim(arr, idx, axis, False) elif isinstance(abstract_idx, ShapedArray) and _int(abstract_idx): idx = mod(idx, arr.shape[axis]) return lax.dynamic_index_in_dim(arr, idx, axis, False) # Handle slice index (only static, otherwise an error is raised) elif isinstance(idx, slice): if not _all( elt is None or isinstance(core.get_aval(elt), ConcreteArray) for elt in (idx.start, idx.stop, idx.step)): msg = ( "Array slice indices must have static start/stop/step to be used " "with Numpy indexing syntax. Try lax.dynamic_slice instead.") raise IndexError(msg) else: start, limit, stride, needs_rev = _static_idx(idx, arr.shape[axis]) result = lax.slice_in_dim(arr, start, limit, stride, axis=axis) return lax.rev(result, [axis]) if needs_rev else result # Handle non-advanced tuple indices by recursing once elif isinstance(idx, tuple) and _all(onp.ndim(elt) == 0 for elt in idx): canonical_idx = _canonicalize_tuple_index(arr, idx) result, axis = arr, 0 for elt in (elt for elt in canonical_idx if elt is not None): result = _rewriting_take(result, elt, axis=axis) axis += isinstance(elt, slice) # advance axis index if not eliminated unexpanded_shape_itr = iter(result.shape) result_shape = tuple(1 if elt is None else next(unexpanded_shape_itr) for elt in canonical_idx if not isinstance(elt, int)) return lax.reshape(result, result_shape) # Handle advanced indexing (non-tuple sequence, ndarray of dtype int or bool, # or a tuple with at least one sequence object). # # # Handle integer array indexing *without* ellipsis/slices/nones # if _is_advanced_int_indexer_without_slices(idx): if isinstance(idx, list): if _any(_shape(e) for e in idx): # At least one sequence element in the index list means broadcasting. idx = broadcast_arrays(*idx) else: # The index list is a flat list of integers. idx = [ lax.concatenate([lax.reshape(e, (1, )) for e in idx], 0) ] else: # The indexer is just a single integer array. idx = [idx] flat_idx = tuple( mod(ravel(x), arr.shape[i]) for i, x in enumerate(idx)) out = lax.index_take(arr, flat_idx, tuple(range(len(idx)))) return lax.reshape(out, idx[0].shape + _shape(arr)[len(idx):]) # Handle integer array indexing *with* ellipsis/slices/nones by recursing once # elif _is_advanced_int_indexer(idx): canonical_idx = _canonicalize_tuple_index(arr, tuple(idx)) idx_noadvanced = [ slice(None) if _is_int(e) else e for e in canonical_idx ] arr_sliced = _rewriting_take(arr, tuple(idx_noadvanced)) advanced_pairs = ((e, i) for i, e in enumerate(canonical_idx) if _is_int(e)) idx_advanced, axes = zip(*advanced_pairs) idx_advanced = broadcast_arrays(*idx_advanced) flat_idx = tuple( mod(ravel(x), arr_sliced.shape[i]) for i, x in zip(axes, idx_advanced)) out = lax.index_take(arr_sliced, flat_idx, axes) shape_suffix = tuple(onp.delete(_shape(arr_sliced), axes)) out = lax.reshape(out, idx_advanced[0].shape + shape_suffix) axes_are_contiguous = onp.all(onp.diff(axes) == 1) if axes_are_contiguous: start = axes[0] naxes = idx_advanced[0].ndim out = moveaxis(out, list(range(naxes)), list(range(start, start + naxes))) return out msg = "Indexing mode not yet supported. Open a feature request!\n{}" raise IndexError(msg.format(idx))