Beispiel #1
0
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
                  indices_are_sorted, unique_indices, normalize_indices):
    dtype = lax.dtype(x)
    x, y = jnp._promote_dtypes(x, y)

    idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx,
                                                dynamic_idx)
    indexer = jnp._index_to_gather(jnp.shape(x),
                                   idx,
                                   normalize_indices=normalize_indices)

    # Broadcast `y` to the slice output shape.
    y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
    # Collapse any `None`/`jnp.newaxis` dimensions.
    y = jnp.squeeze(y, axis=indexer.newaxis_dims)
    if indexer.reversed_y_dims:
        y = lax.rev(y, indexer.reversed_y_dims)

    # Transpose the gather dimensions into scatter dimensions (cf.
    # lax._gather_transpose_rule)
    dnums = lax.ScatterDimensionNumbers(
        update_window_dims=indexer.dnums.offset_dims,
        inserted_window_dims=indexer.dnums.collapsed_slice_dims,
        scatter_dims_to_operand_dims=indexer.dnums.start_index_map)
    out = scatter_op(x,
                     indexer.gather_indices,
                     y,
                     dnums,
                     indices_are_sorted=indices_are_sorted,
                     unique_indices=unique_indices)
    return lax.convert_element_type(out, dtype)
Beispiel #2
0
def _lu_jvp_rule(primals, tangents):
    a, = primals
    a_dot, = tangents
    lu, pivots, permutation = lu_p.bind(a)

    a_shape = jnp.shape(a)
    m, n = a_shape[-2:]
    dtype = lax.dtype(a)
    k = min(m, n)

    batch_dims = a_shape[:-2]
    iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1, )))
    x = a_dot[iotas[:-1] + (permutation, slice(None))]

    # Differentiation of Matrix Functionals Using Triangular Factorization
    # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas
    #
    #     LU = A
    # ==> L'U + LU' = A'
    # ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U)
    # ==> L' = L . tril(inv(L) . A' . inv(U), -1)
    #     U' = triu(inv(L) . A' . inv(U)) . U

    ndims = len(a_shape)
    l_padding = [(0, 0, 0)] * ndims
    l_padding[-1] = (0, m - k, 0)
    zero = jnp._constant_like(lu, 0)
    l = lax.pad(jnp.tril(lu[..., :, :k], -1), zero, l_padding)
    l = l + jnp.eye(m, m, dtype=dtype)

    u_eye = lax.pad(jnp.eye(n - k, n - k, dtype=dtype), zero,
                    ((k, 0, 0), (k, 0, 0)))
    u_padding = [(0, 0, 0)] * ndims
    u_padding[-2] = (0, n - k, 0)
    u = lax.pad(jnp.triu(lu[..., :k, :]), zero, u_padding) + u_eye

    la = triangular_solve(l,
                          x,
                          left_side=True,
                          transpose_a=False,
                          lower=True,
                          unit_diagonal=True)
    lau = triangular_solve(u,
                           la,
                           left_side=False,
                           transpose_a=False,
                           lower=False)

    l_dot = jnp.matmul(l, jnp.tril(lau, -1))
    u_dot = jnp.matmul(jnp.triu(lau), u)
    lu_dot = l_dot + u_dot
    return (lu, pivots, permutation), (lu_dot, ad_util.Zero.from_value(pivots),
                                       ad_util.Zero.from_value(permutation))
Beispiel #3
0
def det(a):
    a = _promote_arg_dtypes(jnp.asarray(a))
    a_shape = jnp.shape(a)
    if len(a_shape) >= 2 and a_shape[-1] == 2 and a_shape[-2] == 2:
        return _det_2x2(a)
    elif len(a_shape) >= 2 and a_shape[-1] == 3 and a_shape[-2] == 3:
        return _det_3x3(a)
    elif len(a_shape) >= 2 and a_shape[-1] == a_shape[-2]:
        sign, logdet = slogdet(a)
        return sign * jnp.exp(logdet)
    else:
        msg = "Argument to _det() must have shape [..., n, n], got {}"
        raise ValueError(msg.format(a_shape))
