예제 #1
0
def logpdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("norm.logpdf", x, loc, scale)
    two = _constant_like(x, 2)
    scale_sqrd = lax.pow(scale, two)
    log_normalizer = lax.log(lax.mul(_constant_like(x, 2 * np.pi), scale_sqrd))
    quadratic = lax.div(lax.pow(lax.sub(x, loc), two), scale_sqrd)
    return lax.div(lax.neg(lax.add(log_normalizer, quadratic)), two)
예제 #2
0
파일: pareto.py 프로젝트: yashk2810/jax
def logpdf(x, b, loc=0, scale=1):
    x, b, loc, scale = _promote_args_inexact("pareto.logpdf", x, b, loc, scale)
    one = _constant_like(x, 1)
    scaled_x = lax.div(lax.sub(x, loc), scale)
    normalize_term = lax.log(lax.div(scale, b))
    log_probs = lax.neg(
        lax.add(normalize_term, lax.mul(lax.add(b, one), lax.log(scaled_x))))
    return where(lax.lt(x, lax.add(loc, scale)), -inf, log_probs)
예제 #3
0
def logpdf(x, df, loc=0, scale=1):
  x, df, loc, scale = _promote_args_inexact("t.logpdf", x, df, loc, scale)
  two = _lax_const(x, 2)
  scaled_x = lax.div(lax.sub(x, loc), scale)
  df_over_two = lax.div(df, two)
  df_plus_one_over_two = lax.add(df_over_two, _lax_const(x, 0.5))
  normalize_term_const = lax.mul(lax.mul(scale, scale), _lax_const(x, np.pi))
  normalize_term_tmp = lax.div(lax.log(lax.mul(normalize_term_const, df)), two)
  normalize_term = lax.sub(lax.add(lax.lgamma(df_over_two), normalize_term_tmp),
                           lax.lgamma(df_plus_one_over_two))
  quadratic = lax.div(lax.mul(scaled_x, scaled_x), df)
  return lax.neg(lax.add(normalize_term, lax.mul(df_plus_one_over_two, lax.log1p(quadratic))))
예제 #4
0
def logpdf(x, df, loc=0, scale=1):
    x, df, loc, scale = _promote_args_inexact("chi2.logpdf", x, df, loc, scale)
    one = _constant_like(x, 1)
    two = _constant_like(x, 2)
    y = lax.div(lax.sub(x, loc), scale)
    df_on_two = lax.div(df, two)

    kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two))

    nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two)))

    log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel)
    return where(lax.lt(x, loc), -inf, log_probs)
예제 #5
0
def logaddexp2(x1, x2):
  x1, x2 = _promote_args_inexact("logaddexp2", x1, x2)
  amax = lax.max(x1, x2)
  if dtypes.issubdtype(x1.dtype, np.floating):
    delta = lax.sub(x1, x2)
    return lax.select(lax_internal._isnan(delta),
                      lax.add(x1, x2),  # NaNs or infinities of the same sign.
                      lax.add(amax, lax.div(lax.log1p(exp2(lax.neg(lax.abs(delta)))),
                                            _constant_like(x1, np.log(2)))))
  else:
    delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2)))
    out = lax.add(amax, lax.div(lax.log1p(exp2(delta)), _constant_like(x1, np.log(2))))
    return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi / np.log(2)))
예제 #6
0
def _mean(a,
          axis: Optional[Union[int, Tuple[int, ...]]] = None,
          dtype=None,
          out=None,
          keepdims=False,
          *,
          where=None):
    _check_arraylike("mean", a)
    lax_internal._check_user_dtype_supported(dtype, "mean")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.mean is not supported.")

    if where is None:
        if axis is None:
            normalizer = core.dimension_as_value(np.size(a))
        else:
            normalizer = core.dimension_as_value(_axis_size(a, axis))
    else:
        normalizer = sum(_broadcast_to(where, np.shape(a)),
                         axis,
                         dtype=dtype,
                         keepdims=keepdims)

    if dtype is None:
        dtype = dtypes._to_inexact_dtype(dtypes.dtype(a))
    dtype = dtypes.canonicalize_dtype(dtype)

    return lax.div(sum(a, axis, dtype=dtype, keepdims=keepdims, where=where),
                   lax.convert_element_type(normalizer, dtype))
