Пример #1
0
def div(lhs, rhs):
    if dtypes.issubdtype(dtypes.result_type(lhs), np.integer):
        quotient = np.floor_divide(lhs, rhs)
        select = np.logical_and(
            np.sign(lhs) != np.sign(rhs),
            np.remainder(lhs, rhs) != 0)
        return np.where(select, quotient + 1, quotient)
    else:
        return np.divide(lhs, rhs)
Пример #2
0
def _reduce_window_sum_shape_rule(operand, *, window_dimensions,
                                  window_strides, padding, base_dilation,
                                  window_dilation):
    if not dtypes.issubdtype(operand.dtype, np.number):
        msg = "operand to reduce_window_sum must have a number dtype, got {}"
        raise TypeError(msg.format(np.dtype(operand.dtype).name))
    return _common_reduce_window_shape_rule(operand, window_dimensions,
                                            window_strides, padding,
                                            base_dilation, window_dilation)
Пример #3
0
def _check_special(name, xla_shape, buf):
    assert not xla_shape.is_tuple()
    if dtypes.issubdtype(xla_shape.element_type(), np.inexact):
        if config.jax_debug_nans and np.any(np.isnan(buf.to_py())):
            raise FloatingPointError(
                f"invalid value (nan) encountered in {name}")
        if config.jax_debug_infs and np.any(np.isinf(buf.to_py())):
            raise FloatingPointError(
                f"invalid value (inf) encountered in {name}")
Пример #4
0
 def fn(x1, x2):
   x1, x2 =  _promote_args(numpy_fn.__name__, x1, x2)
   # Comparison on complex types are defined as a lexicographic ordering on
   # the (real, imag) pair.
   if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating):
     rx = lax.real(x1)
     ry = lax.real(x2)
     return lax.select(lax.eq(rx, ry), lax_fn(lax.imag(x1), lax.imag(x2)),
                       lax_fn(rx, ry))
   return lax_fn(x1, x2)
Пример #5
0
def _var_promote_types(a_dtype, dtype):
    if dtype:
        if (not dtypes.issubdtype(dtype, np.complexfloating)
                and dtypes.issubdtype(a_dtype, np.complexfloating)):
            msg = (
                "jax.numpy.var does not yet support real dtype parameters when "
                "computing the variance of an array of complex values. The "
                "semantics of numpy.var seem unclear in this case. Please comment "
                "on https://github.com/google/jax/issues/2283 if this behavior is "
                "important to you.")
            raise ValueError(msg)
        a_dtype = dtypes.promote_types(a_dtype, dtype)
    else:
        if not dtypes.issubdtype(a_dtype, np.inexact):
            dtype = a_dtype = dtypes.canonicalize_dtype(dtypes.float_)
        else:
            dtype = _complex_elem_type(a_dtype)
            a_dtype = dtypes.promote_types(a_dtype, np.float32)
    return a_dtype, dtype
Пример #6
0
def _average(a,
             axis: Optional[Union[int, Tuple[int, ...]]] = None,
             weights=None,
             returned=False):
    a = _asarray(a)

    if weights is None:  # Treat all weights as 1
        avg = mean(a, axis=axis)
        if axis is None:
            weights_sum = lax.full((),
                                   core.dimension_as_value(np.size(a)),
                                   dtype=avg.dtype)
        else:
            weights_sum = lax.full_like(avg,
                                        core.dimension_as_value(a.shape[axis]),
                                        dtype=avg.dtype)
    else:
        weights = _asarray(weights)

        if dtypes.issubdtype(a.dtype, np.inexact):
            out_dtype = dtypes.result_type(a.dtype, weights.dtype)
        else:
            out_dtype = dtypes.result_type(a.dtype, weights.dtype,
                                           dtypes.float_)
        out_dtype = dtypes.canonicalize_dtype(out_dtype)

        a_shape = np.shape(a)
        a_ndim = len(a_shape)
        weights_shape = np.shape(weights)
        axis = None if axis is None else _canonicalize_axis(axis, a_ndim)

        if a_shape != weights_shape:
            # Make sure the dimensions work out
            if axis is None:
                raise ValueError("Axis must be specified when shapes of a and "
                                 "weights differ.")
            if len(weights_shape) != 1:
                raise ValueError("1D weights expected when shapes of a and "
                                 "weights differ.")
            if not core.symbolic_equal_dim(weights_shape[0], a_shape[axis]):
                raise ValueError("Length of weights not "
                                 "compatible with specified axis.")

            weights = _broadcast_to(weights,
                                    (a_ndim - 1) * (1, ) + weights_shape)
            weights = _moveaxis(weights, -1, axis)

        weights_sum = sum(weights, axis=axis, dtype=out_dtype)
        avg = sum(a * weights, axis=axis, dtype=out_dtype) / weights_sum

    if returned:
        if avg.shape != weights_sum.shape:
            weights_sum = _broadcast_to(weights_sum, avg.shape)
        return avg, weights_sum
    return avg
