Exemplo n.º 1
0
 def fn(x1, x2):
   x1, x2 =  _promote_args(numpy_fn.__name__, x1, x2)
   # Comparison on complex types are defined as a lexicographic ordering on
   # the (real, imag) pair.
   if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating):
     rx = lax.real(x1)
     ry = lax.real(x2)
     return lax.select(lax.eq(rx, ry), lax_fn(lax.imag(x1), lax.imag(x2)),
                       lax_fn(rx, ry))
   return lax_fn(x1, x2)
Exemplo n.º 2
0
def nanvar(a,
           axis: Optional[Union[int, Tuple[int, ...]]] = None,
           dtype=None,
           out=None,
           ddof=0,
           keepdims=False,
           where=None):
    _check_arraylike("nanvar", a)
    lax_internal._check_user_dtype_supported(dtype, "nanvar")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.nanvar is not supported.")

    a_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
    a_mean = nanmean(a, axis, dtype=a_dtype, keepdims=True, where=where)

    centered = _where(lax_internal._isnan(a), 0,
                      a - a_mean)  # double-where trick for gradients.
    if dtypes.issubdtype(centered.dtype, np.complexfloating):
        centered = lax.real(lax.mul(centered, lax.conj(centered)))
    else:
        centered = lax.square(centered)

    normalizer = sum(lax_internal.bitwise_not(lax_internal._isnan(a)),
                     axis=axis,
                     keepdims=keepdims,
                     where=where)
    normalizer = normalizer - ddof
    normalizer_mask = lax.le(normalizer, 0)
    result = sum(centered, axis, keepdims=keepdims, where=where)
    result = _where(normalizer_mask, np.nan, result)
    divisor = _where(normalizer_mask, 1, normalizer)
    out = lax.div(result, lax.convert_element_type(divisor, result.dtype))
    return lax.convert_element_type(out, dtype)
Exemplo n.º 3
0
def sign(x):
    _check_arraylike('sign', x)
    dtype = dtypes.dtype(x)
    if dtypes.issubdtype(dtype, np.complexfloating):
        re = lax.real(x)
        return lax.complex(lax.sign(_where(re != 0, re, lax.imag(x))),
                           _constant_like(re, 0))
    return lax.sign(x)
Exemplo n.º 4
0
def rint(x):
    _check_arraylike('rint', x)
    dtype = dtypes.dtype(x)
    if dtypes.issubdtype(dtype, np.integer):
        return lax.convert_element_type(x, dtypes.float_)
    if dtypes.issubdtype(dtype, np.complexfloating):
        return lax.complex(rint(lax.real(x)), rint(lax.imag(x)))
    return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN)
Exemplo n.º 5
0
def isinf(x):
    _check_arraylike("isinf", x)
    dtype = dtypes.dtype(x)
    if dtypes.issubdtype(dtype, np.floating):
        return lax.eq(lax.abs(x), _constant_like(x, np.inf))
    elif dtypes.issubdtype(dtype, np.complexfloating):
        re = lax.real(x)
        im = lax.imag(x)
        return lax.bitwise_or(lax.eq(lax.abs(re), _constant_like(re, np.inf)),
                              lax.eq(lax.abs(im), _constant_like(im, np.inf)))
    else:
        return lax.full_like(x, False, dtype=np.bool_)
Exemplo n.º 6
0
def logaddexp(x1, x2):
  x1, x2 = _promote_args_inexact("logaddexp", x1, x2)
  amax = lax.max(x1, x2)
  if dtypes.issubdtype(x1.dtype, np.floating):
    delta = lax.sub(x1, x2)
    return lax.select(lax_internal._isnan(delta),
                      lax.add(x1, x2),  # NaNs or infinities of the same sign.
                      lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta))))))
  else:
    delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2)))
    out = lax.add(amax, lax.log1p(lax.exp(delta)))
    return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi))
Exemplo n.º 7
0
def floor_divide(x1, x2):
  x1, x2 = _promote_args("floor_divide", x1, x2)
  dtype = dtypes.dtype(x1)
  if dtypes.issubdtype(dtype, np.integer):
    quotient = lax.div(x1, x2)
    select = logical_and(lax.sign(x1) != lax.sign(x2), lax.rem(x1, x2) != 0)
    # TODO(mattjj): investigate why subtracting a scalar was causing promotion
    return _where(select, quotient - 1, quotient)
  elif dtypes.issubdtype(dtype, np.complexfloating):
    x1r = lax.real(x1)
    x1i = lax.imag(x1)
    x2r = lax.real(x2)
    x2i = lax.imag(x2)
    which = lax.ge(lax.abs(x2r), lax.abs(x2i))
    rat1 = _where(which, lax.full_like(x2i, 1), lax.div(x2r, x2i))
    rat2 = _where(which, lax.div(x2i, x2r), _lax_const(x2i, 1))
    out = lax.floor(lax.div(lax.add(lax.mul(x1r, rat1), lax.mul(x1i, rat2)),
                            lax.add(lax.mul(x2r, rat1), lax.mul(x2i, rat2))))
    return lax.convert_element_type(out, dtype)
  else:
    return _float_divmod(x1, x2)[0]
Exemplo n.º 8
0
def _var(a,
         axis: Optional[Union[int, Tuple[int, ...]]] = None,
         dtype=None,
         out=None,
         ddof=0,
         keepdims=False,
         *,
         where=None):
    _check_arraylike("var", a)
    lax_internal._check_user_dtype_supported(dtype, "var")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.var is not supported.")

    computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
    a = a.astype(computation_dtype)
    a_mean = mean(a, axis, dtype=computation_dtype, keepdims=True, where=where)
    centered = lax.sub(a, a_mean)
    if dtypes.issubdtype(centered.dtype, np.complexfloating):
        centered = lax.real(lax.mul(centered, lax.conj(centered)))
    else:
        centered = lax.square(centered)

    if where is None:
        if axis is None:
            normalizer = core.dimension_as_value(np.size(a))
        else:
            normalizer = core.dimension_as_value(_axis_size(a, axis))
    else:
        normalizer = sum(_broadcast_to(where, np.shape(a)),
                         axis,
                         dtype=dtype,
                         keepdims=keepdims)
    normalizer = normalizer - ddof

    result = sum(centered, axis, keepdims=keepdims, where=where)
    out = lax.div(result, lax.convert_element_type(normalizer, result.dtype))
    return lax.convert_element_type(out, dtype)
Exemplo n.º 9
0
def real(val):
    _check_arraylike("real", val)
    return lax.real(val) if np.iscomplexobj(val) else val
Exemplo n.º 10
0
def angle(x):
    if iscomplexobj(x):
        return lax.atan2(lax.imag(x), lax.real(x))
    else:
        return zeros_like(x)
Exemplo n.º 11
0
def real(x):
    return lax.real(x) if iscomplexobj(x) else x