Example #1
0
def _asarray(arr):
  """
  Pared-down utility to convert object to a DeviceArray.
  Note this will not correctly handle lists or tuples.
  """
  _check_arraylike("_asarray", arr)
  dtype, weak_type = dtypes._lattice_result_type(arr)
  return lax_internal._convert_element_type(arr, dtype, weak_type)
Example #2
0
 def testDeviceArrayRepr(self, dtype, weak_type):
     val = lax_internal._convert_element_type(0, dtype, weak_type=weak_type)
     rep = repr(val)
     self.assertStartsWith(rep, 'DeviceArray(')
     if weak_type:
         self.assertEndsWith(rep,
                             f"dtype={val.dtype.name}, weak_type=True)")
     else:
         self.assertEndsWith(rep, f"dtype={val.dtype.name})")
Example #3
0
def _promote_dtypes_complex(*args):
  """Convenience function to apply Numpy argument dtype promotion.

  Promotes arguments to a complex type."""
  to_dtype, weak_type = dtypes._lattice_result_type(*args)
  to_dtype = dtypes.canonicalize_dtype(to_dtype)
  to_dtype_complex = dtypes._to_complex_dtype(to_dtype)
  return [lax_internal._convert_element_type(x, to_dtype_complex, weak_type)
          for x in args]
Example #4
0
def _promote_dtypes(*args):
  """Convenience function to apply Numpy argument dtype promotion."""
  # TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing.
  if len(args) < 2:
    return args
  else:
    to_dtype, weak_type = dtypes._lattice_result_type(*args)
    to_dtype = dtypes.canonicalize_dtype(to_dtype)
    return [lax_internal._convert_element_type(x, to_dtype, weak_type) for x in args]
Example #5
0
 def testUnaryPromotion(self, dtype, weak_type):
     # Regression test for https://github.com/google/jax/issues/6051
     x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type)
     if weak_type:
         expected = dtypes.canonicalize_dtype(
             dtypes._default_types['f' if x.dtype ==
                                   'bfloat16' else x.dtype.kind])
     else:
         expected = x.dtype
     self.assertEqual(dtypes.result_type(x), expected)
Example #6
0
def _promote_dtypes_inexact(*args):
    """Convenience function to apply Numpy argument dtype promotion.

  Promotes arguments to an inexact type."""
    to_dtype, weak_type = dtypes._lattice_result_type(*args)
    to_dtype = dtypes.canonicalize_dtype(to_dtype)
    to_dtype_inexact = _to_inexact_dtype(to_dtype)
    weak_type = (weak_type and to_dtype == to_dtype_inexact)
    return [
        lax_internal._convert_element_type(x, to_dtype_inexact, weak_type)
        for x in args
    ]
Example #7
0
def _bicgstab_solve(A,
                    b,
                    x0=None,
                    *,
                    maxiter,
                    tol=1e-5,
                    atol=0.0,
                    M=_identity):

    # tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.bicgstab
    bs = _vdot_real_tree(b, b)
    atol2 = jnp.maximum(jnp.square(tol) * bs, jnp.square(atol))

    # https://en.wikipedia.org/wiki/Biconjugate_gradient_stabilized_method#Preconditioned_BiCGSTAB

    def cond_fun(value):
        x, r, *_, k = value
        rs = _vdot_real_tree(r, r)
        # the last condition checks breakdown
        return (rs > atol2) & (k < maxiter) & (k >= 0)

    def body_fun(value):
        x, r, rhat, alpha, omega, rho, p, q, k = value
        rho_ = _vdot_tree(rhat, r)
        beta = rho_ / rho * alpha / omega
        p_ = _add(r, _mul(beta, _sub(p, _mul(omega, q))))
        phat = M(p_)
        q_ = A(phat)
        alpha_ = rho_ / _vdot_tree(rhat, q_)
        s = _sub(r, _mul(alpha_, q_))
        exit_early = _vdot_real_tree(s, s) < atol2
        shat = M(s)
        t = A(shat)
        omega_ = _vdot_tree(t, s) / _vdot_tree(t, t)  # make cases?
        x_ = tree_map(partial(jnp.where, exit_early),
                      _add(x, _mul(alpha_, phat)),
                      _add(x, _add(_mul(alpha_, phat), _mul(omega_, shat))))
        r_ = tree_map(partial(jnp.where, exit_early), s,
                      _sub(s, _mul(omega_, t)))
        k_ = jnp.where((omega_ == 0) | (alpha_ == 0), -11, k + 1)
        k_ = jnp.where((rho_ == 0), -10, k_)
        return x_, r_, rhat, alpha_, omega_, rho_, p_, q_, k_

    r0 = _sub(b, A(x0))
    rho0 = alpha0 = omega0 = lax_internal._convert_element_type(
        1, *dtypes._lattice_result_type(*tree_leaves(b)))
    initial_value = (x0, r0, r0, alpha0, omega0, rho0, r0, r0, 0)

    x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)

    return x_final
