Example #1
0
def tolerance(dtype, tol=None):
    tol = {} if tol is None else tol
    if not isinstance(tol, dict):
        return tol
    tol = {np.dtype(key): value for key, value in tol.items()}
    dtype = _dtypes.canonicalize_dtype(np.dtype(dtype))
    return tol.get(dtype, default_tolerance()[dtype])
Example #2
0
 def testDefaultTypes(self, type_):
     expected_dtype = dtypes.canonicalize_dtype(
         dtypes.python_scalar_dtypes[type_])
     for f in [jnp.array, jax.jit(jnp.array), jax.jit(lambda x: x)]:
         y = f(type_(0))
         self.assertTrue(isinstance(y, jnp.ndarray), msg=(f, y))
         self.assertEqual(y.dtype, expected_dtype, msg=(f, y))
Example #3
0
def normalize_to_xla_dtypes(val):
  """Normalize dtypes in a value."""
  if hasattr(val, '__array__') or np.isscalar(val):
    return np.asarray(val, dtype=dtypes.canonicalize_dtype(dtypes.result_type(val)))
  elif isinstance(val, (tuple, list)):
    return tuple(normalize_to_xla_dtypes(x) for x in val)
  raise TypeError('Can\'t convert to XLA: {}'.format(val))
Example #4
0
    def _cumulative_reduction(a,
                              axis: Optional[Union[int, Tuple[int,
                                                              ...]]] = None,
                              dtype=None,
                              out=None):
        _check_arraylike(np_reduction.__name__, a)
        if out is not None:
            raise NotImplementedError(
                f"The 'out' argument to jnp.{np_reduction.__name__} "
                f"is not supported.")
        lax_internal._check_user_dtype_supported(dtype, np_reduction.__name__)

        if axis is None or _isscalar(a):
            a = lax.reshape(a, (np.size(a), ))
            axis = 0

        a_shape = list(np.shape(a))
        num_dims = len(a_shape)
        axis = _canonicalize_axis(axis, num_dims)

        if fill_nan:
            a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a)

        if not dtype and dtypes.dtype(a) == np.bool_:
            dtype = dtypes.canonicalize_dtype(dtypes.int_)
        if dtype:
            a = lax.convert_element_type(a, dtype)

        return reduction(a, axis)
Example #5
0
def _mean(a,
          axis: Optional[Union[int, Tuple[int, ...]]] = None,
          dtype=None,
          out=None,
          keepdims=False,
          *,
          where=None):
    _check_arraylike("mean", a)
    lax_internal._check_user_dtype_supported(dtype, "mean")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.mean is not supported.")

    if where is None:
        if axis is None:
            normalizer = core.dimension_as_value(np.size(a))
        else:
            normalizer = core.dimension_as_value(_axis_size(a, axis))
    else:
        normalizer = sum(_broadcast_to(where, np.shape(a)),
                         axis,
                         dtype=dtype,
                         keepdims=keepdims)

    if dtype is None:
        dtype = dtypes._to_inexact_dtype(dtypes.dtype(a))
    dtype = dtypes.canonicalize_dtype(dtype)

    return lax.div(sum(a, axis, dtype=dtype, keepdims=keepdims, where=where),
                   lax.convert_element_type(normalizer, dtype))
Example #6
0
def count_nonzero(a,
                  axis: Optional[Union[int, Tuple[int, ...]]] = None,
                  keepdims=False):
    _check_arraylike("count_nonzero", a)
    return sum(lax.ne(a, _lax_const(a, 0)),
               axis=axis,
               dtype=dtypes.canonicalize_dtype(np.int_),
               keepdims=keepdims)
Example #7
0
 def test_canonicalize_type(self):
     expected = {
         True: _EXPECTED_CANONICALIZE_X64,
         False: _EXPECTED_CANONICALIZE_X32,
     }
     for in_dtype, expected_dtype in expected[config.x64_enabled].items():
         self.assertEqual(dtypes.canonicalize_dtype(in_dtype),
                          expected_dtype)
Example #8
0
def _promote_dtypes(*args):
  """Convenience function to apply Numpy argument dtype promotion."""
  # TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing.
  if len(args) < 2:
    return args
  else:
    to_dtype, weak_type = dtypes._lattice_result_type(*args)
    to_dtype = dtypes.canonicalize_dtype(to_dtype)
    return [lax_internal._convert_element_type(x, to_dtype, weak_type) for x in args]
