Exemplo n.º 1
0
def nanvar(a,
           axis: Optional[Union[int, Tuple[int, ...]]] = None,
           dtype=None,
           out=None,
           ddof=0,
           keepdims=False,
           where=None):
    _check_arraylike("nanvar", a)
    lax_internal._check_user_dtype_supported(dtype, "nanvar")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.nanvar is not supported.")

    a_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
    a_mean = nanmean(a, axis, dtype=a_dtype, keepdims=True, where=where)

    centered = _where(lax_internal._isnan(a), 0,
                      a - a_mean)  # double-where trick for gradients.
    if dtypes.issubdtype(centered.dtype, np.complexfloating):
        centered = lax.real(lax.mul(centered, lax.conj(centered)))
    else:
        centered = lax.square(centered)

    normalizer = sum(lax_internal.bitwise_not(lax_internal._isnan(a)),
                     axis=axis,
                     keepdims=keepdims,
                     where=where)
    normalizer = normalizer - ddof
    normalizer_mask = lax.le(normalizer, 0)
    result = sum(centered, axis, keepdims=keepdims, where=where)
    result = _where(normalizer_mask, np.nan, result)
    divisor = _where(normalizer_mask, 1, normalizer)
    out = lax.div(result, lax.convert_element_type(divisor, result.dtype))
    return lax.convert_element_type(out, dtype)
Exemplo n.º 2
0
def _normalize_float(x):
  info = dtypes.finfo(dtypes.dtype(x))
  cond = lax.abs(x) < info.tiny
  x1 = _where(cond, x * _lax_const(x, 1 << info.nmant), x)
  x2 = _where(cond, lax.full_like(x, -info.nmant, dtype=np.int32), lax.full_like(x, 0, dtype=np.int32))
  int_type = _INT_DTYPES[info.bits]
  return lax.bitcast_convert_type(x1, int_type), x2
Exemplo n.º 3
0
def _normalize_float(x):
    info = dtypes.finfo(dtypes.dtype(x))
    int_type = _INT_DTYPES[info.bits]
    cond = lax.abs(x) < info.tiny
    x1 = _where(cond, x * _lax_const(x, 1 << info.nmant), x)
    x2 = _where(cond, int_type(-info.nmant), int_type(0))
    return lax.bitcast_convert_type(x1, int_type), x2
Exemplo n.º 4
0
def sinc(x):
    _check_arraylike("sinc", x)
    x, = _promote_dtypes_inexact(x)
    eq_zero = lax.eq(x, _lax_const(x, 0))
    pi_x = lax.mul(_lax_const(x, np.pi), x)
    safe_pi_x = _where(eq_zero, _lax_const(x, 1), pi_x)
    return _where(eq_zero, _sinc_maclaurin(0, pi_x),
                  lax.div(lax.sin(safe_pi_x), safe_pi_x))
Exemplo n.º 5
0
def _roots_with_zeros(p, num_leading_zeros):
    # Avoid lapack errors when p is all zero
    p = _where(len(p) == num_leading_zeros, 1.0, p)
    # Roll any leading zeros to the end & compute the roots
    roots = _roots_no_zeros(roll(p, -num_leading_zeros))
    # Sort zero roots to the end.
    roots = lax.sort_key_val(roots == 0, roots)[1]
    # Set roots associated with num_leading_zeros to NaN
    return _where(
        arange(roots.size) < roots.size - num_leading_zeros, roots,
        complex(np.nan, np.nan))
Exemplo n.º 6
0
def arccosh(x):
    # Note: arccosh is multi-valued for complex input, and lax.acosh uses a different
    # convention than np.arccosh.
    out = lax.acosh(*_promote_args_inexact("arccosh", x))
    if dtypes.issubdtype(out.dtype, np.complexfloating):
        out = _where(real(out) < 0, lax.neg(out), out)
    return out
Exemplo n.º 7
0
    def _cumulative_reduction(a,
                              axis: Optional[Union[int, Tuple[int,
                                                              ...]]] = None,
                              dtype=None,
                              out=None):
        _check_arraylike(np_reduction.__name__, a)
        if out is not None:
            raise NotImplementedError(
                f"The 'out' argument to jnp.{np_reduction.__name__} "
                f"is not supported.")
        lax_internal._check_user_dtype_supported(dtype, np_reduction.__name__)

        if axis is None or _isscalar(a):
            a = lax.reshape(a, (np.size(a), ))
            axis = 0

        a_shape = list(np.shape(a))
        num_dims = len(a_shape)
        axis = _canonicalize_axis(axis, num_dims)

        if fill_nan:
            a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a)

        if not dtype and dtypes.dtype(a) == np.bool_:
            dtype = dtypes.canonicalize_dtype(dtypes.int_)
        if dtype:
            a = lax.convert_element_type(a, dtype)

        return reduction(a, axis)
