Exemple #1
0
def _asarray(arr):
  """
  Pared-down utility to convert object to a DeviceArray.
  Note this will not correctly handle lists or tuples.
  """
  _check_arraylike("_asarray", arr)
  dtype, weak_type = dtypes._lattice_result_type(arr)
  return lax_internal._convert_element_type(arr, dtype, weak_type)
Exemple #2
0
def _promote_dtypes_complex(*args):
  """Convenience function to apply Numpy argument dtype promotion.

  Promotes arguments to a complex type."""
  to_dtype, weak_type = dtypes._lattice_result_type(*args)
  to_dtype = dtypes.canonicalize_dtype(to_dtype)
  to_dtype_complex = dtypes._to_complex_dtype(to_dtype)
  return [lax_internal._convert_element_type(x, to_dtype_complex, weak_type)
          for x in args]
Exemple #3
0
def _promote_dtypes(*args):
  """Convenience function to apply Numpy argument dtype promotion."""
  # TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing.
  if len(args) < 2:
    return args
  else:
    to_dtype, weak_type = dtypes._lattice_result_type(*args)
    to_dtype = dtypes.canonicalize_dtype(to_dtype)
    return [lax_internal._convert_element_type(x, to_dtype, weak_type) for x in args]
Exemple #4
0
def _promote_arg_dtypes(*args):
    """Promotes `args` to a common inexact type."""
    dtype, weak_type = dtypes._lattice_result_type(*args)
    if not jnp.issubdtype(dtype, jnp.inexact):
        dtype, weak_type = jnp.float_, False
    dtype = dtypes.canonicalize_dtype(dtype)
    args = [lax._convert_element_type(arg, dtype, weak_type) for arg in args]
    if len(args) == 1:
        return args[0]
    else:
        return args
Exemple #5
0
def _promote_dtypes_inexact(*args):
    """Convenience function to apply Numpy argument dtype promotion.

  Promotes arguments to an inexact type."""
    to_dtype, weak_type = dtypes._lattice_result_type(*args)
    to_dtype = dtypes.canonicalize_dtype(to_dtype)
    to_dtype_inexact = _to_inexact_dtype(to_dtype)
    weak_type = (weak_type and to_dtype == to_dtype_inexact)
    return [
        lax_internal._convert_element_type(x, to_dtype_inexact, weak_type)
        for x in args
    ]
Exemple #6
0
def _bicgstab_solve(A,
                    b,
                    x0=None,
                    *,
                    maxiter,
                    tol=1e-5,
                    atol=0.0,
                    M=_identity):

    # tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.bicgstab
    bs = _vdot_real_tree(b, b)
    atol2 = jnp.maximum(jnp.square(tol) * bs, jnp.square(atol))

    # https://en.wikipedia.org/wiki/Biconjugate_gradient_stabilized_method#Preconditioned_BiCGSTAB

    def cond_fun(value):
        x, r, *_, k = value
        rs = _vdot_real_tree(r, r)
        # the last condition checks breakdown
        return (rs > atol2) & (k < maxiter) & (k >= 0)

    def body_fun(value):
        x, r, rhat, alpha, omega, rho, p, q, k = value
        rho_ = _vdot_tree(rhat, r)
        beta = rho_ / rho * alpha / omega
        p_ = _add(r, _mul(beta, _sub(p, _mul(omega, q))))
        phat = M(p_)
        q_ = A(phat)
        alpha_ = rho_ / _vdot_tree(rhat, q_)
        s = _sub(r, _mul(alpha_, q_))
        exit_early = _vdot_real_tree(s, s) < atol2
        shat = M(s)
        t = A(shat)
        omega_ = _vdot_tree(t, s) / _vdot_tree(t, t)  # make cases?
        x_ = tree_multimap(
            partial(jnp.where, exit_early), _add(x, _mul(alpha_, phat)),
            _add(x, _add(_mul(alpha_, phat), _mul(omega_, shat))))
        r_ = tree_multimap(partial(jnp.where, exit_early), s,
                           _sub(s, _mul(omega_, t)))
        k_ = jnp.where((omega_ == 0) | (alpha_ == 0), -11, k + 1)
        k_ = jnp.where((rho_ == 0), -10, k_)
        return x_, r_, rhat, alpha_, omega_, rho_, p_, q_, k_

    r0 = _sub(b, A(x0))
    rho0 = alpha0 = omega0 = lax._convert_element_type(
        1, *dtypes._lattice_result_type(*tree_leaves(b)))
    initial_value = (x0, r0, r0, alpha0, omega0, rho0, r0, r0, 0)

    x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value)

    return x_final