Пример #7
0
def _reduction_init_val(a, init_val):
    # This function uses np.* functions because lax pattern matches against the
    # specific concrete values of the reduction inputs.
    a_dtype = dtypes.canonicalize_dtype(dtypes.dtype(a))
    if a_dtype == 'bool':
        return np.array(init_val > 0, dtype=a_dtype)
    try:
        return np.array(init_val, dtype=a_dtype)
    except OverflowError:
        assert dtypes.issubdtype(a_dtype, np.integer)
        sign, info = np.sign(init_val), dtypes.iinfo(a_dtype)
        return np.array(info.min if sign < 0 else info.max, dtype=a_dtype)
Пример #8
0
    def rand(shape, dtype):
        """The random sampler function."""
        if not _dtypes.issubdtype(dtype, np.floating):
            # only float types have inf
            return base_rand(shape, dtype)

        if _dtypes.issubdtype(dtype, np.complexfloating):
            base_dtype = np.real(np.array(0, dtype=dtype)).dtype
            out = (rand(shape, base_dtype) +
                   np.array(1j, dtype) * rand(shape, base_dtype))
            return _cast_to_shape(out, shape, dtype)

        dims = _dims_of_shape(shape)
        posinf_flips = rng.rand(*dims) < 0.1
        neginf_flips = rng.rand(*dims) < 0.1

        vals = base_rand(shape, dtype)
        vals = np.where(posinf_flips, np.array(np.inf, dtype=dtype), vals)
        vals = np.where(neginf_flips, np.array(-np.inf, dtype=dtype), vals)

        return _cast_to_shape(np.asarray(vals, dtype=dtype), shape, dtype)
Пример #9
0
    def testBinaryNonPromotion(self, dtype, weak_type, promotion):
        # Regression test for https://github.com/google/jax/issues/6051
        x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type)
        with jax.numpy_dtype_promotion(promotion):
            y = (x + x)

        if promotion == 'standard' or not weak_type or dtype == dtypes.bool_:
            expected_dtype = dtype
        elif dtypes.issubdtype(dtype, np.complexfloating):
            expected_dtype = dtypes.complex_
        elif dtypes.issubdtype(dtype, np.floating):
            expected_dtype = dtypes.float_
        else:
            expected_dtype = dtypes.int_

        # No boolean weak types.
        expected_weak_type = weak_type and dtype != bool
        expected_dtype = dtypes.canonicalize_dtype(expected_dtype)

        self.assertEqual(y.dtype, expected_dtype)
        self.assertEqual(dtypes.is_weakly_typed(y), expected_weak_type)
Пример #10
0
def frexp(x):
  _check_arraylike("frexp", x)
  if dtypes.issubdtype(x.dtype, np.complexfloating):
    raise TypeError("frexp does not support complex-valued inputs")
  elif not dtypes.issubdtype(dtypes.dtype(x), np.floating):
    x = lax.convert_element_type(x, np.float_)

  dtype = dtypes.dtype(x)
  info = dtypes.finfo(dtype)
  mask = (1 << info.nexp) - 1
  bias = ((1 << info.nexp) - 1) >> 1

  x1, x2 = _normalize_float(x)
  x2 += ((x1 >> info.nmant) & mask) - bias + 1
  x1 &= ~(mask << info.nmant)
  x1 |= (bias - 1) << info.nmant
  x1 = lax.bitcast_convert_type(x1, dtype)

  cond = isinf(x) | isnan(x) | (x == 0)
  x2 = _where(cond, lax_internal._zeros(x2), x2)
  return _where(cond, x, x1), lax.convert_element_type(x2, np.int32)