Beispiel #4
0
def _lu(a, permute_l):
    a = np_linalg._promote_arg_dtypes(jnp.asarray(a))
    lu, pivots, permutation = lax_linalg.lu(a)
    dtype = lax.dtype(a)
    m, n = jnp.shape(a)
    p = jnp.real(jnp.array(permutation == jnp.arange(m)[:, None], dtype=dtype))
    k = min(m, n)
    l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype)
    u = jnp.triu(lu)[:k, :]
    if permute_l:
        return jnp.matmul(p, l), u
    else:
        return p, l, u
Beispiel #5
0
def slogdet(a, *, method: Optional[str] = None):
  a, = _promote_dtypes_inexact(jnp.asarray(a))
  a_shape = jnp.shape(a)
  if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
    msg = "Argument to slogdet() must have shape [..., n, n], got {}"
    raise ValueError(msg.format(a_shape))

  if method is None or method == "lu":
    return _slogdet_lu(a)
  elif method == "qr":
    return _slogdet_qr(a)
  else:
    raise ValueError(f"Unknown slogdet method '{method}'. Supported methods "
                     "are 'lu' (`None`), and 'qr'.")
Beispiel #6
0
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
                  indices_are_sorted, unique_indices, mode, normalize_indices):
    dtype = lax.dtype(x)
    weak_type = dtypes.is_weakly_typed(x)

    if dtype != dtypes.result_type(x, y):
        # TODO(jakevdp): change this to an error after the deprecation period.
        warnings.warn(
            "scatter inputs have incompatible types: cannot safely cast "
            f"value from dtype={lax.dtype(y)} to dtype={lax.dtype(x)}. "
            "In future JAX releases this will result in an error.",
            FutureWarning)

    idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx,
                                                dynamic_idx)
    indexer = jnp._index_to_gather(jnp.shape(x),
                                   idx,
                                   normalize_indices=normalize_indices)

    # Avoid calling scatter if the slice shape is empty, both as a fast path and
    # to handle cases like zeros(0)[array([], int32)].
    if core.is_empty_shape(indexer.slice_shape):
        return x

    x, y = jnp._promote_dtypes(x, y)

    # Broadcast `y` to the slice output shape.
    y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
    # Collapse any `None`/`jnp.newaxis` dimensions.
    y = jnp.squeeze(y, axis=indexer.newaxis_dims)
    if indexer.reversed_y_dims:
        y = lax.rev(y, indexer.reversed_y_dims)

    # Transpose the gather dimensions into scatter dimensions (cf.
    # lax._gather_transpose_rule)
    dnums = lax.ScatterDimensionNumbers(
        update_window_dims=indexer.dnums.offset_dims,
        inserted_window_dims=indexer.dnums.collapsed_slice_dims,
        scatter_dims_to_operand_dims=indexer.dnums.start_index_map)
    out = scatter_op(x,
                     indexer.gather_indices,
                     y,
                     dnums,
                     indices_are_sorted=indexer.indices_are_sorted
                     or indices_are_sorted,
                     unique_indices=indexer.unique_indices or unique_indices,
                     mode=mode)
    return lax_internal._convert_element_type(out, dtype, weak_type)
Beispiel #7
0
def _lu(a, permute_l):
    a, = _promote_dtypes_inexact(jnp.asarray(a))
    lu, _, permutation = lax_linalg.lu(a)
    dtype = lax.dtype(a)
    m, n = jnp.shape(a)
    p = jnp.real(
        jnp.array(permutation[None, :] == jnp.arange(
            m, dtype=permutation.dtype)[:, None],
                  dtype=dtype))
    k = min(m, n)
    l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype)
    u = jnp.triu(lu)[:k, :]
    if permute_l:
        return jnp.matmul(p, l), u
    else:
        return p, l, u
