示例#1
0
def varBound(modelState, queryState, X, W, lnVocab=None, XAT=None, XTX=None, scaledWordCounts=None, VTV=None, UTU=None):
    #
    # TODO Standardise hyperparameter handling so we can eliminate this copy and paste
    #

    # Unpack the model and query state tuples for ease of use and maybe speed improvements
    K, Q, F, P, T, A, varA, Y, omY, sigY, sigT, U, V, vocab, _, alphaSq, kappaSq, tauSq = (
        modelState.K,
        modelState.Q,
        modelState.F,
        modelState.P,
        modelState.T,
        modelState.A,
        modelState.varA,
        modelState.Y,
        modelState.omY,
        modelState.sigY,
        modelState.sigT,
        modelState.U,
        modelState.V,
        modelState.vocab,
        modelState.topicVar,
        modelState.featVar,
        modelState.lowTopicVar,
        modelState.lowFeatVar,
    )
    (expLmda, nu, lxi, s, docLen) = (queryState.expLmda, queryState.nu, queryState.lxi, queryState.s, queryState.docLen)

    lmda = np.log(expLmda)
    isigT = la.inv(sigT)
    lnDetSigT = la.det(sigT)
    sigmaSq = 1  # A bit of a hack till hyperparameter handling is standardised

    # Get the number of samples from the shape. Ensure that the shapes are consistent
    # with the model parameters.
    (D, Tcheck) = W.shape
    if Tcheck != T:
        raise ValueError(
            "The shape of the DxT document matrix W is invalid, T is %d but the matrix W has shape (%d, %d)"
            % (T, D, Tcheck)
        )

    (Dcheck, Fcheck) = X.shape
    if Dcheck != D:
        raise ValueError("Inconsistent sizes between the matrices X and W, X has %d rows but W has %d" % (Dcheck, D))
    if Fcheck != F:
        raise ValueError(
            "The shape of the DxF feature matrix X is invalid. F is %d but the matrix X has shape (%d, %d)"
            % (F, Dcheck, Fcheck)
        )

    # We'll need the original xi for this and also Z, the 3D tensor of which for each document D
    # and term T gives the strength of topic K. We'll also need the log of the vocab dist
    xi = deriveXi(lmda, nu, s)

    # If not already provided, we'll also need the following products
    #
    if XAT is None:
        XAT = X.dot(A.T)
    if XTX is None:
        XTX = X.T.dot(X)
    if V is not None and VTV is None:
        VTV = V.T.dot(V)
    if U is not None and UTU is None:
        UTU = U.T.dot(U)

    # also need one over the usual variances
    overSsq, overAsq, overKsq, overTsq = 1.0 / sigmaSq, 1.0 / alphaSq, 1.0 / kappaSq, 1.0 / tauSq
    overTkSq = overTsq * overKsq
    overAsSq = overAsq * overSsq

    # <ln p(Y)>
    #
    trSigY = 1 if sigY is None else np.trace(sigY)
    trOmY = K  # Basically it's the trace of the identity matrix as the posterior and prior cancel out
    lnP_Y = -0.5 * (
        Q * P * LOG_2PI + P * lnDetSigT + overTkSq * trSigY * trOmY + overTkSq * np.trace(isigT.dot(Y).dot(Y.T))
    )

    # <ln P(A|Y)>
    # TODO it looks like I should take the trace of omA \otimes I_K here.
    # TODO Need to check re-arranging sigY and omY is sensible.
    halfKF = 0.5 * K * F

    # Horrible, but varBound can be called by two implementations, one with Y as a matrix-variate
    # where sigY is QxQ and one with Y as a multi-varate, where sigY is a QPxQP.
    A_from_Y = Y.dot(U.T) if V is None else U.dot(Y).dot(V.T)
    A_diff = A - A_from_Y
    varFactorU = np.trace(sigY.dot(np.kron(VTV, UTU))) if sigY.shape[0] == Q * P else np.sum(sigY * UTU)
    varFactorV = 1 if V is None else np.sum(omY * V.T.dot(V))
    lnP_A = (
        -halfKF * LOG_2PI
        - halfKF * log(alphaSq)
        - F / 2.0 * lnDetSigT
        - 0.5 * (overAsSq * varFactorV * varFactorU + np.trace(XTX.dot(varA)) * K + np.sum(isigT.dot(A_diff) * A_diff))
    )

    # <ln p(Theta|A,X)
    #
    lmdaDiff = lmda - XAT
    lnP_Theta = (
        -0.5 * D * LOG_2PI
        - 0.5 * D * lnDetSigT
        - 0.5 / sigmaSq * (np.sum(nu) + D * K * np.sum(XTX * varA) + np.sum(lmdaDiff.dot(isigT) * lmdaDiff))
    )
    # Why is order of sigT reversed? It's 'cause we've not been consistent. A is KxF but lmda is DxK, and
    # note that the distribution of lmda transpose has the same covariances, just in different positions
    #  (i.e. row is col and vice-versa)

    # <ln p(Z|Theta)
    #
    docLenLmdaLxi = docLen[:, np.newaxis] * lmda * lxi
    scaledWordCounts = sparseScalarQuotientOfDot(W, expLmda, vocab, out=scaledWordCounts)

    lnP_Z = 0.0
    lnP_Z -= np.sum(docLenLmdaLxi * lmda)
    lnP_Z -= np.sum(docLen[:, np.newaxis] * nu * nu * lxi)
    lnP_Z += 2 * np.sum(s[:, np.newaxis] * docLenLmdaLxi)
    lnP_Z -= 0.5 * np.sum(docLen[:, np.newaxis] * lmda)
    lnP_Z += np.sum(
        lmda * expLmda * (scaledWordCounts.dot(vocab.T))
    )  # n(d,k) = expLmda * (scaledWordCounts.dot(vocab.T))
    lnP_Z -= np.sum(docLen[:, np.newaxis] * lxi * ((s ** 2)[:, np.newaxis] - xi ** 2))
    lnP_Z += 0.5 * np.sum(docLen[:, np.newaxis] * (s[:, np.newaxis] + xi))
    lnP_Z -= np.sum(docLen[:, np.newaxis] * safe_log_one_plus_exp_of(xi))
    lnP_Z -= np.sum(docLen * s)

    # <ln p(W|Z, vocab)>
    #
    lnP_w_dt = sparseScalarProductOfDot(scaledWordCounts, expLmda, vocab * safe_log(vocab))
    lnP_W = np.sum(lnP_w_dt.data)

    # H[q(Y)]
    lnDetOmY = 0 if omY is None else log(la.det(omY))
    lnDetSigY = 0 if sigY is None else log(max(la.det(sigY), sys.float_info.min))  # TODO FIX THIS
    ent_Y = 0.5 * (P * K * LOG_2PI_E + Q * lnDetOmY + P * lnDetSigY)

    # H[q(A|Y)]
    #
    # A few things - omA is fixed so long as tau and sigma are, so there's no benefit in
    # recalculating this every time.
    #
    # However in a recent test, la.det(omA) = 0
    # this is very strange as omA is the inverse of (s*I + t*XTX)
    #
    #    ent_A = 0.5 * (F * K * LOG_2PI_E + K * log (la.det(omA)) + F * K * log (tau2))\
    ent_A = 0

    # H[q(Theta|A)]
    ent_Theta = 0.5 * (K * LOG_2PI_E + np.sum(np.log(nu * nu)))

    # H[q(Z|\Theta)
    #
    # So Z_dtk \propto expLmda_dt * vocab_tk. We let N here be the normalizer (which is
    # \sum_j expLmda_dt * vocab_tj, which implies N is DxT. We need to evaluate
    # Z_dtk * log Z_dtk. We can pull out the normalizer of the first term, but it has
    # to stay in the log Z_dtk expression, hence the third term in the sum. We can however
    # take advantage of the ability to mix dot and element-wise products for the different
    # components of Z_dtk in that three-term sum, which we denote as S
    #   Finally we use np.sum to sum over d and t
    #
    ent_Z = 0  # entropyOfDot(expLmda, vocab)

    result = lnP_Y + lnP_A + lnP_Theta + lnP_Z + lnP_W + ent_Y + ent_A + ent_Theta + ent_Z

    return result
