def logpdf(x, a, b, loc=0, scale=1): x, a, b, loc, scale = _promote_args_inexact("beta.logpdf", x, a, b, loc, scale) one = lax._const(x, 1) shape_term = lax.neg(betaln(a, b)) y = lax.div(lax.sub(x, loc), scale) log_linear_term = lax.add(xlogy(lax.sub(a, one), y), xlog1py(lax.sub(b, one), lax.neg(y))) log_probs = lax.sub(lax.add(shape_term, log_linear_term), lax.log(scale)) return where(logical_or(lax.gt(x, lax.add(loc, scale)), lax.lt(x, loc)), -inf, log_probs)
def triangular_solve_jvp_rule_a(g_a, ans, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal): m, n = b.shape[-2:] k = 1 if unit_diagonal else 0 g_a = np.tril(g_a, k=-k) if lower else np.triu(g_a, k=k) g_a = lax.neg(g_a) g_a = np.swapaxes(g_a, -1, -2) if transpose_a else g_a g_a = np.conj(g_a) if conjugate_a else g_a dot = partial(lax.dot if g_a.ndim == 2 else lax.batch_matmul, precision=lax.Precision.HIGHEST) def a_inverse(rhs): return triangular_solve(a, rhs, left_side, lower, transpose_a, conjugate_a, unit_diagonal) # triangular_solve is about the same cost as matrix multplication (~n^2 FLOPs # for matrix/vector inputs). Order these operations in whichever order is # cheaper. if left_side: assert g_a.shape[-2:] == a.shape[-2:] == (m, m) and ans.shape[-2:] == ( m, n) if m > n: return a_inverse(dot(g_a, ans)) # A^{-1} (∂A X) else: return dot(a_inverse(g_a), ans) # (A^{-1} ∂A) X else: assert g_a.shape[-2:] == a.shape[-2:] == (n, n) and ans.shape[-2:] == ( m, n) if m < n: return a_inverse(dot(ans, g_a)) # (X ∂A) A^{-1} else: return dot(ans, a_inverse(g_a)) # X (∂A A^{-1})
def logpdf(x, loc=0, scale=1): x, loc, scale = _promote_args_inexact("norm.logpdf", x, loc, scale) two = _constant_like(x, 2) scale_sqrd = lax.pow(scale, two) log_normalizer = lax.log(lax.mul(_constant_like(x, 2 * np.pi), scale_sqrd)) quadratic = lax.div(lax.pow(lax.sub(x, loc), two), scale_sqrd) return lax.div(lax.neg(lax.add(log_normalizer, quadratic)), two)
def arccosh(x): # Note: arccosh is multi-valued for complex input, and lax.acosh uses a different # convention than np.arccosh. out = lax.acosh(*_promote_args_inexact("arccosh", x)) if dtypes.issubdtype(out.dtype, np.complexfloating): out = _where(real(out) < 0, lax.neg(out), out) return out
def logpdf(x, loc=0, scale=1): x, loc, scale = _promote_args_inexact("cauchy.logpdf", x, loc, scale) pi = _constant_like(x, np.pi) scaled_x = lax.div(lax.sub(x, loc), scale) normalize_term = lax.log(lax.mul(pi, scale)) return lax.neg( lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x))))
def cdf(x, loc=0, scale=1): x, loc, scale = _promote_args_inexact("laplace.cdf", x, loc, scale) half = _constant_like(x, 0.5) one = _constant_like(x, 1) zero = _constant_like(x, 0) diff = lax.div(lax.sub(x, loc), scale) return lax.select(lax.le(diff, zero), lax.mul(half, lax.exp(diff)), lax.sub(one, lax.mul(half, lax.exp(lax.neg(diff)))))
def logpdf(x, b, loc=0, scale=1): x, b, loc, scale = _promote_args_inexact("pareto.logpdf", x, b, loc, scale) one = _constant_like(x, 1) scaled_x = lax.div(lax.sub(x, loc), scale) normalize_term = lax.log(lax.div(scale, b)) log_probs = lax.neg( lax.add(normalize_term, lax.mul(lax.add(b, one), lax.log(scaled_x)))) return where(lax.lt(x, lax.add(loc, scale)), -inf, log_probs)
def logpmf(k, n, a, b, loc=0): """JAX implementation of scipy.stats.betabinom.logpmf.""" k, n, a, b, loc = _promote_args_inexact("betabinom.logpmf", k, n, a, b, loc) y = lax.sub(lax.floor(k), loc) one = _lax_const(y, 1) zero = _lax_const(y, 0) combiln = lax.neg( lax.add(lax.log1p(n), betaln(lax.add(lax.sub(n, y), one), lax.add(y, one)))) beta_lns = lax.sub(betaln(lax.add(y, a), lax.add(lax.sub(n, y), b)), betaln(a, b)) log_probs = lax.add(combiln, beta_lns) y_cond = logical_or(lax.lt(y, lax.neg(loc)), lax.gt(y, lax.sub(n, loc))) log_probs = where(y_cond, -inf, log_probs) n_a_b_cond = logical_or(logical_or(lax.lt(n, one), lax.lt(a, zero)), lax.lt(b, zero)) return where(n_a_b_cond, nan, log_probs)
def triangular_solve_jvp_rule_a(g_a, ans, a, b, left_side, lower, transpose_a, conjugate_a): g_a = lax.neg(g_a) g_a = np.swapaxes(g_a, -1, -2) if transpose_a else g_a tmp = triangular_solve(a, g_a, left_side, lower, transpose_a, conjugate_a) dot = lax.dot if g_a.ndim == 2 else lax.batch_matmul if left_side: return dot(tmp, ans) else: return dot(ans, tmp)
def logaddexp(x1, x2): x1, x2 = _promote_args_inexact("logaddexp", x1, x2) amax = lax.max(x1, x2) if dtypes.issubdtype(x1.dtype, np.floating): delta = lax.sub(x1, x2) return lax.select(lax_internal._isnan(delta), lax.add(x1, x2), # NaNs or infinities of the same sign. lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta)))))) else: delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2))) out = lax.add(amax, lax.log1p(lax.exp(delta))) return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi))
def logpdf(x, df, loc=0, scale=1): x, df, loc, scale = _promote_args_inexact("t.logpdf", x, df, loc, scale) two = _lax_const(x, 2) scaled_x = lax.div(lax.sub(x, loc), scale) df_over_two = lax.div(df, two) df_plus_one_over_two = lax.add(df_over_two, _lax_const(x, 0.5)) normalize_term_const = lax.mul(lax.mul(scale, scale), _lax_const(x, np.pi)) normalize_term_tmp = lax.div(lax.log(lax.mul(normalize_term_const, df)), two) normalize_term = lax.sub(lax.add(lax.lgamma(df_over_two), normalize_term_tmp), lax.lgamma(df_plus_one_over_two)) quadratic = lax.div(lax.mul(scaled_x, scaled_x), df) return lax.neg(lax.add(normalize_term, lax.mul(df_plus_one_over_two, lax.log1p(quadratic))))
def logpdf(x, df, loc=0, scale=1): x, df, loc, scale = _promote_args_inexact("chi2.logpdf", x, df, loc, scale) one = _constant_like(x, 1) two = _constant_like(x, 2) y = lax.div(lax.sub(x, loc), scale) df_on_two = lax.div(df, two) kernel = lax.sub(lax.mul(lax.sub(df_on_two, one), lax.log(y)), lax.div(y,two)) nrml_cnst = lax.neg(lax.add(lax.lgamma(df_on_two),lax.div(lax.mul(lax.log(two), df),two))) log_probs = lax.add(lax.sub(nrml_cnst, lax.log(scale)), kernel) return where(lax.lt(x, loc), -inf, log_probs)
def triangular_solve_jvp_rule_a(g_a, ans, a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal): k = 1 if unit_diagonal else 0 g_a = np.tril(g_a, k=-k) if lower else np.triu(g_a, k=k) g_a = lax.neg(g_a) g_a = np.swapaxes(g_a, -1, -2) if transpose_a else g_a g_a = np.conj(g_a) if conjugate_a else g_a tmp = triangular_solve(a, g_a, left_side, lower, transpose_a, conjugate_a, unit_diagonal) dot = lax.dot if g_a.ndim == 2 else lax.batch_matmul if left_side: return dot(tmp, ans) else: return dot(ans, tmp)
def _random_poisson(rng_key, lmbda, shape): """ References ---------- .. [1] Knuth, Donald E. Art of computer programming, volume 2: Seminumerical algorithms. Addison-Wesley Professional, 2014 (p 137). """ L = lax.exp(lax.neg(lmbda)) k = np.zeros(shape=shape) p = np.ones(shape=shape) is_done = p < L while not is_done.all(): _, rng_key = random.split(rng_key) u = random.uniform(rng_key, shape=shape) p = np.where(is_done, p, u * p) k = np.where(is_done, k, k + 1) is_done = p < L return k
series, = series_in primal_out = prim.bind(x) c0, cs = jet(deriv, primals_in, series_in) c = [c0] + cs u = [x] + series v = [primal_out] + [None] * len(series) for k in range(1, len(v)): v[k] = fact(k - 1) * sum( _scale(k, j) * c[k - j] * u[j] for j in range(1, k + 1)) primal_out, *series_out = v return primal_out, series_out def_deriv( lax.erf_p, lambda x: lax.mul(lax._const(x, 2. / np.sqrt(np.pi)), lax.exp(lax.neg(lax.square(x))))) def def_comp(prim, comp): """ Define the jet rule for a primitive in terms of a composition of simpler primitives. """ jet_rules[prim] = partial(jet, comp) def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1) def_comp(lax.log1p_p, lambda x: lax.log(1 + x)) def_comp(lax.sqrt_p, lambda x: x**0.5) def_comp(lax.rsqrt_p, lambda x: x**-0.5) def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1))) def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)))
def eager_unary(state): a = jax.device_put(1) lax.neg(a).block_until_ready() while state: lax.neg(a).block_until_ready()
def eager_unary_dispatch(state): a = jax.device_put(1) lax.neg(a) while state: lax.neg(a)
def expit(x): x = asarray(x) one = lax._const(x, 1) return lax.div(one, lax.add(one, lax.exp(lax.neg(x))))
def expn_jvp(n, primals, tangents): (x, ), (x_dot, ) = primals, tangents return expn(n, x), lax.mul(lax.neg(x_dot), expn(lax.sub(n, _lax_const(n, 1)), x))
def logpdf(x): x, = _promote_args_inexact("logistic.logpdf", x) two = _lax_const(x, 2) half_x = lax.div(x, two) return lax.mul(lax.neg(two), jnp.logaddexp(half_x, lax.neg(half_x)))
def logpdf(x, loc=0, scale=1): x, loc, scale = _promote_args_inexact("expon.logpdf", x, loc, scale) log_scale = lax.log(scale) linear_term = lax.div(lax.sub(x, loc), scale) log_probs = lax.neg(lax.add(linear_term, log_scale)) return where(lax.lt(x, loc), -inf, log_probs)
def sf(x): return expit(lax.neg(x))
def sinh(x): x, = _promote_to_result_dtype(onp.sinh, x) return lax.div(lax.sub(lax.exp(x), lax.exp(lax.neg(x))), _constant_like(x, 2))
def logpdf(x, loc=0, scale=1): x, loc, scale = _promote_args_inexact("laplace.logpdf", x, loc, scale) two = _constant_like(x, 2) linear_term = lax.div(lax.abs(lax.sub(x, loc)), scale) return lax.neg(lax.add(linear_term, lax.log(lax.mul(two, scale))))
def fun(x, y, z): pred = lax.lt(x, 3) true_fun = lambda y: y false_fun = lambda z: lax.neg(z) return lax.cond(pred, y, true_fun, z, false_fun)
def cosh(x): x, = _promote_to_result_dtype(onp.cosh, x) return lax.div(lax.add(lax.exp(x), lax.exp(lax.neg(x))), _constant_like(x, 2))
def entr(x): x, = _promote_args_inexact("entr", x) return lax.select(lax.lt(x, _constant_like(x, 0)), lax.full_like(x, -np.inf), lax.neg(xlogy(x, x)))
def expit(x): x, = _promote_args_inexact("expit", x) one = _lax_const(x, 1) return lax.div(one, lax.add(one, lax.exp(lax.neg(x))))
def deriv_prop(prim, deriv, primals_in, series_in): x, = primals_in series, = series_in primal_out = prim.bind(x) c0, cs = jet(deriv, primals_in, series_in) c = [c0] + cs u = [x] + series v = [primal_out] + [None] * len(series) for k in range(1, len(v)): v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1)) primal_out, *series_out = v return primal_out, series_out def_deriv(lax.erf_p, lambda x: lax.mul(lax._const(x, 2. / np.sqrt(np.pi)), lax.exp(lax.neg(lax.square(x))))) def def_comp(prim, comp): """ Define the jet rule for a primitive in terms of a composition of simpler primitives. """ jet_rules[prim] = partial(jet, comp) def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1) def_comp(lax.log1p_p, lambda x: lax.log(1 + x)) def_comp(lax.sqrt_p, lambda x: x ** 0.5) def_comp(lax.rsqrt_p, lambda x: x ** -0.5) def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1))) def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)))
def logpdf(self, x): x = x * 1.0 return lax.add( xlogy(x, lax.log(self.lmbda)), lax.add(lax.neg(self.lmbda), lax.neg(lax.lgamma(x))), )