Example #8
0
def _promote_arg_dtypes(*args):
    """Promotes `args` to a common inexact type."""
    dtype, weak_type = dtypes._lattice_result_type(*args)
    if not jnp.issubdtype(dtype, jnp.inexact):
        dtype, weak_type = jnp.float_, False
    dtype = dtypes.canonicalize_dtype(dtype)
    args = [
        lax_internal._convert_element_type(arg, dtype, weak_type)
        for arg in args
    ]
    if len(args) == 1:
        return args[0]
    else:
        return args
Example #9
0
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
                  indices_are_sorted, unique_indices, mode, normalize_indices):
    dtype = lax.dtype(x)
    weak_type = dtypes.is_weakly_typed(x)

    if dtype != dtypes.result_type(x, y):
        # TODO(jakevdp): change this to an error after the deprecation period.
        warnings.warn(
            "scatter inputs have incompatible types: cannot safely cast "
            f"value from dtype={lax.dtype(y)} to dtype={lax.dtype(x)}. "
            "In future JAX releases this will result in an error.",
            FutureWarning)

    idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx,
                                                dynamic_idx)
    indexer = jnp._index_to_gather(jnp.shape(x),
                                   idx,
                                   normalize_indices=normalize_indices)

    # Avoid calling scatter if the slice shape is empty, both as a fast path and
    # to handle cases like zeros(0)[array([], int32)].
    if core.is_empty_shape(indexer.slice_shape):
        return x

    x, y = jnp._promote_dtypes(x, y)

    # Broadcast `y` to the slice output shape.
    y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
    # Collapse any `None`/`jnp.newaxis` dimensions.
    y = jnp.squeeze(y, axis=indexer.newaxis_dims)
    if indexer.reversed_y_dims:
        y = lax.rev(y, indexer.reversed_y_dims)

    # Transpose the gather dimensions into scatter dimensions (cf.
    # lax._gather_transpose_rule)
    dnums = lax.ScatterDimensionNumbers(
        update_window_dims=indexer.dnums.offset_dims,
        inserted_window_dims=indexer.dnums.collapsed_slice_dims,
        scatter_dims_to_operand_dims=indexer.dnums.start_index_map)
    out = scatter_op(x,
                     indexer.gather_indices,
                     y,
                     dnums,
                     indices_are_sorted=indexer.indices_are_sorted
                     or indices_are_sorted,
                     unique_indices=indexer.unique_indices or unique_indices,
                     mode=mode)
    return lax_internal._convert_element_type(out, dtype, weak_type)
Example #10
0
def _gmres_batched(A, b, x0, unit_residual, residual_norm, ptol, restart, M):
    """
  Implements a single restart of GMRES. The ``restart``-dimensional Krylov
  subspace
  K(A, x0) = span(A(x0), A@x0, A@A@x0, ..., A^restart @ x0) is built, and the
  projection of the true solution into this subspace is returned.

  This implementation solves a dense linear problem instead of building
  a QR factorization during the Arnoldi process.
  """
    del ptol  # unused
    # https://www-users.cs.umn.edu/~saad/Calais/PREC.pdf
    V = tree_map(
        lambda x: jnp.pad(x[..., None], ((0, 0), ) * x.ndim +
                          ((0, restart), )),
        unit_residual,
    )
    dtype, weak_type = dtypes._lattice_result_type(*tree_leaves(b))
    H = lax_internal._convert_element_type(jnp.eye(restart,
                                                   restart + 1,
                                                   dtype=dtype),
                                           weak_type=weak_type)

    def loop_cond(carry):
        _, _, breakdown, k = carry
        return jnp.logical_and(k < restart, jnp.logical_not(breakdown))

    def arnoldi_process(carry):
        V, H, _, k = carry
        V, H, breakdown = _kth_arnoldi_iteration(k, A, M, V, H)
        return V, H, breakdown, k + 1

    carry = (V, H, False, 0)
    V, H, _, _ = lax.while_loop(loop_cond, arnoldi_process, carry)

    beta_vec = jnp.zeros_like(H,
                              shape=(restart + 1, )).at[0].set(residual_norm)
    y = _lstsq(H.T, beta_vec)
    dx = tree_map(lambda X: _dot(X[..., :-1], y), V)

    x = _add(x0, dx)

    residual = M(_sub(b, A(x)))
    unit_residual, residual_norm = _safe_normalize(residual)
    return x, unit_residual, residual_norm