예제 #7
0
파일: cauchy.py 프로젝트: yashk2810/jax
def logpdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("cauchy.logpdf", x, loc, scale)
    pi = _constant_like(x, np.pi)
    scaled_x = lax.div(lax.sub(x, loc), scale)
    normalize_term = lax.log(lax.mul(pi, scale))
    return lax.neg(
        lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x))))
예제 #8
0
파일: reductions.py 프로젝트: cloudhan/jax
def nanvar(a,
           axis: Optional[Union[int, Tuple[int, ...]]] = None,
           dtype=None,
           out=None,
           ddof=0,
           keepdims=False,
           where=None):
    _check_arraylike("nanvar", a)
    lax_internal._check_user_dtype_supported(dtype, "nanvar")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.nanvar is not supported.")

    a_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
    a_mean = nanmean(a, axis, dtype=a_dtype, keepdims=True, where=where)

    centered = _where(lax_internal._isnan(a), 0,
                      a - a_mean)  # double-where trick for gradients.
    if dtypes.issubdtype(centered.dtype, np.complexfloating):
        centered = lax.real(lax.mul(centered, lax.conj(centered)))
    else:
        centered = lax.square(centered)

    normalizer = sum(lax_internal.bitwise_not(lax_internal._isnan(a)),
                     axis=axis,
                     keepdims=keepdims,
                     where=where)
    normalizer = normalizer - ddof
    normalizer_mask = lax.le(normalizer, 0)
    result = sum(centered, axis, keepdims=keepdims, where=where)
    result = _where(normalizer_mask, np.nan, result)
    divisor = _where(normalizer_mask, 1, normalizer)
    out = lax.div(result, lax.convert_element_type(divisor, result.dtype))
    return lax.convert_element_type(out, dtype)
예제 #9
0
파일: special.py 프로젝트: jbampton/jax
def zeta(x, q=None):
    assert q is not None, "Riemann zeta function is not implemented yet."
    # Reference: Johansson, Fredrik.
    # "Rigorous high-precision computation of the Hurwitz zeta function and its derivatives."
    # Numerical Algorithms 69.2 (2015): 253-270.
    # https://arxiv.org/abs/1309.2877 - formula (5)
    # here we keep the same notation as in reference
    s, a = _promote_args_inexact("zeta", x, q)
    dtype = lax.dtype(a).type
    s_, a_ = jnp.expand_dims(s, -1), jnp.expand_dims(a, -1)
    # precision ~ N, M
    N = M = dtype(8) if lax.dtype(a) == jnp.float32 else dtype(16)
    assert M <= len(_BERNOULLI_COEFS)
    k = jnp.expand_dims(np.arange(N, dtype=N.dtype), tuple(range(a.ndim)))
    S = jnp.sum((a_ + k)**-s_, -1)
    I = lax.div((a + N)**(dtype(1) - s), s - dtype(1))
    T0 = (a + N)**-s
    m = jnp.expand_dims(np.arange(2 * M, dtype=M.dtype), tuple(range(s.ndim)))
    s_over_a = (s_ + m) / (a_ + N)
    T1 = jnp.cumprod(s_over_a, -1)[..., ::2]
    T1 = jnp.clip(T1, a_max=jnp.finfo(dtype).max)
    coefs = np.expand_dims(
        np.array(_BERNOULLI_COEFS[:T1.shape[-1]], dtype=dtype),
        tuple(range(a.ndim)))
    T1 = T1 / coefs
    T = T0 * (dtype(0.5) + T1.sum(-1))
    return S + I + T
예제 #10
0
def nanmean(a,
            axis: Optional[Union[int, Tuple[int, ...]]] = None,
            dtype=None,
            out=None,
            keepdims=False,
            where=None):
    _check_arraylike("nanmean", a)
    lax_internal._check_user_dtype_supported(dtype, "nanmean")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.nanmean is not supported.")
    if dtypes.issubdtype(dtypes.dtype(a), np.bool_) or dtypes.issubdtype(
            dtypes.dtype(a), np.integer):
        return mean(a, axis, dtype, out, keepdims, where=where)
    if dtype is None:
        dtype = dtypes.dtype(a)
    nan_mask = lax_internal.bitwise_not(lax_internal._isnan(a))
    normalizer = sum(nan_mask,
                     axis=axis,
                     dtype=np.int32,
                     keepdims=keepdims,
                     where=where)
    normalizer = lax.convert_element_type(normalizer, dtype)
    td = lax.div(nansum(a, axis, dtype=dtype, keepdims=keepdims, where=where),
                 normalizer)
    return td