Example #9
0
 def __init__(self,
              shape,
              dtype,
              index_dtype,
              nnz,
              weak_type=False,
              named_shape=None):
     super().__init__(shape, dtypes.canonicalize_dtype(dtype))
     named_shape = {} if named_shape is None else named_shape
     self.index_dtype = index_dtype
     self.nnz = nnz
     self.data_aval = core.ShapedArray((nnz, ),
                                       dtypes.canonicalize_dtype(dtype),
                                       weak_type, named_shape)
     self.indices_aval = core.ShapedArray(
         (nnz, len(shape)),
         dtypes.canonicalize_dtype(index_dtype),
         named_shape=named_shape)
Example #10
0
def _promote_dtypes_complex(*args):
  """Convenience function to apply Numpy argument dtype promotion.

  Promotes arguments to a complex type."""
  to_dtype, weak_type = dtypes._lattice_result_type(*args)
  to_dtype = dtypes.canonicalize_dtype(to_dtype)
  to_dtype_complex = dtypes._to_complex_dtype(to_dtype)
  return [lax_internal._convert_element_type(x, to_dtype_complex, weak_type)
          for x in args]
Example #11
0
def to_default_dtype(arr):
    """Convert a value to an array with JAX's default dtype.

  This is generally used for type conversions of values returned by numpy functions,
  to make their dtypes take into account the state of the ``jax_enable_x64`` and
  ``jax_default_dtype_bits`` flags.
  """
    arr = np.asarray(arr)
    dtype = _dtypes._default_types.get(arr.dtype.kind)
    return arr.astype(_dtypes.canonicalize_dtype(dtype)) if dtype else arr
Example #12
0
 def testUnaryPromotion(self, dtype, weak_type):
     # Regression test for https://github.com/google/jax/issues/6051
     x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type)
     if weak_type:
         expected = dtypes.canonicalize_dtype(
             dtypes._default_types['f' if x.dtype ==
                                   'bfloat16' else x.dtype.kind])
     else:
         expected = x.dtype
     self.assertEqual(dtypes.result_type(x), expected)
Example #13
0
def _average(a,
             axis: Optional[Union[int, Tuple[int, ...]]] = None,
             weights=None,
             returned=False):
    a = _asarray(a)

    if weights is None:  # Treat all weights as 1
        avg = mean(a, axis=axis)
        if axis is None:
            weights_sum = lax.full((),
                                   core.dimension_as_value(np.size(a)),
                                   dtype=avg.dtype)
        else:
            weights_sum = lax.full_like(avg,
                                        core.dimension_as_value(a.shape[axis]),
                                        dtype=avg.dtype)
    else:
        weights = _asarray(weights)

        if dtypes.issubdtype(a.dtype, np.inexact):
            out_dtype = dtypes.result_type(a.dtype, weights.dtype)
        else:
            out_dtype = dtypes.result_type(a.dtype, weights.dtype,
                                           dtypes.float_)
        out_dtype = dtypes.canonicalize_dtype(out_dtype)

        a_shape = np.shape(a)
        a_ndim = len(a_shape)
        weights_shape = np.shape(weights)
        axis = None if axis is None else _canonicalize_axis(axis, a_ndim)

        if a_shape != weights_shape:
            # Make sure the dimensions work out
            if axis is None:
                raise ValueError("Axis must be specified when shapes of a and "
                                 "weights differ.")
            if len(weights_shape) != 1:
                raise ValueError("1D weights expected when shapes of a and "
                                 "weights differ.")
            if not core.symbolic_equal_dim(weights_shape[0], a_shape[axis]):
                raise ValueError("Length of weights not "
                                 "compatible with specified axis.")

            weights = _broadcast_to(weights,
                                    (a_ndim - 1) * (1, ) + weights_shape)
            weights = _moveaxis(weights, -1, axis)

        weights_sum = sum(weights, axis=axis, dtype=out_dtype)
        avg = sum(a * weights, axis=axis, dtype=out_dtype) / weights_sum

    if returned:
        if avg.shape != weights_sum.shape:
            weights_sum = _broadcast_to(weights_sum, avg.shape)
        return avg, weights_sum
    return avg
