Ejemplo n.º 1
def multigammaln(a, d):
    d = core.concrete_or_error(int, d, "d argument of multigammaln")
    a, d_ = _promote_args_inexact("multigammaln", a, d)

    constant = lax.mul(
        lax.mul(lax.mul(_constant_like(a, 0.25), d_),
                lax.sub(d_, _constant_like(a, 1))),
        lax.log(_constant_like(a, np.pi)))
    res = jnp.sum(gammaln(
        jnp.expand_dims(a, axis=-1) -
        lax.div(jnp.arange(d, dtype=d_.dtype), _constant_like(a, 2))),
    return res + constant
Ejemplo n.º 2
def logpdf(x, df, loc=0, scale=1):
    x, df, loc, scale = _promote_args_inexact("chi2.logpdf", x, df, loc, scale)
    one = _constant_like(x, 1)
    two = _constant_like(x, 2)
    y = lax.div(lax.sub(x, loc), scale)
    df_on_two = lax.div(df, two)

    kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two))

    nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two)))

    log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel)
    return where(lax.lt(x, loc), -inf, log_probs)
Ejemplo n.º 3
def logpdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("cauchy.logpdf", x, loc, scale)
    pi = _constant_like(x, np.pi)
    scaled_x = lax.div(lax.sub(x, loc), scale)
    normalize_term = lax.log(lax.mul(pi, scale))
    return lax.neg(
        lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x))))