예제 #11
0
def hypot(x1, x2):
  _check_arraylike("hypot", x1, x2)
  x1, x2 = _promote_dtypes_inexact(x1, x2)
  x1 = lax.abs(x1)
  x2 = lax.abs(x2)
  x1, x2 = maximum(x1, x2), minimum(x1, x2)
  return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, lax_internal._ones(x1), x1)))))
예제 #12
0
 def avgpool_op(self, params: Tuple, inputs: DeviceArray):
     out = lax.reduce_window(inputs, 0.0, lax.add, self.pool_size,
                             self.strides, self.padding)
     ones = jnp.ones((1, inputs.shape[1], inputs.shape[2], 1),
                     dtype=inputs.dtype)
     window_sizes = lax.reduce_window(ones, 0.0, lax.add, self.pool_size,
                                      self.strides, self.padding)
     return lax.div(out, window_sizes)
예제 #13
0
def cdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("laplace.cdf", x, loc, scale)
    half = _constant_like(x, 0.5)
    one = _constant_like(x, 1)
    zero = _constant_like(x, 0)
    diff = lax.div(lax.sub(x, loc), scale)
    return lax.select(lax.le(diff, zero), lax.mul(half, lax.exp(diff)),
                      lax.sub(one, lax.mul(half, lax.exp(lax.neg(diff)))))
예제 #14
0
파일: gamma.py 프로젝트: jbampton/jax
def logpdf(x, a, loc=0, scale=1):
    x, a, loc, scale = _promote_args_inexact("gamma.logpdf", x, a, loc, scale)
    one = _lax_const(x, 1)
    y = lax.div(lax.sub(x, loc), scale)
    log_linear_term = lax.sub(xlogy(lax.sub(a, one), y), y)
    shape_terms = lax.add(gammaln(a), lax.log(scale))
    log_probs = lax.sub(log_linear_term, shape_terms)
    return where(lax.lt(x, loc), -inf, log_probs)
예제 #15
0
def sinc(x):
    _check_arraylike("sinc", x)
    x, = _promote_dtypes_inexact(x)
    eq_zero = lax.eq(x, _lax_const(x, 0))
    pi_x = lax.mul(_lax_const(x, np.pi), x)
    safe_pi_x = _where(eq_zero, _lax_const(x, 1), pi_x)
    return _where(eq_zero, _sinc_maclaurin(0, pi_x),
                  lax.div(lax.sin(safe_pi_x), safe_pi_x))
예제 #16
0
def floor_divide(x1, x2):
  x1, x2 = _promote_args("floor_divide", x1, x2)
  if onp.issubdtype(_dtype(x1), onp.integer):
    quotient = lax.div(x1, x2)
    select = logical_and(lax.sign(x1) != lax.sign(x2), lax.rem(x1, x2) != 0)
    # TODO(mattjj): investigate why subtracting a scalar was causing promotion
    return where(select, quotient - onp.array(1, _dtype(quotient)), quotient)
  else:
    return _float_divmod(x1, x2)[0]
예제 #17
0
def round(a, decimals=0):
    if onp.issubdtype(_dtype(a), onp.integer):
        return a  # no-op on integer types

    if decimals == 0:
        return lax.round(a)

    factor = _constant_like(a, 10**decimals)
    return lax.div(lax.round(lax.mul(a, factor)), factor)
예제 #18
0
def _float_divmod(x1, x2):
    # see float_divmod in floatobject.c of CPython
    mod = lax.rem(x1, x2)
    div = lax.div(lax.sub(x1, mod), x2)

    ind = lax.bitwise_and(mod != 0, lax.sign(x2) != lax.sign(mod))
    mod = lax.select(ind, mod + x2, mod)
    div = lax.select(ind, div - _constant_like(div, 1), div)

    return lax.round(div), mod