Example #14
0
def _get_random_data(x: jnp.ndarray) -> Any:
    dtype = dtypes.canonicalize_dtype(x.dtype)
    if np.issubdtype(dtype, np.integer):
        return np.random.randint(0, 100, size=x.shape, dtype=dtype)
    elif np.issubdtype(dtype, np.floating):
        return np.array(np.random.uniform(size=x.shape), dtype=dtype)
    elif dtype == np.bool:
        return np.random.choice(a=[False, True], size=x.shape)
    else:
        raise ValueError(
            f"Unsupported dtype for numerical comparison: {dtype}")
Example #15
0
def zeros(key, shape, dtype: DType = jnp.float_):
    """An initializer that returns a constant array full of zeros.

  The ``key`` argument is ignored.

  >>> import jax, jax.numpy as jnp
  >>> jax.nn.initializers.zeros(jax.random.PRNGKey(42), (2, 3), jnp.float32)
  DeviceArray([[0., 0., 0.],
               [0., 0., 0.]], dtype=float32)
  """
    return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype))
Example #16
0
def _promote_arg_dtypes(*args):
    """Promotes `args` to a common inexact type."""
    dtype, weak_type = dtypes._lattice_result_type(*args)
    if not jnp.issubdtype(dtype, jnp.inexact):
        dtype, weak_type = jnp.float_, False
    dtype = dtypes.canonicalize_dtype(dtype)
    args = [lax._convert_element_type(arg, dtype, weak_type) for arg in args]
    if len(args) == 1:
        return args[0]
    else:
        return args
Example #17
0
def _reduction_init_val(a, init_val):
    # This function uses np.* functions because lax pattern matches against the
    # specific concrete values of the reduction inputs.
    a_dtype = dtypes.canonicalize_dtype(dtypes.dtype(a))
    if a_dtype == 'bool':
        return np.array(init_val > 0, dtype=a_dtype)
    try:
        return np.array(init_val, dtype=a_dtype)
    except OverflowError:
        assert dtypes.issubdtype(a_dtype, np.integer)
        sign, info = np.sign(init_val), dtypes.iinfo(a_dtype)
        return np.array(info.min if sign < 0 else info.max, dtype=a_dtype)
Example #18
0
 def testResultTypeWeakFlag(self):
     float_ = dtypes.canonicalize_dtype(dtypes.float_)
     x_weak = jnp.array(1.)
     x_strong = x_weak.astype(float_)
     self.assertEqual(dtypes.result_type(x_weak), float_)
     self.assertEqual(
         dtypes.result_type(x_weak, return_weak_type_flag=True),
         (float_, True))
     self.assertEqual(dtypes.result_type(x_strong), float_)
     self.assertEqual(
         dtypes.result_type(x_strong, return_weak_type_flag=True),
         (float_, False))
Example #19
0
def _promote_dtypes_inexact(*args):
    """Convenience function to apply Numpy argument dtype promotion.

  Promotes arguments to an inexact type."""
    to_dtype, weak_type = dtypes._lattice_result_type(*args)
    to_dtype = dtypes.canonicalize_dtype(to_dtype)
    to_dtype_inexact = _to_inexact_dtype(to_dtype)
    weak_type = (weak_type and to_dtype == to_dtype_inexact)
    return [
        lax_internal._convert_element_type(x, to_dtype_inexact, weak_type)
        for x in args
    ]
Example #20
0
def _promote_arg_dtypes(*args):
    """Promotes `args` to a common inexact type."""
    def _to_inexact_type(type):
        return type if jnp.issubdtype(type, jnp.inexact) else jnp.float_

    inexact_types = [_to_inexact_type(jnp._dtype(arg)) for arg in args]
    dtype = dtypes.canonicalize_dtype(jnp.result_type(*inexact_types))
    args = [lax.convert_element_type(arg, dtype) for arg in args]
    if len(args) == 1:
        return args[0]
    else:
        return args