Ejemplo n.º 4
Archivo: t.py Proyecto: yashk2810/jax
def logpdf(x, df, loc=0, scale=1):
    x, df, loc, scale = _promote_args_inexact("t.logpdf", x, df, loc, scale)
    two = _constant_like(x, 2)
    scaled_x = lax.div(lax.sub(x, loc), scale)
    df_over_two = lax.div(df, two)
    df_plus_one_over_two = lax.add(df_over_two, _constant_like(x, 0.5))
    normalize_term_const = lax.mul(lax.mul(scale, scale),
                                   _constant_like(x, np.pi))
    normalize_term_tmp = lax.div(lax.log(lax.mul(normalize_term_const, df)),
    normalize_term = lax.sub(
        lax.add(lax.lgamma(df_over_two), normalize_term_tmp),
    quadratic = lax.div(lax.mul(scaled_x, scaled_x), df)
    return lax.neg(
                lax.mul(df_plus_one_over_two, lax.log1p(quadratic))))
Ejemplo n.º 5
def _eval_expint_k(A, B, x):
    # helper function for all subsequent intervals
    A, B = [jnp.array(U, dtype=x.dtype) for U in [A, B]]
    one = _constant_like(x, 1.0)
    w = one / x
    f = jnp.polyval(A, w) / jnp.polyval(B, w)
    f = w * f + one
    return jnp.exp(x) * w * f
Ejemplo n.º 6
def logpdf(x, b, loc=0, scale=1):
    x, b, loc, scale = _promote_args_inexact("pareto.logpdf", x, b, loc, scale)
    one = _constant_like(x, 1)
    scaled_x = lax.div(lax.sub(x, loc), scale)
    normalize_term = lax.log(lax.div(scale, b))
    log_probs = lax.neg(
        lax.add(normalize_term, lax.mul(lax.add(b, one), lax.log(scaled_x))))
    return where(lax.lt(x, lax.add(loc, scale)), -inf, log_probs)
Ejemplo n.º 7
def logpdf(x, a, loc=0, scale=1):
    x, a, loc, scale = _promote_args_inexact("gamma.logpdf", x, a, loc, scale)
    one = _constant_like(x, 1)
    y = lax.div(lax.sub(x, loc), scale)
    log_linear_term = lax.sub(xlogy(lax.sub(a, one), y), y)
    shape_terms = lax.add(gammaln(a), lax.log(scale))
    log_probs = lax.sub(log_linear_term, shape_terms)
    return where(lax.lt(x, loc), -inf, log_probs)
Ejemplo n.º 8
def logpmf(k, n, a, b, loc=0):
    """JAX implementation of scipy.stats.betabinom.logpmf."""
    k, n, a, b, loc = _promote_args_inexact("betabinom.logpmf", k, n, a, b,
    y = lax.sub(lax.floor(k), loc)
    one = _constant_like(y, 1)
    zero = _constant_like(y, 0)
    combiln = lax.neg(
                betaln(lax.add(lax.sub(n, y), one), lax.add(y, one))))
    beta_lns = lax.sub(betaln(lax.add(y, a), lax.add(lax.sub(n, y), b)),
                       betaln(a, b))
    log_probs = lax.add(combiln, beta_lns)
    y_cond = logical_or(lax.lt(y, lax.neg(loc)), lax.gt(y, lax.sub(n, loc)))
    log_probs = where(y_cond, -inf, log_probs)
    n_a_b_cond = logical_or(logical_or(lax.lt(n, one), lax.lt(a, zero)),
                            lax.lt(b, zero))
    return where(n_a_b_cond, nan, log_probs)
Ejemplo n.º 9
def logpdf(x, alpha):
    args = (np.ones((0, ), lax.dtype(x)), np.ones((1, ), lax.dtype(alpha)))
    to_dtype = lax.dtype(osp_stats.dirichlet.logpdf(*args))
    x, alpha = [lax.convert_element_type(arg, to_dtype) for arg in (x, alpha)]
    one = jnp._constant_like(x, 1)
    normalize_term = jnp.sum(gammaln(alpha), axis=-1) - gammaln(
        jnp.sum(alpha, axis=-1))
    log_probs = lax.sub(jnp.sum(xlogy(lax.sub(alpha, one), x), axis=-1),
    return jnp.where(_is_simplex(x), log_probs, -jnp.inf)
Ejemplo n.º 10
def logpmf(k, n, p, loc=0):
    """JAX implementation of scipy.stats.nbinom.logpmf."""
    k, n, p, loc = _promote_args_inexact("nbinom.logpmf", k, n, p, loc)
    one = _constant_like(k, 1)
    y = lax.sub(k, loc)
    comb_term = lax.sub(
        lax.sub(gammaln(lax.add(y, n)), gammaln(n)), gammaln(lax.add(y, one))
    log_linear_term = lax.add(xlogy(n, p), xlogy(y, lax.sub(one, p)))
    log_probs = lax.add(comb_term, log_linear_term)
    return where(lax.lt(k, loc), -inf, log_probs)
Ejemplo n.º 11
def logpdf(x, a, b, loc=0, scale=1):
    x, a, b, loc, scale = _promote_args_inexact("beta.logpdf", x, a, b, loc,
    one = _constant_like(x, 1)
    shape_term = lax.neg(betaln(a, b))
    y = lax.div(lax.sub(x, loc), scale)
    log_linear_term = lax.add(xlogy(lax.sub(a, one), y),
                              xlog1py(lax.sub(b, one), lax.neg(y)))
    log_probs = lax.sub(lax.add(shape_term, log_linear_term), lax.log(scale))
    return where(logical_or(lax.gt(x, lax.add(loc, scale)), lax.lt(x, loc)),
                 -inf, log_probs)
Ejemplo n.º 12
def _lu_jvp_rule(primals, tangents):
    a, = primals
    a_dot, = tangents
    lu, pivots, permutation = lu_p.bind(a)

    a_shape = jnp.shape(a)
    m, n = a_shape[-2:]
    dtype = lax.dtype(a)
    k = min(m, n)

    batch_dims = a_shape[:-2]
    iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1, )))
    x = a_dot[iotas[:-1] + (permutation, slice(None))]

    # Differentiation of Matrix Functionals Using Triangular Factorization
    # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas
    #     LU = A
    # ==> L'U + LU' = A'
    # ==> inv(L) . L' + U' . inv(U) = inv(L) A' inv(U)
    # ==> L' = L . tril(inv(L) . A' . inv(U), -1)
    #     U' = triu(inv(L) . A' . inv(U)) . U

    ndims = len(a_shape)
    l_padding = [(0, 0, 0)] * ndims
    l_padding[-1] = (0, m - k, 0)
    zero = jnp._constant_like(lu, 0)
    l = lax.pad(jnp.tril(lu[..., :, :k], -1), zero, l_padding)
    l = l + jnp.eye(m, m, dtype=dtype)

    u_eye = lax.pad(jnp.eye(n - k, n - k, dtype=dtype), zero,
                    ((k, 0, 0), (k, 0, 0)))
    u_padding = [(0, 0, 0)] * ndims
    u_padding[-2] = (0, n - k, 0)
    u = lax.pad(jnp.triu(lu[..., :k, :]), zero, u_padding) + u_eye

    la = triangular_solve(l,
    lau = triangular_solve(u,

    l_dot = jnp.matmul(l, jnp.tril(lau, -1))
    u_dot = jnp.matmul(jnp.triu(lau), u)
    lu_dot = l_dot + u_dot
    return (lu, pivots, permutation), (lu_dot, ad_util.Zero.from_value(pivots),
Ejemplo n.º 13
def logpdf(x, alpha):
  x, alpha = _promote_dtypes_inexact(x, alpha)
  if alpha.ndim != 1:
    raise ValueError(
      f"`alpha` must be one-dimensional; got alpha.shape={alpha.shape}"
  if x.shape[0] not in (alpha.shape[0], alpha.shape[0] - 1):
    raise ValueError(
      "`x` must have either the same number of entries as `alpha` "
      f"or one entry fewer; got x.shape={x.shape}, alpha.shape={alpha.shape}"
  one = jnp._constant_like(x, 1)
  if x.shape[0] != alpha.shape[0]:
    x = jnp.concatenate([x, lax.sub(one, x.sum(0, keepdims=True))], axis=0)
  normalize_term = jnp.sum(gammaln(alpha)) - gammaln(jnp.sum(alpha))
  if x.ndim > 1:
    alpha = lax.broadcast_in_dim(alpha, alpha.shape + (1,) * (x.ndim - 1), (0,))
  log_probs = lax.sub(jnp.sum(xlogy(lax.sub(alpha, one), x), axis=0), normalize_term)
  return jnp.where(_is_simplex(x), log_probs, -jnp.inf)
Ejemplo n.º 14
def _norm_logpdf(x):
    neg_half = _constant_like(x, -0.5)
    log_normalizer = _constant_like(x, _norm_logpdf_constant)
    return lax.sub(lax.mul(neg_half, lax.square(x)), log_normalizer)
Ejemplo n.º 15
def entr(x):
    x, = _promote_args_inexact("entr", x)
    return lax.select(lax.lt(x, _constant_like(x, 0)),
                      lax.full_like(x, -np.inf), lax.neg(xlogy(x, x)))
Ejemplo n.º 16
 def phi(X):
     l = jnp.tril(X)
     return l / (jnp._constant_like(X, 1) +
                 jnp.eye(X.shape[-1], dtype=X.dtype))
Ejemplo n.º 17
def logpdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("norm.logpdf", x, loc, scale)
    scale_sqrd = lax.square(scale)
    log_normalizer = lax.log(lax.mul(_constant_like(x, 2 * np.pi), scale_sqrd))
    quadratic = lax.div(lax.square(lax.sub(x, loc)), scale_sqrd)
    return lax.div(lax.add(log_normalizer, quadratic), _constant_like(x, -2))
Ejemplo n.º 18
def logpmf(k, mu, loc=0):
    k, mu, loc = jnp._promote_args_inexact("poisson.logpmf", k, mu, loc)
    zero = jnp._constant_like(k, 0)
    x = lax.sub(k, loc)
    log_probs = xlogy(x, mu) - gammaln(x + 1) - mu
    return jnp.where(lax.lt(x, zero), -jnp.inf, log_probs)
Ejemplo n.º 19
def cdf(k, mu, loc=0):
    k, mu, loc = jnp._promote_args_inexact("poisson.logpmf", k, mu, loc)
    zero = jnp._constant_like(k, 0)
    x = lax.sub(k, loc)
    p = gammaincc(jnp.floor(1 + x), mu)
    return jnp.where(lax.lt(x, zero), zero, p)
Ejemplo n.º 20
def logpdf(x, loc=0, scale=1):
    x, loc, scale = _promote_args_inexact("laplace.logpdf", x, loc, scale)
    two = _constant_like(x, 2)
    linear_term = lax.div(lax.abs(lax.sub(x, loc)), scale)
    return lax.neg(lax.add(linear_term, lax.log(lax.mul(two, scale))))
Ejemplo n.º 21
def expn_jvp(n, primals, tangents):
    (x, ), (x_dot, ) = primals, tangents
    return expn(n, x), lax.mul(lax.neg(x_dot),
                               expn(lax.sub(n, _constant_like(n, 1)), x))
Ejemplo n.º 22
 def make_constant(c):
     return lax_numpy._constant_like(tensor, c).astype(dtype)
Ejemplo n.º 23
def softmax(attn_weights, norm_dims, dtype, softmax_hparams, quant_context):
    """Normalizes attention."""
    a = attn_weights

    def unquantized_softmax(a):
        a = lax.exp(
            a - jax.scipy.special.logsumexp(a, axis=norm_dims, keepdims=True))
        return a.astype(dtype)

    # Quantize intermediate activations with QuantOps.
    # Currently only supports unscaled floating-point formats.
    def quantized_softmax(a):
        # We compute softmax as exp(x-max(x))/sum_i(exp(x_i-max(x))), quantizing
        # intermediate values. Note this differs from the log-domain
        # implementation of softmax used above.
        quant_hparams = softmax_hparams.quant_hparams
        fp_quant_config = QuantOps.FloatQuant(is_scaled=False,
        quant_ops = QuantOps.create_symmetric_fp(fp_quant=fp_quant_config,

        a = quant_ops.to_quantized(a, dtype=dtype)
        # Note that the max of a quantized vector is necessarily also quantized to
        # the same precision since the max of a vector must be an existing element
        # of the vector, so we don't need to explicitly insert a quantization
        # operator to the output of the max reduction.
        a_max = jnp.max(a, axis=norm_dims, keepdims=True)
        a_minus_max = quant_ops.to_quantized(a - a_max, dtype=dtype)
        a_exp = quant_ops.to_quantized(jnp.exp(a_minus_max), dtype=dtype)

        sum_exp_quantized_reduction = quantization.quantized_sum(
        sum_exp = quant_ops.to_quantized(sum_exp_quantized_reduction,

        inv_sum_exp = quant_ops.to_quantized(jnp.reciprocal(sum_exp),
        a_softmax = quant_ops.to_quantized(a_exp * inv_sum_exp, dtype=dtype)

        return a_softmax.astype(dtype)

    # If no params, return accurate Softmax.
    if softmax_hparams == SoftmaxHParams(None, None,
                                         None) or softmax_hparams is None:
        return unquantized_softmax(a)

    # TODO(shivaniagrawal): Partial sum quantization (if enabled) will happen for
    # the entire training run, even before the global activation start step.
    if softmax_hparams.quant_hparams is not None:
        return lax.cond(quant_context.quantize_acts, quantized_softmax,
                        unquantized_softmax, a)

    # Approximated Softmax
    exp_hparams = softmax_hparams.exp_hparams
    recip_hparams = softmax_hparams.reciprocal_hparams

    # Substract max value from dimensions to be normalized.
    shape = jax.util.subvals(onp.shape(a),
                             zip(norm_dims, (1, ) * len(norm_dims)))
    dimadd = lambda x: lax.reshape(x, shape)
    # pylint: disable=protected-access
    amax = lax.reduce(a, lax_numpy._constant_like(a, -onp.inf), lax.max,
    amax = lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0))
    amax_singletons = dimadd(amax)
    asubmax = lax.sub(a, amax_singletons)

    # Calculate approximated exponential
    approx_exp = exponential(asubmax, dtype, exp_hparams)

    # If sum_high_bound: Upper clip bound for sum(exp(x-M)).
    asumexp = dimadd(
        lax.reduce(approx_exp, lax_numpy._constant_like(a, 0), lax.add,

    if exp_hparams.sum_high_bound is not None and exp_hparams.sum_high_bound != 0:
        sum_low_bound = 1.
        if (exp_hparams.low_bound != 0) and exp_hparams.clip_and_subtract:
            sum_low_bound = 1 - onp.exp(exp_hparams.low_bound)
        asumexp = jnp.clip(asumexp, sum_low_bound, exp_hparams.sum_high_bound)

    # Approximation of reciprocal.
    arecip = reciprocal(asumexp, dtype, recip_hparams)
    return lax.mul(approx_exp, arecip).astype(dtype)