コード例 #1
0
ファイル: test_util.py プロジェクト: Jakob-Unfried/jax
def tolerance(dtype, tol=None):
    tol = {} if tol is None else tol
    if not isinstance(tol, dict):
        return tol
    tol = {np.dtype(key): value for key, value in tol.items()}
    dtype = _dtypes.canonicalize_dtype(np.dtype(dtype))
    return tol.get(dtype, default_tolerance()[dtype])
コード例 #2
0
ファイル: dtypes_test.py プロジェクト: cloudhan/jax
 def testDefaultTypes(self, type_):
     expected_dtype = dtypes.canonicalize_dtype(
         dtypes.python_scalar_dtypes[type_])
     for f in [jnp.array, jax.jit(jnp.array), jax.jit(lambda x: x)]:
         y = f(type_(0))
         self.assertTrue(isinstance(y, jnp.ndarray), msg=(f, y))
         self.assertEqual(y.dtype, expected_dtype, msg=(f, y))
コード例 #3
0
def normalize_to_xla_dtypes(val):
  """Normalize dtypes in a value."""
  if hasattr(val, '__array__') or np.isscalar(val):
    return np.asarray(val, dtype=dtypes.canonicalize_dtype(dtypes.result_type(val)))
  elif isinstance(val, (tuple, list)):
    return tuple(normalize_to_xla_dtypes(x) for x in val)
  raise TypeError('Can\'t convert to XLA: {}'.format(val))
コード例 #4
0
    def _cumulative_reduction(a,
                              axis: Optional[Union[int, Tuple[int,
                                                              ...]]] = None,
                              dtype=None,
                              out=None):
        _check_arraylike(np_reduction.__name__, a)
        if out is not None:
            raise NotImplementedError(
                f"The 'out' argument to jnp.{np_reduction.__name__} "
                f"is not supported.")
        lax_internal._check_user_dtype_supported(dtype, np_reduction.__name__)

        if axis is None or _isscalar(a):
            a = lax.reshape(a, (np.size(a), ))
            axis = 0

        a_shape = list(np.shape(a))
        num_dims = len(a_shape)
        axis = _canonicalize_axis(axis, num_dims)

        if fill_nan:
            a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a)

        if not dtype and dtypes.dtype(a) == np.bool_:
            dtype = dtypes.canonicalize_dtype(dtypes.int_)
        if dtype:
            a = lax.convert_element_type(a, dtype)

        return reduction(a, axis)
コード例 #5
0
def _mean(a,
          axis: Optional[Union[int, Tuple[int, ...]]] = None,
          dtype=None,
          out=None,
          keepdims=False,
          *,
          where=None):
    _check_arraylike("mean", a)
    lax_internal._check_user_dtype_supported(dtype, "mean")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.mean is not supported.")

    if where is None:
        if axis is None:
            normalizer = core.dimension_as_value(np.size(a))
        else:
            normalizer = core.dimension_as_value(_axis_size(a, axis))
    else:
        normalizer = sum(_broadcast_to(where, np.shape(a)),
                         axis,
                         dtype=dtype,
                         keepdims=keepdims)

    if dtype is None:
        dtype = dtypes._to_inexact_dtype(dtypes.dtype(a))
    dtype = dtypes.canonicalize_dtype(dtype)

    return lax.div(sum(a, axis, dtype=dtype, keepdims=keepdims, where=where),
                   lax.convert_element_type(normalizer, dtype))
コード例 #6
0
def count_nonzero(a,
                  axis: Optional[Union[int, Tuple[int, ...]]] = None,
                  keepdims=False):
    _check_arraylike("count_nonzero", a)
    return sum(lax.ne(a, _lax_const(a, 0)),
               axis=axis,
               dtype=dtypes.canonicalize_dtype(np.int_),
               keepdims=keepdims)
