def varBound(modelState, queryState, X, W, lnVocab=None, XAT=None, XTX=None, scaledWordCounts=None, VTV=None, UTU=None): # # TODO Standardise hyperparameter handling so we can eliminate this copy and paste # # Unpack the model and query state tuples for ease of use and maybe speed improvements K, Q, F, P, T, A, varA, Y, omY, sigY, sigT, U, V, vocab, _, alphaSq, kappaSq, tauSq = ( modelState.K, modelState.Q, modelState.F, modelState.P, modelState.T, modelState.A, modelState.varA, modelState.Y, modelState.omY, modelState.sigY, modelState.sigT, modelState.U, modelState.V, modelState.vocab, modelState.topicVar, modelState.featVar, modelState.lowTopicVar, modelState.lowFeatVar, ) (expLmda, nu, lxi, s, docLen) = (queryState.expLmda, queryState.nu, queryState.lxi, queryState.s, queryState.docLen) lmda = np.log(expLmda) isigT = la.inv(sigT) lnDetSigT = la.det(sigT) sigmaSq = 1 # A bit of a hack till hyperparameter handling is standardised # Get the number of samples from the shape. Ensure that the shapes are consistent # with the model parameters. (D, Tcheck) = W.shape if Tcheck != T: raise ValueError( "The shape of the DxT document matrix W is invalid, T is %d but the matrix W has shape (%d, %d)" % (T, D, Tcheck) ) (Dcheck, Fcheck) = X.shape if Dcheck != D: raise ValueError("Inconsistent sizes between the matrices X and W, X has %d rows but W has %d" % (Dcheck, D)) if Fcheck != F: raise ValueError( "The shape of the DxF feature matrix X is invalid. F is %d but the matrix X has shape (%d, %d)" % (F, Dcheck, Fcheck) ) # We'll need the original xi for this and also Z, the 3D tensor of which for each document D # and term T gives the strength of topic K. We'll also need the log of the vocab dist xi = deriveXi(lmda, nu, s) # If not already provided, we'll also need the following products # if XAT is None: XAT = X.dot(A.T) if XTX is None: XTX = X.T.dot(X) if V is not None and VTV is None: VTV = V.T.dot(V) if U is not None and UTU is None: UTU = U.T.dot(U) # also need one over the usual variances overSsq, overAsq, overKsq, overTsq = 1.0 / sigmaSq, 1.0 / alphaSq, 1.0 / kappaSq, 1.0 / tauSq overTkSq = overTsq * overKsq overAsSq = overAsq * overSsq # <ln p(Y)> # trSigY = 1 if sigY is None else np.trace(sigY) trOmY = K # Basically it's the trace of the identity matrix as the posterior and prior cancel out lnP_Y = -0.5 * ( Q * P * LOG_2PI + P * lnDetSigT + overTkSq * trSigY * trOmY + overTkSq * np.trace(isigT.dot(Y).dot(Y.T)) ) # <ln P(A|Y)> # TODO it looks like I should take the trace of omA \otimes I_K here. # TODO Need to check re-arranging sigY and omY is sensible. halfKF = 0.5 * K * F # Horrible, but varBound can be called by two implementations, one with Y as a matrix-variate # where sigY is QxQ and one with Y as a multi-varate, where sigY is a QPxQP. A_from_Y = Y.dot(U.T) if V is None else U.dot(Y).dot(V.T) A_diff = A - A_from_Y varFactorU = np.trace(sigY.dot(np.kron(VTV, UTU))) if sigY.shape[0] == Q * P else np.sum(sigY * UTU) varFactorV = 1 if V is None else np.sum(omY * V.T.dot(V)) lnP_A = ( -halfKF * LOG_2PI - halfKF * log(alphaSq) - F / 2.0 * lnDetSigT - 0.5 * (overAsSq * varFactorV * varFactorU + np.trace(XTX.dot(varA)) * K + np.sum(isigT.dot(A_diff) * A_diff)) ) # <ln p(Theta|A,X) # lmdaDiff = lmda - XAT lnP_Theta = ( -0.5 * D * LOG_2PI - 0.5 * D * lnDetSigT - 0.5 / sigmaSq * (np.sum(nu) + D * K * np.sum(XTX * varA) + np.sum(lmdaDiff.dot(isigT) * lmdaDiff)) ) # Why is order of sigT reversed? It's 'cause we've not been consistent. A is KxF but lmda is DxK, and # note that the distribution of lmda transpose has the same covariances, just in different positions # (i.e. row is col and vice-versa) # <ln p(Z|Theta) # docLenLmdaLxi = docLen[:, np.newaxis] * lmda * lxi scaledWordCounts = sparseScalarQuotientOfDot(W, expLmda, vocab, out=scaledWordCounts) lnP_Z = 0.0 lnP_Z -= np.sum(docLenLmdaLxi * lmda) lnP_Z -= np.sum(docLen[:, np.newaxis] * nu * nu * lxi) lnP_Z += 2 * np.sum(s[:, np.newaxis] * docLenLmdaLxi) lnP_Z -= 0.5 * np.sum(docLen[:, np.newaxis] * lmda) lnP_Z += np.sum( lmda * expLmda * (scaledWordCounts.dot(vocab.T)) ) # n(d,k) = expLmda * (scaledWordCounts.dot(vocab.T)) lnP_Z -= np.sum(docLen[:, np.newaxis] * lxi * ((s ** 2)[:, np.newaxis] - xi ** 2)) lnP_Z += 0.5 * np.sum(docLen[:, np.newaxis] * (s[:, np.newaxis] + xi)) lnP_Z -= np.sum(docLen[:, np.newaxis] * safe_log_one_plus_exp_of(xi)) lnP_Z -= np.sum(docLen * s) # <ln p(W|Z, vocab)> # lnP_w_dt = sparseScalarProductOfDot(scaledWordCounts, expLmda, vocab * safe_log(vocab)) lnP_W = np.sum(lnP_w_dt.data) # H[q(Y)] lnDetOmY = 0 if omY is None else log(la.det(omY)) lnDetSigY = 0 if sigY is None else log(max(la.det(sigY), sys.float_info.min)) # TODO FIX THIS ent_Y = 0.5 * (P * K * LOG_2PI_E + Q * lnDetOmY + P * lnDetSigY) # H[q(A|Y)] # # A few things - omA is fixed so long as tau and sigma are, so there's no benefit in # recalculating this every time. # # However in a recent test, la.det(omA) = 0 # this is very strange as omA is the inverse of (s*I + t*XTX) # # ent_A = 0.5 * (F * K * LOG_2PI_E + K * log (la.det(omA)) + F * K * log (tau2))\ ent_A = 0 # H[q(Theta|A)] ent_Theta = 0.5 * (K * LOG_2PI_E + np.sum(np.log(nu * nu))) # H[q(Z|\Theta) # # So Z_dtk \propto expLmda_dt * vocab_tk. We let N here be the normalizer (which is # \sum_j expLmda_dt * vocab_tj, which implies N is DxT. We need to evaluate # Z_dtk * log Z_dtk. We can pull out the normalizer of the first term, but it has # to stay in the log Z_dtk expression, hence the third term in the sum. We can however # take advantage of the ability to mix dot and element-wise products for the different # components of Z_dtk in that three-term sum, which we denote as S # Finally we use np.sum to sum over d and t # ent_Z = 0 # entropyOfDot(expLmda, vocab) result = lnP_Y + lnP_A + lnP_Theta + lnP_Z + lnP_W + ent_Y + ent_A + ent_Theta + ent_Z return result
def _quickPrintElbo( updateMsg, iteration, X, W, K, Q, F, P, T, A, varA, Y, omY, sigY, sigT, U, V, vocab, sigmaSq, alphaSq, kappaSq, tauSq, lmda, expLmda, nu, lxi, s, docLen, ): """ This checks that none of the matrix parameters contain a NaN or an Inf value, then calcuates the variational bound, and prints it to stdout with the given update message. A tremendously inefficient method for debugging only. """ def _has_nans(X): return np.isnan(X.data).any() def _has_infs(X): return np.isinf(X.data).any() def _nan(varName): print(str(varName) + " has NaNs") def _inf(varName): print(str(varName) + " has infs") assert not ( lmda is not None and expLmda is not None ), "We can't have both lmda and expLmda not be none, as we assume we only ever have one." # NaN tests if _has_nans(Y): _nan("Y") if omY is not None and _has_nans(omY): _nan("omY") if sigY is not None and _has_nans(sigY): _nan("sigY") if _has_nans(A): _nan("A") if _has_nans(varA): _nan("varA") if expLmda is not None and _has_nans(expLmda): _nan("expLmda") if lmda is not None and _has_nans(lmda): _nan("lmda") if sigT is not None and _has_nans(sigT): _nan("sigT") if _has_nans(nu): _nan("nu") if U is not None and _has_nans(U): _nan("U") if V is not None and _has_nans(V): _nan("V") if _has_nans(vocab): _nan("vocab") # Infs tests if _has_infs(Y): _inf("Y") if omY is not None and _has_infs(omY): _inf("omY") if sigY is not None and _has_infs(sigY): _inf("sigY") if _has_infs(A): _inf("A") if _has_infs(varA): _inf("varA") if expLmda is not None and _has_infs(expLmda): _inf("expLmda") if lmda is not None and _has_infs(lmda): _inf("lmda") if sigT is not None and _has_infs(sigT): _inf("sigT") if _has_infs(nu): _inf("nu") if U is not None and _has_infs(U): _inf("U") if V is not None and _has_infs(V): _inf("V") if _has_infs(vocab): _inf("vocab") wasPassedExpLmda = expLmda is not None if expLmda is None: expLmda = np.exp(lmda, out=lmda) elbo = varBound( VbSideTopicModelState( K, Q, F, P, T, A, varA, Y, omY, sigY, sigT, U, V, vocab, sigmaSq, alphaSq, kappaSq, tauSq ), VbSideTopicQueryState(expLmda, nu, lxi, s, docLen), X, W, ) lmda = np.log(expLmda, out=expLmda) xi = deriveXi(lmda, nu, s) if lmda is not None else deriveXi(np.log(expLmda), nu, s) diff = _quickPrintElbo.last - elbo diffStr = " " if diff <= 0 else "(!)" print( "\t Update %5d: %-30s ELBO : %12.3f %s lmda.mean=%f \tlmda.max=%f \tlmda.min=%f \tnu.mean=%f \txi.mean=%f \ts.mean=%f" % (iteration, updateMsg, elbo, diffStr, lmda.mean(), lmda.max(), lmda.min(), nu.mean(), xi.mean(), s.mean()) ) if wasPassedExpLmda: np.exp(expLmda, out=expLmda) _quickPrintElbo.last = elbo