Beispiel #8
0
def slogdet(a):
    a = _promote_arg_dtypes(jnp.asarray(a))
    dtype = lax.dtype(a)
    a_shape = jnp.shape(a)
    if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
        msg = "Argument to slogdet() must have shape [..., n, n], got {}"
        raise ValueError(msg.format(a_shape))
    lu, pivot, _ = lax_linalg.lu(a)
    diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
    is_zero = jnp.any(diag == jnp.array(0, dtype=dtype), axis=-1)
    parity = jnp.count_nonzero(pivot != jnp.arange(a_shape[-1]), axis=-1)
    if jnp.iscomplexobj(a):
        sign = jnp.prod(diag / jnp.abs(diag), axis=-1)
    else:
        sign = jnp.array(1, dtype=dtype)
        parity = parity + jnp.count_nonzero(diag < 0, axis=-1)
    sign = jnp.where(is_zero, jnp.array(0, dtype=dtype),
                     sign * jnp.array(-2 * (parity % 2) + 1, dtype=dtype))
    logdet = jnp.where(is_zero, jnp.array(-jnp.inf, dtype=dtype),
                       jnp.sum(jnp.log(jnp.abs(diag)), axis=-1))
    return sign, jnp.real(logdet)
Beispiel #9
0
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
                  indices_are_sorted, unique_indices, mode, normalize_indices):
    dtype = lax.dtype(x)
    weak_type = dtypes.is_weakly_typed(x)

    idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx,
                                                dynamic_idx)
    indexer = jnp._index_to_gather(jnp.shape(x),
                                   idx,
                                   normalize_indices=normalize_indices)

    # Avoid calling scatter if the slice shape is empty, both as a fast path and
    # to handle cases like zeros(0)[array([], int32)].
    if core.is_empty_shape(indexer.slice_shape):
        return x

    x, y = jnp._promote_dtypes(x, y)

    # Broadcast `y` to the slice output shape.
    y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
    # Collapse any `None`/`jnp.newaxis` dimensions.
    y = jnp.squeeze(y, axis=indexer.newaxis_dims)
    if indexer.reversed_y_dims:
        y = lax.rev(y, indexer.reversed_y_dims)

    # Transpose the gather dimensions into scatter dimensions (cf.
    # lax._gather_transpose_rule)
    dnums = lax.ScatterDimensionNumbers(
        update_window_dims=indexer.dnums.offset_dims,
        inserted_window_dims=indexer.dnums.collapsed_slice_dims,
        scatter_dims_to_operand_dims=indexer.dnums.start_index_map)
    out = scatter_op(x,
                     indexer.gather_indices,
                     y,
                     dnums,
                     indices_are_sorted=indexer.indices_are_sorted
                     or indices_are_sorted,
                     unique_indices=indexer.unique_indices or unique_indices,
                     mode=mode)
    return lax._convert_element_type(out, dtype, weak_type)
