Пример #1
0
def hypot(x1, x2):
  _check_arraylike("hypot", x1, x2)
  x1, x2 = _promote_dtypes_inexact(x1, x2)
  x1 = lax.abs(x1)
  x2 = lax.abs(x2)
  x1, x2 = maximum(x1, x2), minimum(x1, x2)
  return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, lax_internal._ones(x1), x1)))))
Пример #2
0
def _normalize(x, y):
    z = jnp.where(jnp.isinf(x), x, x + y)
    zz = jnp.where(
        lax.abs(x) > lax.abs(y),
        x - z + y,
        y - z + x,
    )
    return z, jnp.where(jnp.isinf(z), 0, zz)
Пример #3
0
def _add2(x, y):
    (x, xx), (y, yy) = x, y
    r = x + y
    s = jnp.where(
        lax.abs(x) > lax.abs(y),
        x - r + y + yy + xx,
        y - r + x + xx + yy,
    )
    z = r + s
    zz = r - z + s
    return (z, zz)
Пример #4
0
def _sub2(x, y):
    (x, xx), (y, yy) = x, y
    r = x - y
    s = jnp.where(
        lax.abs(x) > lax.abs(y),
        x - r - y - yy + xx,
        -y - r + x + xx - yy,
    )
    z = r + s
    zz = r - z + s
    return (z, zz)
Пример #5
0
def isinf(x):
    _check_arraylike("isinf", x)
    dtype = dtypes.dtype(x)
    if dtypes.issubdtype(dtype, np.floating):
        return lax.eq(lax.abs(x), _constant_like(x, np.inf))
    elif dtypes.issubdtype(dtype, np.complexfloating):
        re = lax.real(x)
        im = lax.imag(x)
        return lax.bitwise_or(lax.eq(lax.abs(re), _constant_like(re, np.inf)),
                              lax.eq(lax.abs(im), _constant_like(im, np.inf)))
    else:
        return lax.full_like(x, False, dtype=np.bool_)
Пример #6
0
def _normalize_float(x):
  info = dtypes.finfo(dtypes.dtype(x))
  cond = lax.abs(x) < info.tiny
  x1 = _where(cond, x * _lax_const(x, 1 << info.nmant), x)
  x2 = _where(cond, lax.full_like(x, -info.nmant, dtype=np.int32), lax.full_like(x, 0, dtype=np.int32))
  int_type = _INT_DTYPES[info.bits]
  return lax.bitcast_convert_type(x1, int_type), x2
Пример #7
0
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = jnp.broadcast_arrays(a, b)
    dims = _reduction_dims(a, axis)
    dimadd = lambda x: lax.expand_dims(x, dims)
    amax = lax.reduce(a, _constant_like(a, -np.inf), lax.max, dims)
    amax = lax.stop_gradient(
        lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)))
    amax_singletons = dimadd(amax)
    if b is None:
        out = lax.add(
            lax.log(
                lax.reduce(lax.exp(lax.sub(a, amax_singletons)),
                           _constant_like(a, 0), lax.add, dims)), amax)
        sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype)
        sign = jnp.where(out == -np.inf, 0.0, sign)
    else:
        sumexp = lax.reduce(lax.mul(lax.exp(lax.sub(a, amax_singletons)), b),
                            _constant_like(a, 0), lax.add, dims)
        sign = lax.stop_gradient(lax.sign(sumexp))
        out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (dimadd(out), dimadd(sign)) if keepdims else (out, sign)
    if b is not None:
        out = jnp.where(sign < 0, np.nan, out)
    return dimadd(out) if keepdims else out