예제 #19
0
def _segment_update(name: str,
                    data: Array,
                    segment_ids: Array,
                    scatter_op: Callable,
                    num_segments: Optional[int] = None,
                    indices_are_sorted: bool = False,
                    unique_indices: bool = False,
                    bucket_size: Optional[int] = None,
                    reducer: Optional[Callable] = None,
                    mode: Optional[lax.GatherScatterMode] = None) -> Array:
    jnp._check_arraylike(name, data, segment_ids)
    mode = lax.GatherScatterMode.FILL_OR_DROP if mode is None else mode
    data = jnp.asarray(data)
    segment_ids = jnp.asarray(segment_ids)
    dtype = data.dtype
    if num_segments is None:
        num_segments = jnp.max(segment_ids) + 1
    num_segments = core.concrete_or_error(
        int, num_segments, "segment_sum() `num_segments` argument.")
    if num_segments is not None and num_segments < 0:
        raise ValueError("num_segments must be non-negative.")


    num_buckets = 1 if bucket_size is None \
                    else util.ceil_of_ratio(segment_ids.size, bucket_size)
    if num_buckets == 1:
        out = jnp.full((num_segments, ) + data.shape[1:],
                       _get_identity(scatter_op, dtype),
                       dtype=dtype)
        return _scatter_update(out,
                               segment_ids,
                               data,
                               scatter_op,
                               indices_are_sorted,
                               unique_indices,
                               normalize_indices=False,
                               mode=mode)

    # Bucketize indices and perform segment_update on each bucket to improve
    # numerical stability for operations like product and sum.
    assert reducer is not None
    out = jnp.full((num_buckets, num_segments) + data.shape[1:],
                   _get_identity(scatter_op, dtype),
                   dtype=dtype)
    out = _scatter_update(
        out,
        np.index_exp[lax.div(jnp.arange(segment_ids.shape[0]), bucket_size),
                     segment_ids[None, :]],
        data,
        scatter_op,
        indices_are_sorted,
        unique_indices,
        normalize_indices=False,
        mode=mode)
    return reducer(out, axis=0).astype(dtype)
예제 #20
0
파일: special.py 프로젝트: zizai/jax
def multigammaln(a, d):
  d = core.concrete_or_error(int, d, "d argument of multigammaln")
  a, d = _promote_args_inexact("multigammaln", a, d)

  constant = lax.mul(lax.mul(lax.mul(_constant_like(a, 0.25), d),
                             lax.sub(d, _constant_like(a, 1))),
                     lax.log(_constant_like(a, np.pi)))
  res = jnp.sum(gammaln(jnp.expand_dims(a, axis=-1) -
                        lax.div(jnp.arange(d), _constant_like(a, 2))),
               axis=-1)
  return res + constant
예제 #21
0
파일: beta.py 프로젝트: 0x0is1/jax
def logpdf(x, a, b, loc=0, scale=1):
    x, a, b, loc, scale = _promote_args_inexact("beta.logpdf", x, a, b, loc,
                                                scale)
    one = lax._const(x, 1)
    shape_term = lax.neg(betaln(a, b))
    y = lax.div(lax.sub(x, loc), scale)
    log_linear_term = lax.add(xlogy(lax.sub(a, one), y),
                              xlog1py(lax.sub(b, one), lax.neg(y)))
    log_probs = lax.sub(lax.add(shape_term, log_linear_term), lax.log(scale))
    return where(logical_or(lax.gt(x, lax.add(loc, scale)), lax.lt(x, loc)),
                 -inf, log_probs)
예제 #22
0
파일: special.py 프로젝트: xeransis/jax
def multigammaln(a, d):
    a, = _promote_args_inexact("multigammaln", a)
    d = lax.convert_element_type(d, lax.dtype(a))
    constant = lax.mul(
        lax.mul(lax.mul(_constant_like(a, 0.25), d),
                lax.sub(d, _constant_like(a, 1))),
        lax.log(_constant_like(a, np.pi)))
    res = jnp.sum(gammaln(
        jnp.expand_dims(a, axis=-1) -
        lax.div(jnp.arange(d), _constant_like(a, 2))),
                  axis=-1)
    return res + constant
