def backwards_recursion(C, c, F, f): def backwards_step(carry, params): V_tp1, v_tp1 = carry (C_t, c_t, F_t, f_t) = params n = f_t.shape[0] FTV_tp1 = F_t.T @ V_tp1 Q_t = C_t + FTV_tp1 @ F_t q_t = c_t + FTV_tp1 @ f_t + F_t.T @ v_tp1 Q_x, Q_u = np.split(Q_t, (n, )) Q_xx, Q_xu = np.split(Q_x, (n, ), axis=1) Q_ux, Q_uu = np.split(Q_u, (n, ), axis=1) q_x, q_u = np.split(q_t, (n, )) Quxqu = np.concatenate((Q_ux, q_u[:, None]), axis=1) Ktkt = -jax.scipy.linalg.solve(Q_uu, Quxqu, sym_pos=True) K_t, k_t = np.split(Ktkt, (n, ), axis=1) k_t = k_t[:, 0] KTQuu = K_t.T @ Q_uu V_t = Q_xx + Q_xu @ K_t + K_t.T @ Q_ux + KTQuu @ K_t v_t = q_x + Q_xu @ k_t + K_t.T @ q_u + KTQuu @ k_t return (V_t, v_t), (K_t, k_t) V_T = np.zeros((f.shape[1], f.shape[1])) v_T = np.zeros((f.shape[1], )) K, kvec = lax.scan(backwards_step, (V_T, v_T), (C, c, F, f))[1] return lax.rev(K, (0, )), lax.rev(kvec, (0, ))
def svd(a, full_matrices: bool = True, compute_uv: bool = True, hermitian: bool = False): a, = _promote_dtypes_inexact(jnp.asarray(a)) if hermitian: w, v = lax_linalg.eigh(a) s = lax.abs(v) if compute_uv: sign = lax.sign(v) idxs = lax.broadcasted_iota(np.int64, s.shape, dimension=s.ndim - 1) s, idxs, sign = lax.sort((s, idxs, sign), dimension=-1, num_keys=1) s = lax.rev(s, dimensions=[s.ndim - 1]) idxs = lax.rev(idxs, dimensions=[s.ndim - 1]) sign = lax.rev(sign, dimensions=[s.ndim - 1]) u = jnp.take_along_axis(w, idxs[..., None, :], axis=-1) vh = _H(u * sign[..., None, :]) return u, s, vh else: return lax.rev(lax.sort(s, dimension=-1), dimensions=[s.ndim - 1]) return lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
def rev_dependency_rule(outstart, outcount, operand, dimensions): instart = [ size - (start + outsize) if d in dimensions else start for d, ( size, outsize, start) in enumerate(zip(operand.shape, outcount.shape, outstart)) ] return ([(instart, outcount.shape)], [ Ones(outcount.shape) if is_ones(outcount) else lax.rev( outcount, dimensions) ], lambda inslice: lax.rev(inslice, dimensions))
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, normalize_indices): dtype = lax.dtype(x) x, y = jnp._promote_dtypes(x, y) idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx) indexer = jnp._index_to_gather(jnp.shape(x), idx, normalize_indices=normalize_indices) # Broadcast `y` to the slice output shape. y = jnp.broadcast_to(y, tuple(indexer.slice_shape)) # Collapse any `None`/`jnp.newaxis` dimensions. y = jnp.squeeze(y, axis=indexer.newaxis_dims) if indexer.reversed_y_dims: y = lax.rev(y, indexer.reversed_y_dims) # Transpose the gather dimensions into scatter dimensions (cf. # lax._gather_transpose_rule) dnums = lax.ScatterDimensionNumbers( update_window_dims=indexer.dnums.offset_dims, inserted_window_dims=indexer.dnums.collapsed_slice_dims, scatter_dims_to_operand_dims=indexer.dnums.start_index_map) out = scatter_op(x, indexer.gather_indices, y, dnums, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices) return lax.convert_element_type(out, dtype)
def testReverseGrad(self): rev = lambda operand: lax.rev(operand, dimensions) dimensions = [0] check_grads(rev, (onp.array([3., 2., 1.]),), 2) dimensions = [0, 1] check_grads(rev, (onp.array([[6., 5., 4.], [3., 2., 1.]]),), 2, rtol={onp.float32: 3e-3})
def testRevLax(self): fun = lambda x: lax.rev(x, [0]) R = np.random.RandomState(0).randn x = R(2, 3) ans = vmap(fun)(x) expected_ans = x[:, ::-1] self.assertAllClose(ans, expected_ans, check_dtypes=False) ans = vmap(fun, (1,), 1)(x) expected_ans = x[::-1, :] self.assertAllClose(ans, expected_ans, check_dtypes=False)
def _matrix_put(ndarray, idx, val, block_size=1): """Similar to numpy.put using LAX operations.""" idx_i, idx_j = idx sli, row_rev = _canonical_idx(ndarray.shape, idx_i, -2, block_size) slj, col_rev = _canonical_idx(ndarray.shape, idx_j, -1, block_size) if not sli.step == slj.step == 1: raise TypeError("Non-unit step not supported in assigment.") if row_rev or col_rev: val = lax.rev(val, *onp.where([row_rev, col_rev])) start_indices = [0] * (ndarray.ndim - 2) + [sli.start, slj.start] return lax.dynamic_update_slice(ndarray, val, start_indices)
def _matrix_take(ndarray, idx, block_size=1): """Similar to numpy.take using LAX operations.""" idx_i, idx_j = idx sli, row_rev = _canonical_idx(ndarray.shape, idx_i, -2, block_size) slj, col_rev = _canonical_idx(ndarray.shape, idx_j, -1, block_size) start_indices = [0] * (ndarray.ndim - 2) + [sli.start, slj.start] limit_indices = list(ndarray.shape[:-2]) + [sli.stop, slj.stop] strides = [1] * (ndarray.ndim - 2) + [sli.step, slj.step] out = lax.slice(ndarray, start_indices, limit_indices, strides) if row_rev or col_rev: out = lax.rev(out, *onp.where([row_rev, col_rev])) return out
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, normalize_indices): dtype = lax.dtype(x) weak_type = dtypes.is_weakly_typed(x) if dtype != dtypes.result_type(x, y): # TODO(jakevdp): change this to an error after the deprecation period. warnings.warn( "scatter inputs have incompatible types: cannot safely cast " f"value from dtype={lax.dtype(y)} to dtype={lax.dtype(x)}. " "In future JAX releases this will result in an error.", FutureWarning) idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx) indexer = jnp._index_to_gather(jnp.shape(x), idx, normalize_indices=normalize_indices) # Avoid calling scatter if the slice shape is empty, both as a fast path and # to handle cases like zeros(0)[array([], int32)]. if core.is_empty_shape(indexer.slice_shape): return x x, y = jnp._promote_dtypes(x, y) # Broadcast `y` to the slice output shape. y = jnp.broadcast_to(y, tuple(indexer.slice_shape)) # Collapse any `None`/`jnp.newaxis` dimensions. y = jnp.squeeze(y, axis=indexer.newaxis_dims) if indexer.reversed_y_dims: y = lax.rev(y, indexer.reversed_y_dims) # Transpose the gather dimensions into scatter dimensions (cf. # lax._gather_transpose_rule) dnums = lax.ScatterDimensionNumbers( update_window_dims=indexer.dnums.offset_dims, inserted_window_dims=indexer.dnums.collapsed_slice_dims, scatter_dims_to_operand_dims=indexer.dnums.start_index_map) out = scatter_op(x, indexer.gather_indices, y, dnums, indices_are_sorted=indexer.indices_are_sorted or indices_are_sorted, unique_indices=indexer.unique_indices or unique_indices, mode=mode) return lax_internal._convert_element_type(out, dtype, weak_type)
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, normalize_indices): dtype = lax.dtype(x) weak_type = dtypes.is_weakly_typed(x) idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx) indexer = jnp._index_to_gather(jnp.shape(x), idx, normalize_indices=normalize_indices) # Avoid calling scatter if the slice shape is empty, both as a fast path and # to handle cases like zeros(0)[array([], int32)]. if core.is_empty_shape(indexer.slice_shape): return x x, y = jnp._promote_dtypes(x, y) # Broadcast `y` to the slice output shape. y = jnp.broadcast_to(y, tuple(indexer.slice_shape)) # Collapse any `None`/`jnp.newaxis` dimensions. y = jnp.squeeze(y, axis=indexer.newaxis_dims) if indexer.reversed_y_dims: y = lax.rev(y, indexer.reversed_y_dims) # Transpose the gather dimensions into scatter dimensions (cf. # lax._gather_transpose_rule) dnums = lax.ScatterDimensionNumbers( update_window_dims=indexer.dnums.offset_dims, inserted_window_dims=indexer.dnums.collapsed_slice_dims, scatter_dims_to_operand_dims=indexer.dnums.start_index_map) out = scatter_op(x, indexer.gather_indices, y, dnums, indices_are_sorted=indexer.indices_are_sorted or indices_are_sorted, unique_indices=indexer.unique_indices or unique_indices, mode=mode) return lax._convert_element_type(out, dtype, weak_type)
def rev2(x): return lax.rev(x, (1,))
def rev1(x): return lax.rev(x, (0,))
def test_rev(shape, dtype, dimensions, rng_factory): rng = rng_factory(np.random) arg = rng(shape, dtype) tu.check_lazy_fun(lambda x: lax.rev(x, dimensions=dimensions), arg)
def _rewriting_take(arr, idx, axis=0): """A function like numpy.take that handles boxes and rewrites to LAX.""" # Handle special indexers: (), Ellipsis, slice(None), and None. # TODO(mattjj): don't compare empty tuple identity (though works for CPython) if idx is () or idx is Ellipsis or _is_slice_none(idx): # pylint: disable=literal-comparison return arr elif idx is None: return expand_dims(arr, 0) # Handle int index _int = lambda aval: not aval.shape and onp.issubdtype( aval.dtype, onp.integer) try: abstract_idx = core.get_aval(idx) except TypeError: abstract_idx = None if isinstance(abstract_idx, ConcreteArray) and _int(abstract_idx): return lax.index_in_dim(arr, idx, axis, False) elif isinstance(abstract_idx, ShapedArray) and _int(abstract_idx): idx = mod(idx, arr.shape[axis]) return lax.dynamic_index_in_dim(arr, idx, axis, False) # Handle slice index (only static, otherwise an error is raised) elif isinstance(idx, slice): if not _all( elt is None or isinstance(core.get_aval(elt), ConcreteArray) for elt in (idx.start, idx.stop, idx.step)): msg = ( "Array slice indices must have static start/stop/step to be used " "with Numpy indexing syntax. Try lax.dynamic_slice instead.") raise IndexError(msg) else: start, limit, stride, needs_rev = _static_idx(idx, arr.shape[axis]) result = lax.slice_in_dim(arr, start, limit, stride, axis=axis) return lax.rev(result, [axis]) if needs_rev else result # Handle non-advanced tuple indices by recursing once elif isinstance(idx, tuple) and _all(onp.ndim(elt) == 0 for elt in idx): canonical_idx = _canonicalize_tuple_index(arr, idx) result, axis = arr, 0 for elt in (elt for elt in canonical_idx if elt is not None): result = _rewriting_take(result, elt, axis=axis) axis += isinstance(elt, slice) # advance axis index if not eliminated unexpanded_shape_itr = iter(result.shape) result_shape = tuple(1 if elt is None else next(unexpanded_shape_itr) for elt in canonical_idx if not isinstance(elt, int)) return lax.reshape(result, result_shape) # Handle advanced indexing (non-tuple sequence, ndarray of dtype int or bool, # or a tuple with at least one sequence object). # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing # https://gist.github.com/seberg/976373b6a2b7c4188591 # Handle integer array indexing *without* ellipsis/slices/nones # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#integer-array-indexing if _is_advanced_int_indexer_without_slices(idx): if isinstance(idx, list): if _any(_shape(e) for e in idx): # At least one sequence element in the index list means broadcasting. idx = broadcast_arrays(*idx) else: # The index list is a flat list of integers. idx = [ lax.concatenate([lax.reshape(e, (1, )) for e in idx], 0) ] else: # The indexer is just a single integer array. idx = [idx] flat_idx = tuple( mod(ravel(x), arr.shape[i]) for i, x in enumerate(idx)) out = lax.index_take(arr, flat_idx, tuple(range(len(idx)))) return lax.reshape(out, idx[0].shape + _shape(arr)[len(idx):]) # Handle integer array indexing *with* ellipsis/slices/nones by recursing once # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#combining-advanced-and-basic-indexing elif _is_advanced_int_indexer(idx): canonical_idx = _canonicalize_tuple_index(arr, tuple(idx)) idx_noadvanced = [ slice(None) if _is_int(e) else e for e in canonical_idx ] arr_sliced = _rewriting_take(arr, tuple(idx_noadvanced)) advanced_pairs = ((e, i) for i, e in enumerate(canonical_idx) if _is_int(e)) idx_advanced, axes = zip(*advanced_pairs) idx_advanced = broadcast_arrays(*idx_advanced) flat_idx = tuple( mod(ravel(x), arr_sliced.shape[i]) for i, x in zip(axes, idx_advanced)) out = lax.index_take(arr_sliced, flat_idx, axes) shape_suffix = tuple(onp.delete(_shape(arr_sliced), axes)) out = lax.reshape(out, idx_advanced[0].shape + shape_suffix) axes_are_contiguous = onp.all(onp.diff(axes) == 1) if axes_are_contiguous: start = axes[0] naxes = idx_advanced[0].ndim out = moveaxis(out, list(range(naxes)), list(range(start, start + naxes))) return out msg = "Indexing mode not yet supported. Open a feature request!\n{}" raise IndexError(msg.format(idx))
def _dct_interleave(x, axis): v0 = lax.slice_in_dim(x, None, None, 2, axis) v1 = lax.rev(lax.slice_in_dim(x, 1, None, 2, axis), (axis, )) return lax.concatenate([v0, v1], axis)
def _rev_j(eqn: JaxprEqn, idx: int, invals: List[ShapedArray], cts_in: ShapedArray) -> np.ndarray: inval = invals[idx] j = _eye_like(cts_in, inval) j = lax.rev(j, eqn.params['dimensions']) return j