Exemplo n.º 8
0
def modf(x, out=None):
    _check_arraylike("modf", x)
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.modf is not supported.")
    whole = _where(lax.ge(x, lax_internal._zero(x)), floor(x), ceil(x))
    return x - whole, whole
Exemplo n.º 9
0
def ldexp(x1, x2):
    _check_arraylike("ldexp", x1, x2)
    x1_dtype = dtypes.dtype(x1)
    x2_dtype = dtypes.dtype(x2)
    if (dtypes.issubdtype(x1_dtype, np.complexfloating)
            or dtypes.issubdtype(x2_dtype, np.inexact)):
        raise ValueError(
            f"ldexp not supported for input types {(x1_dtype, x2_dtype)}")

    x1, x2 = _promote_shapes("ldexp", x1, x2)

    dtype = dtypes.canonicalize_dtype(dtypes._to_inexact_dtype(x1_dtype))
    info = dtypes.finfo(dtype)
    int_type = _INT_DTYPES[info.bits]

    x1 = lax.convert_element_type(x1, dtype)
    x2 = lax.convert_element_type(x2, int_type)

    mask = (1 << info.nexp) - 1
    bias = ((1 << info.nexp) - 1) >> 1
    x, e = _normalize_float(x1)
    x2 += e + ((x >> info.nmant) & mask) - bias

    # find underflow/overflow before denormalization
    underflow_cond = x2 < -(bias + info.nmant)
    overflow_cond = x2 > bias

    m = lax.full_like(x, 1, dtype=dtype)

    # denormals
    cond = x2 < -bias + 1
    x2 = _where(cond, x2 + info.nmant, x2)
    m = _where(cond, m / (1 << info.nmant), m)

    x2 = lax.convert_element_type(x2, np.int32)
    x &= ~(mask << info.nmant)
    x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant)

    x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype)

    # underflow
    x = _where(underflow_cond, lax.full_like(x, 0, dtype=dtype), x)
    # overflow
    x = _where(overflow_cond, lax.sign(x1) * lax.full_like(x, np.inf), x)
    # ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0
    return _where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x)
Exemplo n.º 10
0
def sign(x):
    _check_arraylike('sign', x)
    dtype = dtypes.dtype(x)
    if dtypes.issubdtype(dtype, np.complexfloating):
        re = lax.real(x)
        return lax.complex(lax.sign(_where(re != 0, re, lax.imag(x))),
                           _constant_like(re, 0))
    return lax.sign(x)
Exemplo n.º 11
0
def _power(x1, x2):
  x1, x2 = _promote_args("power", x1, x2)
  dtype = dtypes.dtype(x1)
  if not dtypes.issubdtype(dtype, np.integer):
    return lax.pow(x1, x2)

  # Integer power => use binary exponentiation.

  # TODO(phawkins): add integer pow support to XLA.
  bits = 6  # Anything more would overflow for any x1 > 1
  zero = _constant_like(x2, 0)
  one = _constant_like(x2, 1)
  # Initialize acc carefully such that pow(0, x2) is zero for x2 != 0
  acc = _where(lax.bitwise_and(lax.eq(x1, zero), lax.ne(x2, zero)), zero, one)
  for _ in range(bits):
    acc = _where(lax.bitwise_and(x2, one), lax.mul(acc, x1), acc)
    x1 = lax.mul(x1, x1)
    x2 = lax.shift_right_logical(x2, one)
  return acc