Example #21
0
def one_hot(x: Array,
            num_classes: int,
            *,
            dtype: Any = jnp.float64,
            axis: Union[int, AxisName] = -1) -> Array:
    """One-hot encodes the given indicies.

  Each index in the input ``x`` is encoded as a vector of zeros of length
  ``num_classes`` with the element at ``index`` set to one::

    >>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
    DeviceArray([[1., 0., 0.],
                  [0., 1., 0.],
                  [0., 0., 1.]], dtype=float32)

  Indicies outside the range [0, num_classes) will be encoded as zeros::

    >>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
    DeviceArray([[0., 0., 0.],
                 [0., 0., 0.]], dtype=float32)

  Args:
    x: A tensor of indices.
    num_classes: Number of classes in the one-hot dimension.
    dtype: optional, a float dtype for the returned values (default float64 if
      jax_enable_x64 is true, otherwise float32).
    axis: the axis or axes along which the function should be
      computed.
  """
    num_classes = core.concrete_or_error(
        int, num_classes,
        "The error arose in jax.nn.one_hot argument `num_classes`.")
    dtype = dtypes.canonicalize_dtype(dtype)
    x = jnp.asarray(x)
    try:
        output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1)
    except TypeError:
        axis_size = lax.psum(1, axis)
        if num_classes != axis_size:
            raise ValueError(
                f"Expected num_classes to match the size of axis {axis}, "
                f"but {num_classes} != {axis_size}") from None
        axis_idx = lax.axis_index(axis)
        return jnp.asarray(x == axis_idx, dtype=dtype)
    axis = operator.index(axis)
    lhs = lax.expand_dims(x, (axis, ))
    rhs_shape = [1] * x.ndim
    rhs_shape.insert(output_pos_axis, num_classes)
    rhs = lax.broadcast_in_dim(jnp.arange(num_classes, dtype=x.dtype),
                               rhs_shape, (output_pos_axis, ))
    return jnp.asarray(lhs == rhs, dtype=dtype)
Example #22
0
def _get_random_data(dtype: jnp.dtype, shape: Tuple[int, ...], seed=0) -> Any:
  dtype = dtypes.canonicalize_dtype(dtype)
  np.random.seed(seed)
  # Adjust the max values of the numbers based on the seed, so different seeds
  # result in different ranges.
  max_value = max(1, 100*seed)
  if np.issubdtype(dtype, np.integer):
    return np.random.randint(0, max_value, size=shape, dtype=dtype)
  elif np.issubdtype(dtype, np.floating):
    return np.array(np.random.uniform(size=shape), dtype=dtype) * max_value
  elif dtype == np.bool:
    return np.random.choice(a=[False, True], size=shape)
  else:
    raise ValueError(f"Unsupported dtype for numerical comparison: {dtype}")
Example #23
0
File: mlir.py Project: rsepassi/jax
def _numpy_array_constant(x: np.ndarray,
                          canonicalize_types) -> Sequence[ir.Value]:
    if canonicalize_types:
        x = np.asarray(x, dtypes.canonicalize_dtype(x.dtype))
    aval = xla.abstractify(x)
    ir_type = aval_to_ir_type(aval)
    if x.dtype == np.bool_:
        x = np.packbits(x, bitorder='little')
    elif x.dtype == dtypes.bfloat16:
        x = x.view(np.uint16)
    x = np.ascontiguousarray(x)
    attr = ir.DenseElementsAttr.get(x,
                                    type=ir_type.element_type,
                                    shape=aval.shape)
    return (mhlo.ConstOp(ir_type, attr).result, )
Example #24
0
def ldexp(x1, x2):
    _check_arraylike("ldexp", x1, x2)
    x1_dtype = dtypes.dtype(x1)
    x2_dtype = dtypes.dtype(x2)
    if (dtypes.issubdtype(x1_dtype, np.complexfloating)
            or dtypes.issubdtype(x2_dtype, np.inexact)):
        raise ValueError(
            f"ldexp not supported for input types {(x1_dtype, x2_dtype)}")

    x1, x2 = _promote_shapes("ldexp", x1, x2)

    dtype = dtypes.canonicalize_dtype(dtypes._to_inexact_dtype(x1_dtype))
    info = dtypes.finfo(dtype)
    int_type = _INT_DTYPES[info.bits]

    x1 = lax.convert_element_type(x1, dtype)
    x2 = lax.convert_element_type(x2, int_type)

    mask = (1 << info.nexp) - 1
    bias = ((1 << info.nexp) - 1) >> 1
    x, e = _normalize_float(x1)
    x2 += e + ((x >> info.nmant) & mask) - bias

    # find underflow/overflow before denormalization
    underflow_cond = x2 < -(bias + info.nmant)
    overflow_cond = x2 > bias

    m = lax.full_like(x, 1, dtype=dtype)

    # denormals
    cond = x2 < -bias + 1
    x2 = _where(cond, x2 + info.nmant, x2)
    m = _where(cond, m / (1 << info.nmant), m)

    x2 = lax.convert_element_type(x2, np.int32)
    x &= ~(mask << info.nmant)
    x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant)

    x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype)

    # underflow
    x = _where(underflow_cond, lax.full_like(x, 0, dtype=dtype), x)
    # overflow
    x = _where(overflow_cond, lax.sign(x1) * lax.full_like(x, np.inf), x)
    # ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0
    return _where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x)