Beispiel #10
0
def norm(x,
         ord=None,
         axis: Union[None, Tuple[int, ...], int] = None,
         keepdims=False):
    x = _promote_arg_dtypes(jnp.asarray(x))
    x_shape = jnp.shape(x)
    ndim = len(x_shape)

    if axis is None:
        # NumPy has an undocumented behavior that admits arbitrary rank inputs if
        # `ord` is None: https://github.com/numpy/numpy/issues/14215
        if ord is None:
            return jnp.sqrt(
                jnp.sum(jnp.real(x * jnp.conj(x)), keepdims=keepdims))
        axis = tuple(range(ndim))
    elif isinstance(axis, tuple):
        axis = tuple(canonicalize_axis(x, ndim) for x in axis)
    else:
        axis = (canonicalize_axis(axis, ndim), )

    num_axes = len(axis)
    if num_axes == 1:
        if ord is None or ord == 2:
            return jnp.sqrt(
                jnp.sum(jnp.real(x * jnp.conj(x)),
                        axis=axis,
                        keepdims=keepdims))
        elif ord == jnp.inf:
            return jnp.amax(jnp.abs(x), axis=axis, keepdims=keepdims)
        elif ord == -jnp.inf:
            return jnp.amin(jnp.abs(x), axis=axis, keepdims=keepdims)
        elif ord == 0:
            return jnp.sum(x != 0,
                           dtype=jnp.finfo(lax.dtype(x)).dtype,
                           axis=axis,
                           keepdims=keepdims)
        elif ord == 1:
            # Numpy has a special case for ord == 1 as an optimization. We don't
            # really need the optimization (XLA could do it for us), but the Numpy
            # code has slightly different type promotion semantics, so we need a
            # special case too.
            return jnp.sum(jnp.abs(x), axis=axis, keepdims=keepdims)
        else:
            abs_x = jnp.abs(x)
            ord = lax._const(abs_x, ord)
            out = jnp.sum(abs_x**ord, axis=axis, keepdims=keepdims)
            return jnp.power(out, 1. / ord)

    elif num_axes == 2:
        row_axis, col_axis = cast(Tuple[int, ...], axis)
        if ord is None or ord in ('f', 'fro'):
            return jnp.sqrt(
                jnp.sum(jnp.real(x * jnp.conj(x)),
                        axis=axis,
                        keepdims=keepdims))
        elif ord == 1:
            if not keepdims and col_axis > row_axis:
                col_axis -= 1
            return jnp.amax(jnp.sum(jnp.abs(x),
                                    axis=row_axis,
                                    keepdims=keepdims),
                            axis=col_axis,
                            keepdims=keepdims)
        elif ord == -1:
            if not keepdims and col_axis > row_axis:
                col_axis -= 1
            return jnp.amin(jnp.sum(jnp.abs(x),
                                    axis=row_axis,
                                    keepdims=keepdims),
                            axis=col_axis,
                            keepdims=keepdims)
        elif ord == jnp.inf:
            if not keepdims and row_axis > col_axis:
                row_axis -= 1
            return jnp.amax(jnp.sum(jnp.abs(x),
                                    axis=col_axis,
                                    keepdims=keepdims),
                            axis=row_axis,
                            keepdims=keepdims)
        elif ord == -jnp.inf:
            if not keepdims and row_axis > col_axis:
                row_axis -= 1
            return jnp.amin(jnp.sum(jnp.abs(x),
                                    axis=col_axis,
                                    keepdims=keepdims),
                            axis=row_axis,
                            keepdims=keepdims)
        elif ord in ('nuc', 2, -2):
            x = jnp.moveaxis(x, axis, (-2, -1))
            if ord == 2:
                reducer = jnp.amax
            elif ord == -2:
                reducer = jnp.amin
            else:
                reducer = jnp.sum
            y = reducer(svd(x, compute_uv=False), axis=-1)
            if keepdims:
                result_shape = list(x_shape)
                result_shape[axis[0]] = 1
                result_shape[axis[1]] = 1
                y = jnp.reshape(y, result_shape)
            return y
        else:
            raise ValueError("Invalid order '{}' for matrix norm.".format(ord))
    else:
        raise ValueError(
            "Invalid axis values ({}) for jnp.linalg.norm.".format(axis))
