def _pdf_gamma(a, b, axes=None, scale=4, color="k"): """ """ if axes is None: axes = plt.gca() if np.size(a) != 1 or np.size(b) != 1: raise ValueError("Parameters must be scalars") mean = a / b v = scale * np.sqrt(a / b ** 2) m = max(0, mean - v) n = mean + v x = np.linspace(m, n, num=100) logx = np.log(x) lpdf = random.gamma_logpdf(b * x, logx, a * logx, a * np.log(b), special.gammaln(a)) p = np.exp(lpdf) return axes.plot(x, p, color=color)
def _pdf_gamma(a, b, axes=None, scale=4, color='k'): """ """ if axes is None: axes = plt.gca() if np.size(a) != 1 or np.size(b) != 1: raise ValueError("Parameters must be scalars") mean = a / b v = scale * np.sqrt(a / b**2) m = max(0, mean - v) n = mean + v x = np.linspace(m, n, num=100) logx = np.log(x) lpdf = random.gamma_logpdf(b * x, logx, a * logx, a * np.log(b), special.gammaln(a)) p = np.exp(lpdf) return axes.plot(x, p, color=color)
def _compute_bound(self, R, logdet=None, inv=None, Q=None, gradient=False, terms=False): """ Rotate q(X) and q(alpha). Assume: p(X|alpha) = prod_m N(x_m|0,diag(alpha)) p(alpha) = prod_d G(a_d,b_d) """ ## R = self._full_rotation_matrix(R) ## if inv is not None: ## inv = self._full_rotation_matrix(inv) # # Transform the distributions and moments # plates_alpha = self.plates_alpha plates_X = self.plates_X # Compute rotated second moment if self.plate_axis is not None: # The plate axis has been moved to be the last plate axis if Q is None: raise ValueError("Plates should be rotated but no Q give") # Transform covariance sumQ = np.sum(Q, axis=0) QCovQ = sumQ[:, None, None] ** 2 * self.CovX # Rotate plates if self.precompute: QX_QX = np.einsum("...kalb,...ik,...il->...iab", self.X_X, Q, Q) XX = QX_QX + QCovQ XX = sum_to_plates(XX, plates_alpha[:-1], ndim=2) Xmu = np.einsum("...kaib,...ik->...iab", self.X_mu, Q) Xmu = sum_to_plates(Xmu, plates_alpha[:-1], ndim=2) else: X = self.X mu = self.mu QX = np.einsum("...ik,...kj->...ij", Q, X) XX = sum_to_plates(QCovQ, plates_alpha[:-1], ndim=2) + sum_to_plates( linalg.outer(QX, QX), plates_alpha[:-1], ndim=2, plates_from=plates_X ) Xmu = sum_to_plates(linalg.outer(QX, self.mu), plates_alpha[:-1], ndim=2, plates_from=plates_X) mu2 = self.mu2 D = np.shape(XX)[-1] logdet_Q = D * np.log(np.abs(sumQ)) else: XX = self.XX mu2 = self.mu2 Xmu = self.Xmu logdet_Q = 0 # Compute transformed moments # mu2 = np.einsum('...ii->...i', mu2) RXmu = np.einsum("...ik,...ki->...i", R, Xmu) RXX = np.einsum("...ik,...kj->...ij", R, XX) RXXR = np.einsum("...ik,...ik->...i", RXX, R) # <(X-mu) * (X-mu)'>_R XmuXmu = RXXR - 2 * RXmu + mu2 D = np.shape(R)[0] # Compute q(alpha) if self.update_alpha: # Parameters a0 = self.a0 b0 = self.b0 a = self.a b = b0 + 0.5 * sum_to_plates(XmuXmu, plates_alpha, plates_from=None, ndim=0) # Some expectations alpha = a / b logb = np.log(b) logalpha = -logb # + const b0_alpha = b0 * alpha a0_logalpha = a0 * logalpha else: alpha = self.alpha logalpha = 0 # # Compute the cost # def sum_plates(V, *plates): full_plates = misc.broadcasted_shape(*plates) r = self.node_X.broadcasting_multiplier(full_plates, np.shape(V)) return r * np.sum(V) XmuXmu_alpha = XmuXmu * alpha if logdet is None: logdet_R = np.linalg.slogdet(R)[1] inv_R = np.linalg.inv(R) else: logdet_R = logdet inv_R = inv # Compute entropy H(X) logH_X = random.gaussian_entropy(-2 * sum_plates(logdet_R + logdet_Q, plates_X), 0) # Compute <log p(X|alpha)> logp_X = random.gaussian_logpdf( sum_plates(XmuXmu_alpha, plates_alpha[:-1] + [D]), 0, 0, sum_plates(logalpha, plates_X + [D]), 0 ) if self.update_alpha: # Compute entropy H(alpha) # This cancels out with the log(alpha) term in log(p(alpha)) logH_alpha = 0 # Compute <log p(alpha)> logp_alpha = random.gamma_logpdf( sum_plates(b0_alpha, plates_alpha), 0, sum_plates(a0_logalpha, plates_alpha), 0, 0 ) else: logH_alpha = 0 logp_alpha = 0 # Compute the bound if terms: bound = {self.node_X: logp_X + logH_X} if self.update_alpha: bound.update({self.node_alpha: logp_alpha + logH_alpha}) else: bound = 0 + logp_X + logp_alpha + logH_X + logH_alpha if not gradient: return bound # # Compute the gradient with respect R # broadcasting_multiplier = self.node_X.broadcasting_multiplier def sum_plates(V, plates): ones = np.ones(np.shape(R)) r = broadcasting_multiplier(plates, np.shape(V)[:-2]) return r * misc.sum_multiply(V, ones, axis=(-1, -2), sumaxis=False, keepdims=False) D_XmuXmu = 2 * RXX - 2 * gaussian.transpose_covariance(Xmu) DXmuXmu_alpha = np.einsum("...i,...ij->...ij", alpha, D_XmuXmu) if self.update_alpha: D_b = 0.5 * D_XmuXmu XmuXmu_Dalpha = np.einsum( "...i,...i,...i,...ij->...ij", sum_to_plates(XmuXmu, plates_alpha, plates_from=None, ndim=0), alpha, -1 / b, D_b, ) D_b0_alpha = np.einsum("...i,...i,...i,...ij->...ij", b0, alpha, -1 / b, D_b) D_logb = np.einsum("...i,...ij->...ij", 1 / b, D_b) D_logalpha = -D_logb D_a0_logalpha = a0 * D_logalpha else: XmuXmu_Dalpha = 0 D_logalpha = 0 D_XmuXmu_alpha = DXmuXmu_alpha + XmuXmu_Dalpha D_logR = inv_R.T # Compute dH(X) dlogH_X = random.gaussian_entropy(-2 * sum_plates(D_logR, plates_X), 0) # Compute d<log p(X|alpha)> dlogp_X = random.gaussian_logpdf( sum_plates(D_XmuXmu_alpha, plates_alpha[:-1]), 0, 0, (sum_plates(D_logalpha, plates_X) * broadcasting_multiplier((D,), plates_alpha[-1:])), 0, ) if self.update_alpha: # Compute dH(alpha) # This cancels out with the log(alpha) term in log(p(alpha)) dlogH_alpha = 0 # Compute d<log p(alpha)> dlogp_alpha = random.gamma_logpdf( sum_plates(D_b0_alpha, plates_alpha[:-1]), 0, sum_plates(D_a0_logalpha, plates_alpha[:-1]), 0, 0 ) else: dlogH_alpha = 0 dlogp_alpha = 0 if terms: raise NotImplementedError() dR_bound = {self.node_X: dlogp_X + dlogH_X} if self.update_alpha: dR_bound.update({self.node_alpha: dlogp_alpha + dlogH_alpha}) else: dR_bound = 0 * dlogp_X + dlogp_X + dlogp_alpha + dlogH_X + dlogH_alpha if self.subset: indices = np.ix_(self.subset, self.subset) dR_bound = dR_bound[indices] if self.plate_axis is None: return (bound, dR_bound) # # Compute the gradient with respect to Q (if Q given) # # Some pre-computations Q_RCovR = np.einsum("...ik,...kl,...il,...->...i", R, self.CovX, R, sumQ) if self.precompute: Xr_rX = np.einsum("...abcd,...jb,...jd->...jac", self.X_X, R, R) QXr_rX = np.einsum("...akj,...ik->...aij", Xr_rX, Q) RX_mu = np.einsum("...jk,...akbj->...jab", R, self.X_mu) else: RX = np.einsum("...ik,...k->...i", R, X) QXR = np.einsum("...ik,...kj->...ij", Q, RX) QXr_rX = np.einsum("...ik,...jk->...kij", QXR, RX) RX_mu = np.einsum("...ik,...jk->...kij", RX, mu) QXr_rX = sum_to_plates(QXr_rX, plates_alpha[:-2], ndim=3, plates_from=plates_X[:-1]) RX_mu = sum_to_plates(RX_mu, plates_alpha[:-2], ndim=3, plates_from=plates_X[:-1]) def psi(v): """ Compute: d/dQ 1/2*trace(diag(v)*<(X-mu)*(X-mu)>) = Q*<X>'*R'*diag(v)*R*<X> + ones * Q diag( tr(R'*diag(v)*R*Cov) ) + mu*diag(v)*R*<X> """ # Precompute all terms to plates_alpha because v has shape # plates_alpha. # Gradient of 0.5*v*<x>*<x> v_QXrrX = np.einsum("...kij,...ik->...ij", QXr_rX, v) # Gradient of 0.5*v*Cov Q_tr_R_v_R_Cov = np.einsum("...k,...k->...", Q_RCovR, v)[..., None, :] # Gradient of mu*v*x mu_v_R_X = np.einsum("...ik,...kji->...ij", v, RX_mu) return v_QXrrX + Q_tr_R_v_R_Cov - mu_v_R_X def sum_plates(V, plates): ones = np.ones(np.shape(Q)) r = self.node_X.broadcasting_multiplier(plates, np.shape(V)[:-2]) return r * misc.sum_multiply(V, ones, axis=(-1, -2), sumaxis=False, keepdims=False) if self.update_alpha: D_logb = psi(1 / b) XX_Dalpha = -psi(alpha / b * sum_to_plates(XmuXmu, plates_alpha)) D_logalpha = -D_logb else: XX_Dalpha = 0 D_logalpha = 0 DXX_alpha = 2 * psi(alpha) D_XX_alpha = DXX_alpha + XX_Dalpha D_logdetQ = D / sumQ N = np.shape(Q)[-1] # Compute dH(X) dQ_logHX = random.gaussian_entropy(-2 * sum_plates(D_logdetQ, plates_X[:-1]), 0) # Compute d<log p(X|alpha)> dQ_logpX = random.gaussian_logpdf( sum_plates(D_XX_alpha, plates_alpha[:-2]), 0, 0, (sum_plates(D_logalpha, plates_X[:-1]) * broadcasting_multiplier((N, D), plates_alpha[-2:])), 0, ) if self.update_alpha: D_alpha = -psi(alpha / b) D_b0_alpha = b0 * D_alpha D_a0_logalpha = a0 * D_logalpha # Compute dH(alpha) # This cancels out with the log(alpha) term in log(p(alpha)) dQ_logHalpha = 0 # Compute d<log p(alpha)> dQ_logpalpha = random.gamma_logpdf( sum_plates(D_b0_alpha, plates_alpha[:-2]), 0, sum_plates(D_a0_logalpha, plates_alpha[:-2]), 0, 0 ) else: dQ_logHalpha = 0 dQ_logpalpha = 0 if terms: raise NotImplementedError() dQ_bound = {self.node_X: dQ_logpX + dQ_logHX} if self.update_alpha: dQ_bound.update({self.node_alpha: dQ_logpalpha + dQ_logHalpha}) else: dQ_bound = 0 * dQ_logpX + dQ_logpX + dQ_logpalpha + dQ_logHX + dQ_logHalpha return (bound, dR_bound, dQ_bound)