Beispiel #1
0
def logpmf(k, p, loc=0):
    k, p, loc = jnp._promote_args_inexact("geom.logpmf", k, p, loc)
    zero = lax._const(k, 0)
    one = lax._const(k, 1)
    x = lax.sub(k, loc)
    log_probs = xlog1py(lax.sub(x, one), -p) + lax.log(p)
    return jnp.where(lax.le(x, zero), -jnp.inf, log_probs)
Beispiel #2
0
def cdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("laplace.cdf", x, loc, scale)
    half = lax._const(x, 0.5)
    one = lax._const(x, 1)
    zero = lax._const(x, 0)
    diff = lax.div(lax.sub(x, loc), scale)
    return lax.select(lax.le(diff, zero), lax.mul(half, lax.exp(diff)),
                      lax.sub(one, lax.mul(half, lax.exp(lax.neg(diff)))))
Beispiel #3
0
def logpmf(k, p, loc=0):
  k, p, loc = jnp._promote_args_inexact("bernoulli.logpmf", k, p, loc)
  zero = lax._const(k, 0)
  one = lax._const(k, 1)
  x = lax.sub(k, loc)
  log_probs = xlogy(x, p) + xlog1py(lax.sub(one, x), -p)
  return jnp.where(jnp.logical_or(lax.lt(x, zero), lax.gt(x, one)),
                  -jnp.inf, log_probs)
Beispiel #4
0
def multigammaln(a, d):
    d = core.concrete_or_error(int, d, "d argument of multigammaln")
    a, d_ = _promote_args_inexact("multigammaln", a, d)

    constant = lax.mul(
        lax.mul(lax.mul(lax._const(a, 0.25), d_),
                lax.sub(d_, lax._const(a, 1))), lax.log(lax._const(a, np.pi)))
    b = lax.div(jnp.arange(d, dtype=d_.dtype), lax._const(a, 2))
    res = jnp.sum(gammaln(
        jnp.expand_dims(a, axis=-1) -
        jnp.expand_dims(b, axis=tuple(range(a.ndim)))),
                  axis=-1)
    return res + constant
Beispiel #5
0
def logpdf(x, df, loc=0, scale=1):
    x, df, loc, scale = _promote_args_inexact("t.logpdf", x, df, loc, scale)
    two = lax._const(x, 2)
    scaled_x = lax.div(lax.sub(x, loc), scale)
    df_over_two = lax.div(df, two)
    df_plus_one_over_two = lax.add(df_over_two, lax._const(x, 0.5))
    normalize_term_const = lax.mul(lax.mul(scale, scale), lax._const(x, np.pi))
    normalize_term_tmp = lax.div(lax.log(lax.mul(normalize_term_const, df)),
                                 two)
    normalize_term = lax.sub(
        lax.add(lax.lgamma(df_over_two), normalize_term_tmp),
        lax.lgamma(df_plus_one_over_two))
    quadratic = lax.div(lax.mul(scaled_x, scaled_x), df)
    return lax.neg(
        lax.add(normalize_term,
                lax.mul(df_plus_one_over_two, lax.log1p(quadratic))))
Beispiel #6
0
def logpdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("cauchy.logpdf", x, loc, scale)
    pi = lax._const(x, np.pi)
    scaled_x = lax.div(lax.sub(x, loc), scale)
    normalize_term = lax.log(lax.mul(pi, scale))
    return lax.neg(
        lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x))))
Beispiel #7
0
def threefry_seed(seed: int) -> jnp.ndarray:
    """Create a single raw threefry PRNG key given an integer seed.

  Args:
    seed: a 64- or 32-bit integer used as the value of the key.

  Returns:
    The PRNG key contents, modeled as an array of shape (2,) and dtype
    uint32. The key is constructed from a 64-bit seed by effectively
    bit-casting to a pair of uint32 values (or from a 32-bit seed by
    first padding out with zeros).
  """
    # Avoid overflowerror in X32 mode by first converting ints to int64.
    # This breaks JIT invariance for large ints, but supports the common
    # use-case of instantiating with Python hashes in X32 mode.
    if isinstance(seed, int):
        seed_arr = jnp.asarray(np.int64(seed))
    else:
        seed_arr = jnp.asarray(seed)
    if seed_arr.shape:
        raise TypeError(f"PRNG key seed must be a scalar; got {seed!r}.")
    if not np.issubdtype(seed_arr.dtype, np.integer):
        raise TypeError(f"PRNG key seed must be an integer; got {seed!r}")

    convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32),
                                    [1])
    k1 = convert(lax.shift_right_logical(seed_arr, lax._const(seed_arr, 32)))
    k2 = convert(jnp.bitwise_and(seed_arr, np.uint32(0xFFFFFFFF)))
    return lax.concatenate([k1, k2], 0)
