def _case1(zagf): z, alpha, _, flag = zagf # dz = - dCDF(z; a) / pdf(z; a) # pdf = z^(a-1) * e^(-z) / Gamma(a) # CDF(z; a) = IncompleteGamma(a, z) / Gamma(a) # dCDF(z; a) = (dIncompleteGamma - IncompleteGamma * Digamma(a)) / Gamma(a) # =: unnormalized_dCDF / Gamma(a) # IncompleteGamma ~ z^a [ 1/a - z/(a+1) + z^2/2!(a+2) - z^3/3!(a+3) + z^4/4!(a+4) - z^5/5!(a+5) ] # =: z^a * term1 # dIncompleteGamma ~ z^a * log(z) * term1 - z^a [1/a^2 - z/(a+1)^2 + z^2/2!(a+2)^2 # - z^3/3!(a+3)^2 + z^4/4!(a+4)^2 - z^5/5!(a+5)^2 ] # =: z^a * log(z) * term1 - z^a * term2 # unnormalized_dCDF = z^a { [log(z) - Digamma(a)] * term1 - term2 } zi = 1.0 update = zi / alpha term1 = update term2 = update / alpha for i in range(1, 6): zi = -zi * z / i update = zi / (alpha + i) term1 = term1 + update term2 = term2 + update / (alpha + i) unnormalized_cdf_dot = np.power(z, alpha) * ( (np.log(z) - lax.digamma(alpha)) * term1 - term2) unnormalized_pdf = np.power(z, alpha - 1) * np.exp(-z) grad = -unnormalized_cdf_dot / unnormalized_pdf return z, alpha, grad, ~flag
def digamma(x): x, = _promote_args_inexact("digamma", x) return lax.digamma(x)