def tolerance(dtype, tol=None): tol = {} if tol is None else tol if not isinstance(tol, dict): return tol tol = {np.dtype(key): value for key, value in tol.items()} dtype = _dtypes.canonicalize_dtype(np.dtype(dtype)) return tol.get(dtype, default_tolerance()[dtype])
def testDefaultTypes(self, type_): expected_dtype = dtypes.canonicalize_dtype( dtypes.python_scalar_dtypes[type_]) for f in [jnp.array, jax.jit(jnp.array), jax.jit(lambda x: x)]: y = f(type_(0)) self.assertTrue(isinstance(y, jnp.ndarray), msg=(f, y)) self.assertEqual(y.dtype, expected_dtype, msg=(f, y))
def normalize_to_xla_dtypes(val): """Normalize dtypes in a value.""" if hasattr(val, '__array__') or np.isscalar(val): return np.asarray(val, dtype=dtypes.canonicalize_dtype(dtypes.result_type(val))) elif isinstance(val, (tuple, list)): return tuple(normalize_to_xla_dtypes(x) for x in val) raise TypeError('Can\'t convert to XLA: {}'.format(val))
def _cumulative_reduction(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, out=None): _check_arraylike(np_reduction.__name__, a) if out is not None: raise NotImplementedError( f"The 'out' argument to jnp.{np_reduction.__name__} " f"is not supported.") lax_internal._check_user_dtype_supported(dtype, np_reduction.__name__) if axis is None or _isscalar(a): a = lax.reshape(a, (np.size(a), )) axis = 0 a_shape = list(np.shape(a)) num_dims = len(a_shape) axis = _canonicalize_axis(axis, num_dims) if fill_nan: a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a) if not dtype and dtypes.dtype(a) == np.bool_: dtype = dtypes.canonicalize_dtype(dtypes.int_) if dtype: a = lax.convert_element_type(a, dtype) return reduction(a, axis)
def _mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, out=None, keepdims=False, *, where=None): _check_arraylike("mean", a) lax_internal._check_user_dtype_supported(dtype, "mean") if out is not None: raise NotImplementedError( "The 'out' argument to jnp.mean is not supported.") if where is None: if axis is None: normalizer = core.dimension_as_value(np.size(a)) else: normalizer = core.dimension_as_value(_axis_size(a, axis)) else: normalizer = sum(_broadcast_to(where, np.shape(a)), axis, dtype=dtype, keepdims=keepdims) if dtype is None: dtype = dtypes._to_inexact_dtype(dtypes.dtype(a)) dtype = dtypes.canonicalize_dtype(dtype) return lax.div(sum(a, axis, dtype=dtype, keepdims=keepdims, where=where), lax.convert_element_type(normalizer, dtype))
def count_nonzero(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims=False): _check_arraylike("count_nonzero", a) return sum(lax.ne(a, _lax_const(a, 0)), axis=axis, dtype=dtypes.canonicalize_dtype(np.int_), keepdims=keepdims)
def test_canonicalize_type(self): expected = { True: _EXPECTED_CANONICALIZE_X64, False: _EXPECTED_CANONICALIZE_X32, } for in_dtype, expected_dtype in expected[config.x64_enabled].items(): self.assertEqual(dtypes.canonicalize_dtype(in_dtype), expected_dtype)
def _promote_dtypes(*args): """Convenience function to apply Numpy argument dtype promotion.""" # TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing. if len(args) < 2: return args else: to_dtype, weak_type = dtypes._lattice_result_type(*args) to_dtype = dtypes.canonicalize_dtype(to_dtype) return [lax_internal._convert_element_type(x, to_dtype, weak_type) for x in args]
def __init__(self, shape, dtype, index_dtype, nnz, weak_type=False, named_shape=None): super().__init__(shape, dtypes.canonicalize_dtype(dtype)) named_shape = {} if named_shape is None else named_shape self.index_dtype = index_dtype self.nnz = nnz self.data_aval = core.ShapedArray((nnz, ), dtypes.canonicalize_dtype(dtype), weak_type, named_shape) self.indices_aval = core.ShapedArray( (nnz, len(shape)), dtypes.canonicalize_dtype(index_dtype), named_shape=named_shape)
def _promote_dtypes_complex(*args): """Convenience function to apply Numpy argument dtype promotion. Promotes arguments to a complex type.""" to_dtype, weak_type = dtypes._lattice_result_type(*args) to_dtype = dtypes.canonicalize_dtype(to_dtype) to_dtype_complex = dtypes._to_complex_dtype(to_dtype) return [lax_internal._convert_element_type(x, to_dtype_complex, weak_type) for x in args]
def to_default_dtype(arr): """Convert a value to an array with JAX's default dtype. This is generally used for type conversions of values returned by numpy functions, to make their dtypes take into account the state of the ``jax_enable_x64`` and ``jax_default_dtype_bits`` flags. """ arr = np.asarray(arr) dtype = _dtypes._default_types.get(arr.dtype.kind) return arr.astype(_dtypes.canonicalize_dtype(dtype)) if dtype else arr
def testUnaryPromotion(self, dtype, weak_type): # Regression test for https://github.com/google/jax/issues/6051 x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type) if weak_type: expected = dtypes.canonicalize_dtype( dtypes._default_types['f' if x.dtype == 'bfloat16' else x.dtype.kind]) else: expected = x.dtype self.assertEqual(dtypes.result_type(x), expected)
def _average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None, returned=False): a = _asarray(a) if weights is None: # Treat all weights as 1 avg = mean(a, axis=axis) if axis is None: weights_sum = lax.full((), core.dimension_as_value(np.size(a)), dtype=avg.dtype) else: weights_sum = lax.full_like(avg, core.dimension_as_value(a.shape[axis]), dtype=avg.dtype) else: weights = _asarray(weights) if dtypes.issubdtype(a.dtype, np.inexact): out_dtype = dtypes.result_type(a.dtype, weights.dtype) else: out_dtype = dtypes.result_type(a.dtype, weights.dtype, dtypes.float_) out_dtype = dtypes.canonicalize_dtype(out_dtype) a_shape = np.shape(a) a_ndim = len(a_shape) weights_shape = np.shape(weights) axis = None if axis is None else _canonicalize_axis(axis, a_ndim) if a_shape != weights_shape: # Make sure the dimensions work out if axis is None: raise ValueError("Axis must be specified when shapes of a and " "weights differ.") if len(weights_shape) != 1: raise ValueError("1D weights expected when shapes of a and " "weights differ.") if not core.symbolic_equal_dim(weights_shape[0], a_shape[axis]): raise ValueError("Length of weights not " "compatible with specified axis.") weights = _broadcast_to(weights, (a_ndim - 1) * (1, ) + weights_shape) weights = _moveaxis(weights, -1, axis) weights_sum = sum(weights, axis=axis, dtype=out_dtype) avg = sum(a * weights, axis=axis, dtype=out_dtype) / weights_sum if returned: if avg.shape != weights_sum.shape: weights_sum = _broadcast_to(weights_sum, avg.shape) return avg, weights_sum return avg
def _get_random_data(x: jnp.ndarray) -> Any: dtype = dtypes.canonicalize_dtype(x.dtype) if np.issubdtype(dtype, np.integer): return np.random.randint(0, 100, size=x.shape, dtype=dtype) elif np.issubdtype(dtype, np.floating): return np.array(np.random.uniform(size=x.shape), dtype=dtype) elif dtype == np.bool: return np.random.choice(a=[False, True], size=x.shape) else: raise ValueError( f"Unsupported dtype for numerical comparison: {dtype}")
def zeros(key, shape, dtype: DType = jnp.float_): """An initializer that returns a constant array full of zeros. The ``key`` argument is ignored. >>> import jax, jax.numpy as jnp >>> jax.nn.initializers.zeros(jax.random.PRNGKey(42), (2, 3), jnp.float32) DeviceArray([[0., 0., 0.], [0., 0., 0.]], dtype=float32) """ return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype))
def _promote_arg_dtypes(*args): """Promotes `args` to a common inexact type.""" dtype, weak_type = dtypes._lattice_result_type(*args) if not jnp.issubdtype(dtype, jnp.inexact): dtype, weak_type = jnp.float_, False dtype = dtypes.canonicalize_dtype(dtype) args = [lax._convert_element_type(arg, dtype, weak_type) for arg in args] if len(args) == 1: return args[0] else: return args
def _reduction_init_val(a, init_val): # This function uses np.* functions because lax pattern matches against the # specific concrete values of the reduction inputs. a_dtype = dtypes.canonicalize_dtype(dtypes.dtype(a)) if a_dtype == 'bool': return np.array(init_val > 0, dtype=a_dtype) try: return np.array(init_val, dtype=a_dtype) except OverflowError: assert dtypes.issubdtype(a_dtype, np.integer) sign, info = np.sign(init_val), dtypes.iinfo(a_dtype) return np.array(info.min if sign < 0 else info.max, dtype=a_dtype)
def testResultTypeWeakFlag(self): float_ = dtypes.canonicalize_dtype(dtypes.float_) x_weak = jnp.array(1.) x_strong = x_weak.astype(float_) self.assertEqual(dtypes.result_type(x_weak), float_) self.assertEqual( dtypes.result_type(x_weak, return_weak_type_flag=True), (float_, True)) self.assertEqual(dtypes.result_type(x_strong), float_) self.assertEqual( dtypes.result_type(x_strong, return_weak_type_flag=True), (float_, False))
def _promote_dtypes_inexact(*args): """Convenience function to apply Numpy argument dtype promotion. Promotes arguments to an inexact type.""" to_dtype, weak_type = dtypes._lattice_result_type(*args) to_dtype = dtypes.canonicalize_dtype(to_dtype) to_dtype_inexact = _to_inexact_dtype(to_dtype) weak_type = (weak_type and to_dtype == to_dtype_inexact) return [ lax_internal._convert_element_type(x, to_dtype_inexact, weak_type) for x in args ]
def _promote_arg_dtypes(*args): """Promotes `args` to a common inexact type.""" def _to_inexact_type(type): return type if jnp.issubdtype(type, jnp.inexact) else jnp.float_ inexact_types = [_to_inexact_type(jnp._dtype(arg)) for arg in args] dtype = dtypes.canonicalize_dtype(jnp.result_type(*inexact_types)) args = [lax.convert_element_type(arg, dtype) for arg in args] if len(args) == 1: return args[0] else: return args
def one_hot(x: Array, num_classes: int, *, dtype: Any = jnp.float64, axis: Union[int, AxisName] = -1) -> Array: """One-hot encodes the given indicies. Each index in the input ``x`` is encoded as a vector of zeros of length ``num_classes`` with the element at ``index`` set to one:: >>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3) DeviceArray([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32) Indicies outside the range [0, num_classes) will be encoded as zeros:: >>> jax.nn.one_hot(jnp.array([-1, 3]), 3) DeviceArray([[0., 0., 0.], [0., 0., 0.]], dtype=float32) Args: x: A tensor of indices. num_classes: Number of classes in the one-hot dimension. dtype: optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32). axis: the axis or axes along which the function should be computed. """ num_classes = core.concrete_or_error( int, num_classes, "The error arose in jax.nn.one_hot argument `num_classes`.") dtype = dtypes.canonicalize_dtype(dtype) x = jnp.asarray(x) try: output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1) except TypeError: axis_size = lax.psum(1, axis) if num_classes != axis_size: raise ValueError( f"Expected num_classes to match the size of axis {axis}, " f"but {num_classes} != {axis_size}") from None axis_idx = lax.axis_index(axis) return jnp.asarray(x == axis_idx, dtype=dtype) axis = operator.index(axis) lhs = lax.expand_dims(x, (axis, )) rhs_shape = [1] * x.ndim rhs_shape.insert(output_pos_axis, num_classes) rhs = lax.broadcast_in_dim(jnp.arange(num_classes, dtype=x.dtype), rhs_shape, (output_pos_axis, )) return jnp.asarray(lhs == rhs, dtype=dtype)
def _get_random_data(dtype: jnp.dtype, shape: Tuple[int, ...], seed=0) -> Any: dtype = dtypes.canonicalize_dtype(dtype) np.random.seed(seed) # Adjust the max values of the numbers based on the seed, so different seeds # result in different ranges. max_value = max(1, 100*seed) if np.issubdtype(dtype, np.integer): return np.random.randint(0, max_value, size=shape, dtype=dtype) elif np.issubdtype(dtype, np.floating): return np.array(np.random.uniform(size=shape), dtype=dtype) * max_value elif dtype == np.bool: return np.random.choice(a=[False, True], size=shape) else: raise ValueError(f"Unsupported dtype for numerical comparison: {dtype}")
def _numpy_array_constant(x: np.ndarray, canonicalize_types) -> Sequence[ir.Value]: if canonicalize_types: x = np.asarray(x, dtypes.canonicalize_dtype(x.dtype)) aval = xla.abstractify(x) ir_type = aval_to_ir_type(aval) if x.dtype == np.bool_: x = np.packbits(x, bitorder='little') elif x.dtype == dtypes.bfloat16: x = x.view(np.uint16) x = np.ascontiguousarray(x) attr = ir.DenseElementsAttr.get(x, type=ir_type.element_type, shape=aval.shape) return (mhlo.ConstOp(ir_type, attr).result, )
def ldexp(x1, x2): _check_arraylike("ldexp", x1, x2) x1_dtype = dtypes.dtype(x1) x2_dtype = dtypes.dtype(x2) if (dtypes.issubdtype(x1_dtype, np.complexfloating) or dtypes.issubdtype(x2_dtype, np.inexact)): raise ValueError( f"ldexp not supported for input types {(x1_dtype, x2_dtype)}") x1, x2 = _promote_shapes("ldexp", x1, x2) dtype = dtypes.canonicalize_dtype(dtypes._to_inexact_dtype(x1_dtype)) info = dtypes.finfo(dtype) int_type = _INT_DTYPES[info.bits] x1 = lax.convert_element_type(x1, dtype) x2 = lax.convert_element_type(x2, int_type) mask = (1 << info.nexp) - 1 bias = ((1 << info.nexp) - 1) >> 1 x, e = _normalize_float(x1) x2 += e + ((x >> info.nmant) & mask) - bias # find underflow/overflow before denormalization underflow_cond = x2 < -(bias + info.nmant) overflow_cond = x2 > bias m = lax.full_like(x, 1, dtype=dtype) # denormals cond = x2 < -bias + 1 x2 = _where(cond, x2 + info.nmant, x2) m = _where(cond, m / (1 << info.nmant), m) x2 = lax.convert_element_type(x2, np.int32) x &= ~(mask << info.nmant) x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant) x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype) # underflow x = _where(underflow_cond, lax.full_like(x, 0, dtype=dtype), x) # overflow x = _where(overflow_cond, lax.sign(x1) * lax.full_like(x, np.inf), x) # ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0 return _where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x)
def init(key, shape, dtype=dtype): dtype = dtypes.canonicalize_dtype(dtype) if len(shape) < 2: raise ValueError( "orthogonal initializer requires at least a 2D shape") n_rows, n_cols = prod(shape) // shape[column_axis], shape[column_axis] matrix_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows, n_cols) A = random.normal(key, matrix_shape, dtype) Q, R = jnp.linalg.qr(A) diag_sign = lax.broadcast_to_rank(jnp.sign(jnp.diag(R)), rank=Q.ndim) Q *= diag_sign # needed for a uniform distribution if n_rows < n_cols: Q = Q.T Q = jnp.reshape( Q, tuple(np.delete(shape, column_axis)) + (shape[column_axis], )) Q = jnp.moveaxis(Q, -1, column_axis) return scale * Q
def _numpy_array_constant(x: np.ndarray, canonicalize_types) -> Sequence[ir.Value]: if canonicalize_types: x = np.asarray(x, dtypes.canonicalize_dtype(x.dtype)) ir_type = ir.RankedTensorType.get(x.shape, dtype_to_ir_type(x.dtype)) shape = x.shape if x.dtype == np.bool_: nelems = x.size x = np.packbits(x, bitorder='little') # TODO(b/209005197): Work around for MLIR crash for non-splat single element # buffers. if nelems == 1: x = np.array(0 if x.item() == 0 else 0xff, np.uint8) elif x.dtype == dtypes.bfloat16: x = x.view(np.uint16) x = np.ascontiguousarray(x) attr = ir.DenseElementsAttr.get(x, type=ir_type.element_type, shape=shape) return (mhlo.ConstOp(ir_type, attr).result, )
def _var_promote_types(a_dtype, dtype): if dtype: if (not dtypes.issubdtype(dtype, np.complexfloating) and dtypes.issubdtype(a_dtype, np.complexfloating)): msg = ( "jax.numpy.var does not yet support real dtype parameters when " "computing the variance of an array of complex values. The " "semantics of numpy.var seem unclear in this case. Please comment " "on https://github.com/google/jax/issues/2283 if this behavior is " "important to you.") raise ValueError(msg) a_dtype = dtypes.promote_types(a_dtype, dtype) else: if not dtypes.issubdtype(a_dtype, np.inexact): dtype = a_dtype = dtypes.canonicalize_dtype(dtypes.float_) else: dtype = _complex_elem_type(a_dtype) a_dtype = dtypes.promote_types(a_dtype, np.float32) return a_dtype, dtype
def _gmres_batched(A, b, x0, unit_residual, residual_norm, ptol, restart, M): """ Implements a single restart of GMRES. The ``restart``-dimensional Krylov subspace K(A, x0) = span(A(x0), A@x0, A@A@x0, ..., A^restart @ x0) is built, and the projection of the true solution into this subspace is returned. This implementation solves a dense linear problem instead of building a QR factorization during the Arnoldi process. """ del ptol # unused # https://www-users.cs.umn.edu/~saad/Calais/PREC.pdf V = tree_map( lambda x: jnp.pad(x[..., None], ((0, 0),) * x.ndim + ((0, restart),)), unit_residual, ) dtype, weak_type = dtypes._lattice_result_type(*tree_leaves(b)) dtype = dtypes.canonicalize_dtype(dtype) H = lax_internal._convert_element_type( jnp.eye(restart, restart + 1, dtype=dtype), weak_type=weak_type) def loop_cond(carry): _, _, breakdown, k = carry return jnp.logical_and(k < restart, jnp.logical_not(breakdown)) def arnoldi_process(carry): V, H, _, k = carry V, H, breakdown = _kth_arnoldi_iteration(k, A, M, V, H) return V, H, breakdown, k + 1 carry = (V, H, False, 0) V, H, _, _ = lax.while_loop(loop_cond, arnoldi_process, carry) beta_vec = jnp.zeros_like(H, shape=(restart + 1,)).at[0].set(residual_norm.astype(dtype)) y = _lstsq(H.T, beta_vec) dx = tree_map(lambda X: _dot(X[..., :-1], y), V) x = _add(x0, dx) residual = M(_sub(b, A(x))) unit_residual, residual_norm = _safe_normalize(residual) return x, unit_residual, residual_norm
def _safe_normalize(x, thresh=None): """ Returns the L2-normalized vector (which can be a pytree) x, and optionally the computed norm. If the computed norm is less than the threshold `thresh`, which by default is the machine precision of x's dtype, it will be taken to be 0, and the normalized x to be the zero vector. """ norm = _norm(x) dtype, weak_type = dtypes._lattice_result_type(*tree_leaves(x)) dtype = dtypes.canonicalize_dtype(dtype) if thresh is None: thresh = jnp.finfo(norm.dtype).eps thresh = thresh.astype(dtype).real use_norm = norm > thresh norm_cast = lax_internal._convert_element_type(norm, dtype, weak_type) normalized_x = tree_map(lambda y: jnp.where(use_norm, y / norm_cast, 0.0), x) norm = jnp.where(use_norm, norm, 0.0) return normalized_x, norm
def testBinaryPromotion(self, swap, jit): testcases = [ (jnp.array(1.), 0., jnp.float64), (jnp.array(1.), jnp.array(0.), jnp.float64), (jnp.array(1.), jnp.array(0., dtype=jnp.float16), jnp.float16), (jnp.array(1.), jnp.array(0., dtype=jnp.float32), jnp.float32), (jnp.array(1.), jnp.array(0., dtype=jnp.float64), jnp.float64), (jnp.array(1., dtype=jnp.float16), 0., jnp.float16), (jnp.array(1., dtype=jnp.float32), 0., jnp.float32), (jnp.array(1., dtype=jnp.float64), 0., jnp.float64), (jnp.array(1., dtype=jnp.float16), jnp.array(0., dtype=jnp.float16), jnp.float16), (jnp.array(1., dtype=jnp.float16), jnp.array(0., dtype=jnp.float32), jnp.float32), (jnp.array(1., dtype=jnp.float16), jnp.array(0., dtype=jnp.float64), jnp.float64), (jnp.array(1., dtype=jnp.float32), jnp.array(0., dtype=jnp.float32), jnp.float32), (jnp.array(1., dtype=jnp.float32), jnp.array(0., dtype=jnp.float64), jnp.float64), (jnp.array(1., dtype=jnp.float64), jnp.array(0., dtype=jnp.float64), jnp.float64), (jnp.array([1.]), 0., jnp.float_), (jnp.array([1.]), jnp.array(0.), jnp.float_), (jnp.array([1.]), jnp.array(0., dtype=jnp.float16), jnp.float_), (jnp.array([1.]), jnp.array(0., dtype=jnp.float32), jnp.float_), (jnp.array([1.]), jnp.array(0., dtype=jnp.float64), jnp.float64), (jnp.array([1.], dtype=jnp.float32), jnp.array(0., dtype=jnp.float16), jnp.float32), (jnp.array([1.], dtype=jnp.float16), jnp.array(0., dtype=jnp.float32), jnp.float32), (jnp.array([1.], dtype=jnp.float16), 0., jnp.float16), ] op = jax.jit(operator.add) if jit else operator.add for x, y, dtype in testcases: x, y = (y, x) if swap else (x, y) z = op(x, y) self.assertTrue(isinstance(z, jnp.ndarray), msg=(x, y, z)) self.assertEqual(z.dtype, dtypes.canonicalize_dtype(dtype), msg=(x, y, z))