Beispiel #8
0
def logpdf(x, df, loc=0, scale=1):
    x, df, loc, scale = _promote_args_inexact("chi2.logpdf", x, df, loc, scale)
    one = lax._const(x, 1)
    two = lax._const(x, 2)
    y = lax.div(lax.sub(x, loc), scale)
    df_on_two = lax.div(df, two)

    kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)),
                     lax.div(y, two))

    nrml_cnst = lax.neg(
        lax.add(lax.lgamma(df_on_two), lax.div(lax.mul(lax.log(two), df),
                                               two)))

    log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel)
    return where(lax.lt(x, loc), -inf, log_probs)
Beispiel #9
0
def logpdf(x, b, loc=0, scale=1):
    x, b, loc, scale = _promote_args_inexact("pareto.logpdf", x, b, loc, scale)
    one = lax._const(x, 1)
    scaled_x = lax.div(lax.sub(x, loc), scale)
    normalize_term = lax.log(lax.div(scale, b))
    log_probs = lax.neg(
        lax.add(normalize_term, lax.mul(lax.add(b, one), lax.log(scaled_x))))
    return where(lax.lt(x, lax.add(loc, scale)), -inf, log_probs)
Beispiel #10
0
def _eval_expint_k(A, B, x):
    # helper function for all subsequent intervals
    A, B = [jnp.array(U, dtype=x.dtype) for U in [A, B]]
    one = lax._const(x, 1.0)
    w = one / x
    f = jnp.polyval(A, w) / jnp.polyval(B, w)
    f = w * f + one
    return jnp.exp(x) * w * f
Beispiel #11
0
def logpdf(x, a, loc=0, scale=1):
    x, a, loc, scale = _promote_args_inexact("gamma.logpdf", x, a, loc, scale)
    one = lax._const(x, 1)
    y = lax.div(lax.sub(x, loc), scale)
    log_linear_term = lax.sub(xlogy(lax.sub(a, one), y), y)
    shape_terms = lax.add(gammaln(a), lax.log(scale))
    log_probs = lax.sub(log_linear_term, shape_terms)
    return where(lax.lt(x, loc), -inf, log_probs)
Beispiel #12
0
def logpmf(k, n, a, b, loc=0):
    """JAX implementation of scipy.stats.betabinom.logpmf."""
    k, n, a, b, loc = _promote_args_inexact("betabinom.logpmf", k, n, a, b,
                                            loc)
    y = lax.sub(lax.floor(k), loc)
    one = lax._const(y, 1)
    zero = lax._const(y, 0)
    combiln = lax.neg(
        lax.add(lax.log1p(n),
                betaln(lax.add(lax.sub(n, y), one), lax.add(y, one))))
    beta_lns = lax.sub(betaln(lax.add(y, a), lax.add(lax.sub(n, y), b)),
                       betaln(a, b))
    log_probs = lax.add(combiln, beta_lns)
    y_cond = logical_or(lax.lt(y, lax.neg(loc)), lax.gt(y, lax.sub(n, loc)))
    log_probs = where(y_cond, -inf, log_probs)
    n_a_b_cond = logical_or(logical_or(lax.lt(n, one), lax.lt(a, zero)),
                            lax.lt(b, zero))
    return where(n_a_b_cond, nan, log_probs)