Exemplo n.º 12
0
def frexp(x):
    _check_arraylike("frexp", x)
    x, = _promote_dtypes_inexact(x)
    if dtypes.issubdtype(x.dtype, np.complexfloating):
        raise TypeError("frexp does not support complex-valued inputs")

    dtype = dtypes.dtype(x)
    info = dtypes.finfo(dtype)
    mask = (1 << info.nexp) - 1
    bias = ((1 << info.nexp) - 1) >> 1

    x1, x2 = _normalize_float(x)
    x2 += ((x1 >> info.nmant) & mask) - bias + 1
    x1 &= ~(mask << info.nmant)
    x1 |= (bias - 1) << info.nmant
    x1 = lax.bitcast_convert_type(x1, dtype)

    cond = isinf(x) | isnan(x) | (x == 0)
    x2 = _where(cond, lax_internal._zeros(x2), x2)
    return _where(cond, x, x1), lax.convert_element_type(x2, np.int32)
Exemplo n.º 13
0
def floor_divide(x1, x2):
  x1, x2 = _promote_args("floor_divide", x1, x2)
  dtype = dtypes.dtype(x1)
  if dtypes.issubdtype(dtype, np.integer):
    quotient = lax.div(x1, x2)
    select = logical_and(lax.sign(x1) != lax.sign(x2), lax.rem(x1, x2) != 0)
    # TODO(mattjj): investigate why subtracting a scalar was causing promotion
    return _where(select, quotient - 1, quotient)
  elif dtypes.issubdtype(dtype, np.complexfloating):
    x1r = lax.real(x1)
    x1i = lax.imag(x1)
    x2r = lax.real(x2)
    x2i = lax.imag(x2)
    which = lax.ge(lax.abs(x2r), lax.abs(x2i))
    rat1 = _where(which, lax.full_like(x2i, 1), lax.div(x2r, x2i))
    rat2 = _where(which, lax.div(x2i, x2r), _lax_const(x2i, 1))
    out = lax.floor(lax.div(lax.add(lax.mul(x1r, rat1), lax.mul(x1i, rat2)),
                            lax.add(lax.mul(x2r, rat1), lax.mul(x2i, rat2))))
    return lax.convert_element_type(out, dtype)
  else:
    return _float_divmod(x1, x2)[0]
Exemplo n.º 14
0
def ldexp(x1, x2):
  _check_arraylike("ldexp", x1, x2)
  dtype = dtypes.canonicalize_dtype(_result_dtype(np.ldexp, x1, x2))
  x1, x2 = _promote_shapes("ldexp", x1, x2)
  x1 = lax.convert_element_type(x1, dtype)

  info = dtypes.finfo(dtype)
  mask = (1 << info.nexp) - 1
  bias = ((1 << info.nexp) - 1) >> 1

  int_type = _INT_DTYPES[info.bits]

  x, e = _normalize_float(x1)
  x2 += e + ((x >> info.nmant) & mask) - bias

  # find underflow/overflow before denormalization
  underflow_cond = x2 < -(bias + info.nmant)
  overflow_cond = x2 > bias

  m = lax.full_like(x, 1, dtype=dtype)

  # denormals
  cond = x2 < -bias + 1
  x2 = _where(cond, x2 + info.nmant, x2)
  m = _where(cond, m / (1 << info.nmant), m)

  x2 = lax.convert_element_type(x2, np.int32)
  x &= ~(mask << info.nmant)
  x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant)

  x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype)

  # underflow
  x = _where(underflow_cond, lax.full_like(x, 0, dtype=dtype), x)
  # overflow
  x = _where(overflow_cond, lax.sign(x1) * lax.full_like(x, np.inf), x)
  # ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0
  return _where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x)
Exemplo n.º 15
0
def _nan_reduction(a,
                   name,
                   jnp_reduction,
                   init_val,
                   nan_if_all_nan,
                   axis=None,
                   keepdims=None,
                   **kwargs):
    _check_arraylike(name, a)
    if not dtypes.issubdtype(dtypes.dtype(a), np.inexact):
        return jnp_reduction(a, axis=axis, keepdims=keepdims, **kwargs)

    out = jnp_reduction(_where(lax_internal._isnan(a),
                               _reduction_init_val(a, init_val), a),
                        axis=axis,
                        keepdims=keepdims,
                        **kwargs)
    if nan_if_all_nan:
        return _where(
            all(lax_internal._isnan(a), axis=axis, keepdims=keepdims),
            _lax_const(a, np.nan), out)
    else:
        return out