Beispiel #11
0
def _cofactor_solve(a, b):
    """Equivalent to det(a)*solve(a, b) for nonsingular mat.

  Intermediate function used for jvp and vjp of det.
  This function borrows heavily from jax.numpy.linalg.solve and
  jax.numpy.linalg.slogdet to compute the gradient of the determinant
  in a way that is well defined even for low rank matrices.

  This function handles two different cases:
  * rank(a) == n or n-1
  * rank(a) < n-1

  For rank n-1 matrices, the gradient of the determinant is a rank 1 matrix.
  Rather than computing det(a)*solve(a, b), which would return NaN, we work
  directly with the LU decomposition. If a = p @ l @ u, then
  det(a)*solve(a, b) =
  prod(diag(u)) * u^-1 @ l^-1 @ p^-1 b =
  prod(diag(u)) * triangular_solve(u, solve(p @ l, b))
  If a is rank n-1, then the lower right corner of u will be zero and the
  triangular_solve will fail.
  Let x = solve(p @ l, b) and y = det(a)*solve(a, b).
  Then y_{n}
  x_{n} / u_{nn} * prod_{i=1...n}(u_{ii}) =
  x_{n} * prod_{i=1...n-1}(u_{ii})
  So by replacing the lower-right corner of u with prod_{i=1...n-1}(u_{ii})^-1
  we can avoid the triangular_solve failing.
  To correctly compute the rest of y_{i} for i != n, we simply multiply
  x_{i} by det(a) for all i != n, which will be zero if rank(a) = n-1.

  For the second case, a check is done on the matrix to see if `solve`
  returns NaN or Inf, and gives a matrix of zeros as a result, as the
  gradient of the determinant of a matrix with rank less than n-1 is 0.
  This will still return the correct value for rank n-1 matrices, as the check
  is applied *after* the lower right corner of u has been updated.

  Args:
    a: A square matrix or batch of matrices, possibly singular.
    b: A matrix, or batch of matrices of the same dimension as a.

  Returns:
    det(a) and cofactor(a)^T*b, aka adjugate(a)*b
  """
    a = _promote_arg_dtypes(jnp.asarray(a))
    b = _promote_arg_dtypes(jnp.asarray(b))
    a_shape = jnp.shape(a)
    b_shape = jnp.shape(b)
    a_ndims = len(a_shape)
    if not (a_ndims >= 2 and a_shape[-1] == a_shape[-2]
            and b_shape[-2:] == a_shape[-2:]):
        msg = ("The arguments to _cofactor_solve must have shapes "
               "a=[..., m, m] and b=[..., m, m]; got a={} and b={}")
        raise ValueError(msg.format(a_shape, b_shape))
    if a_shape[-1] == 1:
        return a[..., 0, 0], b
    # lu contains u in the upper triangular matrix and l in the strict lower
    # triangular matrix.
    # The diagonal of l is set to ones without loss of generality.
    lu, pivots, permutation = lax_linalg.lu(a)
    dtype = lax.dtype(a)
    batch_dims = lax.broadcast_shapes(lu.shape[:-2], b.shape[:-2])
    x = jnp.broadcast_to(b, batch_dims + b.shape[-2:])
    lu = jnp.broadcast_to(lu, batch_dims + lu.shape[-2:])
    # Compute (partial) determinant, ignoring last diagonal of LU
    diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
    parity = jnp.count_nonzero(pivots != jnp.arange(a_shape[-1]), axis=-1)
    sign = jnp.asarray(-2 * (parity % 2) + 1, dtype=dtype)
    # partial_det[:, -1] contains the full determinant and
    # partial_det[:, -2] contains det(u) / u_{nn}.
    partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None]
    lu = lu.at[..., -1, -1].set(1.0 / partial_det[..., -2])
    permutation = jnp.broadcast_to(permutation, batch_dims + (a_shape[-1], ))
    iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1, )))
    # filter out any matrices that are not full rank
    d = jnp.ones(x.shape[:-1], x.dtype)
    d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False)
    d = jnp.any(jnp.logical_or(jnp.isnan(d), jnp.isinf(d)), axis=-1)
    d = jnp.tile(d[..., None, None], d.ndim * (1, ) + x.shape[-2:])
    x = jnp.where(d, jnp.zeros_like(x), x)  # first filter
    x = x[iotas[:-1] + (permutation, slice(None))]
    x = lax_linalg.triangular_solve(lu,
                                    x,
                                    left_side=True,
                                    lower=True,
                                    unit_diagonal=True)
    x = jnp.concatenate(
        (x[..., :-1, :] * partial_det[..., -1, None, None], x[..., -1:, :]),
        axis=-2)
    x = lax_linalg.triangular_solve(lu, x, left_side=True, lower=False)
    x = jnp.where(d, jnp.zeros_like(x), x)  # second filter

    return partial_det[..., -1], x