Пример #11
0
def logaddexp(x1, x2):
  x1, x2 = _promote_args_inexact("logaddexp", x1, x2)
  amax = lax.max(x1, x2)
  if dtypes.issubdtype(x1.dtype, np.floating):
    delta = lax.sub(x1, x2)
    return lax.select(lax_internal._isnan(delta),
                      lax.add(x1, x2),  # NaNs or infinities of the same sign.
                      lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta))))))
  else:
    delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2)))
    out = lax.add(amax, lax.log1p(lax.exp(delta)))
    return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi))
Пример #12
0
def floor_divide(x1, x2):
  x1, x2 = _promote_args("floor_divide", x1, x2)
  dtype = dtypes.dtype(x1)
  if dtypes.issubdtype(dtype, np.integer):
    quotient = lax.div(x1, x2)
    select = logical_and(lax.sign(x1) != lax.sign(x2), lax.rem(x1, x2) != 0)
    # TODO(mattjj): investigate why subtracting a scalar was causing promotion
    return _where(select, quotient - 1, quotient)
  elif dtypes.issubdtype(dtype, np.complexfloating):
    x1r = lax.real(x1)
    x1i = lax.imag(x1)
    x2r = lax.real(x2)
    x2i = lax.imag(x2)
    which = lax.ge(lax.abs(x2r), lax.abs(x2i))
    rat1 = _where(which, lax.full_like(x2i, 1), lax.div(x2r, x2i))
    rat2 = _where(which, lax.div(x2i, x2r), _lax_const(x2i, 1))
    out = lax.floor(lax.div(lax.add(lax.mul(x1r, rat1), lax.mul(x1i, rat2)),
                            lax.add(lax.mul(x2r, rat1), lax.mul(x2i, rat2))))
    return lax.convert_element_type(out, dtype)
  else:
    return _float_divmod(x1, x2)[0]
Пример #13
0
def signbit(x):
    x, = _promote_args("signbit", x)
    dtype = dtypes.dtype(x)
    if dtypes.issubdtype(dtype, np.integer):
        return lax.lt(x, _constant_like(x, 0))
    elif dtypes.issubdtype(dtype, np.bool_):
        return lax.full_like(x, False, dtype=np.bool_)
    elif not dtypes.issubdtype(dtype, np.floating):
        raise ValueError("jax.numpy.signbit is not well defined for %s" %
                         dtype)

    # TPU supports BF16 but not S16 types, so as a workaround, convert BF16 to
    # F32.
    if dtype == dtypes.bfloat16:
        dtype = np.float32
        x = lax.convert_element_type(x, np.float32)

    info = dtypes.finfo(dtype)
    if info.bits not in _INT_DTYPES:
        raise NotImplementedError(
            "jax.numpy.signbit only supports 16, 32, and 64-bit types.")
    int_type = _INT_DTYPES[info.bits]
    x = lax.bitcast_convert_type(x, int_type)
    return lax.convert_element_type(x >> (info.nexp + info.nmant), np.bool_)
Пример #14
0
    def testCsdWithSameParamAgainstNumpy(self, *, shape, dtype, fs, window,
                                         nperseg, noverlap, nfft, detrend,
                                         scaling, timeaxis, average):
        is_complex = dtypes.issubdtype(dtype, np.complexfloating)
        if is_complex and detrend is not None:
            self.skipTest(
                "Complex signal is not supported in lax-backed `signal.detrend`."
            )

        kwds = dict(fs=fs,
                    window=window,
                    nperseg=nperseg,
                    noverlap=noverlap,
                    nfft=nfft,
                    detrend=detrend,
                    return_onesided=not is_complex,
                    scaling=scaling,
                    axis=timeaxis,
                    average=average)

        def osp_fun(x, y):
            # When the identical parameters are given, jsp-version follows
            # the behavior with copied parameters.
            freqs, Pxy = osp_signal.csd(x, y.copy(), **kwds)
            # Make type-casting the same as JAX.
            return freqs.astype(_real_dtype(dtype)), Pxy.astype(
                _complex_dtype(dtype))

        jsp_fun = partial(jsp_signal.csd, **kwds)
        tol = {
            np.float32: 1e-5,
            np.float64: 1e-12,
            np.complex64: 1e-5,
            np.complex128: 1e-12
        }
        if jtu.device_under_test() == 'tpu':
            tol = _TPU_FFT_TOL

        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(shape, dtype)] * 2

        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                rtol=tol,
                                atol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