Пример #8
0
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = _promote_args_inexact("logsumexp", a, b)
        a = jnp.where(b != 0, a, -jnp.inf)
    pos_dims, dims = _reduction_dims(a, axis)
    amax = jnp.max(a, axis=dims, keepdims=keepdims)
    amax = lax.stop_gradient(
        lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)))
    amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
    if b is None:
        out = lax.add(
            lax.log(
                jnp.sum(lax.exp(lax.sub(a, amax_with_dims)),
                        axis=dims,
                        keepdims=keepdims)), amax)
        sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype)
        sign = jnp.where(out == -np.inf, 0.0, sign)
    else:
        sumexp = jnp.sum(lax.mul(lax.exp(lax.sub(a, amax_with_dims)), b),
                         axis=dims,
                         keepdims=keepdims)
        sign = lax.stop_gradient(lax.sign(sumexp))
        out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (out, sign)
    if b is not None:
        out = jnp.where(sign < 0, np.nan, out)
    return out
Пример #9
0
def svd(a,
        full_matrices: bool = True,
        compute_uv: bool = True,
        hermitian: bool = False):
    a, = _promote_dtypes_inexact(jnp.asarray(a))
    if hermitian:
        w, v = lax_linalg.eigh(a)
        s = lax.abs(v)
        if compute_uv:
            sign = lax.sign(v)
            idxs = lax.broadcasted_iota(np.int64,
                                        s.shape,
                                        dimension=s.ndim - 1)
            s, idxs, sign = lax.sort((s, idxs, sign), dimension=-1, num_keys=1)
            s = lax.rev(s, dimensions=[s.ndim - 1])
            idxs = lax.rev(idxs, dimensions=[s.ndim - 1])
            sign = lax.rev(sign, dimensions=[s.ndim - 1])
            u = jnp.take_along_axis(w, idxs[..., None, :], axis=-1)
            vh = _H(u * sign[..., None, :])
            return u, s, vh
        else:
            return lax.rev(lax.sort(s, dimension=-1), dimensions=[s.ndim - 1])

    return lax_linalg.svd(a,
                          full_matrices=full_matrices,
                          compute_uv=compute_uv)
Пример #10
0
def _normalize_float(x):
    info = dtypes.finfo(dtypes.dtype(x))
    int_type = _INT_DTYPES[info.bits]
    cond = lax.abs(x) < info.tiny
    x1 = _where(cond, x * _lax_const(x, 1 << info.nmant), x)
    x2 = _where(cond, int_type(-info.nmant), int_type(0))
    return lax.bitcast_convert_type(x1, int_type), x2
Пример #11
0
def var(a, axis=None, keepdims=False, ddof=0):
    if ddof != 0:
        raise NotImplementedError("Only implemented for ddof=0.")
    centered = subtract(a, mean(a, axis, keepdims=True))
    if iscomplexobj(centered):
        centered = lax.abs(centered)
    return mean(lax.mul(centered, centered), axis, keepdims=keepdims)
Пример #12
0
def _logaddexp(x1, x2):
  """
  Logaddexp while ignoring the custom_jvp rule.
  """
  amax = lax.max(x1, x2)
  delta = lax.sub(x1, x2)
  return lax.select(jnp.isnan(delta),
                    lax.add(x1, x2),  # NaNs or infinities of the same sign.
                    lax.add(amax, lax.log1p(lax.exp(-lax.abs(delta)))))
Пример #13
0
def floor_divide(x1, x2):
  x1, x2 = _promote_args("floor_divide", x1, x2)
  dtype = dtypes.dtype(x1)
  if dtypes.issubdtype(dtype, np.integer):
    quotient = lax.div(x1, x2)
    select = logical_and(lax.sign(x1) != lax.sign(x2), lax.rem(x1, x2) != 0)
    # TODO(mattjj): investigate why subtracting a scalar was causing promotion
    return _where(select, quotient - 1, quotient)
  elif dtypes.issubdtype(dtype, np.complexfloating):
    x1r = lax.real(x1)
    x1i = lax.imag(x1)
    x2r = lax.real(x2)
    x2i = lax.imag(x2)
    which = lax.ge(lax.abs(x2r), lax.abs(x2i))
    rat1 = _where(which, lax.full_like(x2i, 1), lax.div(x2r, x2i))
    rat2 = _where(which, lax.div(x2i, x2r), _lax_const(x2i, 1))
    out = lax.floor(lax.div(lax.add(lax.mul(x1r, rat1), lax.mul(x1i, rat2)),
                            lax.add(lax.mul(x2r, rat1), lax.mul(x2i, rat2))))
    return lax.convert_element_type(out, dtype)
  else:
    return _float_divmod(x1, x2)[0]
