Esempio n. 1
0
def backwards_recursion(C, c, F, f):
    def backwards_step(carry, params):
        V_tp1, v_tp1 = carry
        (C_t, c_t, F_t, f_t) = params
        n = f_t.shape[0]

        FTV_tp1 = F_t.T @ V_tp1
        Q_t = C_t + FTV_tp1 @ F_t
        q_t = c_t + FTV_tp1 @ f_t + F_t.T @ v_tp1

        Q_x, Q_u = np.split(Q_t, (n, ))
        Q_xx, Q_xu = np.split(Q_x, (n, ), axis=1)
        Q_ux, Q_uu = np.split(Q_u, (n, ), axis=1)
        q_x, q_u = np.split(q_t, (n, ))

        Quxqu = np.concatenate((Q_ux, q_u[:, None]), axis=1)
        Ktkt = -jax.scipy.linalg.solve(Q_uu, Quxqu, sym_pos=True)
        K_t, k_t = np.split(Ktkt, (n, ), axis=1)
        k_t = k_t[:, 0]

        KTQuu = K_t.T @ Q_uu
        V_t = Q_xx + Q_xu @ K_t + K_t.T @ Q_ux + KTQuu @ K_t
        v_t = q_x + Q_xu @ k_t + K_t.T @ q_u + KTQuu @ k_t
        return (V_t, v_t), (K_t, k_t)

    V_T = np.zeros((f.shape[1], f.shape[1]))
    v_T = np.zeros((f.shape[1], ))
    K, kvec = lax.scan(backwards_step, (V_T, v_T), (C, c, F, f))[1]
    return lax.rev(K, (0, )), lax.rev(kvec, (0, ))
Esempio n. 2
0
def svd(a,
        full_matrices: bool = True,
        compute_uv: bool = True,
        hermitian: bool = False):
    a, = _promote_dtypes_inexact(jnp.asarray(a))
    if hermitian:
        w, v = lax_linalg.eigh(a)
        s = lax.abs(v)
        if compute_uv:
            sign = lax.sign(v)
            idxs = lax.broadcasted_iota(np.int64,
                                        s.shape,
                                        dimension=s.ndim - 1)
            s, idxs, sign = lax.sort((s, idxs, sign), dimension=-1, num_keys=1)
            s = lax.rev(s, dimensions=[s.ndim - 1])
            idxs = lax.rev(idxs, dimensions=[s.ndim - 1])
            sign = lax.rev(sign, dimensions=[s.ndim - 1])
            u = jnp.take_along_axis(w, idxs[..., None, :], axis=-1)
            vh = _H(u * sign[..., None, :])
            return u, s, vh
        else:
            return lax.rev(lax.sort(s, dimension=-1), dimensions=[s.ndim - 1])

    return lax_linalg.svd(a,
                          full_matrices=full_matrices,
                          compute_uv=compute_uv)
Esempio n. 3
0
def rev_dependency_rule(outstart, outcount, operand, dimensions):
    instart = [
        size - (start + outsize) if d in dimensions else start
        for d, (
            size, outsize,
            start) in enumerate(zip(operand.shape, outcount.shape, outstart))
    ]
    return ([(instart, outcount.shape)], [
        Ones(outcount.shape) if is_ones(outcount) else lax.rev(
            outcount, dimensions)
    ], lambda inslice: lax.rev(inslice, dimensions))
Esempio n. 4
0
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
                  indices_are_sorted, unique_indices, normalize_indices):
    dtype = lax.dtype(x)
    x, y = jnp._promote_dtypes(x, y)

    idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx,
                                                dynamic_idx)
    indexer = jnp._index_to_gather(jnp.shape(x),
                                   idx,
                                   normalize_indices=normalize_indices)

    # Broadcast `y` to the slice output shape.
    y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
    # Collapse any `None`/`jnp.newaxis` dimensions.
    y = jnp.squeeze(y, axis=indexer.newaxis_dims)
    if indexer.reversed_y_dims:
        y = lax.rev(y, indexer.reversed_y_dims)

    # Transpose the gather dimensions into scatter dimensions (cf.
    # lax._gather_transpose_rule)
    dnums = lax.ScatterDimensionNumbers(
        update_window_dims=indexer.dnums.offset_dims,
        inserted_window_dims=indexer.dnums.collapsed_slice_dims,
        scatter_dims_to_operand_dims=indexer.dnums.start_index_map)
    out = scatter_op(x,
                     indexer.gather_indices,
                     y,
                     dnums,
                     indices_are_sorted=indices_are_sorted,
                     unique_indices=unique_indices)
    return lax.convert_element_type(out, dtype)