예제 #23
0
def floor_divide(x1, x2):
  x1, x2 = _promote_args("floor_divide", x1, x2)
  dtype = dtypes.dtype(x1)
  if dtypes.issubdtype(dtype, np.integer):
    quotient = lax.div(x1, x2)
    select = logical_and(lax.sign(x1) != lax.sign(x2), lax.rem(x1, x2) != 0)
    # TODO(mattjj): investigate why subtracting a scalar was causing promotion
    return _where(select, quotient - 1, quotient)
  elif dtypes.issubdtype(dtype, np.complexfloating):
    x1r = lax.real(x1)
    x1i = lax.imag(x1)
    x2r = lax.real(x2)
    x2i = lax.imag(x2)
    which = lax.ge(lax.abs(x2r), lax.abs(x2i))
    rat1 = _where(which, lax.full_like(x2i, 1), lax.div(x2r, x2i))
    rat2 = _where(which, lax.div(x2i, x2r), _lax_const(x2i, 1))
    out = lax.floor(lax.div(lax.add(lax.mul(x1r, rat1), lax.mul(x1i, rat2)),
                            lax.add(lax.mul(x2r, rat1), lax.mul(x2i, rat2))))
    return lax.convert_element_type(out, dtype)
  else:
    return _float_divmod(x1, x2)[0]
예제 #24
0
파일: jet.py 프로젝트: MichaelMarien/jax
def _atan2_taylor(primals_in, series_in):
  x, y = primals_in
  primal_out = lax.atan2(x, y)

  x, series = jet(lax.div, primals_in, series_in)
  c0, cs = jet(lambda x: lax.div(1, 1 + lax.square(x)), (x, ), (series, ))
  c = [c0] + cs
  u = [x] + series
  v = [primal_out] + [None] * len(series)
  for k in range(1, len(v)):
    v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))
  primal_out, *series_out = v
  return primal_out, series_out
예제 #25
0
파일: special.py 프로젝트: jbampton/jax
def multigammaln(a, d):
    d = core.concrete_or_error(int, d, "d argument of multigammaln")
    a, d_ = _promote_args_inexact("multigammaln", a, d)

    constant = lax.mul(
        lax.mul(lax.mul(_lax_const(a, 0.25), d_),
                lax.sub(d_, _lax_const(a, 1))), lax.log(_lax_const(a, np.pi)))
    b = lax.div(jnp.arange(d, dtype=d_.dtype), _lax_const(a, 2))
    res = jnp.sum(gammaln(
        jnp.expand_dims(a, axis=-1) -
        jnp.expand_dims(b, axis=tuple(range(a.ndim)))),
                  axis=-1)
    return res + constant
예제 #26
0
def _var(a,
         axis: Optional[Union[int, Tuple[int, ...]]] = None,
         dtype=None,
         out=None,
         ddof=0,
         keepdims=False,
         *,
         where=None):
    _check_arraylike("var", a)
    lax_internal._check_user_dtype_supported(dtype, "var")
    if out is not None:
        raise NotImplementedError(
            "The 'out' argument to jnp.var is not supported.")

    computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype)
    a = a.astype(computation_dtype)
    a_mean = mean(a, axis, dtype=computation_dtype, keepdims=True, where=where)
    centered = lax.sub(a, a_mean)
    if dtypes.issubdtype(centered.dtype, np.complexfloating):
        centered = lax.real(lax.mul(centered, lax.conj(centered)))
    else:
        centered = lax.square(centered)

    if where is None:
        if axis is None:
            normalizer = core.dimension_as_value(np.size(a))
        else:
            normalizer = core.dimension_as_value(_axis_size(a, axis))
    else:
        normalizer = sum(_broadcast_to(where, np.shape(a)),
                         axis,
                         dtype=dtype,
                         keepdims=keepdims)
    normalizer = normalizer - ddof

    result = sum(centered, axis, keepdims=keepdims, where=where)
    out = lax.div(result, lax.convert_element_type(normalizer, result.dtype))
    return lax.convert_element_type(out, dtype)
예제 #27
0
 def variance(self):
     # var is inf for alpha <= 2
     a = lax.div((self.scale**2) * self.alpha,
                 (self.alpha - 1)**2 * (self.alpha - 2))
     return np.where(self.alpha <= 2, np.inf, a)
예제 #28
0
 def mean(self):
     # mean is inf for alpha <= 1
     a = lax.div(self.alpha * self.scale, (self.alpha - 1))
     return np.where(self.alpha <= 1, np.inf, a)
예제 #29
0
파일: special.py 프로젝트: GregCT/jax
def expit(x):
    x = asarray(x)
    one = lax._const(x, 1)
    return lax.div(one, lax.add(one, lax.exp(lax.neg(x))))
예제 #30
0
파일: special.py 프로젝트: GregCT/jax
def logit(x):
    x = asarray(x)
    return lax.log(lax.div(x, lax.sub(lax._const(x, 1), x)))