Exemplo n.º 16
0
def roots(p, *, strip_zeros=True):
    _check_arraylike("roots", p)
    p = atleast_1d(*_promote_dtypes_inexact(p))
    if p.ndim != 1:
        raise ValueError("Input must be a rank-1 array.")
    if p.size < 2:
        return array([], dtype=dtypes._to_complex_dtype(p.dtype))
    num_leading_zeros = _where(all(p == 0), len(p), argmin(p == 0))

    if strip_zeros:
        num_leading_zeros = core.concrete_or_error(
            int, num_leading_zeros,
            "The error occurred in the jnp.roots() function. To use this within a "
            "JIT-compiled context, pass strip_zeros=False, but be aware that leading zeros "
            "will be result in some returned roots being set to NaN.")
        return _roots_no_zeros(p[num_leading_zeros:])
    else:
        return _roots_with_zeros(p, num_leading_zeros)
Exemplo n.º 17
0
def _reduction(a,
               name,
               np_fun,
               op,
               init_val,
               has_identity=True,
               preproc=None,
               bool_op=None,
               upcast_f16_for_computation=False,
               axis=None,
               dtype=None,
               out=None,
               keepdims=False,
               initial=None,
               where_=None,
               parallel_reduce=None):
    bool_op = bool_op or op
    # Note: we must accept out=None as an argument, because numpy reductions delegate to
    # object methods. For example `np.sum(x)` will call `x.sum()` if the `sum()` method
    # exists, passing along all its arguments.
    if out is not None:
        raise NotImplementedError(
            f"The 'out' argument to jnp.{name} is not supported.")
    _check_arraylike(name, a)
    lax_internal._check_user_dtype_supported(dtype, name)
    axis = core.concrete_or_error(None, axis,
                                  f"axis argument to jnp.{name}().")

    if initial is None and not has_identity and where_ is not None:
        raise ValueError(
            f"reduction operation {name} does not have an identity, so to use a "
            f"where mask one has to specify 'initial'")

    a = a if isinstance(a, ndarray) else _asarray(a)
    a = preproc(a) if preproc else a
    pos_dims, dims = _reduction_dims(a, axis)

    if initial is None and not has_identity:
        shape = np.shape(a)
        if not _all(core.greater_equal_dim(shape[d], 1) for d in pos_dims):
            raise ValueError(
                f"zero-size array to reduction operation {name} which has no identity"
            )

    result_dtype = dtypes.canonicalize_dtype(
        dtype or dtypes.dtype(np_fun(np.ones((), dtype=dtypes.dtype(a)))))
    if upcast_f16_for_computation and dtypes.issubdtype(
            result_dtype, np.inexact):
        computation_dtype = _upcast_f16(result_dtype)
    else:
        computation_dtype = result_dtype
    a = lax.convert_element_type(a, computation_dtype)
    op = op if computation_dtype != np.bool_ else bool_op
    # NB: in XLA, init_val must be an identity for the op, so the user-specified
    # initial value must be applied afterward.
    init_val = _reduction_init_val(a, init_val)
    if where_ is not None:
        a = _where(where_, a, init_val)
    if pos_dims is not dims:
        if parallel_reduce is None:
            raise NotImplementedError(
                f"Named reductions not implemented for jnp.{name}()")
        result = parallel_reduce(a, dims)
    else:
        result = lax.reduce(a, init_val, op, dims)
    if initial is not None:
        result = op(lax.convert_element_type(initial, a.dtype), result)
    if keepdims:
        result = lax.expand_dims(result, pos_dims)
    return lax.convert_element_type(result, dtype or result_dtype)
Exemplo n.º 18
0
def heaviside(x1, x2):
    _check_arraylike("heaviside", x1, x2)
    x1, x2 = _promote_dtypes_inexact(x1, x2)
    zero = _lax_const(x1, 0)
    return _where(lax.lt(x1, zero), zero,
                  _where(lax.gt(x1, zero), _lax_const(x1, 1), x2))
Exemplo n.º 19
0
def fmod(x1, x2):
    _check_arraylike("fmod", x1, x2)
    if dtypes.issubdtype(dtypes.result_type(x1, x2), np.integer):
        x2 = _where(x2 == 0, lax_internal._ones(x2), x2)
    return lax.rem(*_promote_args("fmod", x1, x2))
Exemplo n.º 20
0
def copysign(x1, x2):
    x1, x2 = _promote_args_inexact("copysign", x1, x2)
    if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating):
        raise TypeError("copysign does not support complex-valued inputs")
    return _where(signbit(x2).astype(bool), -lax.abs(x1), lax.abs(x1))