コード例 #7
0
ファイル: dtypes_test.py プロジェクト: xueeinstein/jax
 def test_canonicalize_type(self):
     expected = {
         True: _EXPECTED_CANONICALIZE_X64,
         False: _EXPECTED_CANONICALIZE_X32,
     }
     for in_dtype, expected_dtype in expected[config.x64_enabled].items():
         self.assertEqual(dtypes.canonicalize_dtype(in_dtype),
                          expected_dtype)
コード例 #8
0
ファイル: util.py プロジェクト: xueeinstein/jax
def _promote_dtypes(*args):
  """Convenience function to apply Numpy argument dtype promotion."""
  # TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing.
  if len(args) < 2:
    return args
  else:
    to_dtype, weak_type = dtypes._lattice_result_type(*args)
    to_dtype = dtypes.canonicalize_dtype(to_dtype)
    return [lax_internal._convert_element_type(x, to_dtype, weak_type) for x in args]
コード例 #9
0
ファイル: custom_object_test.py プロジェクト: John1Tang/jax
 def __init__(self,
              shape,
              dtype,
              index_dtype,
              nnz,
              weak_type=False,
              named_shape=None):
     super().__init__(shape, dtypes.canonicalize_dtype(dtype))
     named_shape = {} if named_shape is None else named_shape
     self.index_dtype = index_dtype
     self.nnz = nnz
     self.data_aval = core.ShapedArray((nnz, ),
                                       dtypes.canonicalize_dtype(dtype),
                                       weak_type, named_shape)
     self.indices_aval = core.ShapedArray(
         (nnz, len(shape)),
         dtypes.canonicalize_dtype(index_dtype),
         named_shape=named_shape)
コード例 #10
0
ファイル: util.py プロジェクト: xueeinstein/jax
def _promote_dtypes_complex(*args):
  """Convenience function to apply Numpy argument dtype promotion.

  Promotes arguments to a complex type."""
  to_dtype, weak_type = dtypes._lattice_result_type(*args)
  to_dtype = dtypes.canonicalize_dtype(to_dtype)
  to_dtype_complex = dtypes._to_complex_dtype(to_dtype)
  return [lax_internal._convert_element_type(x, to_dtype_complex, weak_type)
          for x in args]
コード例 #11
0
ファイル: test_util.py プロジェクト: MichaelMarien/jax
def to_default_dtype(arr):
    """Convert a value to an array with JAX's default dtype.

  This is generally used for type conversions of values returned by numpy functions,
  to make their dtypes take into account the state of the ``jax_enable_x64`` and
  ``jax_default_dtype_bits`` flags.
  """
    arr = np.asarray(arr)
    dtype = _dtypes._default_types.get(arr.dtype.kind)
    return arr.astype(_dtypes.canonicalize_dtype(dtype)) if dtype else arr
コード例 #12
0
ファイル: dtypes_test.py プロジェクト: xueeinstein/jax
 def testUnaryPromotion(self, dtype, weak_type):
     # Regression test for https://github.com/google/jax/issues/6051
     x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type)
     if weak_type:
         expected = dtypes.canonicalize_dtype(
             dtypes._default_types['f' if x.dtype ==
                                   'bfloat16' else x.dtype.kind])
     else:
         expected = x.dtype
     self.assertEqual(dtypes.result_type(x), expected)