Пример #15
0
 def testIsSubdtype(self):
   for t in scalar_types:
     self.assertTrue(dtypes.issubdtype(t, t))
     self.assertTrue(dtypes.issubdtype(np.dtype(t).type, t))
     self.assertTrue(dtypes.issubdtype(t, np.dtype(t).type))
     self.assertTrue(dtypes.issubdtype(t, np.dtype(t)))
     if t != jnp.bfloat16:
       for category in [np.generic, jnp.inexact, jnp.integer, jnp.signedinteger,
                        jnp.unsignedinteger, jnp.floating, jnp.complexfloating]:
         self.assertEqual(dtypes.issubdtype(t, category),
                          np.issubdtype(np.dtype(t).type, category))
         self.assertEqual(dtypes.issubdtype(t, category),
                          np.issubdtype(np.dtype(t).type, category))
Пример #16
0
    def testStftAgainstNumpy(self, *, shape, dtype, fs, window, nperseg,
                             noverlap, nfft, detrend, boundary, padded,
                             timeaxis):
        is_complex = dtypes.issubdtype(dtype, np.complexfloating)
        if is_complex and detrend is not None:
            self.skipTest(
                "Complex signal is not supported in lax-backed `signal.detrend`."
            )

        kwds = dict(fs=fs,
                    window=window,
                    nfft=nfft,
                    boundary=boundary,
                    padded=padded,
                    detrend=detrend,
                    nperseg=nperseg,
                    noverlap=noverlap,
                    axis=timeaxis,
                    return_onesided=not is_complex)

        def osp_fun(x):
            freqs, time, Pxx = osp_signal.stft(x, **kwds)
            return freqs.astype(_real_dtype(dtype)), time.astype(
                _real_dtype(dtype)), Pxx.astype(_complex_dtype(dtype))

        jsp_fun = partial(jsp_signal.stft, **kwds)
        tol = {
            np.float32: 1e-5,
            np.float64: 1e-12,
            np.complex64: 1e-5,
            np.complex128: 1e-12
        }
        if jtu.device_under_test() == 'tpu':
            tol = _TPU_FFT_TOL

        rng = jtu.rand_default(self.rng())
        args_maker = lambda: [rng(shape, dtype)]

        self._CheckAgainstNumpy(osp_fun,
                                jsp_fun,
                                args_maker,
                                rtol=tol,
                                atol=tol)
        self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)
Пример #17
0
def _power(x1, x2):
  x1, x2 = _promote_args("power", x1, x2)
  dtype = dtypes.dtype(x1)
  if not dtypes.issubdtype(dtype, np.integer):
    return lax.pow(x1, x2)

  # Integer power => use binary exponentiation.

  # TODO(phawkins): add integer pow support to XLA.
  bits = 6  # Anything more would overflow for any x1 > 1
  zero = _constant_like(x2, 0)
  one = _constant_like(x2, 1)
  # Initialize acc carefully such that pow(0, x2) is zero for x2 != 0
  acc = _where(lax.bitwise_and(lax.eq(x1, zero), lax.ne(x2, zero)), zero, one)
  for _ in range(bits):
    acc = _where(lax.bitwise_and(x2, one), lax.mul(acc, x1), acc)
    x1 = lax.mul(x1, x1)
    x2 = lax.shift_right_logical(x2, one)
  return acc
Пример #18
0
def nanvar(a,
           axis: Optional[Union[int, Tuple[int, ...]]] = None,
           dtype=None,
           out=None,
           ddof=0,
           keepdims=False,
           where=None):
    _check_arraylike("nanvar", a)
    lax_internal._check_user_dtype_supported(dtype, "nanvar")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.nanvar is not supported.")

    computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
    a = a.astype(computation_dtype)
    a_mean = nanmean(a,
                     axis,
                     dtype=computation_dtype,
                     keepdims=True,
                     where=where)

    centered = _where(lax_internal._isnan(a), 0,
                      lax.sub(a, a_mean))  # double-where trick for gradients.
    if dtypes.issubdtype(centered.dtype, np.complexfloating):
        centered = lax.real(lax.mul(centered, lax.conj(centered)))
    else:
        centered = lax.square(centered)

    normalizer = sum(lax_internal.bitwise_not(lax_internal._isnan(a)),
                     axis=axis,
                     keepdims=keepdims,
                     where=where)
    normalizer = normalizer - ddof
    normalizer_mask = lax.le(normalizer, 0)
    result = sum(centered, axis, keepdims=keepdims, where=where)
    result = _where(normalizer_mask, np.nan, result)
    divisor = _where(normalizer_mask, 1, normalizer)
    out = lax.div(result, lax.convert_element_type(divisor, result.dtype))
    return lax.convert_element_type(out, dtype)