Beispiel #13
0
def logpmf(k, n, p, loc=0):
    """JAX implementation of scipy.stats.nbinom.logpmf."""
    k, n, p, loc = _promote_args_inexact("nbinom.logpmf", k, n, p, loc)
    one = lax._const(k, 1)
    y = lax.sub(k, loc)
    comb_term = lax.sub(lax.sub(gammaln(lax.add(y, n)), gammaln(n)),
                        gammaln(lax.add(y, one)))
    log_linear_term = lax.add(xlogy(n, p), xlogy(y, lax.sub(one, p)))
    log_probs = lax.add(comb_term, log_linear_term)
    return where(lax.lt(k, loc), -inf, log_probs)
Beispiel #14
0
def logpdf(x, a, b, loc=0, scale=1):
    x, a, b, loc, scale = _promote_args_inexact("beta.logpdf", x, a, b, loc,
                                                scale)
    one = lax._const(x, 1)
    shape_term = lax.neg(betaln(a, b))
    y = lax.div(lax.sub(x, loc), scale)
    log_linear_term = lax.add(xlogy(lax.sub(a, one), y),
                              xlog1py(lax.sub(b, one), lax.neg(y)))
    log_probs = lax.sub(lax.add(shape_term, log_linear_term), lax.log(scale))
    return where(logical_or(lax.gt(x, lax.add(loc, scale)), lax.lt(x, loc)),
                 -inf, log_probs)
Beispiel #15
0
def logpdf(x, alpha):
  x, alpha = _promote_dtypes_inexact(x, alpha)
  if alpha.ndim != 1:
    raise ValueError(
      f"`alpha` must be one-dimensional; got alpha.shape={alpha.shape}"
    )
  if x.shape[0] not in (alpha.shape[0], alpha.shape[0] - 1):
    raise ValueError(
      "`x` must have either the same number of entries as `alpha` "
      f"or one entry fewer; got x.shape={x.shape}, alpha.shape={alpha.shape}"
    )
  one = lax._const(x, 1)
  if x.shape[0] != alpha.shape[0]:
    x = jnp.concatenate([x, lax.sub(one, x.sum(0, keepdims=True))], axis=0)
  normalize_term = jnp.sum(gammaln(alpha)) - gammaln(jnp.sum(alpha))
  if x.ndim > 1:
    alpha = lax.broadcast_in_dim(alpha, alpha.shape + (1,) * (x.ndim - 1), (0,))
  log_probs = lax.sub(jnp.sum(xlogy(lax.sub(alpha, one), x), axis=0), normalize_term)
  return jnp.where(_is_simplex(x), log_probs, -jnp.inf)
Beispiel #16
0
def logpdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("norm.logpdf", x, loc, scale)
    scale_sqrd = lax.square(scale)
    log_normalizer = lax.log(lax.mul(lax._const(x, 2 * np.pi), scale_sqrd))
    quadratic = lax.div(lax.square(lax.sub(x, loc)), scale_sqrd)
    return lax.div(lax.add(log_normalizer, quadratic), lax._const(x, -2))
Beispiel #17
0
def entr(x):
    x, = _promote_args_inexact("entr", x)
    return lax.select(lax.lt(x, lax._const(x, 0)), lax.full_like(x, -np.inf),
                      lax.neg(xlogy(x, x)))
Beispiel #18
0
def deriv_prop(prim, deriv, primals_in, series_in):
  x, = primals_in
  series, = series_in
  primal_out = prim.bind(x)
  c0, cs = jet(deriv, primals_in, series_in)
  c = [c0] + cs
  u = [x] + series
  v = [primal_out] + [None] * len(series)
  for k in range(1, len(v)):
    v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))
  primal_out, *series_out = v
  return primal_out, series_out


def_deriv(lax.erf_p, lambda x: lax.mul(lax._const(x, 2. / np.sqrt(np.pi)), lax.exp(lax.neg(lax.square(x)))))


def def_comp(prim, comp):
  """
  Define the jet rule for a primitive in terms of a composition of simpler primitives.
  """
  jet_rules[prim] = partial(jet, comp)


def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1)
def_comp(lax.log1p_p, lambda x: lax.log(1 + x))
def_comp(lax.sqrt_p, lambda x: x ** 0.5)
def_comp(lax.rsqrt_p, lambda x: x ** -0.5)
def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)))
def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)))
Beispiel #19
0
def expit(x):
    x = asarray(x)
    one = lax._const(x, 1)
    return lax.div(one, lax.add(one, lax.exp(lax.neg(x))))