コード例 #13
0
ファイル: reductions.py プロジェクト: cloudhan/jax
def _average(a,
             axis: Optional[Union[int, Tuple[int, ...]]] = None,
             weights=None,
             returned=False):
    a = _asarray(a)

    if weights is None:  # Treat all weights as 1
        avg = mean(a, axis=axis)
        if axis is None:
            weights_sum = lax.full((),
                                   core.dimension_as_value(np.size(a)),
                                   dtype=avg.dtype)
        else:
            weights_sum = lax.full_like(avg,
                                        core.dimension_as_value(a.shape[axis]),
                                        dtype=avg.dtype)
    else:
        weights = _asarray(weights)

        if dtypes.issubdtype(a.dtype, np.inexact):
            out_dtype = dtypes.result_type(a.dtype, weights.dtype)
        else:
            out_dtype = dtypes.result_type(a.dtype, weights.dtype,
                                           dtypes.float_)
        out_dtype = dtypes.canonicalize_dtype(out_dtype)

        a_shape = np.shape(a)
        a_ndim = len(a_shape)
        weights_shape = np.shape(weights)
        axis = None if axis is None else _canonicalize_axis(axis, a_ndim)

        if a_shape != weights_shape:
            # Make sure the dimensions work out
            if axis is None:
                raise ValueError("Axis must be specified when shapes of a and "
                                 "weights differ.")
            if len(weights_shape) != 1:
                raise ValueError("1D weights expected when shapes of a and "
                                 "weights differ.")
            if not core.symbolic_equal_dim(weights_shape[0], a_shape[axis]):
                raise ValueError("Length of weights not "
                                 "compatible with specified axis.")

            weights = _broadcast_to(weights,
                                    (a_ndim - 1) * (1, ) + weights_shape)
            weights = _moveaxis(weights, -1, axis)

        weights_sum = sum(weights, axis=axis, dtype=out_dtype)
        avg = sum(a * weights, axis=axis, dtype=out_dtype) / weights_sum

    if returned:
        if avg.shape != weights_sum.shape:
            weights_sum = _broadcast_to(weights_sum, avg.shape)
        return avg, weights_sum
    return avg
コード例 #14
0
def _get_random_data(x: jnp.ndarray) -> Any:
    dtype = dtypes.canonicalize_dtype(x.dtype)
    if np.issubdtype(dtype, np.integer):
        return np.random.randint(0, 100, size=x.shape, dtype=dtype)
    elif np.issubdtype(dtype, np.floating):
        return np.array(np.random.uniform(size=x.shape), dtype=dtype)
    elif dtype == np.bool:
        return np.random.choice(a=[False, True], size=x.shape)
    else:
        raise ValueError(
            f"Unsupported dtype for numerical comparison: {dtype}")
コード例 #15
0
ファイル: initializers.py プロジェクト: xueeinstein/jax
def zeros(key, shape, dtype: DType = jnp.float_):
    """An initializer that returns a constant array full of zeros.

  The ``key`` argument is ignored.

  >>> import jax, jax.numpy as jnp
  >>> jax.nn.initializers.zeros(jax.random.PRNGKey(42), (2, 3), jnp.float32)
  DeviceArray([[0., 0., 0.],
               [0., 0., 0.]], dtype=float32)
  """
    return jnp.zeros(shape, dtypes.canonicalize_dtype(dtype))
コード例 #16
0
ファイル: linalg.py プロジェクト: ahoenselaar/jax
def _promote_arg_dtypes(*args):
    """Promotes `args` to a common inexact type."""
    dtype, weak_type = dtypes._lattice_result_type(*args)
    if not jnp.issubdtype(dtype, jnp.inexact):
        dtype, weak_type = jnp.float_, False
    dtype = dtypes.canonicalize_dtype(dtype)
    args = [lax._convert_element_type(arg, dtype, weak_type) for arg in args]
    if len(args) == 1:
        return args[0]
    else:
        return args
コード例 #17
0
def _reduction_init_val(a, init_val):
    # This function uses np.* functions because lax pattern matches against the
    # specific concrete values of the reduction inputs.
    a_dtype = dtypes.canonicalize_dtype(dtypes.dtype(a))
    if a_dtype == 'bool':
        return np.array(init_val > 0, dtype=a_dtype)
    try:
        return np.array(init_val, dtype=a_dtype)
    except OverflowError:
        assert dtypes.issubdtype(a_dtype, np.integer)
        sign, info = np.sign(init_val), dtypes.iinfo(a_dtype)
        return np.array(info.min if sign < 0 else info.max, dtype=a_dtype)
