def _asarray(arr): """ Pared-down utility to convert object to a DeviceArray. Note this will not correctly handle lists or tuples. """ _check_arraylike("_asarray", arr) dtype, weak_type = dtypes._lattice_result_type(arr) return lax_internal._convert_element_type(arr, dtype, weak_type)
def testDeviceArrayRepr(self, dtype, weak_type): val = lax_internal._convert_element_type(0, dtype, weak_type=weak_type) rep = repr(val) self.assertStartsWith(rep, 'DeviceArray(') if weak_type: self.assertEndsWith(rep, f"dtype={val.dtype.name}, weak_type=True)") else: self.assertEndsWith(rep, f"dtype={val.dtype.name})")
def _promote_dtypes_complex(*args): """Convenience function to apply Numpy argument dtype promotion. Promotes arguments to a complex type.""" to_dtype, weak_type = dtypes._lattice_result_type(*args) to_dtype = dtypes.canonicalize_dtype(to_dtype) to_dtype_complex = dtypes._to_complex_dtype(to_dtype) return [lax_internal._convert_element_type(x, to_dtype_complex, weak_type) for x in args]
def _promote_dtypes(*args): """Convenience function to apply Numpy argument dtype promotion.""" # TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing. if len(args) < 2: return args else: to_dtype, weak_type = dtypes._lattice_result_type(*args) to_dtype = dtypes.canonicalize_dtype(to_dtype) return [lax_internal._convert_element_type(x, to_dtype, weak_type) for x in args]
def testUnaryPromotion(self, dtype, weak_type): # Regression test for https://github.com/google/jax/issues/6051 x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type) if weak_type: expected = dtypes.canonicalize_dtype( dtypes._default_types['f' if x.dtype == 'bfloat16' else x.dtype.kind]) else: expected = x.dtype self.assertEqual(dtypes.result_type(x), expected)
def _promote_dtypes_inexact(*args): """Convenience function to apply Numpy argument dtype promotion. Promotes arguments to an inexact type.""" to_dtype, weak_type = dtypes._lattice_result_type(*args) to_dtype = dtypes.canonicalize_dtype(to_dtype) to_dtype_inexact = _to_inexact_dtype(to_dtype) weak_type = (weak_type and to_dtype == to_dtype_inexact) return [ lax_internal._convert_element_type(x, to_dtype_inexact, weak_type) for x in args ]
def _bicgstab_solve(A, b, x0=None, *, maxiter, tol=1e-5, atol=0.0, M=_identity): # tolerance handling uses the "non-legacy" behavior of scipy.sparse.linalg.bicgstab bs = _vdot_real_tree(b, b) atol2 = jnp.maximum(jnp.square(tol) * bs, jnp.square(atol)) # https://en.wikipedia.org/wiki/Biconjugate_gradient_stabilized_method#Preconditioned_BiCGSTAB def cond_fun(value): x, r, *_, k = value rs = _vdot_real_tree(r, r) # the last condition checks breakdown return (rs > atol2) & (k < maxiter) & (k >= 0) def body_fun(value): x, r, rhat, alpha, omega, rho, p, q, k = value rho_ = _vdot_tree(rhat, r) beta = rho_ / rho * alpha / omega p_ = _add(r, _mul(beta, _sub(p, _mul(omega, q)))) phat = M(p_) q_ = A(phat) alpha_ = rho_ / _vdot_tree(rhat, q_) s = _sub(r, _mul(alpha_, q_)) exit_early = _vdot_real_tree(s, s) < atol2 shat = M(s) t = A(shat) omega_ = _vdot_tree(t, s) / _vdot_tree(t, t) # make cases? x_ = tree_map(partial(jnp.where, exit_early), _add(x, _mul(alpha_, phat)), _add(x, _add(_mul(alpha_, phat), _mul(omega_, shat)))) r_ = tree_map(partial(jnp.where, exit_early), s, _sub(s, _mul(omega_, t))) k_ = jnp.where((omega_ == 0) | (alpha_ == 0), -11, k + 1) k_ = jnp.where((rho_ == 0), -10, k_) return x_, r_, rhat, alpha_, omega_, rho_, p_, q_, k_ r0 = _sub(b, A(x0)) rho0 = alpha0 = omega0 = lax_internal._convert_element_type( 1, *dtypes._lattice_result_type(*tree_leaves(b))) initial_value = (x0, r0, r0, alpha0, omega0, rho0, r0, r0, 0) x_final, *_ = lax.while_loop(cond_fun, body_fun, initial_value) return x_final
def _promote_arg_dtypes(*args): """Promotes `args` to a common inexact type.""" dtype, weak_type = dtypes._lattice_result_type(*args) if not jnp.issubdtype(dtype, jnp.inexact): dtype, weak_type = jnp.float_, False dtype = dtypes.canonicalize_dtype(dtype) args = [ lax_internal._convert_element_type(arg, dtype, weak_type) for arg in args ] if len(args) == 1: return args[0] else: return args
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, normalize_indices): dtype = lax.dtype(x) weak_type = dtypes.is_weakly_typed(x) if dtype != dtypes.result_type(x, y): # TODO(jakevdp): change this to an error after the deprecation period. warnings.warn( "scatter inputs have incompatible types: cannot safely cast " f"value from dtype={lax.dtype(y)} to dtype={lax.dtype(x)}. " "In future JAX releases this will result in an error.", FutureWarning) idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx) indexer = jnp._index_to_gather(jnp.shape(x), idx, normalize_indices=normalize_indices) # Avoid calling scatter if the slice shape is empty, both as a fast path and # to handle cases like zeros(0)[array([], int32)]. if core.is_empty_shape(indexer.slice_shape): return x x, y = jnp._promote_dtypes(x, y) # Broadcast `y` to the slice output shape. y = jnp.broadcast_to(y, tuple(indexer.slice_shape)) # Collapse any `None`/`jnp.newaxis` dimensions. y = jnp.squeeze(y, axis=indexer.newaxis_dims) if indexer.reversed_y_dims: y = lax.rev(y, indexer.reversed_y_dims) # Transpose the gather dimensions into scatter dimensions (cf. # lax._gather_transpose_rule) dnums = lax.ScatterDimensionNumbers( update_window_dims=indexer.dnums.offset_dims, inserted_window_dims=indexer.dnums.collapsed_slice_dims, scatter_dims_to_operand_dims=indexer.dnums.start_index_map) out = scatter_op(x, indexer.gather_indices, y, dnums, indices_are_sorted=indexer.indices_are_sorted or indices_are_sorted, unique_indices=indexer.unique_indices or unique_indices, mode=mode) return lax_internal._convert_element_type(out, dtype, weak_type)
def _gmres_batched(A, b, x0, unit_residual, residual_norm, ptol, restart, M): """ Implements a single restart of GMRES. The ``restart``-dimensional Krylov subspace K(A, x0) = span(A(x0), A@x0, A@A@x0, ..., A^restart @ x0) is built, and the projection of the true solution into this subspace is returned. This implementation solves a dense linear problem instead of building a QR factorization during the Arnoldi process. """ del ptol # unused # https://www-users.cs.umn.edu/~saad/Calais/PREC.pdf V = tree_map( lambda x: jnp.pad(x[..., None], ((0, 0), ) * x.ndim + ((0, restart), )), unit_residual, ) dtype, weak_type = dtypes._lattice_result_type(*tree_leaves(b)) H = lax_internal._convert_element_type(jnp.eye(restart, restart + 1, dtype=dtype), weak_type=weak_type) def loop_cond(carry): _, _, breakdown, k = carry return jnp.logical_and(k < restart, jnp.logical_not(breakdown)) def arnoldi_process(carry): V, H, _, k = carry V, H, breakdown = _kth_arnoldi_iteration(k, A, M, V, H) return V, H, breakdown, k + 1 carry = (V, H, False, 0) V, H, _, _ = lax.while_loop(loop_cond, arnoldi_process, carry) beta_vec = jnp.zeros_like(H, shape=(restart + 1, )).at[0].set(residual_norm) y = _lstsq(H.T, beta_vec) dx = tree_map(lambda X: _dot(X[..., :-1], y), V) x = _add(x0, dx) residual = M(_sub(b, A(x))) unit_residual, residual_norm = _safe_normalize(residual) return x, unit_residual, residual_norm
def _safe_normalize(x, thresh=None): """ Returns the L2-normalized vector (which can be a pytree) x, and optionally the computed norm. If the computed norm is less than the threshold `thresh`, which by default is the machine precision of x's dtype, it will be taken to be 0, and the normalized x to be the zero vector. """ norm = _norm(x) dtype, weak_type = dtypes._lattice_result_type(*tree_leaves(x)) dtype = dtypes.canonicalize_dtype(dtype) if thresh is None: thresh = jnp.finfo(norm.dtype).eps thresh = thresh.astype(dtype).real use_norm = norm > thresh norm_cast = lax_internal._convert_element_type(norm, dtype, weak_type) normalized_x = tree_map(lambda y: jnp.where(use_norm, y / norm_cast, 0.0), x) norm = jnp.where(use_norm, norm, 0.0) return normalized_x, norm
def testBinaryNonPromotion(self, dtype, weak_type, promotion): # Regression test for https://github.com/google/jax/issues/6051 x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type) with jax.numpy_dtype_promotion(promotion): y = (x + x) if promotion == 'standard' or not weak_type or dtype == dtypes.bool_: expected_dtype = dtype elif dtypes.issubdtype(dtype, np.complexfloating): expected_dtype = dtypes.complex_ elif dtypes.issubdtype(dtype, np.floating): expected_dtype = dtypes.float_ else: expected_dtype = dtypes.int_ # No boolean weak types. expected_weak_type = weak_type and dtype != bool expected_dtype = dtypes.canonicalize_dtype(expected_dtype) self.assertEqual(y.dtype, expected_dtype) self.assertEqual(dtypes.is_weakly_typed(y), expected_weak_type)
def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, normalize_indices): dtype = lax.dtype(x) weak_type = dtypes.is_weakly_typed(x) idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx) indexer = jnp._index_to_gather(jnp.shape(x), idx, normalize_indices=normalize_indices) # Avoid calling scatter if the slice shape is empty, both as a fast path and # to handle cases like zeros(0)[array([], int32)]. if core.is_empty_shape(indexer.slice_shape): return x x, y = jnp._promote_dtypes(x, y) # Broadcast `y` to the slice output shape. y = jnp.broadcast_to(y, tuple(indexer.slice_shape)) # Collapse any `None`/`jnp.newaxis` dimensions. y = jnp.squeeze(y, axis=indexer.newaxis_dims) if indexer.reversed_y_dims: y = lax.rev(y, indexer.reversed_y_dims) # Transpose the gather dimensions into scatter dimensions (cf. # lax._gather_transpose_rule) dnums = lax.ScatterDimensionNumbers( update_window_dims=indexer.dnums.offset_dims, inserted_window_dims=indexer.dnums.collapsed_slice_dims, scatter_dims_to_operand_dims=indexer.dnums.start_index_map) out = scatter_op(x, indexer.gather_indices, y, dnums, indices_are_sorted=indexer.indices_are_sorted or indices_are_sorted, unique_indices=indexer.unique_indices or unique_indices, mode=mode) return lax_internal._convert_element_type(out, dtype, weak_type)
def testBinaryNonPromotion(self, dtype, weak_type): # Regression test for https://github.com/google/jax/issues/6051 x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type) y = (x + x) assert x.dtype == y.dtype assert dtypes.is_weakly_typed(y) == dtypes.is_weakly_typed(x)