Пример #14
0
def logaddexp(x1, x2):
  x1, x2 = _promote_args_inexact("logaddexp", x1, x2)
  amax = lax.max(x1, x2)
  if dtypes.issubdtype(x1.dtype, np.floating):
    delta = lax.sub(x1, x2)
    return lax.select(lax_internal._isnan(delta),
                      lax.add(x1, x2),  # NaNs or infinities of the same sign.
                      lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta))))))
  else:
    delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2)))
    out = lax.add(amax, lax.log1p(lax.exp(delta)))
    return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi))
Пример #15
0
def _ndtr(x):
    """Implements ndtr core logic."""
    dtype = lax.dtype(x).type
    half_sqrt_2 = dtype(0.5) * np.sqrt(2., dtype=dtype)
    w = x * half_sqrt_2
    z = lax.abs(w)
    y = lax.select(
        lax.lt(z, half_sqrt_2),
        dtype(1.) + lax.erf(w),
        lax.select(lax.gt(w, dtype(0.)),
                   dtype(2.) - lax.erfc(z), lax.erfc(z)))
    return dtype(0.5) * y
Пример #16
0
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
    if b is not None:
        a, b = _promote_args_inexact("logsumexp", a, b)
        a = jnp.where(b != 0, a, -jnp.inf)
    else:
        a, = _promote_args_inexact("logsumexp", a)
    pos_dims, dims = _reduction_dims(a, axis)
    amax = jnp.max(a, axis=dims, keepdims=keepdims)
    amax = lax.stop_gradient(
        lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0)))
    amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
    # fast path if the result cannot be negative.
    if b is None and not np.issubdtype(a.dtype, np.complexfloating):
        out = lax.add(
            lax.log(
                jnp.sum(lax.exp(lax.sub(a, amax_with_dims)),
                        axis=dims,
                        keepdims=keepdims)), amax)
        sign = jnp.where(jnp.isnan(out), out, 1.0)
        sign = jnp.where(jnp.isneginf(out), 0.0, sign).astype(out.dtype)
    else:
        expsub = lax.exp(lax.sub(a, amax_with_dims))
        if b is not None:
            expsub = lax.mul(expsub, b)
        sumexp = jnp.sum(expsub, axis=dims, keepdims=keepdims)

        sign = lax.stop_gradient(jnp.sign(sumexp))
        if np.issubdtype(sumexp.dtype, np.complexfloating):
            if return_sign:
                sumexp = sign * sumexp
            out = lax.add(lax.log(sumexp), amax)
        else:
            out = lax.add(lax.log(lax.abs(sumexp)), amax)
    if return_sign:
        return (out, sign)
    if b is not None:
        if not np.issubdtype(out.dtype, np.complexfloating):
            with jax.debug_nans(False):
                out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype),
                                out)
    return out
Пример #17
0
def cdf(x, p):
    x, p = _promote_args_inexact("gennorm.cdf", x, p)
    return .5 * (1 + lax.sign(x) * lax.igamma(1 / p, lax.abs(x)**p))
Пример #18
0
def logpdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("laplace.logpdf", x, loc, scale)
    two = _constant_like(x, 2)
    linear_term = lax.div(lax.abs(lax.sub(x, loc)), scale)
    return lax.neg(lax.add(linear_term, lax.log(lax.mul(two, scale))))