コード例 #18
0
ファイル: dtypes_test.py プロジェクト: xueeinstein/jax
 def testResultTypeWeakFlag(self):
     float_ = dtypes.canonicalize_dtype(dtypes.float_)
     x_weak = jnp.array(1.)
     x_strong = x_weak.astype(float_)
     self.assertEqual(dtypes.result_type(x_weak), float_)
     self.assertEqual(
         dtypes.result_type(x_weak, return_weak_type_flag=True),
         (float_, True))
     self.assertEqual(dtypes.result_type(x_strong), float_)
     self.assertEqual(
         dtypes.result_type(x_strong, return_weak_type_flag=True),
         (float_, False))
コード例 #19
0
def _promote_dtypes_inexact(*args):
    """Convenience function to apply Numpy argument dtype promotion.

  Promotes arguments to an inexact type."""
    to_dtype, weak_type = dtypes._lattice_result_type(*args)
    to_dtype = dtypes.canonicalize_dtype(to_dtype)
    to_dtype_inexact = _to_inexact_dtype(to_dtype)
    weak_type = (weak_type and to_dtype == to_dtype_inexact)
    return [
        lax_internal._convert_element_type(x, to_dtype_inexact, weak_type)
        for x in args
    ]
コード例 #20
0
ファイル: linalg.py プロジェクト: nbswords/jax
def _promote_arg_dtypes(*args):
    """Promotes `args` to a common inexact type."""
    def _to_inexact_type(type):
        return type if jnp.issubdtype(type, jnp.inexact) else jnp.float_

    inexact_types = [_to_inexact_type(jnp._dtype(arg)) for arg in args]
    dtype = dtypes.canonicalize_dtype(jnp.result_type(*inexact_types))
    args = [lax.convert_element_type(arg, dtype) for arg in args]
    if len(args) == 1:
        return args[0]
    else:
        return args
コード例 #21
0
def one_hot(x: Array,
            num_classes: int,
            *,
            dtype: Any = jnp.float64,
            axis: Union[int, AxisName] = -1) -> Array:
    """One-hot encodes the given indicies.

  Each index in the input ``x`` is encoded as a vector of zeros of length
  ``num_classes`` with the element at ``index`` set to one::

    >>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
    DeviceArray([[1., 0., 0.],
                  [0., 1., 0.],
                  [0., 0., 1.]], dtype=float32)

  Indicies outside the range [0, num_classes) will be encoded as zeros::

    >>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
    DeviceArray([[0., 0., 0.],
                 [0., 0., 0.]], dtype=float32)

  Args:
    x: A tensor of indices.
    num_classes: Number of classes in the one-hot dimension.
    dtype: optional, a float dtype for the returned values (default float64 if
      jax_enable_x64 is true, otherwise float32).
    axis: the axis or axes along which the function should be
      computed.
  """
    num_classes = core.concrete_or_error(
        int, num_classes,
        "The error arose in jax.nn.one_hot argument `num_classes`.")
    dtype = dtypes.canonicalize_dtype(dtype)
    x = jnp.asarray(x)
    try:
        output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1)
    except TypeError:
        axis_size = lax.psum(1, axis)
        if num_classes != axis_size:
            raise ValueError(
                f"Expected num_classes to match the size of axis {axis}, "
                f"but {num_classes} != {axis_size}") from None
        axis_idx = lax.axis_index(axis)
        return jnp.asarray(x == axis_idx, dtype=dtype)
    axis = operator.index(axis)
    lhs = lax.expand_dims(x, (axis, ))
    rhs_shape = [1] * x.ndim
    rhs_shape.insert(output_pos_axis, num_classes)
    rhs = lax.broadcast_in_dim(jnp.arange(num_classes, dtype=x.dtype),
                               rhs_shape, (output_pos_axis, ))
    return jnp.asarray(lhs == rhs, dtype=dtype)