Beispiel #12
0
def _ndtri(p):
    """Implements ndtri core logic."""

    # Constants used in piece-wise rational approximations. Taken from the cephes
    # library:
    # https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
    p0 = list(
        reversed([
            -5.99633501014107895267E1, 9.80010754185999661536E1,
            -5.66762857469070293439E1, 1.39312609387279679503E1,
            -1.23916583867381258016E0
        ]))
    q0 = list(
        reversed([
            1.0, 1.95448858338141759834E0, 4.67627912898881538453E0,
            8.63602421390890590575E1, -2.25462687854119370527E2,
            2.00260212380060660359E2, -8.20372256168333339912E1,
            1.59056225126211695515E1, -1.18331621121330003142E0
        ]))
    p1 = list(
        reversed([
            4.05544892305962419923E0, 3.15251094599893866154E1,
            5.71628192246421288162E1, 4.40805073893200834700E1,
            1.46849561928858024014E1, 2.18663306850790267539E0,
            -1.40256079171354495875E-1, -3.50424626827848203418E-2,
            -8.57456785154685413611E-4
        ]))
    q1 = list(
        reversed([
            1.0, 1.57799883256466749731E1, 4.53907635128879210584E1,
            4.13172038254672030440E1, 1.50425385692907503408E1,
            2.50464946208309415979E0, -1.42182922854787788574E-1,
            -3.80806407691578277194E-2, -9.33259480895457427372E-4
        ]))
    p2 = list(
        reversed([
            3.23774891776946035970E0, 6.91522889068984211695E0,
            3.93881025292474443415E0, 1.33303460815807542389E0,
            2.01485389549179081538E-1, 1.23716634817820021358E-2,
            3.01581553508235416007E-4, 2.65806974686737550832E-6,
            6.23974539184983293730E-9
        ]))
    q2 = list(
        reversed([
            1.0, 6.02427039364742014255E0, 3.67983563856160859403E0,
            1.37702099489081330271E0, 2.16236993594496635890E-1,
            1.34204006088543189037E-2, 3.28014464682127739104E-4,
            2.89247864745380683936E-6, 6.79019408009981274425E-9
        ]))

    dtype = lax.dtype(p).type
    shape = jnp.shape(p)

    def _create_polynomial(var, coeffs):
        """Compute n_th order polynomial via Horner's method."""
        coeffs = np.array(coeffs, dtype)
        if not coeffs.size:
            return jnp.zeros_like(var)
        return coeffs[0] + _create_polynomial(var, coeffs[1:]) * var

    maybe_complement_p = jnp.where(p > dtype(-np.expm1(-2.)), dtype(1.) - p, p)
    # Write in an arbitrary value in place of 0 for p since 0 will cause NaNs
    # later on. The result from the computation when p == 0 is not used so any
    # number that doesn't result in NaNs is fine.
    sanitized_mcp = jnp.where(maybe_complement_p <= dtype(0.),
                              jnp.full(shape, dtype(0.5)), maybe_complement_p)

    # Compute x for p > exp(-2): x/sqrt(2pi) = w + w**3 P0(w**2)/Q0(w**2).
    w = sanitized_mcp - dtype(0.5)
    ww = lax.square(w)
    x_for_big_p = w + w * ww * (_create_polynomial(ww, p0) /
                                _create_polynomial(ww, q0))
    x_for_big_p *= -dtype(np.sqrt(2. * np.pi))

    # Compute x for p <= exp(-2): x = z - log(z)/z - (1/z) P(1/z) / Q(1/z),
    # where z = sqrt(-2. * log(p)), and P/Q are chosen between two different
    # arrays based on whether p < exp(-32).
    z = lax.sqrt(dtype(-2.) * lax.log(sanitized_mcp))
    first_term = z - lax.log(z) / z
    second_term_small_p = (_create_polynomial(dtype(1.) / z, p2) /
                           _create_polynomial(dtype(1.) / z, q2) / z)
    second_term_otherwise = (_create_polynomial(dtype(1.) / z, p1) /
                             _create_polynomial(dtype(1.) / z, q1) / z)
    x_for_small_p = first_term - second_term_small_p
    x_otherwise = first_term - second_term_otherwise

    x = jnp.where(sanitized_mcp > dtype(np.exp(-2.)), x_for_big_p,
                  jnp.where(z >= dtype(8.0), x_for_small_p, x_otherwise))

    x = jnp.where(p > dtype(1. - np.exp(-2.)), x, -x)
    infinity = jnp.full(shape, dtype(np.inf))
    x_nan_replaced = jnp.where(p <= dtype(0.0), -infinity,
                               jnp.where(p >= dtype(1.0), infinity, x))
    return x_nan_replaced