Example #25
0
 def init(key, shape, dtype=dtype):
     dtype = dtypes.canonicalize_dtype(dtype)
     if len(shape) < 2:
         raise ValueError(
             "orthogonal initializer requires at least a 2D shape")
     n_rows, n_cols = prod(shape) // shape[column_axis], shape[column_axis]
     matrix_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows,
                                                              n_cols)
     A = random.normal(key, matrix_shape, dtype)
     Q, R = jnp.linalg.qr(A)
     diag_sign = lax.broadcast_to_rank(jnp.sign(jnp.diag(R)), rank=Q.ndim)
     Q *= diag_sign  # needed for a uniform distribution
     if n_rows < n_cols: Q = Q.T
     Q = jnp.reshape(
         Q,
         tuple(np.delete(shape, column_axis)) + (shape[column_axis], ))
     Q = jnp.moveaxis(Q, -1, column_axis)
     return scale * Q
Example #26
0
def _numpy_array_constant(x: np.ndarray,
                          canonicalize_types) -> Sequence[ir.Value]:
    if canonicalize_types:
        x = np.asarray(x, dtypes.canonicalize_dtype(x.dtype))
    ir_type = ir.RankedTensorType.get(x.shape, dtype_to_ir_type(x.dtype))
    shape = x.shape
    if x.dtype == np.bool_:
        nelems = x.size
        x = np.packbits(x, bitorder='little')
        # TODO(b/209005197): Work around for MLIR crash for non-splat single element
        # buffers.
        if nelems == 1:
            x = np.array(0 if x.item() == 0 else 0xff, np.uint8)
    elif x.dtype == dtypes.bfloat16:
        x = x.view(np.uint16)
    x = np.ascontiguousarray(x)
    attr = ir.DenseElementsAttr.get(x, type=ir_type.element_type, shape=shape)
    return (mhlo.ConstOp(ir_type, attr).result, )
Example #27
0
def _var_promote_types(a_dtype, dtype):
    if dtype:
        if (not dtypes.issubdtype(dtype, np.complexfloating)
                and dtypes.issubdtype(a_dtype, np.complexfloating)):
            msg = (
                "jax.numpy.var does not yet support real dtype parameters when "
                "computing the variance of an array of complex values. The "
                "semantics of numpy.var seem unclear in this case. Please comment "
                "on https://github.com/google/jax/issues/2283 if this behavior is "
                "important to you.")
            raise ValueError(msg)
        a_dtype = dtypes.promote_types(a_dtype, dtype)
    else:
        if not dtypes.issubdtype(a_dtype, np.inexact):
            dtype = a_dtype = dtypes.canonicalize_dtype(dtypes.float_)
        else:
            dtype = _complex_elem_type(a_dtype)
            a_dtype = dtypes.promote_types(a_dtype, np.float32)
    return a_dtype, dtype
Example #28
0
def _gmres_batched(A, b, x0, unit_residual, residual_norm, ptol, restart, M):
  """
  Implements a single restart of GMRES. The ``restart``-dimensional Krylov
  subspace
  K(A, x0) = span(A(x0), A@x0, A@A@x0, ..., A^restart @ x0) is built, and the
  projection of the true solution into this subspace is returned.

  This implementation solves a dense linear problem instead of building
  a QR factorization during the Arnoldi process.
  """
  del ptol  # unused
  # https://www-users.cs.umn.edu/~saad/Calais/PREC.pdf
  V = tree_map(
      lambda x: jnp.pad(x[..., None], ((0, 0),) * x.ndim + ((0, restart),)),
      unit_residual,
  )
  dtype, weak_type = dtypes._lattice_result_type(*tree_leaves(b))
  dtype = dtypes.canonicalize_dtype(dtype)
  H = lax_internal._convert_element_type(
      jnp.eye(restart, restart + 1, dtype=dtype), weak_type=weak_type)

  def loop_cond(carry):
    _, _, breakdown, k = carry
    return jnp.logical_and(k < restart, jnp.logical_not(breakdown))

  def arnoldi_process(carry):
    V, H, _, k = carry
    V, H, breakdown = _kth_arnoldi_iteration(k, A, M, V, H)
    return V, H, breakdown, k + 1

  carry = (V, H, False, 0)
  V, H, _, _ = lax.while_loop(loop_cond, arnoldi_process, carry)

  beta_vec = jnp.zeros_like(H, shape=(restart + 1,)).at[0].set(residual_norm.astype(dtype))
  y = _lstsq(H.T, beta_vec)
  dx = tree_map(lambda X: _dot(X[..., :-1], y), V)

  x = _add(x0, dx)

  residual = M(_sub(b, A(x)))
  unit_residual, residual_norm = _safe_normalize(residual)
  return x, unit_residual, residual_norm