Esempio n. 5
0
  def testReverseGrad(self):
    rev = lambda operand: lax.rev(operand, dimensions)

    dimensions = [0]
    check_grads(rev, (onp.array([3., 2., 1.]),), 2)

    dimensions = [0, 1]
    check_grads(rev, (onp.array([[6., 5., 4.], [3., 2., 1.]]),), 2,
                rtol={onp.float32: 3e-3})
Esempio n. 6
0
  def testRevLax(self):
    fun = lambda x: lax.rev(x, [0])
    R = np.random.RandomState(0).randn
    x = R(2, 3)

    ans = vmap(fun)(x)
    expected_ans = x[:, ::-1]
    self.assertAllClose(ans, expected_ans, check_dtypes=False)

    ans = vmap(fun, (1,), 1)(x)
    expected_ans = x[::-1, :]
    self.assertAllClose(ans, expected_ans, check_dtypes=False)
Esempio n. 7
0
def _matrix_put(ndarray, idx, val, block_size=1):
    """Similar to numpy.put using LAX operations."""
    idx_i, idx_j = idx
    sli, row_rev = _canonical_idx(ndarray.shape, idx_i, -2, block_size)
    slj, col_rev = _canonical_idx(ndarray.shape, idx_j, -1, block_size)
    if not sli.step == slj.step == 1:
        raise TypeError("Non-unit step not supported in assigment.")

    if row_rev or col_rev:
        val = lax.rev(val, *onp.where([row_rev, col_rev]))

    start_indices = [0] * (ndarray.ndim - 2) + [sli.start, slj.start]
    return lax.dynamic_update_slice(ndarray, val, start_indices)
Esempio n. 8
0
def _matrix_take(ndarray, idx, block_size=1):
    """Similar to numpy.take using LAX operations."""
    idx_i, idx_j = idx
    sli, row_rev = _canonical_idx(ndarray.shape, idx_i, -2, block_size)
    slj, col_rev = _canonical_idx(ndarray.shape, idx_j, -1, block_size)

    start_indices = [0] * (ndarray.ndim - 2) + [sli.start, slj.start]
    limit_indices = list(ndarray.shape[:-2]) + [sli.stop, slj.stop]
    strides = [1] * (ndarray.ndim - 2) + [sli.step, slj.step]
    out = lax.slice(ndarray, start_indices, limit_indices, strides)

    if row_rev or col_rev:
        out = lax.rev(out, *onp.where([row_rev, col_rev]))
    return out
Esempio n. 9
0
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
                  indices_are_sorted, unique_indices, mode, normalize_indices):
    dtype = lax.dtype(x)
    weak_type = dtypes.is_weakly_typed(x)

    if dtype != dtypes.result_type(x, y):
        # TODO(jakevdp): change this to an error after the deprecation period.
        warnings.warn(
            "scatter inputs have incompatible types: cannot safely cast "
            f"value from dtype={lax.dtype(y)} to dtype={lax.dtype(x)}. "
            "In future JAX releases this will result in an error.",
            FutureWarning)

    idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx,
                                                dynamic_idx)
    indexer = jnp._index_to_gather(jnp.shape(x),
                                   idx,
                                   normalize_indices=normalize_indices)

    # Avoid calling scatter if the slice shape is empty, both as a fast path and
    # to handle cases like zeros(0)[array([], int32)].
    if core.is_empty_shape(indexer.slice_shape):
        return x

    x, y = jnp._promote_dtypes(x, y)

    # Broadcast `y` to the slice output shape.
    y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
    # Collapse any `None`/`jnp.newaxis` dimensions.
    y = jnp.squeeze(y, axis=indexer.newaxis_dims)
    if indexer.reversed_y_dims:
        y = lax.rev(y, indexer.reversed_y_dims)

    # Transpose the gather dimensions into scatter dimensions (cf.
    # lax._gather_transpose_rule)
    dnums = lax.ScatterDimensionNumbers(
        update_window_dims=indexer.dnums.offset_dims,
        inserted_window_dims=indexer.dnums.collapsed_slice_dims,
        scatter_dims_to_operand_dims=indexer.dnums.start_index_map)
    out = scatter_op(x,
                     indexer.gather_indices,
                     y,
                     dnums,
                     indices_are_sorted=indexer.indices_are_sorted
                     or indices_are_sorted,
                     unique_indices=indexer.unique_indices or unique_indices,
                     mode=mode)
    return lax_internal._convert_element_type(out, dtype, weak_type)