Beispiel #20
0
@_wraps(osp_special.erfinv)
def erfinv(x):
    x, = _promote_args_inexact("erfinv", x)
    return lax.erf_inv(x)


@api.custom_jvp
@_wraps(osp_special.logit, update_doc=False)
def logit(x):
    x = asarray(x)
    return lax.log(lax.div(x, lax.sub(lax._const(x, 1), x)))


logit.defjvps(
    lambda g, ans, x: lax.div(g, lax.mul(x, lax.sub(lax._const(x, 1), x))))


@api.custom_jvp
@_wraps(osp_special.expit, update_doc=False)
def expit(x):
    x = asarray(x)
    one = lax._const(x, 1)
    return lax.div(one, lax.add(one, lax.exp(lax.neg(x))))


expit.defjvps(lambda g, ans, x: g * ans * (lax._const(ans, 1) - ans))


@_wraps(osp_special.logsumexp)
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
Beispiel #21
0
def _norm_logpdf(x):
    neg_half = lax._const(x, -0.5)
    log_normalizer = lax._const(x, _norm_logpdf_constant)
    return lax.sub(lax.mul(neg_half, lax.square(x)), log_normalizer)
Beispiel #22
0
Datei: jet.py Projekt: 0x0is1/jax
    x, = primals_in
    series, = series_in
    primal_out = prim.bind(x)
    c0, cs = jet(deriv, primals_in, series_in)
    c = [c0] + cs
    u = [x] + series
    v = [primal_out] + [None] * len(series)
    for k in range(1, len(v)):
        v[k] = fact(k - 1) * sum(
            _scale(k, j) * c[k - j] * u[j] for j in range(1, k + 1))
    primal_out, *series_out = v
    return primal_out, series_out


def_deriv(
    lax.erf_p, lambda x: lax.mul(lax._const(x, 2. / np.sqrt(np.pi)),
                                 lax.exp(lax.neg(lax.square(x)))))


def def_comp(prim, comp):
    """
  Define the jet rule for a primitive in terms of a composition of simpler primitives.
  """
    jet_rules[prim] = partial(jet, comp)


def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1)
def_comp(lax.log1p_p, lambda x: lax.log(1 + x))
def_comp(lax.sqrt_p, lambda x: x**0.5)
def_comp(lax.rsqrt_p, lambda x: x**-0.5)
def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)))
Beispiel #23
0
def logpmf(k, mu, loc=0):
    k, mu, loc = jnp._promote_args_inexact("poisson.logpmf", k, mu, loc)
    zero = lax._const(k, 0)
    x = lax.sub(k, loc)
    log_probs = xlogy(x, mu) - gammaln(x + 1) - mu
    return jnp.where(lax.lt(x, zero), -jnp.inf, log_probs)
Beispiel #24
0
def expn_jvp(n, primals, tangents):
    (x, ), (x_dot, ) = primals, tangents
    return expn(n, x), lax.mul(lax.neg(x_dot),
                               expn(lax.sub(n, lax._const(n, 1)), x))
Beispiel #25
0
def logit(x):
    x = asarray(x)
    return lax.log(lax.div(x, lax.sub(lax._const(x, 1), x)))
Beispiel #26
0
def cdf(k, mu, loc=0):
    k, mu, loc = jnp._promote_args_inexact("poisson.logpmf", k, mu, loc)
    zero = lax._const(k, 0)
    x = lax.sub(k, loc)
    p = gammaincc(jnp.floor(1 + x), mu)
    return jnp.where(lax.lt(x, zero), zero, p)