Пример #19
0
def _rand_dtype(rand, shape, dtype, scale=1., post=lambda x: x):
    """Produce random values given shape, dtype, scale, and post-processor.

  Args:
    rand: a function for producing random values of a given shape, e.g. a
      bound version of either np.RandomState.randn or np.RandomState.rand.
    shape: a shape value as a tuple of positive integers.
    dtype: a numpy dtype.
    scale: optional, a multiplicative scale for the random values (default 1).
    post: optional, a callable for post-processing the random values (default
      identity).

  Returns:
    An ndarray of the given shape and dtype using random values based on a call
    to rand but scaled, converted to the appropriate dtype, and post-processed.
  """
    r = lambda: np.asarray(scale * rand(*_dims_of_shape(shape)), dtype)
    if _dtypes.issubdtype(dtype, np.complexfloating):
        vals = r() + 1.0j * r()
    else:
        vals = r()
    return _cast_to_shape(np.asarray(post(vals), dtype), shape, dtype)
Пример #20
0
def _var(a,
         axis: Optional[Union[int, Tuple[int, ...]]] = None,
         dtype=None,
         out=None,
         ddof=0,
         keepdims=False,
         *,
         where=None):
    _check_arraylike("var", a)
    lax_internal._check_user_dtype_supported(dtype, "var")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.var is not supported.")

    computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
    a = a.astype(computation_dtype)
    a_mean = mean(a, axis, dtype=computation_dtype, keepdims=True, where=where)
    centered = lax.sub(a, a_mean)
    if dtypes.issubdtype(centered.dtype, np.complexfloating):
        centered = lax.real(lax.mul(centered, lax.conj(centered)))
    else:
        centered = lax.square(centered)

    if where is None:
        if axis is None:
            normalizer = core.dimension_as_value(np.size(a))
        else:
            normalizer = core.dimension_as_value(_axis_size(a, axis))
    else:
        normalizer = sum(_broadcast_to(where, np.shape(a)),
                         axis,
                         dtype=dtype,
                         keepdims=keepdims)
    normalizer = normalizer - ddof

    result = sum(centered, axis, keepdims=keepdims, where=where)
    out = lax.div(result, lax.convert_element_type(normalizer, result.dtype))
    return lax.convert_element_type(out, dtype)
Пример #21
0
def _closure_convert_for_avals(fun, in_tree, in_avals):
    wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
    out_tree = out_tree()

    # We only want to closure convert for constants with respect to which we're
    # differentiating. As a proxy for that, we hoist consts with float dtype.
    # TODO(frostig,mattjj): revise this approach
    from jax.numpy import inexact
    is_float = lambda c: dtypes.issubdtype(dtypes.dtype(c), inexact)
    (closure_consts, hoisted_consts), merge = partition_list(is_float, consts)
    num_consts = len(hoisted_consts)

    def converted_fun(*args_hconsts):
        num_args = len(args_hconsts) - num_consts
        args, hoisted_consts = split_list(args_hconsts, [num_args])
        consts = merge(closure_consts, hoisted_consts)
        all_args, in_tree2 = tree_flatten(tuple(args))
        assert in_tree == in_tree2
        out_flat = core.eval_jaxpr(jaxpr, consts, *all_args)
        return tree_unflatten(out_tree, out_flat)

    return converted_fun, hoisted_consts
Пример #22
0
def _nan_reduction(a,
                   name,
                   jnp_reduction,
                   init_val,
                   nan_if_all_nan,
                   axis=None,
                   keepdims=None,
                   **kwargs):
    _check_arraylike(name, a)
    if not dtypes.issubdtype(dtypes.dtype(a), np.inexact):
        return jnp_reduction(a, axis=axis, keepdims=keepdims, **kwargs)

    out = jnp_reduction(_where(lax_internal._isnan(a),
                               _reduction_init_val(a, init_val), a),
                        axis=axis,
                        keepdims=keepdims,
                        **kwargs)
    if nan_if_all_nan:
        return _where(
            all(lax_internal._isnan(a), axis=axis, keepdims=keepdims),
            _lax_const(a, np.nan), out)
    else:
        return out
Пример #23
0
def _get_min_identity(dt):
    return np.inf if dtypes.issubdtype(dt, np.floating) else np.iinfo(dt).max