コード例 #22
0
def _get_random_data(dtype: jnp.dtype, shape: Tuple[int, ...], seed=0) -> Any:
  dtype = dtypes.canonicalize_dtype(dtype)
  np.random.seed(seed)
  # Adjust the max values of the numbers based on the seed, so different seeds
  # result in different ranges.
  max_value = max(1, 100*seed)
  if np.issubdtype(dtype, np.integer):
    return np.random.randint(0, max_value, size=shape, dtype=dtype)
  elif np.issubdtype(dtype, np.floating):
    return np.array(np.random.uniform(size=shape), dtype=dtype) * max_value
  elif dtype == np.bool:
    return np.random.choice(a=[False, True], size=shape)
  else:
    raise ValueError(f"Unsupported dtype for numerical comparison: {dtype}")
コード例 #23
0
ファイル: mlir.py プロジェクト: rsepassi/jax
def _numpy_array_constant(x: np.ndarray,
                          canonicalize_types) -> Sequence[ir.Value]:
    if canonicalize_types:
        x = np.asarray(x, dtypes.canonicalize_dtype(x.dtype))
    aval = xla.abstractify(x)
    ir_type = aval_to_ir_type(aval)
    if x.dtype == np.bool_:
        x = np.packbits(x, bitorder='little')
    elif x.dtype == dtypes.bfloat16:
        x = x.view(np.uint16)
    x = np.ascontiguousarray(x)
    attr = ir.DenseElementsAttr.get(x,
                                    type=ir_type.element_type,
                                    shape=aval.shape)
    return (mhlo.ConstOp(ir_type, attr).result, )
コード例 #24
0
def ldexp(x1, x2):
    _check_arraylike("ldexp", x1, x2)
    x1_dtype = dtypes.dtype(x1)
    x2_dtype = dtypes.dtype(x2)
    if (dtypes.issubdtype(x1_dtype, np.complexfloating)
            or dtypes.issubdtype(x2_dtype, np.inexact)):
        raise ValueError(
            f"ldexp not supported for input types {(x1_dtype, x2_dtype)}")

    x1, x2 = _promote_shapes("ldexp", x1, x2)

    dtype = dtypes.canonicalize_dtype(dtypes._to_inexact_dtype(x1_dtype))
    info = dtypes.finfo(dtype)
    int_type = _INT_DTYPES[info.bits]

    x1 = lax.convert_element_type(x1, dtype)
    x2 = lax.convert_element_type(x2, int_type)

    mask = (1 << info.nexp) - 1
    bias = ((1 << info.nexp) - 1) >> 1
    x, e = _normalize_float(x1)
    x2 += e + ((x >> info.nmant) & mask) - bias

    # find underflow/overflow before denormalization
    underflow_cond = x2 < -(bias + info.nmant)
    overflow_cond = x2 > bias

    m = lax.full_like(x, 1, dtype=dtype)

    # denormals
    cond = x2 < -bias + 1
    x2 = _where(cond, x2 + info.nmant, x2)
    m = _where(cond, m / (1 << info.nmant), m)

    x2 = lax.convert_element_type(x2, np.int32)
    x &= ~(mask << info.nmant)
    x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant)

    x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype)

    # underflow
    x = _where(underflow_cond, lax.full_like(x, 0, dtype=dtype), x)
    # overflow
    x = _where(overflow_cond, lax.sign(x1) * lax.full_like(x, np.inf), x)
    # ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0
    return _where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x)
コード例 #25
0
ファイル: initializers.py プロジェクト: xueeinstein/jax
 def init(key, shape, dtype=dtype):
     dtype = dtypes.canonicalize_dtype(dtype)
     if len(shape) < 2:
         raise ValueError(
             "orthogonal initializer requires at least a 2D shape")
     n_rows, n_cols = prod(shape) // shape[column_axis], shape[column_axis]
     matrix_shape = (n_cols, n_rows) if n_rows < n_cols else (n_rows,
                                                              n_cols)
     A = random.normal(key, matrix_shape, dtype)
     Q, R = jnp.linalg.qr(A)
     diag_sign = lax.broadcast_to_rank(jnp.sign(jnp.diag(R)), rank=Q.ndim)
     Q *= diag_sign  # needed for a uniform distribution
     if n_rows < n_cols: Q = Q.T
     Q = jnp.reshape(
         Q,
         tuple(np.delete(shape, column_axis)) + (shape[column_axis], ))
     Q = jnp.moveaxis(Q, -1, column_axis)
     return scale * Q