Beispiel #27
0
def norm(x,
         ord=None,
         axis: Union[None, Tuple[int, ...], int] = None,
         keepdims=False):
    x = _promote_arg_dtypes(jnp.asarray(x))
    x_shape = jnp.shape(x)
    ndim = len(x_shape)

    if axis is None:
        # NumPy has an undocumented behavior that admits arbitrary rank inputs if
        # `ord` is None: https://github.com/numpy/numpy/issues/14215
        if ord is None:
            return jnp.sqrt(
                jnp.sum(jnp.real(x * jnp.conj(x)), keepdims=keepdims))
        axis = tuple(range(ndim))
    elif isinstance(axis, tuple):
        axis = tuple(canonicalize_axis(x, ndim) for x in axis)
    else:
        axis = (canonicalize_axis(axis, ndim), )

    num_axes = len(axis)
    if num_axes == 1:
        if ord is None or ord == 2:
            return jnp.sqrt(
                jnp.sum(jnp.real(x * jnp.conj(x)),
                        axis=axis,
                        keepdims=keepdims))
        elif ord == jnp.inf:
            return jnp.amax(jnp.abs(x), axis=axis, keepdims=keepdims)
        elif ord == -jnp.inf:
            return jnp.amin(jnp.abs(x), axis=axis, keepdims=keepdims)
        elif ord == 0:
            return jnp.sum(x != 0,
                           dtype=jnp.finfo(lax.dtype(x)).dtype,
                           axis=axis,
                           keepdims=keepdims)
        elif ord == 1:
            # Numpy has a special case for ord == 1 as an optimization. We don't
            # really need the optimization (XLA could do it for us), but the Numpy
            # code has slightly different type promotion semantics, so we need a
            # special case too.
            return jnp.sum(jnp.abs(x), axis=axis, keepdims=keepdims)
        else:
            abs_x = jnp.abs(x)
            ord = lax._const(abs_x, ord)
            out = jnp.sum(abs_x**ord, axis=axis, keepdims=keepdims)
            return jnp.power(out, 1. / ord)

    elif num_axes == 2:
        row_axis, col_axis = cast(Tuple[int, ...], axis)
        if ord is None or ord in ('f', 'fro'):
            return jnp.sqrt(
                jnp.sum(jnp.real(x * jnp.conj(x)),
                        axis=axis,
                        keepdims=keepdims))
        elif ord == 1:
            if not keepdims and col_axis > row_axis:
                col_axis -= 1
            return jnp.amax(jnp.sum(jnp.abs(x),
                                    axis=row_axis,
                                    keepdims=keepdims),
                            axis=col_axis,
                            keepdims=keepdims)
        elif ord == -1:
            if not keepdims and col_axis > row_axis:
                col_axis -= 1
            return jnp.amin(jnp.sum(jnp.abs(x),
                                    axis=row_axis,
                                    keepdims=keepdims),
                            axis=col_axis,
                            keepdims=keepdims)
        elif ord == jnp.inf:
            if not keepdims and row_axis > col_axis:
                row_axis -= 1
            return jnp.amax(jnp.sum(jnp.abs(x),
                                    axis=col_axis,
                                    keepdims=keepdims),
                            axis=row_axis,
                            keepdims=keepdims)
        elif ord == -jnp.inf:
            if not keepdims and row_axis > col_axis:
                row_axis -= 1
            return jnp.amin(jnp.sum(jnp.abs(x),
                                    axis=col_axis,
                                    keepdims=keepdims),
                            axis=row_axis,
                            keepdims=keepdims)
        elif ord in ('nuc', 2, -2):
            x = jnp.moveaxis(x, axis, (-2, -1))
            if ord == 2:
                reducer = jnp.amax
            elif ord == -2:
                reducer = jnp.amin
            else:
                reducer = jnp.sum
            y = reducer(svd(x, compute_uv=False), axis=-1)
            if keepdims:
                result_shape = list(x_shape)
                result_shape[axis[0]] = 1
                result_shape[axis[1]] = 1
                y = jnp.reshape(y, result_shape)
            return y
        else:
            raise ValueError("Invalid order '{}' for matrix norm.".format(ord))
    else:
        raise ValueError(
            "Invalid axis values ({}) for jnp.linalg.norm.".format(axis))
Beispiel #28
0
def logpdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("laplace.logpdf", x, loc, scale)
    two = lax._const(x, 2)
    linear_term = lax.div(lax.abs(lax.sub(x, loc)), scale)
    return lax.neg(lax.add(linear_term, lax.log(lax.mul(two, scale))))