Пример #24
0
def _to_inexact_dtype(dtype):
    """Promotes a dtype into an inexact dtype, if it is not already one."""
    return dtype if dtypes.issubdtype(
        dtype, np.inexact) else dtypes.promote_types(dtype, dtypes.float_)
Пример #25
0
 def op(*args):
     zero = lambda x: lax.full_like(x, shape=(), fill_value=0)
     args = (x if dtypes.issubdtype(dtypes.dtype(x), np.bool_) else lax.ne(
         x, zero(x)) for x in args)
     return bitwise_op(*_promote_args(np_op.__name__, *args))
Пример #26
0
def _reduction(a,
               name,
               np_fun,
               op,
               init_val,
               has_identity=True,
               preproc=None,
               bool_op=None,
               upcast_f16_for_computation=False,
               axis=None,
               dtype=None,
               out=None,
               keepdims=False,
               initial=None,
               where_=None,
               parallel_reduce=None):
    bool_op = bool_op or op
    # Note: we must accept out=None as an argument, because numpy reductions delegate to
    # object methods. For example `np.sum(x)` will call `x.sum()` if the `sum()` method
    # exists, passing along all its arguments.
    if out is not None:
        raise NotImplementedError(
            f"The 'out' argument to jnp.{name} is not supported.")
    _check_arraylike(name, a)
    lax_internal._check_user_dtype_supported(dtype, name)
    axis = core.concrete_or_error(None, axis,
                                  f"axis argument to jnp.{name}().")

    if initial is None and not has_identity and where_ is not None:
        raise ValueError(
            f"reduction operation {name} does not have an identity, so to use a "
            f"where mask one has to specify 'initial'")

    a = a if isinstance(a, ndarray) else _asarray(a)
    a = preproc(a) if preproc else a
    pos_dims, dims = _reduction_dims(a, axis)

    if initial is None and not has_identity:
        shape = np.shape(a)
        if not _all(core.greater_equal_dim(shape[d], 1) for d in pos_dims):
            raise ValueError(
                f"zero-size array to reduction operation {name} which has no identity"
            )

    result_dtype = dtypes.canonicalize_dtype(
        dtype or dtypes.dtype(np_fun(np.ones((), dtype=dtypes.dtype(a)))))
    if upcast_f16_for_computation and dtypes.issubdtype(
            result_dtype, np.inexact):
        computation_dtype = _upcast_f16(result_dtype)
    else:
        computation_dtype = result_dtype
    a = lax.convert_element_type(a, computation_dtype)
    op = op if computation_dtype != np.bool_ else bool_op
    # NB: in XLA, init_val must be an identity for the op, so the user-specified
    # initial value must be applied afterward.
    init_val = _reduction_init_val(a, init_val)
    if where_ is not None:
        a = _where(where_, a, init_val)
    if pos_dims is not dims:
        if parallel_reduce is None:
            raise NotImplementedError(
                f"Named reductions not implemented for jnp.{name}()")
        result = parallel_reduce(a, dims)
    else:
        result = lax.reduce(a, init_val, op, dims)
    if initial is not None:
        result = op(lax.convert_element_type(initial, a.dtype), result)
    if keepdims:
        result = lax.expand_dims(result, pos_dims)
    return lax.convert_element_type(result, dtype or result_dtype)
Пример #27
0
def absolute(x):
    _check_arraylike('absolute', x)
    dt = dtypes.dtype(x)
    return x if dt == np.bool_ or dtypes.issubdtype(
        dt, np.unsignedinteger) else lax.abs(x)
Пример #28
0
def copysign(x1, x2):
    x1, x2 = _promote_args_inexact("copysign", x1, x2)
    if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating):
        raise TypeError("copysign does not support complex-valued inputs")
    return _where(signbit(x2).astype(bool), -lax.abs(x1), lax.abs(x1))
Пример #29
0
def divmod(x1, x2):
    x1, x2 = _promote_args("divmod", x1, x2)
    if dtypes.issubdtype(dtypes.dtype(x1), np.integer):
        return floor_divide(x1, x2), remainder(x1, x2)
    else:
        return _float_divmod(x1, x2)
Пример #30
0
def fmod(x1, x2):
    _check_arraylike("fmod", x1, x2)
    if dtypes.issubdtype(dtypes.result_type(x1, x2), np.integer):
        x2 = _where(x2 == 0, lax_internal._ones(x2), x2)
    return lax.rem(*_promote_args("fmod", x1, x2))