Exemple #7
0
def _gmres_batched(A, b, x0, unit_residual, residual_norm, ptol, restart, M):
    """
  Implements a single restart of GMRES. The ``restart``-dimensional Krylov
  subspace
  K(A, x0) = span(A(x0), A@x0, A@A@x0, ..., A^restart @ x0) is built, and the
  projection of the true solution into this subspace is returned.

  This implementation solves a dense linear problem instead of building
  a QR factorization during the Arnoldi process.
  """
    del ptol  # unused
    # https://www-users.cs.umn.edu/~saad/Calais/PREC.pdf
    V = tree_map(
        lambda x: jnp.pad(x[..., None], ((0, 0), ) * x.ndim +
                          ((0, restart), )),
        unit_residual,
    )
    dtype, weak_type = dtypes._lattice_result_type(*tree_leaves(b))
    H = lax_internal._convert_element_type(jnp.eye(restart,
                                                   restart + 1,
                                                   dtype=dtype),
                                           weak_type=weak_type)

    def loop_cond(carry):
        _, _, breakdown, k = carry
        return jnp.logical_and(k < restart, jnp.logical_not(breakdown))

    def arnoldi_process(carry):
        V, H, _, k = carry
        V, H, breakdown = _kth_arnoldi_iteration(k, A, M, V, H)
        return V, H, breakdown, k + 1

    carry = (V, H, False, 0)
    V, H, _, _ = lax.while_loop(loop_cond, arnoldi_process, carry)

    beta_vec = jnp.zeros_like(H,
                              shape=(restart + 1, )).at[0].set(residual_norm)
    y = _lstsq(H.T, beta_vec)
    dx = tree_map(lambda X: _dot(X[..., :-1], y), V)

    x = _add(x0, dx)

    residual = M(_sub(b, A(x)))
    unit_residual, residual_norm = _safe_normalize(residual)
    return x, unit_residual, residual_norm
Exemple #8
0
    def testPromoteDtypesStrict(self):
        # Check that strong types have diagonal promotion table:
        for t1 in all_dtypes:
            for t2 in all_dtypes:
                if t1 == t2:
                    self.assertEqual(t1, dtypes.promote_types(t1, t2))
                else:
                    self.assertRaises(dtypes.TypePromotionError,
                                      dtypes.promote_types, t1, t2)

        # Promotion between weak types matches numpy promotion
        for t1 in [int, float, complex]:
            for t2 in [int, float, complex]:
                py_result = type(t1(0) + t2(0))
                lattice_dtype, lattice_weak_type = dtypes._lattice_result_type(
                    t1, t2)
                self.assertTrue(lattice_weak_type)
                self.assertEqual(lattice_dtype, np.dtype(py_result))

        # Check that weak promotion only works if strong value is not cast:
        for t1 in bool_dtypes:
            self.assertRaises(dtypes.TypePromotionError, dtypes.promote_types,
                              t1, int)
            self.assertRaises(dtypes.TypePromotionError, dtypes.promote_types,
                              t1, float)
            self.assertRaises(dtypes.TypePromotionError, dtypes.promote_types,
                              t1, complex)
        for t1 in signed_dtypes + unsigned_dtypes:
            self.assertEqual(dtypes.promote_types(t1, int), t1)
            self.assertRaises(dtypes.TypePromotionError, dtypes.promote_types,
                              t1, float)
            self.assertRaises(dtypes.TypePromotionError, dtypes.promote_types,
                              t1, complex)
        for t1 in float_dtypes:
            self.assertEqual(dtypes.promote_types(t1, int), t1)
            self.assertEqual(dtypes.promote_types(t1, float), t1)
            self.assertRaises(dtypes.TypePromotionError, dtypes.promote_types,
                              t1, complex)
        for t1 in complex_dtypes:
            self.assertEqual(dtypes.promote_types(t1, int), t1)
            self.assertEqual(dtypes.promote_types(t1, float), t1)
            self.assertEqual(dtypes.promote_types(t1, complex), t1)