Esempio n. 10
0
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
                  indices_are_sorted, unique_indices, mode, normalize_indices):
    dtype = lax.dtype(x)
    weak_type = dtypes.is_weakly_typed(x)

    idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx,
                                                dynamic_idx)
    indexer = jnp._index_to_gather(jnp.shape(x),
                                   idx,
                                   normalize_indices=normalize_indices)

    # Avoid calling scatter if the slice shape is empty, both as a fast path and
    # to handle cases like zeros(0)[array([], int32)].
    if core.is_empty_shape(indexer.slice_shape):
        return x

    x, y = jnp._promote_dtypes(x, y)

    # Broadcast `y` to the slice output shape.
    y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
    # Collapse any `None`/`jnp.newaxis` dimensions.
    y = jnp.squeeze(y, axis=indexer.newaxis_dims)
    if indexer.reversed_y_dims:
        y = lax.rev(y, indexer.reversed_y_dims)

    # Transpose the gather dimensions into scatter dimensions (cf.
    # lax._gather_transpose_rule)
    dnums = lax.ScatterDimensionNumbers(
        update_window_dims=indexer.dnums.offset_dims,
        inserted_window_dims=indexer.dnums.collapsed_slice_dims,
        scatter_dims_to_operand_dims=indexer.dnums.start_index_map)
    out = scatter_op(x,
                     indexer.gather_indices,
                     y,
                     dnums,
                     indices_are_sorted=indexer.indices_are_sorted
                     or indices_are_sorted,
                     unique_indices=indexer.unique_indices or unique_indices,
                     mode=mode)
    return lax._convert_element_type(out, dtype, weak_type)
Esempio n. 11
0
 def rev2(x):
   return lax.rev(x, (1,))
Esempio n. 12
0
 def rev1(x):
   return lax.rev(x, (0,))
Esempio n. 13
0
def test_rev(shape, dtype, dimensions, rng_factory):
    rng = rng_factory(np.random)
    arg = rng(shape, dtype)
    tu.check_lazy_fun(lambda x: lax.rev(x, dimensions=dimensions), arg)
