Ejemplo n.º 1
0
def varBound(modelState, queryState, X, W, lnVocab=None, XAT=None, XTX=None, scaledWordCounts=None, VTV=None, UTU=None):
    #
    # TODO Standardise hyperparameter handling so we can eliminate this copy and paste
    #

    # Unpack the model and query state tuples for ease of use and maybe speed improvements
    K, Q, F, P, T, A, varA, Y, omY, sigY, sigT, U, V, vocab, _, alphaSq, kappaSq, tauSq = (
        modelState.K,
        modelState.Q,
        modelState.F,
        modelState.P,
        modelState.T,
        modelState.A,
        modelState.varA,
        modelState.Y,
        modelState.omY,
        modelState.sigY,
        modelState.sigT,
        modelState.U,
        modelState.V,
        modelState.vocab,
        modelState.topicVar,
        modelState.featVar,
        modelState.lowTopicVar,
        modelState.lowFeatVar,
    )
    (expLmda, nu, lxi, s, docLen) = (queryState.expLmda, queryState.nu, queryState.lxi, queryState.s, queryState.docLen)

    lmda = np.log(expLmda)
    isigT = la.inv(sigT)
    lnDetSigT = la.det(sigT)
    sigmaSq = 1  # A bit of a hack till hyperparameter handling is standardised

    # Get the number of samples from the shape. Ensure that the shapes are consistent
    # with the model parameters.
    (D, Tcheck) = W.shape
    if Tcheck != T:
        raise ValueError(
            "The shape of the DxT document matrix W is invalid, T is %d but the matrix W has shape (%d, %d)"
            % (T, D, Tcheck)
        )

    (Dcheck, Fcheck) = X.shape
    if Dcheck != D:
        raise ValueError("Inconsistent sizes between the matrices X and W, X has %d rows but W has %d" % (Dcheck, D))
    if Fcheck != F:
        raise ValueError(
            "The shape of the DxF feature matrix X is invalid. F is %d but the matrix X has shape (%d, %d)"
            % (F, Dcheck, Fcheck)
        )

    # We'll need the original xi for this and also Z, the 3D tensor of which for each document D
    # and term T gives the strength of topic K. We'll also need the log of the vocab dist
    xi = deriveXi(lmda, nu, s)

    # If not already provided, we'll also need the following products
    #
    if XAT is None:
        XAT = X.dot(A.T)
    if XTX is None:
        XTX = X.T.dot(X)
    if V is not None and VTV is None:
        VTV = V.T.dot(V)
    if U is not None and UTU is None:
        UTU = U.T.dot(U)

    # also need one over the usual variances
    overSsq, overAsq, overKsq, overTsq = 1.0 / sigmaSq, 1.0 / alphaSq, 1.0 / kappaSq, 1.0 / tauSq
    overTkSq = overTsq * overKsq
    overAsSq = overAsq * overSsq

    # <ln p(Y)>
    #
    trSigY = 1 if sigY is None else np.trace(sigY)
    trOmY = K  # Basically it's the trace of the identity matrix as the posterior and prior cancel out
    lnP_Y = -0.5 * (
        Q * P * LOG_2PI + P * lnDetSigT + overTkSq * trSigY * trOmY + overTkSq * np.trace(isigT.dot(Y).dot(Y.T))
    )

    # <ln P(A|Y)>
    # TODO it looks like I should take the trace of omA \otimes I_K here.
    # TODO Need to check re-arranging sigY and omY is sensible.
    halfKF = 0.5 * K * F

    # Horrible, but varBound can be called by two implementations, one with Y as a matrix-variate
    # where sigY is QxQ and one with Y as a multi-varate, where sigY is a QPxQP.
    A_from_Y = Y.dot(U.T) if V is None else U.dot(Y).dot(V.T)
    A_diff = A - A_from_Y
    varFactorU = np.trace(sigY.dot(np.kron(VTV, UTU))) if sigY.shape[0] == Q * P else np.sum(sigY * UTU)
    varFactorV = 1 if V is None else np.sum(omY * V.T.dot(V))
    lnP_A = (
        -halfKF * LOG_2PI
        - halfKF * log(alphaSq)
        - F / 2.0 * lnDetSigT
        - 0.5 * (overAsSq * varFactorV * varFactorU + np.trace(XTX.dot(varA)) * K + np.sum(isigT.dot(A_diff) * A_diff))
    )

    # <ln p(Theta|A,X)
    #
    lmdaDiff = lmda - XAT
    lnP_Theta = (
        -0.5 * D * LOG_2PI
        - 0.5 * D * lnDetSigT
        - 0.5 / sigmaSq * (np.sum(nu) + D * K * np.sum(XTX * varA) + np.sum(lmdaDiff.dot(isigT) * lmdaDiff))
    )
    # Why is order of sigT reversed? It's 'cause we've not been consistent. A is KxF but lmda is DxK, and
    # note that the distribution of lmda transpose has the same covariances, just in different positions
    #  (i.e. row is col and vice-versa)

    # <ln p(Z|Theta)
    #
    docLenLmdaLxi = docLen[:, np.newaxis] * lmda * lxi
    scaledWordCounts = sparseScalarQuotientOfDot(W, expLmda, vocab, out=scaledWordCounts)

    lnP_Z = 0.0
    lnP_Z -= np.sum(docLenLmdaLxi * lmda)
    lnP_Z -= np.sum(docLen[:, np.newaxis] * nu * nu * lxi)
    lnP_Z += 2 * np.sum(s[:, np.newaxis] * docLenLmdaLxi)
    lnP_Z -= 0.5 * np.sum(docLen[:, np.newaxis] * lmda)
    lnP_Z += np.sum(
        lmda * expLmda * (scaledWordCounts.dot(vocab.T))
    )  # n(d,k) = expLmda * (scaledWordCounts.dot(vocab.T))
    lnP_Z -= np.sum(docLen[:, np.newaxis] * lxi * ((s ** 2)[:, np.newaxis] - xi ** 2))
    lnP_Z += 0.5 * np.sum(docLen[:, np.newaxis] * (s[:, np.newaxis] + xi))
    lnP_Z -= np.sum(docLen[:, np.newaxis] * safe_log_one_plus_exp_of(xi))
    lnP_Z -= np.sum(docLen * s)

    # <ln p(W|Z, vocab)>
    #
    lnP_w_dt = sparseScalarProductOfDot(scaledWordCounts, expLmda, vocab * safe_log(vocab))
    lnP_W = np.sum(lnP_w_dt.data)

    # H[q(Y)]
    lnDetOmY = 0 if omY is None else log(la.det(omY))
    lnDetSigY = 0 if sigY is None else log(max(la.det(sigY), sys.float_info.min))  # TODO FIX THIS
    ent_Y = 0.5 * (P * K * LOG_2PI_E + Q * lnDetOmY + P * lnDetSigY)

    # H[q(A|Y)]
    #
    # A few things - omA is fixed so long as tau and sigma are, so there's no benefit in
    # recalculating this every time.
    #
    # However in a recent test, la.det(omA) = 0
    # this is very strange as omA is the inverse of (s*I + t*XTX)
    #
    #    ent_A = 0.5 * (F * K * LOG_2PI_E + K * log (la.det(omA)) + F * K * log (tau2))\
    ent_A = 0

    # H[q(Theta|A)]
    ent_Theta = 0.5 * (K * LOG_2PI_E + np.sum(np.log(nu * nu)))

    # H[q(Z|\Theta)
    #
    # So Z_dtk \propto expLmda_dt * vocab_tk. We let N here be the normalizer (which is
    # \sum_j expLmda_dt * vocab_tj, which implies N is DxT. We need to evaluate
    # Z_dtk * log Z_dtk. We can pull out the normalizer of the first term, but it has
    # to stay in the log Z_dtk expression, hence the third term in the sum. We can however
    # take advantage of the ability to mix dot and element-wise products for the different
    # components of Z_dtk in that three-term sum, which we denote as S
    #   Finally we use np.sum to sum over d and t
    #
    ent_Z = 0  # entropyOfDot(expLmda, vocab)

    result = lnP_Y + lnP_A + lnP_Theta + lnP_Z + lnP_W + ent_Y + ent_A + ent_Theta + ent_Z

    return result