Exemple #9
0
    def testPromoteDtypesStrict(self):
        msg = ("Input dtypes .* have no available implicit dtype promotion "
               "path when jax_numpy_dtype_promotion=strict")

        assertTypePromotionError = functools.partial(self.assertRaisesRegex,
                                                     dtypes.TypePromotionError,
                                                     msg, dtypes.promote_types)

        # Check that strong types have diagonal promotion table:
        for t1 in all_dtypes:
            for t2 in all_dtypes:
                if t1 == t2:
                    self.assertEqual(t1, dtypes.promote_types(t1, t2))
                else:
                    assertTypePromotionError(t1, t2)

        # Promotion between weak types matches numpy promotion
        for t1 in [int, float, complex]:
            for t2 in [int, float, complex]:
                py_result = type(t1(0) + t2(0))
                lattice_dtype, lattice_weak_type = dtypes._lattice_result_type(
                    t1, t2)
                self.assertTrue(lattice_weak_type)
                self.assertEqual(lattice_dtype, np.dtype(py_result))

        # Check that weak promotion only works if strong value is not cast:
        for t1 in bool_dtypes:
            assertTypePromotionError(t1, int)
            assertTypePromotionError(t1, float)
            assertTypePromotionError(t1, complex)
        for t1 in signed_dtypes + unsigned_dtypes:
            self.assertEqual(dtypes.promote_types(t1, int), t1)
            assertTypePromotionError(t1, float)
            assertTypePromotionError(t1, complex)
        for t1 in float_dtypes:
            self.assertEqual(dtypes.promote_types(t1, int), t1)
            self.assertEqual(dtypes.promote_types(t1, float), t1)
            assertTypePromotionError(t1, complex)
        for t1 in complex_dtypes:
            self.assertEqual(dtypes.promote_types(t1, int), t1)
            self.assertEqual(dtypes.promote_types(t1, float), t1)
            self.assertEqual(dtypes.promote_types(t1, complex), t1)
Exemple #10
0
def _safe_normalize(x, thresh=None):
  """
  Returns the L2-normalized vector (which can be a pytree) x, and optionally
  the computed norm. If the computed norm is less than the threshold `thresh`,
  which by default is the machine precision of x's dtype, it will be
  taken to be 0, and the normalized x to be the zero vector.
  """
  norm = _norm(x)
  dtype, weak_type = dtypes._lattice_result_type(*tree_leaves(x))
  dtype = dtypes.canonicalize_dtype(dtype)
  if thresh is None:
    thresh = jnp.finfo(norm.dtype).eps
  thresh = thresh.astype(dtype).real

  use_norm = norm > thresh

  norm_cast = lax_internal._convert_element_type(norm, dtype, weak_type)
  normalized_x = tree_map(lambda y: jnp.where(use_norm, y / norm_cast, 0.0), x)
  norm = jnp.where(use_norm, norm, 0.0)
  return normalized_x, norm
Exemple #11
0
    def testPromoteDtypesStandard(self):
        for t1 in all_dtypes:
            self.assertEqual(t1, dtypes.promote_types(t1, t1))

            self.assertEqual(t1, dtypes.promote_types(t1, np.bool_))
            self.assertEqual(np.dtype(np.complex128),
                             dtypes.promote_types(t1, np.complex128))

            for t2 in all_dtypes:
                # Symmetry
                self.assertEqual(dtypes.promote_types(t1, t2),
                                 dtypes.promote_types(t2, t1))

        self.assertEqual(np.dtype(np.float32),
                         dtypes.promote_types(np.float16, dtypes.bfloat16))

        # Promotions of non-inexact types against inexact types always prefer
        # the inexact types.
        for t in float_dtypes + complex_dtypes:
            for i in bool_dtypes + signed_dtypes + unsigned_dtypes:
                self.assertEqual(t, dtypes.promote_types(t, i))

        # Promotions between exact types, or between inexact types, match NumPy.
        for groups in [
                bool_dtypes + signed_dtypes + unsigned_dtypes,
                np_float_dtypes + complex_dtypes
        ]:
            for t1, t2 in itertools.combinations(groups, 2):
                self.assertEqual(np.promote_types(t1, t2),
                                 dtypes.promote_types(t1, t2))

        # Promotion between weak types matches numpy promotion
        for t1 in [int, float, complex]:
            for t2 in [int, float, complex]:
                py_result = type(t1(0) + t2(0))
                lattice_dtype, lattice_weak_type = dtypes._lattice_result_type(
                    t1, t2)
                self.assertTrue(lattice_weak_type)
                self.assertEqual(lattice_dtype, np.dtype(py_result))
Exemple #12
0
def zeros_like_array(x):
    dtype, weak_type = dtypes._lattice_result_type(x)
    dtype = dtypes.canonicalize_dtype(dtype)
    aval = ShapedArray(np.shape(x), dtype, weak_type=weak_type)
    return ad_util.zeros_like_aval(aval)