コード例 #26
0
def _numpy_array_constant(x: np.ndarray,
                          canonicalize_types) -> Sequence[ir.Value]:
    if canonicalize_types:
        x = np.asarray(x, dtypes.canonicalize_dtype(x.dtype))
    ir_type = ir.RankedTensorType.get(x.shape, dtype_to_ir_type(x.dtype))
    shape = x.shape
    if x.dtype == np.bool_:
        nelems = x.size
        x = np.packbits(x, bitorder='little')
        # TODO(b/209005197): Work around for MLIR crash for non-splat single element
        # buffers.
        if nelems == 1:
            x = np.array(0 if x.item() == 0 else 0xff, np.uint8)
    elif x.dtype == dtypes.bfloat16:
        x = x.view(np.uint16)
    x = np.ascontiguousarray(x)
    attr = ir.DenseElementsAttr.get(x, type=ir_type.element_type, shape=shape)
    return (mhlo.ConstOp(ir_type, attr).result, )
コード例 #27
0
ファイル: reductions.py プロジェクト: cloudhan/jax
def _var_promote_types(a_dtype, dtype):
    if dtype:
        if (not dtypes.issubdtype(dtype, np.complexfloating)
                and dtypes.issubdtype(a_dtype, np.complexfloating)):
            msg = (
                "jax.numpy.var does not yet support real dtype parameters when "
                "computing the variance of an array of complex values. The "
                "semantics of numpy.var seem unclear in this case. Please comment "
                "on https://github.com/google/jax/issues/2283 if this behavior is "
                "important to you.")
            raise ValueError(msg)
        a_dtype = dtypes.promote_types(a_dtype, dtype)
    else:
        if not dtypes.issubdtype(a_dtype, np.inexact):
            dtype = a_dtype = dtypes.canonicalize_dtype(dtypes.float_)
        else:
            dtype = _complex_elem_type(a_dtype)
            a_dtype = dtypes.promote_types(a_dtype, np.float32)
    return a_dtype, dtype
コード例 #28
0
ファイル: linalg.py プロジェクト: xueeinstein/jax
def _gmres_batched(A, b, x0, unit_residual, residual_norm, ptol, restart, M):
  """
  Implements a single restart of GMRES. The ``restart``-dimensional Krylov
  subspace
  K(A, x0) = span(A(x0), A@x0, A@A@x0, ..., A^restart @ x0) is built, and the
  projection of the true solution into this subspace is returned.

  This implementation solves a dense linear problem instead of building
  a QR factorization during the Arnoldi process.
  """
  del ptol  # unused
  # https://www-users.cs.umn.edu/~saad/Calais/PREC.pdf
  V = tree_map(
      lambda x: jnp.pad(x[..., None], ((0, 0),) * x.ndim + ((0, restart),)),
      unit_residual,
  )
  dtype, weak_type = dtypes._lattice_result_type(*tree_leaves(b))
  dtype = dtypes.canonicalize_dtype(dtype)
  H = lax_internal._convert_element_type(
      jnp.eye(restart, restart + 1, dtype=dtype), weak_type=weak_type)

  def loop_cond(carry):
    _, _, breakdown, k = carry
    return jnp.logical_and(k < restart, jnp.logical_not(breakdown))

  def arnoldi_process(carry):
    V, H, _, k = carry
    V, H, breakdown = _kth_arnoldi_iteration(k, A, M, V, H)
    return V, H, breakdown, k + 1

  carry = (V, H, False, 0)
  V, H, _, _ = lax.while_loop(loop_cond, arnoldi_process, carry)

  beta_vec = jnp.zeros_like(H, shape=(restart + 1,)).at[0].set(residual_norm.astype(dtype))
  y = _lstsq(H.T, beta_vec)
  dx = tree_map(lambda X: _dot(X[..., :-1], y), V)

  x = _add(x0, dx)

  residual = M(_sub(b, A(x)))
  unit_residual, residual_norm = _safe_normalize(residual)
  return x, unit_residual, residual_norm