Example #29
0
def _safe_normalize(x, thresh=None):
  """
  Returns the L2-normalized vector (which can be a pytree) x, and optionally
  the computed norm. If the computed norm is less than the threshold `thresh`,
  which by default is the machine precision of x's dtype, it will be
  taken to be 0, and the normalized x to be the zero vector.
  """
  norm = _norm(x)
  dtype, weak_type = dtypes._lattice_result_type(*tree_leaves(x))
  dtype = dtypes.canonicalize_dtype(dtype)
  if thresh is None:
    thresh = jnp.finfo(norm.dtype).eps
  thresh = thresh.astype(dtype).real

  use_norm = norm > thresh

  norm_cast = lax_internal._convert_element_type(norm, dtype, weak_type)
  normalized_x = tree_map(lambda y: jnp.where(use_norm, y / norm_cast, 0.0), x)
  norm = jnp.where(use_norm, norm, 0.0)
  return normalized_x, norm
Example #30
0
 def testBinaryPromotion(self, swap, jit):
     testcases = [
         (jnp.array(1.), 0., jnp.float64),
         (jnp.array(1.), jnp.array(0.), jnp.float64),
         (jnp.array(1.), jnp.array(0., dtype=jnp.float16), jnp.float16),
         (jnp.array(1.), jnp.array(0., dtype=jnp.float32), jnp.float32),
         (jnp.array(1.), jnp.array(0., dtype=jnp.float64), jnp.float64),
         (jnp.array(1., dtype=jnp.float16), 0., jnp.float16),
         (jnp.array(1., dtype=jnp.float32), 0., jnp.float32),
         (jnp.array(1., dtype=jnp.float64), 0., jnp.float64),
         (jnp.array(1., dtype=jnp.float16),
          jnp.array(0., dtype=jnp.float16), jnp.float16),
         (jnp.array(1., dtype=jnp.float16),
          jnp.array(0., dtype=jnp.float32), jnp.float32),
         (jnp.array(1., dtype=jnp.float16),
          jnp.array(0., dtype=jnp.float64), jnp.float64),
         (jnp.array(1., dtype=jnp.float32),
          jnp.array(0., dtype=jnp.float32), jnp.float32),
         (jnp.array(1., dtype=jnp.float32),
          jnp.array(0., dtype=jnp.float64), jnp.float64),
         (jnp.array(1., dtype=jnp.float64),
          jnp.array(0., dtype=jnp.float64), jnp.float64),
         (jnp.array([1.]), 0., jnp.float_),
         (jnp.array([1.]), jnp.array(0.), jnp.float_),
         (jnp.array([1.]), jnp.array(0., dtype=jnp.float16), jnp.float_),
         (jnp.array([1.]), jnp.array(0., dtype=jnp.float32), jnp.float_),
         (jnp.array([1.]), jnp.array(0., dtype=jnp.float64), jnp.float64),
         (jnp.array([1.], dtype=jnp.float32),
          jnp.array(0., dtype=jnp.float16), jnp.float32),
         (jnp.array([1.], dtype=jnp.float16),
          jnp.array(0., dtype=jnp.float32), jnp.float32),
         (jnp.array([1.], dtype=jnp.float16), 0., jnp.float16),
     ]
     op = jax.jit(operator.add) if jit else operator.add
     for x, y, dtype in testcases:
         x, y = (y, x) if swap else (x, y)
         z = op(x, y)
         self.assertTrue(isinstance(z, jnp.ndarray), msg=(x, y, z))
         self.assertEqual(z.dtype,
                          dtypes.canonicalize_dtype(dtype),
                          msg=(x, y, z))