Пример #19
0
    def __call__(self, x, training: bool):
        """Normalizes the input using batch statistics.
        Args:
            x: the input to be normalized.
        Returns:
            Normalized inputs (the same shape as inputs).
        """
        x = jnp.asarray(x, self.dtype)
        axis = self.axis if isinstance(self.axis, tuple) else (self.axis, )
        axis = _absolute_dims(x.ndim, axis)
        feature_shape = tuple(d if i in axis else 1
                              for i, d in enumerate(x.shape))
        reduced_feature_shape = tuple(d for i, d in enumerate(x.shape)
                                      if i in axis)
        reduction_axis = tuple(i for i in range(x.ndim) if i not in axis)

        # we detect if we're in initialization via empty variable tree.
        initializing = not self.has_variable('batch_stats', 'mean')

        ra_mean = self.variable('batch_stats', 'mean',
                                lambda s: jnp.zeros(s, jnp.float32),
                                reduced_feature_shape)
        ra_var = self.variable('batch_stats', 'var',
                               lambda s: jnp.ones(s, jnp.float32),
                               reduced_feature_shape)

        if not training:
            mean, var = ra_mean.value, ra_var.value
        else:
            mean = jnp.mean(x, axis=reduction_axis, keepdims=False)
            var = jnp.mean(lax.abs(x - mean),
                           axis=reduction_axis,
                           keepdims=False) * jnp.sqrt(jnp.pi / 2)
            if self.axis_name is not None and not initializing:
                concatenated_mean = jnp.concatenate([mean, var])
                mean, var = jnp.split(
                    lax.pmean(concatenated_mean,
                              axis_name=self.axis_name,
                              axis_index_groups=self.axis_index_groups), 2)

            if not initializing:
                ra_mean.value = self.momentum * ra_mean.value + (
                    1 - self.momentum) * mean
                ra_var.value = self.momentum * ra_var.value + (
                    1 - self.momentum) * var

        mean = jnp.asarray(mean, self.dtype)
        var = jnp.asarray(var, self.dtype)
        y = x - mean.reshape(feature_shape)
        mul = lax.reciprocal(var + self.epsilon)
        if self.use_scale:
            scale = self.param('scale', self.scale_init,
                               reduced_feature_shape).reshape(feature_shape)
            scale = jnp.asarray(scale, self.dtype)
            mul = mul * scale
        y = y * mul
        if self.use_bias:
            bias = self.param('bias', self.bias_init,
                              reduced_feature_shape).reshape(feature_shape)
            bias = jnp.asarray(bias, self.dtype)
            y = y + bias
        return jnp.asarray(y, self.dtype)
Пример #20
0
def _abs2(x):
    x, xx = x
    sign = jnp.where(lax.sign(x) == lax.sign(xx), 1, -1)
    return (lax.abs(x), sign * lax.abs(xx))
Пример #21
0
def copysign(x1, x2):
    x1, x2 = _promote_args_inexact("copysign", x1, x2)
    if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating):
        raise TypeError("copysign does not support complex-valued inputs")
    return _where(signbit(x2).astype(bool), -lax.abs(x1), lax.abs(x1))
Пример #22
0
def absolute(x):
    _check_arraylike('absolute', x)
    dt = dtypes.dtype(x)
    return x if dt == np.bool_ or dtypes.issubdtype(
        dt, np.unsignedinteger) else lax.abs(x)
Пример #23
0
def logpdf(x, p):
    x, p = _promote_args_inexact("gennorm.logpdf", x, p)
    return lax.log(.5 * p) - lax.lgamma(1 / p) - lax.abs(x)**p
Пример #24
0
def isclose(a, b, rtol=1e-05, atol=1e-08):
    a, b = _promote_args("isclose", a, b)
    rtol = lax.convert_element_type(rtol, _dtype(a))
    atol = lax.convert_element_type(atol, _dtype(a))
    return lax.le(lax.abs(lax.sub(a, b)),
                  lax.add(atol, lax.mul(rtol, lax.abs(b))))
Пример #25
0
def i1(x):
    x, = _promote_args_inexact("i1", x)
    return lax.mul(lax.exp(lax.abs(x)), lax.bessel_i1e(x))
Пример #26
0
 def _abs(x):
     return lax.abs(x)