コード例 #29
0
ファイル: linalg.py プロジェクト: xueeinstein/jax
def _safe_normalize(x, thresh=None):
  """
  Returns the L2-normalized vector (which can be a pytree) x, and optionally
  the computed norm. If the computed norm is less than the threshold `thresh`,
  which by default is the machine precision of x's dtype, it will be
  taken to be 0, and the normalized x to be the zero vector.
  """
  norm = _norm(x)
  dtype, weak_type = dtypes._lattice_result_type(*tree_leaves(x))
  dtype = dtypes.canonicalize_dtype(dtype)
  if thresh is None:
    thresh = jnp.finfo(norm.dtype).eps
  thresh = thresh.astype(dtype).real

  use_norm = norm > thresh

  norm_cast = lax_internal._convert_element_type(norm, dtype, weak_type)
  normalized_x = tree_map(lambda y: jnp.where(use_norm, y / norm_cast, 0.0), x)
  norm = jnp.where(use_norm, norm, 0.0)
  return normalized_x, norm
コード例 #30
0
ファイル: dtypes_test.py プロジェクト: xueeinstein/jax
 def testBinaryPromotion(self, swap, jit):
     testcases = [
         (jnp.array(1.), 0., jnp.float64),
         (jnp.array(1.), jnp.array(0.), jnp.float64),
         (jnp.array(1.), jnp.array(0., dtype=jnp.float16), jnp.float16),
         (jnp.array(1.), jnp.array(0., dtype=jnp.float32), jnp.float32),
         (jnp.array(1.), jnp.array(0., dtype=jnp.float64), jnp.float64),
         (jnp.array(1., dtype=jnp.float16), 0., jnp.float16),
         (jnp.array(1., dtype=jnp.float32), 0., jnp.float32),
         (jnp.array(1., dtype=jnp.float64), 0., jnp.float64),
         (jnp.array(1., dtype=jnp.float16),
          jnp.array(0., dtype=jnp.float16), jnp.float16),
         (jnp.array(1., dtype=jnp.float16),
          jnp.array(0., dtype=jnp.float32), jnp.float32),
         (jnp.array(1., dtype=jnp.float16),
          jnp.array(0., dtype=jnp.float64), jnp.float64),
         (jnp.array(1., dtype=jnp.float32),
          jnp.array(0., dtype=jnp.float32), jnp.float32),
         (jnp.array(1., dtype=jnp.float32),
          jnp.array(0., dtype=jnp.float64), jnp.float64),
         (jnp.array(1., dtype=jnp.float64),
          jnp.array(0., dtype=jnp.float64), jnp.float64),
         (jnp.array([1.]), 0., jnp.float_),
         (jnp.array([1.]), jnp.array(0.), jnp.float_),
         (jnp.array([1.]), jnp.array(0., dtype=jnp.float16), jnp.float_),
         (jnp.array([1.]), jnp.array(0., dtype=jnp.float32), jnp.float_),
         (jnp.array([1.]), jnp.array(0., dtype=jnp.float64), jnp.float64),
         (jnp.array([1.], dtype=jnp.float32),
          jnp.array(0., dtype=jnp.float16), jnp.float32),
         (jnp.array([1.], dtype=jnp.float16),
          jnp.array(0., dtype=jnp.float32), jnp.float32),
         (jnp.array([1.], dtype=jnp.float16), 0., jnp.float16),
     ]
     op = jax.jit(operator.add) if jit else operator.add
     for x, y, dtype in testcases:
         x, y = (y, x) if swap else (x, y)
         z = op(x, y)
         self.assertTrue(isinstance(z, jnp.ndarray), msg=(x, y, z))
         self.assertEqual(z.dtype,
                          dtypes.canonicalize_dtype(dtype),
                          msg=(x, y, z))