Esempio n. 14
0
def _rewriting_take(arr, idx, axis=0):
    """A function like numpy.take that handles boxes and rewrites to LAX."""

    # Handle special indexers: (), Ellipsis, slice(None), and None.
    # TODO(mattjj): don't compare empty tuple identity (though works for CPython)
    if idx is () or idx is Ellipsis or _is_slice_none(idx):  # pylint: disable=literal-comparison
        return arr
    elif idx is None:
        return expand_dims(arr, 0)

    # Handle int index
    _int = lambda aval: not aval.shape and onp.issubdtype(
        aval.dtype, onp.integer)
    try:
        abstract_idx = core.get_aval(idx)
    except TypeError:
        abstract_idx = None

    if isinstance(abstract_idx, ConcreteArray) and _int(abstract_idx):
        return lax.index_in_dim(arr, idx, axis, False)
    elif isinstance(abstract_idx, ShapedArray) and _int(abstract_idx):
        idx = mod(idx, arr.shape[axis])
        return lax.dynamic_index_in_dim(arr, idx, axis, False)

    # Handle slice index (only static, otherwise an error is raised)
    elif isinstance(idx, slice):
        if not _all(
                elt is None or isinstance(core.get_aval(elt), ConcreteArray)
                for elt in (idx.start, idx.stop, idx.step)):
            msg = (
                "Array slice indices must have static start/stop/step to be used "
                "with Numpy indexing syntax. Try lax.dynamic_slice instead.")
            raise IndexError(msg)
        else:
            start, limit, stride, needs_rev = _static_idx(idx, arr.shape[axis])
            result = lax.slice_in_dim(arr, start, limit, stride, axis=axis)
            return lax.rev(result, [axis]) if needs_rev else result

    # Handle non-advanced tuple indices by recursing once
    elif isinstance(idx, tuple) and _all(onp.ndim(elt) == 0 for elt in idx):
        canonical_idx = _canonicalize_tuple_index(arr, idx)
        result, axis = arr, 0
        for elt in (elt for elt in canonical_idx if elt is not None):
            result = _rewriting_take(result, elt, axis=axis)
            axis += isinstance(elt,
                               slice)  # advance axis index if not eliminated
        unexpanded_shape_itr = iter(result.shape)
        result_shape = tuple(1 if elt is None else next(unexpanded_shape_itr)
                             for elt in canonical_idx
                             if not isinstance(elt, int))
        return lax.reshape(result, result_shape)

    # Handle advanced indexing (non-tuple sequence, ndarray of dtype int or bool,
    # or a tuple with at least one sequence object).
    # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
    # https://gist.github.com/seberg/976373b6a2b7c4188591

    # Handle integer array indexing *without* ellipsis/slices/nones
    # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#integer-array-indexing
    if _is_advanced_int_indexer_without_slices(idx):
        if isinstance(idx, list):
            if _any(_shape(e) for e in idx):
                # At least one sequence element in the index list means broadcasting.
                idx = broadcast_arrays(*idx)
            else:
                # The index list is a flat list of integers.
                idx = [
                    lax.concatenate([lax.reshape(e, (1, )) for e in idx], 0)
                ]
        else:
            # The indexer is just a single integer array.
            idx = [idx]

        flat_idx = tuple(
            mod(ravel(x), arr.shape[i]) for i, x in enumerate(idx))
        out = lax.index_take(arr, flat_idx, tuple(range(len(idx))))
        return lax.reshape(out, idx[0].shape + _shape(arr)[len(idx):])

    # Handle integer array indexing *with* ellipsis/slices/nones by recursing once
    # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#combining-advanced-and-basic-indexing
    elif _is_advanced_int_indexer(idx):
        canonical_idx = _canonicalize_tuple_index(arr, tuple(idx))
        idx_noadvanced = [
            slice(None) if _is_int(e) else e for e in canonical_idx
        ]
        arr_sliced = _rewriting_take(arr, tuple(idx_noadvanced))

        advanced_pairs = ((e, i) for i, e in enumerate(canonical_idx)
                          if _is_int(e))
        idx_advanced, axes = zip(*advanced_pairs)
        idx_advanced = broadcast_arrays(*idx_advanced)

        flat_idx = tuple(
            mod(ravel(x), arr_sliced.shape[i])
            for i, x in zip(axes, idx_advanced))
        out = lax.index_take(arr_sliced, flat_idx, axes)
        shape_suffix = tuple(onp.delete(_shape(arr_sliced), axes))
        out = lax.reshape(out, idx_advanced[0].shape + shape_suffix)

        axes_are_contiguous = onp.all(onp.diff(axes) == 1)
        if axes_are_contiguous:
            start = axes[0]
            naxes = idx_advanced[0].ndim
            out = moveaxis(out, list(range(naxes)),
                           list(range(start, start + naxes)))
        return out

    msg = "Indexing mode not yet supported. Open a feature request!\n{}"
    raise IndexError(msg.format(idx))
Esempio n. 15
0
def _dct_interleave(x, axis):
    v0 = lax.slice_in_dim(x, None, None, 2, axis)
    v1 = lax.rev(lax.slice_in_dim(x, 1, None, 2, axis), (axis, ))
    return lax.concatenate([v0, v1], axis)
Esempio n. 16
0
def _rev_j(eqn: JaxprEqn, idx: int, invals: List[ShapedArray],
           cts_in: ShapedArray) -> np.ndarray:
    inval = invals[idx]
    j = _eye_like(cts_in, inval)
    j = lax.rev(j, eqn.params['dimensions'])
    return j