示例#2
0
def _quickPrintElbo(
    updateMsg,
    iteration,
    X,
    W,
    K,
    Q,
    F,
    P,
    T,
    A,
    varA,
    Y,
    omY,
    sigY,
    sigT,
    U,
    V,
    vocab,
    sigmaSq,
    alphaSq,
    kappaSq,
    tauSq,
    lmda,
    expLmda,
    nu,
    lxi,
    s,
    docLen,
):
    """
    This checks that none of the matrix parameters contain a NaN or an Inf
    value, then calcuates the variational bound, and prints it to stdout with
    the given update message.
    
    A tremendously inefficient method for debugging only.
    """

    def _has_nans(X):
        return np.isnan(X.data).any()

    def _has_infs(X):
        return np.isinf(X.data).any()

    def _nan(varName):
        print(str(varName) + " has NaNs")

    def _inf(varName):
        print(str(varName) + " has infs")

    assert not (
        lmda is not None and expLmda is not None
    ), "We can't have both lmda and expLmda not be none, as we assume we only ever have one."

    # NaN tests
    if _has_nans(Y):
        _nan("Y")
    if omY is not None and _has_nans(omY):
        _nan("omY")
    if sigY is not None and _has_nans(sigY):
        _nan("sigY")

    if _has_nans(A):
        _nan("A")
    if _has_nans(varA):
        _nan("varA")

    if expLmda is not None and _has_nans(expLmda):
        _nan("expLmda")
    if lmda is not None and _has_nans(lmda):
        _nan("lmda")
    if sigT is not None and _has_nans(sigT):
        _nan("sigT")
    if _has_nans(nu):
        _nan("nu")

    if U is not None and _has_nans(U):
        _nan("U")
    if V is not None and _has_nans(V):
        _nan("V")

    if _has_nans(vocab):
        _nan("vocab")

    # Infs tests
    if _has_infs(Y):
        _inf("Y")
    if omY is not None and _has_infs(omY):
        _inf("omY")
    if sigY is not None and _has_infs(sigY):
        _inf("sigY")

    if _has_infs(A):
        _inf("A")
    if _has_infs(varA):
        _inf("varA")

    if expLmda is not None and _has_infs(expLmda):
        _inf("expLmda")
    if lmda is not None and _has_infs(lmda):
        _inf("lmda")
    if sigT is not None and _has_infs(sigT):
        _inf("sigT")
    if _has_infs(nu):
        _inf("nu")

    if U is not None and _has_infs(U):
        _inf("U")
    if V is not None and _has_infs(V):
        _inf("V")

    if _has_infs(vocab):
        _inf("vocab")

    wasPassedExpLmda = expLmda is not None
    if expLmda is None:
        expLmda = np.exp(lmda, out=lmda)

    elbo = varBound(
        VbSideTopicModelState(
            K, Q, F, P, T, A, varA, Y, omY, sigY, sigT, U, V, vocab, sigmaSq, alphaSq, kappaSq, tauSq
        ),
        VbSideTopicQueryState(expLmda, nu, lxi, s, docLen),
        X,
        W,
    )

    lmda = np.log(expLmda, out=expLmda)
    xi = deriveXi(lmda, nu, s) if lmda is not None else deriveXi(np.log(expLmda), nu, s)

    diff = _quickPrintElbo.last - elbo
    diffStr = "   " if diff <= 0 else "(!)"

    print(
        "\t Update %5d: %-30s  ELBO : %12.3f %s  lmda.mean=%f \tlmda.max=%f \tlmda.min=%f \tnu.mean=%f \txi.mean=%f \ts.mean=%f"
        % (iteration, updateMsg, elbo, diffStr, lmda.mean(), lmda.max(), lmda.min(), nu.mean(), xi.mean(), s.mean())
    )
    if wasPassedExpLmda:
        np.exp(expLmda, out=expLmda)
    _quickPrintElbo.last = elbo