Beispiel #13
0
def eigh_tridiagonal(d,
                     e,
                     *,
                     eigvals_only=False,
                     select='a',
                     select_range=None,
                     tol=None):
    if not eigvals_only:
        raise NotImplementedError(
            "Calculation of eigenvectors is not implemented")

    def _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, x):
        """Implements the Sturm sequence recurrence."""
        n = alpha.shape[0]
        zeros = jnp.zeros(x.shape, dtype=jnp.int32)
        ones = jnp.ones(x.shape, dtype=jnp.int32)

        # The first step in the Sturm sequence recurrence
        # requires special care if x is equal to alpha[0].
        def sturm_step0():
            q = alpha[0] - x
            count = jnp.where(q < 0, ones, zeros)
            q = jnp.where(alpha[0] == x, alpha0_perturbation, q)
            return q, count

        # Subsequent steps all take this form:
        def sturm_step(i, q, count):
            q = alpha[i] - beta_sq[i - 1] / q - x
            count = jnp.where(q <= pivmin, count + 1, count)
            q = jnp.where(q <= pivmin, jnp.minimum(q, -pivmin), q)
            return q, count

        # The first step initializes q and count.
        q, count = sturm_step0()

        # Peel off ((n-1) % blocksize) steps from the main loop, so we can run
        # the bulk of the iterations unrolled by a factor of blocksize.
        blocksize = 16
        i = 1
        peel = (n - 1) % blocksize
        unroll_cnt = peel

        def unrolled_steps(args):
            start, q, count = args
            for j in range(unroll_cnt):
                q, count = sturm_step(start + j, q, count)
            return start + unroll_cnt, q, count

        i, q, count = unrolled_steps((i, q, count))

        # Run the remaining steps of the Sturm sequence using a partially
        # unrolled while loop.
        unroll_cnt = blocksize

        def cond(iqc):
            i, q, count = iqc
            return jnp.less(i, n)

        _, _, count = lax.while_loop(cond, unrolled_steps, (i, q, count))
        return count

    alpha = jnp.asarray(d)
    beta = jnp.asarray(e)
    supported_dtypes = (jnp.float32, jnp.float64, jnp.complex64,
                        jnp.complex128)
    if alpha.dtype != beta.dtype:
        raise TypeError(
            "diagonal and off-diagonal values must have same dtype, "
            f"got {alpha.dtype} and {beta.dtype}")
    if alpha.dtype not in supported_dtypes or beta.dtype not in supported_dtypes:
        raise TypeError(
            "Only float32 and float64 inputs are supported as inputs "
            "to jax.scipy.linalg.eigh_tridiagonal, got "
            f"{alpha.dtype} and {beta.dtype}")
    n = alpha.shape[0]
    if n <= 1:
        return jnp.real(alpha)

    if jnp.issubdtype(alpha.dtype, jnp.complexfloating):
        alpha = jnp.real(alpha)
        beta_sq = jnp.real(beta * jnp.conj(beta))
        beta_abs = jnp.sqrt(beta_sq)
    else:
        beta_abs = jnp.abs(beta)
        beta_sq = jnp.square(beta)

    # Estimate the largest and smallest eigenvalues of T using the Gershgorin
    # circle theorem.
    off_diag_abs_row_sum = jnp.concatenate(
        [beta_abs[:1], beta_abs[:-1] + beta_abs[1:], beta_abs[-1:]], axis=0)
    lambda_est_max = jnp.amax(alpha + off_diag_abs_row_sum)
    lambda_est_min = jnp.amin(alpha - off_diag_abs_row_sum)
    # Upper bound on 2-norm of T.
    t_norm = jnp.maximum(jnp.abs(lambda_est_min), jnp.abs(lambda_est_max))

    # Compute the smallest allowed pivot in the Sturm sequence to avoid
    # overflow.
    finfo = np.finfo(alpha.dtype)
    one = np.ones([], dtype=alpha.dtype)
    safemin = np.maximum(one / finfo.max, (one + finfo.eps) * finfo.tiny)
    pivmin = safemin * jnp.maximum(1, jnp.amax(beta_sq))
    alpha0_perturbation = jnp.square(finfo.eps * beta_abs[0])
    abs_tol = finfo.eps * t_norm
    if tol is not None:
        abs_tol = jnp.maximum(tol, abs_tol)

    # In the worst case, when the absolute tolerance is eps*lambda_est_max and
    # lambda_est_max = -lambda_est_min, we have to take as many bisection steps
    # as there are bits in the mantissa plus 1.
    # The proof is left as an exercise to the reader.
    max_it = finfo.nmant + 1

    # Determine the indices of the desired eigenvalues, based on select and
    # select_range.
    if select == 'a':
        target_counts = jnp.arange(n, dtype=jnp.int32)
    elif select == 'i':
        if select_range[0] > select_range[1]:
            raise ValueError('Got empty index range in select_range.')
        target_counts = jnp.arange(select_range[0],
                                   select_range[1] + 1,
                                   dtype=jnp.int32)
    elif select == 'v':
        # TODO(phawkins): requires dynamic shape support.
        raise NotImplementedError("eigh_tridiagonal(..., select='v') is not "
                                  "implemented")
    else:
        raise ValueError("'select must have a value in {'a', 'i', 'v'}.")

    # Run binary search for all desired eigenvalues in parallel, starting from
    # the interval lightly wider than the estimated
    # [lambda_est_min, lambda_est_max].
    fudge = 2.1  # We widen starting interval the Gershgorin interval a bit.
    norm_slack = jnp.array(n, alpha.dtype) * fudge * finfo.eps * t_norm
    lower = lambda_est_min - norm_slack - 2 * fudge * pivmin
    upper = lambda_est_max + norm_slack + fudge * pivmin

    # Pre-broadcast the scalars used in the Sturm sequence for improved
    # performance.
    target_shape = jnp.shape(target_counts)
    lower = jnp.broadcast_to(lower, shape=target_shape)
    upper = jnp.broadcast_to(upper, shape=target_shape)
    mid = 0.5 * (upper + lower)
    pivmin = jnp.broadcast_to(pivmin, target_shape)
    alpha0_perturbation = jnp.broadcast_to(alpha0_perturbation, target_shape)

    # Start parallel binary searches.
    def cond(args):
        i, lower, _, upper = args
        return jnp.logical_and(jnp.less(i, max_it),
                               jnp.less(abs_tol, jnp.amax(upper - lower)))

    def body(args):
        i, lower, mid, upper = args
        counts = _sturm(alpha, beta_sq, pivmin, alpha0_perturbation, mid)
        lower = jnp.where(counts <= target_counts, mid, lower)
        upper = jnp.where(counts > target_counts, mid, upper)
        mid = 0.5 * (lower + upper)
        return i + 1, lower, mid, upper

    _, _, mid, _ = lax.while_loop(cond, body, (0, lower, mid, upper))
    return mid