Example #11
0
def _safe_normalize(x, thresh=None):
  """
  Returns the L2-normalized vector (which can be a pytree) x, and optionally
  the computed norm. If the computed norm is less than the threshold `thresh`,
  which by default is the machine precision of x's dtype, it will be
  taken to be 0, and the normalized x to be the zero vector.
  """
  norm = _norm(x)
  dtype, weak_type = dtypes._lattice_result_type(*tree_leaves(x))
  dtype = dtypes.canonicalize_dtype(dtype)
  if thresh is None:
    thresh = jnp.finfo(norm.dtype).eps
  thresh = thresh.astype(dtype).real

  use_norm = norm > thresh

  norm_cast = lax_internal._convert_element_type(norm, dtype, weak_type)
  normalized_x = tree_map(lambda y: jnp.where(use_norm, y / norm_cast, 0.0), x)
  norm = jnp.where(use_norm, norm, 0.0)
  return normalized_x, norm
Example #12
0
    def testBinaryNonPromotion(self, dtype, weak_type, promotion):
        # Regression test for https://github.com/google/jax/issues/6051
        x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type)
        with jax.numpy_dtype_promotion(promotion):
            y = (x + x)

        if promotion == 'standard' or not weak_type or dtype == dtypes.bool_:
            expected_dtype = dtype
        elif dtypes.issubdtype(dtype, np.complexfloating):
            expected_dtype = dtypes.complex_
        elif dtypes.issubdtype(dtype, np.floating):
            expected_dtype = dtypes.float_
        else:
            expected_dtype = dtypes.int_

        # No boolean weak types.
        expected_weak_type = weak_type and dtype != bool
        expected_dtype = dtypes.canonicalize_dtype(expected_dtype)

        self.assertEqual(y.dtype, expected_dtype)
        self.assertEqual(dtypes.is_weakly_typed(y), expected_weak_type)
Example #13
0
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
                  indices_are_sorted, unique_indices, mode, normalize_indices):
    dtype = lax.dtype(x)
    weak_type = dtypes.is_weakly_typed(x)

    idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx,
                                                dynamic_idx)
    indexer = jnp._index_to_gather(jnp.shape(x),
                                   idx,
                                   normalize_indices=normalize_indices)

    # Avoid calling scatter if the slice shape is empty, both as a fast path and
    # to handle cases like zeros(0)[array([], int32)].
    if core.is_empty_shape(indexer.slice_shape):
        return x

    x, y = jnp._promote_dtypes(x, y)

    # Broadcast `y` to the slice output shape.
    y = jnp.broadcast_to(y, tuple(indexer.slice_shape))
    # Collapse any `None`/`jnp.newaxis` dimensions.
    y = jnp.squeeze(y, axis=indexer.newaxis_dims)
    if indexer.reversed_y_dims:
        y = lax.rev(y, indexer.reversed_y_dims)

    # Transpose the gather dimensions into scatter dimensions (cf.
    # lax._gather_transpose_rule)
    dnums = lax.ScatterDimensionNumbers(
        update_window_dims=indexer.dnums.offset_dims,
        inserted_window_dims=indexer.dnums.collapsed_slice_dims,
        scatter_dims_to_operand_dims=indexer.dnums.start_index_map)
    out = scatter_op(x,
                     indexer.gather_indices,
                     y,
                     dnums,
                     indices_are_sorted=indexer.indices_are_sorted
                     or indices_are_sorted,
                     unique_indices=indexer.unique_indices or unique_indices,
                     mode=mode)
    return lax_internal._convert_element_type(out, dtype, weak_type)
Example #14
0
 def testBinaryNonPromotion(self, dtype, weak_type):
     # Regression test for https://github.com/google/jax/issues/6051
     x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type)
     y = (x + x)
     assert x.dtype == y.dtype
     assert dtypes.is_weakly_typed(y) == dtypes.is_weakly_typed(x)