    def testCrossValPerplexityOnRealDataWithLdaGibbsInc(self):
        ActiveFolds = 3
        dtype = np.float64 # DTYPE

        data = DataSet.from_files(words_file=AclWordPath, links_file=AclCitePath)

        data.convert_to_dtype(np.int32) # Gibbs expects integers as input, regardless of model dtype
        data.prune_and_shuffle(min_doc_len=MinDocLen, min_link_count=MinLinkCount)

        # Training setup
        TrainSamplesPerTopic = 10
        QuerySamplesPerTopic = 2
        Thin = 2
        Debug = False

        # Start running experiments
        topicCounts = [5, 10, 15, 20, 25, 30, 35, 40, 45, 50]
        for K in topicCounts:
            trainPlan = lda_gibbs.newTrainPlan(K * TrainSamplesPerTopic, thin=Thin, debug=Debug)
            queryPlan = lda_gibbs.newTrainPlan(K * QuerySamplesPerTopic, thin=Thin, debug=Debug)

            trainPerps = []
            queryPerps = []
            for fold in range(ActiveFolds): # range(NumFolds):
                trainData, queryData = data.cross_valid_split(fold, NumFolds)
                estData, evalData = queryData.doc_completion_split()

                model = lda_gibbs.newModelAtRandom(trainData, K, dtype=dtype)
                query = lda_gibbs.newQueryState(trainData, model)

                # Train the model, and the immediately save the result to a file for subsequent inspection
                model, trainResult, (_, _, _) = lda_gibbs.train (trainData, model, query, trainPlan)

                like = lda_gibbs.log_likelihood(trainData, model, trainResult)
                perp = perplexity_from_like(like, trainData.word_count)

                query = lda_gibbs.newQueryState(estData, model)
                _, queryResult = lda_gibbs.query(estData, model, query, queryPlan)

                like = lda_gibbs.log_likelihood(evalData, model, queryResult)
                perp = perplexity_from_like(like, evalData.word_count)

            trainPerps.append(sum(trainPerps) / ActiveFolds)
            queryPerps.append(sum(queryPerps) / ActiveFolds)
            print("K=%d,Segment=Train,%s" % (K, ",".join([str(p) for p in trainPerps])))
            print("K=%d,Segment=Query,%s" % (K, ",".join([str(p) for p in queryPerps])))
    def testPerplexityOnRealData(self):
        dtype = np.float64 # DTYPE

        data = DataSet.from_files(words_file=AclWordPath, links_file=AclCitePath)
        with open(AclDictPath, "rb") as f:
            d = pkl.load(f)

        data.prune_and_shuffle(min_doc_len=MinDocLen, min_link_count=MinLinkCount)

        # IDF frequency for when we print out the vocab later
        freq = np.squeeze(np.asarray(data.words.sum(axis=0)))
        scale = np.reciprocal(1 + freq)

        # Initialise the model
        K = 50
        model      = mtm.newModelAtRandom(data, K, K - 1, dtype=dtype)
        queryState = mtm.newQueryState(data, model)
        trainPlan  = mtm.newTrainPlan(iterations=200, logFrequency=10, fastButInaccurate=False, debug=True)

        # Train the model, and the immediately save the result to a file for subsequent inspection
        model, query, (bndItrs, bndVals, bndLikes) = mtm.train (data, model, queryState, trainPlan)
#        with open(newModelFileFromModel(model), "wb") as f:
#            pkl.dump ((model, query, (bndItrs, bndVals, bndLikes)), f)

        # Plot the evolution of the bound during training.
        fig, ax1 = plt.subplots()
        ax1.plot(bndItrs, bndVals, 'b-')
        ax1.set_ylabel('Bound', color='b')

        ax2 = ax1.twinx()
        ax2.plot(bndItrs, bndLikes, 'r-')
        ax2.set_ylabel('Likelihood', color='r')


        fig, ax1 = plt.subplots()
        ax1.imshow(model.topicCov, interpolation="nearest", cmap=cm.Greys_r)

        # Print out the most likely topic words
        # scale = np.reciprocal(1 + np.squeeze(np.array(data.words.sum(axis=0))))
        vocab = mtm.wordDists(model)
        topWordCount = 10
        kTopWordInds = [self.topWordInds(vocab[k,:], topWordCount) for k in range(K)]

        like = mtm.log_likelihood(data, model, query)
        perp = perplexity_from_like(like, data.word_count)

        print ("Prior %s" % (str(model.topicPrior)))
        print ("Perplexity: %f\n\n" % perp)

        for k in range(model.K):
            print("\nTopic %d\n=============================" % k)
            print("\n".join("%-20s\t%0.4f" % (d[kTopWordInds[k][c]], vocab[k][kTopWordInds[k][c]]) for c in range(topWordCount)))
def train(data, model, query, plan):
    iterations, burnIn, thin, weightUpdateInterval, _, debug = \
        plan.iterations, plan.burnIn, plan.thin, plan.weightUpdateInterval, plan.logFrequency, plan.debug
    w_list, z_list, docLens = \
        query.w_list, query.z_list, query.docLens
    K, T, weights, topicPrior, vocabPrior, _, _, _, dtype, name = \
        model.K, model.T, model.weights, model.topicPrior, model.vocabPrior, model.topicSum, model.vocabSum, model.numSamples, model.dtype, model.name

    assert model.dtype == np.float64, "This is only implemented for 64-bit floats"
    D = docLens.shape[0]
    X = data.feats
    assert docLens.max(
    ) < 65536, "This only works for documents with fewer than 65,536 words"

    ndk = np.zeros((D, K), dtype=np.uint16)
    nkv = np.zeros((K, T), dtype=np.int32)
    nk = np.zeros((K, ), dtype=np.int32)

    num_samples = (iterations - burnIn) // thin
    n_dk_samples = np.zeros((D, K, num_samples), dtype=np.uint16)
    topicSum = np.zeros((D, K), dtype=dtype)
    vocabSum = np.zeros((K, T), dtype=dtype)

    compiled.sumSuffStats(w_list, z_list, docLens, ndk, nkv, nk)

    # Burn in
    alphas = X.dot(weights.T)
    if debug: print("Burning")
    compiled.sample (burnIn, burnIn + 1, w_list, z_list, docLens, \
            alphas, ndk, nkv, nk, n_dk_samples, topicSum, vocabSum, \
            vocabPrior, False, debug)

    # True samples
    if debug: print("Training")
    sample_count = 0
    for _ in range(0, iterations - burnIn, weightUpdateInterval):
        alphas[:, :] = X.dot(weights.T)
        sample_count += compiled.sample (weightUpdateInterval, thin, w_list, z_list, docLens, \
                alphas, ndk, nkv, nk, n_dk_samples, topicSum, vocabSum, \
                vocabPrior, False, debug)

        if debug:  # Print out the perplexity so far
            likely = log_likelihood(data, \
                ModelState (K, T, weights, topicPrior, vocabPrior, n_dk_samples, topicSum, vocabSum, sample_count, dtype, name), \
                QueryState (w_list, z_list, docLens, topicSum, sample_count))
            perp = perplexity_from_like(likely, data)
            print("Sample-Count = %3d  Perplexity = %7.2f" %
                  (sample_count, perp))

        updateWeights(n_dk_samples, sample_count, X, weights, debug)

#     compiled.freeGlobalRng()

    return \
        ModelState (K, T, weights, topicPrior, vocabPrior, n_dk_samples, topicSum, vocabSum, sample_count, dtype, name), \
        QueryState (w_list, z_list, docLens, topicSum, sample_count), \
        (np.zeros(1), np.zeros(1), np.zeros(1))
def query(data, modelState, queryState, queryPlan):
    Given a _trained_ model, attempts to predict the topics for each of
    the inputs.
    data - the dataset of words, features and links of which only words are used in this model
    modelState - the _trained_ model
    queryState - the query state generated for the query dataset
    queryPlan  - used in this case as we need to tighten up the approx
    The model state and query state, in that order. The model state is
    unchanged, the query is.
    iterations, epsilon, logFrequency, diagonalPriorCov, debug = queryPlan.iterations, queryPlan.epsilon, queryPlan.logFrequency, queryPlan.fastButInaccurate, queryPlan.debug
    means, expMeans, varcs, n = queryState.means, queryState.expMeans, queryState.varcs, queryState.docLens
    K, topicMean, sigT, vocab, vocabPrior, A, dtype = modelState.K, modelState.topicMean, modelState.sigT, modelState.vocab, modelState.vocabPrior, modelState.A, modelState.dtype
    debugFn = _debug_with_bound if debug else _debug_with_nothing
    W = data.words
    D = W.shape[0]
    # Necessary temp variables (notably the count of topic to word assignments
    # per topic per doc)
    isigT = la.inv(sigT)
    # Update the Variances
    varcs = 1./((n * (K-1.)/K)[:,np.newaxis] + isigT.flat[::K+1])
    debugFn (0, varcs, "varcs", W, K, topicMean, sigT, vocab, vocabPrior, dtype, means, varcs, A, n)
    lastPerp = 1E+300 if dtype is np.float64 else 1E+30
    R = W.copy()
    for itr in range(iterations):
        expMeans = np.exp(means - means.max(axis=1)[:,np.newaxis], out=expMeans)
        R = sparseScalarQuotientOfDot(W, expMeans, vocab, out=R)
        V = expMeans * R.dot(vocab.T)
        # Update the Means
        rhs = V.copy()
        rhs += n[:,np.newaxis] * means.dot(A) + isigT.dot(topicMean)
        rhs -= n[:,np.newaxis] * rowwise_softmax(means, out=means)
        if diagonalPriorCov:
            means = varcs * rhs
            for d in range(D):
                means[d,:] = la.inv(isigT + n[d] * A).dot(rhs[d,:])
        debugFn (itr, means, "means", W, K, topicMean, sigT, vocab, vocabPrior, dtype, means, varcs, A, n)
        like = log_likelihood(data, modelState, QueryState(means, expMeans, varcs, n))
        perp = perplexity_from_like(like, data.word_count)
        if itr > 20 and lastPerp - perp < 1:
        lastPerp = perp

    return modelState, queryState
    def testCrossValPerplexityOnRealDataWithLdaInc(self):
        ActiveFolds = 3
        dtype = np.float64 # DTYPE

        data = DataSet.from_files(words_file=AclWordPath, links_file=AclCitePath)

        data.prune_and_shuffle(min_doc_len=MinDocLen, min_link_count=MinLinkCount)

        # Initialise the model
        trainPlan = lda.newTrainPlan(iterations=800, logFrequency=10, fastButInaccurate=False, debug=False)
        queryPlan = lda.newTrainPlan(iterations=50, logFrequency=5, fastButInaccurate=False, debug=False)

        topicCounts = [30, 35, 40, 45, 50] # [5, 10, 15, 20, 25, 30, 35, 40, 45, 50]
        for K in topicCounts:
            trainPerps = []
            queryPerps = []
            for fold in range(ActiveFolds): # range(NumFolds):
                trainData, queryData = data.cross_valid_split(fold, NumFolds)

                model = lda.newModelAtRandom(trainData, K, dtype=dtype)
                query = lda.newQueryState(trainData, model)

                # Train the model, and the immediately save the result to a file for subsequent inspection
                model, trainResult, (_, _, _) = lda.train (trainData, model, query, trainPlan)

                like = lda.log_likelihood(trainData, model, trainResult)
                perp = perplexity_from_like(like, trainData.word_count)

                estData, evalData = queryData.doc_completion_split()
                query = lda.newQueryState(estData, model)
                model, queryResult = lda.query(estData, model, query, queryPlan)

                like = lda.log_likelihood(evalData, model, queryResult)
                perp = perplexity_from_like(like, evalData.word_count)

            trainPerps.append(sum(trainPerps) / ActiveFolds)
            queryPerps.append(sum(queryPerps) / ActiveFolds)
            print("K=%d,Segment=Train,%s" % (K, ",".join([str(p) for p in trainPerps])))
            print("K=%d,Segment=Query,%s" % (K, ",".join([str(p) for p in queryPerps])))
    def testPerplexityOnRealDataWithLdaInc(self):
        dtype = np.float64 # DTYPE

        data = DataSet.from_files(words_file=AclWordPath, links_file=AclCitePath)
        with open(AclDictPath, "rb") as f:
            d = pkl.load(f)

        data.prune_and_shuffle(min_doc_len=MinDocLen, min_link_count=MinLinkCount)

        # IDF frequency for when we print out the vocab later
        freq = np.squeeze(np.asarray(data.words.sum(axis=0)))
        scale = np.reciprocal(1 + freq)

        # Initialise the model
        topicCounts = [5, 10, 15, 20, 25, 30, 35, 40, 45, 50]
        perps = []
        for K in topicCounts:
            model      = lda.newModelAtRandom(data, K, dtype=dtype)
            queryState = lda.newQueryState(data, model)
            trainPlan  = lda.newTrainPlan(iterations=800, logFrequency=10, fastButInaccurate=False, debug=False)

            # Train the model, and the immediately save the result to a file for subsequent inspection
            model, query, (bndItrs, bndVals, bndLikes) = lda.train (data, model, queryState, trainPlan)
    #        with open(newModelFileFromModel(model), "wb") as f:
    #            pkl.dump ((model, query, (bndItrs, bndVals, bndLikes)), f)

            # Print out the most likely topic words
            # scale = np.reciprocal(1 + np.squeeze(np.array(data.words.sum(axis=0))))
            # vocab = lda.wordDists(model)
            # topWordCount = 10
            # kTopWordInds = [self.topWordInds(vocab[k,:], topWordCount) for k in range(K)]

            like = lda.log_likelihood(data, model, query)
            perp = perplexity_from_like(like, data.word_count)


            print ("K = %2d : Perplexity = %f\n\n" % (K, perp))
            # for k in range(model.K):
            #     print("\nTopic %d\n=============================" % k)
            #     print("\n".join("%-20s\t%0.4f" % (d[kTopWordInds[k][c]], vocab[k][kTopWordInds[k][c]]) for c in range(topWordCount)))

        # Plot the evolution of the bound during training.
        fig, ax1 = plt.subplots()
        ax1.plot(topicCounts, perps, 'b-')
        ax1.set_xlabel('Topic Count')
        ax1.set_ylabel('Perplexity', color='b')

def _debug_with_bound(itr, var_value, var_name, data, K, topicMean, topicCov,
                      vocab, dtype, means, varcs, A, n):
    if np.isnan(var_value).any():
        printStderr("WARNING: " + var_name + " contains NaNs")
    if np.isinf(var_value).any():
        printStderr("WARNING: " + var_name + " contains INFs")
    if var_value.dtype != dtype:
        printStderr("WARNING: dtype(" + var_name + ") = " +

    model = ModelState(K, topicMean, topicCov, vocab, A, dtype, MODEL_NAME)
    query = QueryState(means, varcs, n)

    old_bound = _debug_with_bound.old_bound
    bound = var_bound(data, model, query)
    diff = "" if old_bound == 0 else "%15.4f" % (bound - old_bound)
    _debug_with_bound.old_bound = bound

    addendum = ""
    if var_name == "topicCov":
            addendum = "det(topicCov) = %g" % (la.det(topicCov))
            addendum = "det(topicCov) = <undefined>"

    if isnan(bound):
        printStderr("Bound is NaN")
        perp = perplexity_from_like(log_likelihood(data, model, query),
        if int(bound - old_bound) < 0:
                "Iter %3d Update %-15s Bound %22f (%15s) (%5.0f)     %s" %
                (itr, var_name, bound, diff, perp, addendum))
            print("Iter %3d Update %-15s Bound %22f (%15s) (%5.0f)  %s" %
                  (itr, var_name, bound, diff, perp, addendum))
def train (data, modelState, queryState, trainPlan):
    Infers the topic distributions in general, and specifically for
    each individual datapoint.
    data - the dataset of words, features and links of which only words and
           features are used in this model
    modelState - the actual CTM model
    queryState - the query results - essentially all the "local" variables
                 matched to the given observations
    trainPlan  - how to execute the training process (e.g. iterations,
                 log-interval etc.)
    A new model object with the updated model (note parameters are
    updated in place, so make a defensive copy if you want it)
    A new query object with the update query parameters
    W, X = data.words, data.feats

    assert W.dtype == modelState.dtype
    assert X.dtype == modelState.dtype
    D,_ = W.shape
    # Unpack the the structs, for ease of access and efficiency
    iterations, epsilon, logFrequency, fastButInaccurate, debug = trainPlan.iterations, trainPlan.epsilon, trainPlan.logFrequency, trainPlan.fastButInaccurate, trainPlan.debug
    means, expMeans, varcs, lxi, s, n = queryState.means, queryState.expMeans, queryState.varcs, queryState.lxi, queryState.s, queryState.docLens
    F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, dtype = modelState.F, modelState.P, modelState.K, modelState.A, modelState.R_A, modelState.fv, modelState.Y, modelState.R_Y, modelState.lfv, modelState.V, modelState.sigT, modelState.vocab, modelState.vocabPrior, modelState.dtype
    # Book-keeping for logs
    boundIters  = np.zeros(shape=(iterations // logFrequency,))
    boundValues = np.zeros(shape=(iterations // logFrequency,))
    likeValues  = np.zeros(shape=(iterations // logFrequency,))
    bvIdx = 0
    _debug_with_bound.old_bound = 0
    debugFn = _debug_with_bound if debug else _debug_with_nothing
    # Initialize some working variables
    isigT = la.inv(sigT)
    R = W.copy()
    sigT_regularizer = 0.001
    aI_P = 1./lfv  * ssp.eye(P, dtype=dtype)
    tI_F = 1./fv * ssp.eye(F, dtype=dtype)
    print("Creating posterior covariance of A, this will take some time...")
    XTX = X.T.dot(X)
    R_A = XTX
    if ssp.issparse(R_A):
        R_A = R_A.todense()  # dense inverse typically as fast or faster than sparse inverse
    R_A.flat[::F+1] += 1./fv # and the result is usually dense in any case
    R_A = la.inv(R_A)
    print("Covariance matrix calculated, launching inference")
    # Iterate over parameters
    for itr in range(iterations):
        # We start with the M-Step, so the parameters are consistent with our
        # initialisation of the RVs when we do the E-Step
        # Update the covariance of the prior
        diff_a_yv = (A-Y.dot(V))
        diff_m_xa = (means-X.dot(A.T))
        sigT  = 1./lfv * (Y.dot(Y.T))
        sigT += 1./fv * diff_a_yv.dot(diff_a_yv.T)
        sigT += diff_m_xa.T.dot(diff_m_xa)
        sigT.flat[::K+1] += varcs.sum(axis=0)
        sigT /= (P+F+D)
        sigT.flat[::K+1] += sigT_regularizer
        # Diagonalize it
        sigT = np.diag(sigT.flat[::K+1])
        # and invert it.
        isigT = np.diag(np.reciprocal(sigT.flat[::K+1]))
        debugFn (itr, sigT, "sigT", W, X, XTX, F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, means, varcs, lxi, s, n)
        # Building Blocks - temporarily replaces means with exp(means)
        expMeans = np.exp(means - means.max(axis=1)[:,np.newaxis], out=expMeans)
        R = sparseScalarQuotientOfDot(W, expMeans, vocab, out=R)
        S = expMeans * R.dot(vocab.T)
        # Update the vocabulary
        vocab *= (R.T.dot(expMeans)).T # Awkward order to maintain sparsity (R is sparse, expMeans is dense)
        vocab += vocabPrior
        vocab = normalizerows_ip(vocab)
        # Reset the means to their original form, and log effect of vocab update
        debugFn (itr, vocab, "vocab", W, X, XTX, F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, means, varcs, lxi, s, n)
        # Finally update the parameter V
        V = la.inv(R_Y + Y.T.dot(isigT).dot(Y)).dot(Y.T.dot(isigT).dot(A))
        debugFn (itr, V, "V", W, X, XTX, F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, means, varcs, lxi, s, n)
        # And now this is the E-Step, though it's followed by updates for the
        # parameters also that handle the log-sum-exp approximation.
        # Update the distribution on the latent space
        R_Y_base = aI_P + 1/fv * V.dot(V.T)
        R_Y = la.inv(R_Y_base)
        debugFn (itr, R_Y, "R_Y", W, X, XTX, F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, means, varcs, lxi, s, n)
        Y = 1./fv * A.dot(V.T).dot(R_Y)
        debugFn (itr, Y, "Y", W, X, XTX, F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, means, varcs, lxi, s, n)
        # Update the mapping from the features to topics
        A = (1./fv * (Y).dot(V) + (X.T.dot(means)).T).dot(R_A)
        debugFn (itr, A, "A", W, X, XTX, F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, means, varcs, lxi, s, n)
        # Update the Means
        vMat   = (s[:,np.newaxis] * lxi - 0.5) * n[:,np.newaxis] + S
        rhsMat = vMat + X.dot(A.T).dot(isigT) # TODO Verify this
        lhsMat = np.reciprocal(np.diag(isigT)[np.newaxis,:] + n[:,np.newaxis] *  lxi)  # inverse of D diagonal matrices...
        means = lhsMat * rhsMat # as LHS is a diagonal matrix for all d, it's equivalent
                                # do doing a hadamard product for all d
        debugFn (itr, means, "means", W, X, XTX, F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, means, varcs, lxi, s, n)
        # Update the Variances
        varcs = 1./(n[:,np.newaxis] * lxi + isigT.flat[::K+1])
        debugFn (itr, varcs, "varcs", W, X, XTX, F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, means, varcs, lxi, s, n)
        # Update the approximation parameters
        lxi = 2 * ctm.negJakkolaOfDerivedXi(means, varcs, s)
        debugFn (itr, lxi, "lxi", W, X, XTX, F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, means, varcs, lxi, s, n)
        # s can sometimes grow unboundedly
        # Follow Bouchard's suggested approach of fixing it at zero
#         s = (np.sum(lxi * means, axis=1) + 0.25 * K - 0.5) / np.sum(lxi, axis=1)
#         debugFn (itr, s, "s", W, X, XTX, F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, means, varcs, lxi, s, n)
        if logFrequency > 0 and itr % logFrequency == 0:
            modelState = ModelState(F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, MODEL_NAME)
            queryState = QueryState(means, expMeans, varcs, lxi, s, n)
            boundValues[bvIdx] = var_bound(data, modelState, queryState, XTX)
            likeValues[bvIdx]  = log_likelihood(data, modelState, queryState)
            boundIters[bvIdx]  = itr
            perp = perplexity_from_like(likeValues[bvIdx], n.sum())
            print (time.strftime('%X') + " : Iteration %d: Perplexity %4.2f  bound %f" % (itr, perp, boundValues[bvIdx]))
            if bvIdx > 0 and  boundValues[bvIdx - 1] > boundValues[bvIdx]:
                printStderr ("ERROR: bound degradation: %f > %f" % (boundValues[bvIdx - 1], boundValues[bvIdx]))
#             print ("Means: min=%f, avg=%f, max=%f\n\n" % (means.min(), means.mean(), means.max()))

            # Check to see if the improvment in the likelihood has fallen below the threshold
            if bvIdx > 1 and boundIters[bvIdx] > 50:
                lastPerp = perplexity_from_like(likeValues[bvIdx - 1], n.sum())
                if lastPerp - perp < 1:
                    boundIters, boundValues, likelyValues = clamp (boundIters, boundValues, likeValues, bvIdx)
                    return modelState, queryState, (boundIters, boundValues, likeValues)
            bvIdx += 1

    return \
        ModelState(F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, MODEL_NAME), \
        QueryState(means, expMeans, varcs, lxi, s, n), \
        (boundIters, boundValues, likeValues)
    def testPerplexityOnRealDataWithMtm2(self):
        dtype = np.float64 # DTYPE

        data = DataSet.from_files(words_file=AclWordPath, links_file=AclCitePath)
        with open(AclDictPath, "rb") as f:
            d = pkl.load(f)

        data.prune_and_shuffle(min_doc_len=MinDocLen, min_link_count=MinLinkCount)

        # IDF frequency for when we print out the vocab later
        freq = np.squeeze(np.asarray(data.words.sum(axis=0)))
        scale = np.reciprocal(1 + freq)

        # Initialise the model
        K = 30 # TopicCount
        model      = mtm2.newModelAtRandom(data, K, dtype=dtype)
        queryState = mtm2.newQueryState(data, model)
        trainPlan  = mtm2.newTrainPlan(iterations=200, logFrequency=10, fastButInaccurate=False, debug=False)

        # Train the model, and the immediately save the result to a file for subsequent inspection
        model, query, (bndItrs, bndVals, bndLikes) = mtm2.train(data, model, queryState, trainPlan)
#        with open(newModelFileFromModel(model), "wb") as f:
#            pkl.dump ((model, query, (bndItrs, bndVals, bndLikes)), f)

        # Plot the evolution of the bound during training.
        fig, ax1 = plt.subplots()
        ax1.plot(bndItrs, bndVals, 'b-')
        ax1.set_ylabel('Bound', color='b')

        ax2 = ax1.twinx()
        ax2.plot(bndItrs, bndLikes, 'r-')
        ax2.set_ylabel('Likelihood', color='r')


        fig, ax1 = plt.subplots()
        ax1.imshow(model.topicCov, interpolation="nearest", cmap=cm.Greys_r)

        # Print out the most likely topic words
        # scale = np.reciprocal(1 + np.squeeze(np.array(data.words.sum(axis=0))))
        vocab = mtm2.wordDists(model)
        topWordCount = 10
        kTopWordInds = [self.topWordInds(vocab[k,:], topWordCount) for k in range(K)]

        like = mtm2.log_likelihood(data, model, query)
        perp = perplexity_from_like(like, data.word_count)

        print("Perplexity: %f\n\n" % perp)

        for k in range(model.K):
            print("\nTopic %d\n=============================" % k)
            print("\n".join("%-20s\t%0.4f" % (d[kTopWordInds[k][c]], vocab[k][kTopWordInds[k][c]]) for c in range(topWordCount)))

        print ("Most likely documents for each topic")
        print ("====================================")
        with open ("/Users/bryanfeeney/iCloud/Datasets/ACL/ACL.100/doc_ids.pkl", 'rb') as f:
            fileIds = pkl.load (f)
        docs_dict = [fileIds[fi] for fi in data.order]

        for k in range(model.K):
            arg_max_prob = np.argmax(query.means[:, k])
            print("K=%2d  Document ID = %s (found at %d)" % (k, docs_dict[arg_max_prob], arg_max_prob))

        print ("Done")

        with open ("/Users/bryanfeeney/Desktop/mtm2-" + str(K) + ".pkl", "wb") as f:
            pkl.dump((model, query), f)
def train(data, model: ModelState, query: QueryState, plan: TrainPlan, updateVocab=True):
    Infers the topic distributions in general, and specifically for
    each individual datapoint,

    data - the training data, we just use the DxT document-term matrix
    model - the initial model configuration. This is MUTATED IN-PLACE
    query - the query results - essentially all the "local" variables
            matched to the given observations. Also MUTATED IN-PLACE
    plan  - how to execute the training process (e.g. iterations,
            log-interval etc.)

    The updated model object (note parameters are updated in place, so make a
    defensive copy if you want it)
    The query object with the update query parameters
    iterations, epsilon, logFrequency, fastButInaccurate, debug = \
        plan.iterations, plan.epsilon, plan.logFrequency, plan.fastButInaccurate, plan.debug
    docLens = query.docLens
    topicDist = topicDists(query)
    K, topicPrior, vocabPrior, wordDistParam, corpusTopicDistParam, dtype = \
        model.K, model.topicPrior, model.vocabPrior, wordDistsDirichletParam(model), corpusTopicDistDirichletParam(model), model.dtype

    W = data.words

    iters, bnds, likes = [], [], []

    # Quick sanity check
    if np.any(docLens < 1):
        raise ValueError("Input document-term matrix contains at least one document with no words")
    assert dtype == np.float64, "Only implemented for 64-bit floats"

    lnCorpusTopicDist = fns.digamma(corpusTopicDistParam) - fns.digamma(corpusTopicDistParam.sum())
    lnWordDist = fns.digamma(wordDistParam) - fns.digamma(wordDistParam.sum(axis=1))[:, np.newaxis]
    oldTopicDist = np.ndarray(topicDist.shape, dtype=topicDist.dtype)

    for itr in range(iterations):
        oldTopicDist[:, :] = topicDist[:, :]

        topicDist[:, :] = (data.words @ lnWordDist.T)
        topicDist[:, :] += lnCorpusTopicDist[np.newaxis, :]
        topicDist -= topicDist.max(axis=1)[:, np.newaxis]
        np.exp(topicDist, out=topicDist)
        topicDist /= topicDist.sum(axis=1)[:, np.newaxis]

        if np.abs(oldTopicDist - topicDist).sum() < CHANGE_TOLERANCE * topicDist.shape[0] * topicDist.shape[1]:
            logging.info(f"Stopping train after {itr + 1} iterations as change in topic distibution is minimal")

        corpusTopicDistParam = topicDist.sum(axis=0) + model.topicPrior
        fns.digamma(corpusTopicDistParam, out=lnCorpusTopicDist)
        lnCorpusTopicDist -= fns.digamma(corpusTopicDistParam.sum())

        # Derive new parameter estimates
        wordDistParam = (data.words.T @ topicDist).T \
                      + model.vocabPrior[np.newaxis, :]
        fns.digamma(wordDistParam, out=lnWordDist)
        lnWordDist -= fns.digamma(wordDistParam.sum(axis=1))[:, np.newaxis]

        if debug or (logFrequency > 0 and itr % logFrequency == 0):
            m = ModelState(K, topicPrior, vocabPrior, wordDistParam, corpusTopicDistParam, True, dtype, model.name)
            q = QueryState(query.docLens, topicDist, True)

            bnds.append(var_bound(data, m, q))
            likes.append(log_likelihood_point(data, m, q))

            perp = perplexity_from_like(likes[-1], W.sum())
            print("Iteration %d : Train Perp = %4.0f  Bound = %.3f" % (itr, perp, bnds[-1]))

            if len(iters) > 2 and iters[-1] > 50:
                lastPerp = perplexity_from_like(likes[-2], W.sum())
                if lastPerp - perp < 1:

    return ModelState(K, topicPrior, vocabPrior, wordDistParam, corpusTopicDistParam, True, dtype, model.name), \
           QueryState(query.docLens, topicDist, True), \
           (np.array(iters, dtype=np.int32), np.array(bnds), np.array(likes))
def train(data, modelState, queryState, trainPlan):
    Infers the topic distributions in general, and specifically for
    each individual datapoint.
    W - the DxT document-term matrix
    X - The DxF document-feature matrix, which is IGNORED in this case
    modelState - the actual CTM model
    queryState - the query results - essentially all the "local" variables
                 matched to the given observations
    trainPlan  - how to execute the training process (e.g. iterations,
                 log-interval etc.)
    A new model object with the updated model (note parameters are
    updated in place, so make a defensive copy if you want itr)
    A new query object with the update query parameters
    W, X = data.words, data.feats
    D, T = W.shape
    F = X.shape[1]

    # tmpNumDense = np.array([
    #     4	, 8	, 2	, 0	, 0,
    #     0	, 6	, 0	, 17, 0,
    #     12	, 13	, 1	, 7	, 8,
    #     0	, 5	, 0	, 0	, 0,
    #     0	, 6	, 0	, 0	, 44,
    #     0	, 7	, 2	, 0	, 0], dtype=np.float64).reshape((6,5))
    # tmpNum = ssp.csr_matrix(tmpNumDense)
    # tmpDenomleft = (rd.random((tmpNum.shape[0], 12)) * 5).astype(np.int32).astype(np.float64) / 10
    # tmpDenomRight = (rd.random((12, tmpNum.shape[1])) * 5).astype(np.int32).astype(np.float64)
    # tmpResult = tmpNum.copy()
    # tmpResult = sparseScalarQuotientOfDot(tmpNum, tmpDenomleft, tmpDenomRight)
    # print (str(tmpNum.todense()))
    # print (str(tmpDenomleft.dot(tmpDenomRight)))
    # print (str(tmpResult.todense()))

    # Unpack the the structs, for ease of access and efficiency
    iterations, epsilon, logFrequency, diagonalPriorCov, debug = trainPlan.iterations, trainPlan.epsilon, trainPlan.logFrequency, trainPlan.fastButInaccurate, trainPlan.debug
    means, docLens = queryState.means, queryState.docLens
    K, A, U, Y,  V, covA, tv, ltv, fv, lfv, vocab, vocabPrior, dtype = \
        modelState.K, modelState.A, modelState.U, modelState.Y,  modelState.V, modelState.covA, modelState.tv, modelState.ltv, modelState.fv, modelState.lfv, modelState.vocab, modelState.vocabPrior, modelState.dtype

    tp, fp, ltp, lfp = 1. / tv, 1. / fv, 1. / ltv, 1. / lfv  # turn variances into precisions

    # FIXME Use passed in hypers
    print("tp = %f tv=%f" % (tp, tv))
    vocabPrior = np.ones(shape=(T, ), dtype=modelState.dtype)

    # FIXME undo truncation
    F = 363
    A = A[:F, :]
    X = X[:, :F]
    U = U[:F, :]
    data = DataSet(words=W, feats=X)

    # Book-keeping for logs
    boundIters, boundValues, likelyValues = [], [], []

    debugFn = _debug_with_bound if debug else _debug_with_nothing

    # Initialize some working variables
    if covA is None:
        precA = (fp * ssp.eye(F) +
                 X.T.dot(X)).todense()  # As the inverse is almost always dense
        covA = la.inv(precA,
                      overwrite_a=True)  # it's faster to densify in advance
    uniqLens = np.unique(docLens)

    debugFn(-1, covA, "covA", W, X, means, docLens, K, A, U, Y, V, covA, tv,
            ltv, fv, lfv, vocab, vocabPrior)

    H = 0.5 * (np.eye(K) - np.ones((K, K), dtype=dtype) / K)

    expMeans = means.copy()
    expMeans = np.exp(means - means.max(axis=1)[:, np.newaxis], out=expMeans)
    R = sparseScalarQuotientOfDot(W, expMeans, vocab, out=W.copy())

    lhs = H.copy()
    rhs = expMeans.copy()
    Y_rhs = Y.copy()

    # Iterate over parameters
    for itr in range(iterations):

        # Update U, V given A
        V = try_solve_sym_pos(Y.T.dot(U.T).dot(U).dot(Y),
        V /= V[0, 0]
        U = try_solve_sym_pos(Y.dot(V.T).dot(V).dot(Y.T),

        # Update Y given U, V, A
        Y_rhs[:, :] = U.T.dot(A).dot(V)

        Sv, Uv = la.eigh(V.T.dot(V), overwrite_a=True)
        Su, Uu = la.eigh(U.T.dot(U), overwrite_a=True)

        s = np.outer(Sv, Su).flatten()
        s += ltv * lfv
        np.reciprocal(s, out=s)

        M = Uu.T.dot(Y_rhs).dot(Uv)
        M *= unvec(s, row_count=M.shape[0])

        Y = Uu.dot(M).dot(Uv.T)
        debugFn(itr, Y, "Y", W, X, means, docLens, K, A, U, Y, V, covA, tv,
                ltv, fv, lfv, vocab, vocabPrior)

        A = covA.dot(fp * U.dot(Y).dot(V.T) + X.T.dot(means))
        debugFn(itr, A, "A", W, X, means, docLens, K, A, U, Y, V, covA, tv,
                ltv, fv, lfv, vocab, vocabPrior)

        # And now this is the E-Step, though itr's followed by updates for the
        # parameters also that handle the log-sum-exp approximation.

        # TODO One big sort by size, plus batch it.

        # Update the Means

        rhs[:, :] = expMeans
        rhs *= R.dot(vocab.T)
        rhs += X.dot(A) * tp
        rhs += docLens[:, np.newaxis] * means.dot(H)
        rhs -= docLens[:, np.newaxis] * rowwise_softmax(means, out=means)
        for l in uniqLens:
            inds = np.where(docLens == l)[0]
            lhs[:, :] = l * H
            lhs[np.diag_indices_from(lhs)] += tp
            lhs[:, :] = la.inv(lhs)
            means[inds, :] = rhs[inds, :].dot(
            )  # left and right got switched going from vectors to matrices :-/

        debugFn(itr, means, "means", W, X, means, docLens, K, A, U, Y, V, covA,
                tv, ltv, fv, lfv, vocab, vocabPrior)

        # Standard deviation
        # DK        = means.shape[0] * means.shape[1]
        # newTp     = np.sum(means)
        # newTp     = (-newTp * newTp)
        # rhs[:,:]  = means
        # rhs      *= means
        # newTp     = DK * np.sum(rhs) - newTp
        # newTp    /= DK * (DK - 1)
        # newTp     = min(max(newTp, 1E-36), 1E+36)
        # tp        = 1 / newTp
        # if itr % logFrequency == 0:
        #     print ("Iter %3d stdev = %f, prec = %f, np.std^2=%f, np.mean=%f" % (itr, sqrt(newTp), tp, np.std(means.reshape((D*K,))) ** 2, np.mean(means.reshape((D*K,)))))

        # Update the vocabulary
        expMeans = np.exp(means - means.max(axis=1)[:, np.newaxis],
        R = sparseScalarQuotientOfDot(W, expMeans, vocab, out=R)

        vocab *= (
        ).T  # Awkward order to maintain sparsity (R is sparse, expMeans is dense)
        vocab += vocabPrior
        vocab = normalizerows_ip(vocab)

        debugFn(itr, vocab, "vocab", W, X, means, docLens, K, A, U, Y, V, covA,
                tv, ltv, fv, lfv, vocab, vocabPrior)
        # print ("Iter %3d Vocab.min = %f" % (itr, vocab.min()))

        # Update the vocab prior
        # vocabPrior = estimate_dirichlet_param (vocab, vocabPrior)
        # print ("Iter %3d VocabPrior.(min, max) = (%f, %f) VocabPrior.mean=%f" % (itr, vocabPrior.min(), vocabPrior.max(), vocabPrior.mean()))

        if logFrequency > 0 and itr % logFrequency == 0:
            modelState = ModelState(K, A, U, Y, V, covA, tv, ltv, fv, lfv,
                                    vocab, vocabPrior, dtype, modelState.name)
            queryState = QueryState(means, docLens)

            boundValues.append(var_bound(data, modelState, queryState))
            likelyValues.append(log_likelihood(data, modelState, queryState))

                time.strftime('%X') +
                " : Iteration %d: bound %f \t Perplexity: %.2f" %
                (itr, boundValues[-1],
                 perplexity_from_like(likelyValues[-1], docLens.sum())))
            if len(boundValues) > 1:
                if boundValues[-2] > boundValues[-1]:
                    if debug:
                        printStderr("ERROR: bound degradation: %f > %f" %
                                    (boundValues[-2], boundValues[-1]))

                # Check to see if the improvement in the bound has fallen below the threshold
                if itr > 100 and len(likelyValues) > 3 \
                    and abs(perplexity_from_like(likelyValues[-1], docLens.sum()) - perplexity_from_like(likelyValues[-2], docLens.sum())) < 1.0:

    return \
        ModelState(K, A, U, Y,  V, covA, tv, ltv, fv, lfv, vocab, vocabPrior, dtype, modelState.name), \
        QueryState(means, expMeans, docLens), \
        (np.array(boundIters), np.array(boundValues), np.array(likelyValues))
def train(data, modelState, queryState, trainPlan):
    Infers the topic distributions in general, and specifically for
    each individual datapoint.
    W - the DxT document-term matrix
    X - The DxF document-feature matrix, which is IGNORED in this case
    modelState - the actual CTM model
    queryState - the query results - essentially all the "local" variables
                 matched to the given observations
    trainPlan  - how to execute the training process (e.g. iterations,
                 log-interval etc.)

    A new model object with the updated model (note parameters are
    updated in place, so make a defensive copy if you want itr)
    A new query object with the update query parameters
    W, L, LT, X = data.words, data.links, ssp.csr_matrix(
        data.links.T), data.feats
    D, _ = W.shape
    out_links = np.squeeze(np.asarray(data.links.sum(axis=1)))

    # Unpack the the structs, for ease of access and efficiency
    iterations, epsilon, logFrequency, diagonalPriorCov, debug = trainPlan.iterations, trainPlan.epsilon, trainPlan.logFrequency, trainPlan.fastButInaccurate, trainPlan.debug
    means, varcs, docLens = queryState.means, queryState.varcs, queryState.docLens
    K, topicMean, topicCov, vocab, A, dtype = modelState.K, modelState.topicMean, modelState.topicCov, modelState.vocab, modelState.A, modelState.dtype

    emit_counts = docLens + out_links

    # Book-keeping for logs
    boundIters, boundValues, likelyValues = [], [], []

    if debug:
        debugFn = _debug_with_bound

        initLikely = log_likelihood(data, modelState, queryState)
        initPerp = perplexity_from_like(initLikely, data.word_count)
        print("Initial perplexity is: %.2f" % initPerp)
        debugFn = _debug_with_nothing

    # Initialize some working variables
    W_weight = W.copy()
    L_weight = L.copy()
    LT_weight = LT.copy()

    pseudoObsMeans = K + NIW_PSEUDO_OBS_MEAN
    pseudoObsVar = K + NIW_PSEUDO_OBS_VAR
    priorSigT_diag = np.ndarray(shape=(K, ), dtype=dtype)

    # Iterate over parameters
    for itr in range(iterations):

        # We start with the M-Step, so the parameters are consistent with our
        # initialisation of the RVs when we do the E-Step

        # Update the mean and covariance of the prior
        topicMean = means.sum(axis = 0) / (D + pseudoObsMeans) \
                  if USE_NIW_PRIOR \
                  else means.mean(axis=0)
        debugFn(itr, topicMean, "topicMean", data, K, topicMean, topicCov,
                vocab, dtype, means, varcs, A, docLens)

        if USE_NIW_PRIOR:
            diff = means - topicMean[np.newaxis, :]
            topicCov = diff.T.dot(diff) \
                 + pseudoObsVar * np.outer(topicMean, topicMean)
            topicCov += np.diag(varcs.mean(axis=0) + priorSigT_diag)
            topicCov /= (D + pseudoObsVar - K)
            topicCov = np.cov(
                means.T) if topicCov.dtype == np.float64 else np.cov(
            topicCov += np.diag(varcs.mean(axis=0))

        if diagonalPriorCov:
            diag = np.diag(topicCov)
            topicCov = np.diag(diag)
            itopicCov = np.diag(1. / diag)
            itopicCov = la.inv(topicCov)

        debugFn(itr, topicCov, "topicCov", data, K, topicMean, topicCov, vocab,
                dtype, means, varcs, A, docLens)
        #        print("                topicCov.det = " + str(la.det(topicCov)))

        # Building Blocks - temporarily replaces means with exp(means)
        expMeansCol = np.exp(means - means.max(axis=0)[np.newaxis, :])
        lse_at_k = np.sum(expMeansCol, axis=0)
        F = 0.5 * means \
          - (1. / (2*D + 2)) * means.sum(axis=0) \
          - expMeansCol / lse_at_k[np.newaxis, :]

        expMeansRow = np.exp(means - means.max(axis=1)[:, np.newaxis])
        W_weight = sparseScalarQuotientOfDot(W,

        # Update the vocabularies

        vocab *= (
        ).T  # Awkward order to maintain sparsity (R is sparse, expMeans is dense)
        vocab += VocabPrior
        vocab = normalizerows_ip(vocab)

        docVocab = (
            expMeansCol /
            lse_at_k[np.newaxis, :]).T  # FIXME Dupes line in definitino of F

        # Recalculate w_top_sums with the new vocab and log vocab improvement
        W_weight = sparseScalarQuotientOfDot(W,
        w_top_sums = W_weight.dot(vocab.T) * expMeansRow

        debugFn(itr, vocab, "vocab", data, K, topicMean, topicCov, vocab,
                dtype, means, varcs, A, docLens)

        # Now do likewise for the links, do it twice to model in-counts (first) and
        # out-counts (Second). The difference is the transpose
        LT_weight = sparseScalarQuotientOfDot(LT,
        l_intop_sums = LT_weight.dot(docVocab.T) * expMeansRow
        in_counts = l_intop_sums.sum(axis=0)

        L_weight = sparseScalarQuotientOfDot(L,
        l_outtop_sums = L_weight.dot(docVocab.T) * expMeansRow

        # Reset the means and use them to calculate the weighted sum of means
        meanSum = means.sum(axis=0) * in_counts

        # And now this is the E-Step, though itr's followed by updates for the
        # parameters also that handle the log-sum-exp approximation.

        # Update the Variances: var_d = (2 N_d * A + itopicCov)^{-1}
        varcs = np.reciprocal(docLens[:, np.newaxis] * (0.5 - 1. / K) +
        debugFn(itr, varcs, "varcs", data, K, topicMean, topicCov, vocab,
                dtype, means, varcs, A, docLens)

        # Update the Means
        rhs = w_top_sums.copy()
        rhs += l_intop_sums
        rhs += l_outtop_sums
        rhs += itopicCov.dot(topicMean)
        rhs += emit_counts[:, np.newaxis] * (means.dot(A) -
        rhs += in_counts[np.newaxis, :] * F
        if diagonalPriorCov:
            raise ValueError("Not implemented")
            for d in range(D):
                rhs_ = rhs[d, :] + (1. /
                                    (4 * D + 4)) * (meanSum -
                                                    in_counts * means[d, :])
                means[d, :] = la.inv(itopicCov + emit_counts[d] * A +
                                     np.diag(D * in_counts /
                                             (2 * D + 2))).dot(rhs_)
                if np.any(np.isnan(means[d, :])) or np.any(
                        np.isinf(means[d, :])):

                if np.any(np.isnan(
                        np.exp(means[d, :] - means[d, :].max()))) or np.any(
                            np.isinf(np.exp(means[d, :] - means[d, :].max()))):

        debugFn(itr, means, "means", data, K, topicMean, topicCov, vocab,
                dtype, means, varcs, A, docLens)

        if logFrequency > 0 and itr % logFrequency == 0:
            modelState = ModelState(K, topicMean, topicCov, vocab, A, dtype,
            queryState = QueryState(means, varcs, docLens)

            boundValues.append(var_bound(data, modelState, queryState))
            likelyValues.append(log_likelihood(data, modelState, queryState))

                time.strftime('%X') +
                " : Iteration %d: bound %f \t Perplexity: %.2f" %
                (itr, boundValues[-1],
                 perplexity_from_like(likelyValues[-1], docLens.sum())))
            if len(boundValues) > 1:
                if boundValues[-2] > boundValues[-1]:
                    printStderr("ERROR: bound degradation: %f > %f" %
                                (boundValues[-2], boundValues[-1]))

                # Check to see if the improvement in the bound has fallen below the threshold
                if False and itr > 100 and abs(
                        perplexity_from_like(likelyValues[-1], docLens.sum()) -
                        perplexity_from_like(likelyValues[-2], docLens.sum())
                ) < 1.0:

    return \
        ModelState(K, topicMean, topicCov, vocab, A, dtype, MODEL_NAME), \
        QueryState(means, varcs, docLens), \
        (np.array(boundIters), np.array(boundValues), np.array(likelyValues))
def train (data, modelState, queryState, trainPlan):
    Infers the topic distributions in general, and specifically for
    each individual datapoint.
    W - the DxT document-term matrix
    X - The DxF document-feature matrix, which is IGNORED in this case
    modelState - the actual CTM model
    queryState - the query results - essentially all the "local" variables
                 matched to the given observations
    trainPlan  - how to execute the training process (e.g. iterations,
                 log-interval etc.)
    A new model object with the updated model (note parameters are
    updated in place, so make a defensive copy if you want itr)
    A new query object with the update query parameters
    W   = data.words
    D,_ = W.shape
    # Unpack the the structs, for ease of access and efficiency
    iterations, epsilon, logFrequency, diagonalPriorCov, debug = trainPlan.iterations, trainPlan.epsilon, trainPlan.logFrequency, trainPlan.fastButInaccurate, trainPlan.debug
    means, expMeans, varcs, docLens = queryState.means, queryState.expMeans, queryState.varcs, queryState.docLens
    K, topicMean, sigT, vocab, vocabPrior, A, dtype = modelState.K, modelState.topicMean, modelState.sigT, modelState.vocab, modelState.vocabPrior, modelState.A, modelState.dtype
    # Book-keeping for logs
    boundIters, boundValues, likelyValues = [], [], []
    debugFn = _debug_with_bound if debug else _debug_with_nothing
    # Initialize some working variables
    isigT = la.inv(sigT)
    R = W.copy()
    pseudoObsMeans = K + NIW_PSEUDO_OBS_MEAN
    pseudoObsVar   = K + NIW_PSEUDO_OBS_VAR
    priorSigT_diag = np.ndarray(shape=(K,), dtype=dtype)
    priorSigT_diag.fill (NIW_PSI)
    # Iterate over parameters
    for itr in range(iterations):
        # We start with the M-Step, so the parameters are consistent with our
        # initialisation of the RVs when we do the E-Step
        # Update the mean and covariance of the prior
        topicMean = means.sum(axis = 0) / (D + pseudoObsMeans) \
                  if USE_NIW_PRIOR \
                  else means.mean(axis=0)
        debugFn (itr, topicMean, "topicMean", W, K, topicMean, sigT, vocab, vocabPrior, dtype, means, varcs, A, docLens)
        if USE_NIW_PRIOR:
            diff = means - topicMean[np.newaxis,:]
            sigT = diff.T.dot(diff) \
                 + pseudoObsVar * np.outer(topicMean, topicMean)
            sigT += np.diag(varcs.mean(axis=0) + priorSigT_diag)
            sigT /= (D + pseudoObsVar - K)
            sigT = np.cov(means.T) if sigT.dtype == np.float64 else np.cov(means.T).astype(dtype)
            sigT += np.diag(varcs.mean(axis=0))
        if diagonalPriorCov:
            diag = np.diag(sigT)
            sigT = np.diag(diag)
            isigT = np.diag(1./ diag)
            isigT = la.inv(sigT)

        # FIXME Undo debug
        sigT  = np.eye(K)
        isigT = la.inv(sigT)
        debugFn (itr, sigT, "sigT", W, K, topicMean, sigT, vocab, vocabPrior, dtype, means, varcs, A, docLens)
#        print("                sigT.det = " + str(la.det(sigT)))
        # Building Blocks - temporarily replaces means with exp(means)
        expMeans = np.exp(means - means.max(axis=1)[:,np.newaxis], out=expMeans)
        R = sparseScalarQuotientOfDot(W, expMeans, vocab, out=R)
        # Update the vocabulary
        vocab *= (R.T.dot(expMeans)).T # Awkward order to maintain sparsity (R is sparse, expMeans is dense)
        vocab += vocabPrior
        vocab = normalizerows_ip(vocab)
        # Reset the means to their original form, and log effect of vocab update
        R = sparseScalarQuotientOfDot(W, expMeans, vocab, out=R)
        V = expMeans * R.dot(vocab.T)

        debugFn (itr, vocab, "vocab", W, K, topicMean, sigT, vocab, vocabPrior, dtype, means, varcs, A, docLens)
        # And now this is the E-Step, though itr's followed by updates for the
        # parameters also that handle the log-sum-exp approximation.
        # Update the Variances: var_d = (2 N_d * A + isigT)^{-1}
        varcs = np.reciprocal(docLens[:,np.newaxis] * (K-1.)/K + np.diagonal(sigT))
        debugFn (itr, varcs, "varcs", W, K, topicMean, sigT, vocab, vocabPrior, dtype, means, varcs, A, docLens)
        # Update the Means
        rhs = V.copy()
        rhs += docLens[:,np.newaxis] * means.dot(A) + isigT.dot(topicMean)
        rhs -= docLens[:,np.newaxis] * rowwise_softmax(means, out=means)
        if diagonalPriorCov:
            means = varcs * rhs
            for d in range(D):
                means[d, :] = la.inv(isigT + docLens[d] * A).dot(rhs[d, :])
#         means -= (means[:,0])[:,np.newaxis]
        debugFn (itr, means, "means", W, K, topicMean, sigT, vocab, vocabPrior, dtype, means, varcs, A, docLens)
        if logFrequency > 0 and itr % logFrequency == 0:
            modelState = ModelState(K, topicMean, sigT, vocab, vocabPrior, A, dtype, MODEL_NAME)
            queryState = QueryState(means, expMeans, varcs, docLens)
            boundValues.append(var_bound(data, modelState, queryState))
            likelyValues.append(log_likelihood(data, modelState, queryState))
            print (time.strftime('%X') + " : Iteration %d: bound %f \t Perplexity: %.2f" % (itr, boundValues[-1], perplexity_from_like(likelyValues[-1], docLens.sum())))
            if len(boundValues) > 1:
                if boundValues[-2] > boundValues[-1]:
                    if debug: printStderr ("ERROR: bound degradation: %f > %f" % (boundValues[-2], boundValues[-1]))
                # Check to see if the improvement in the bound has fallen below the threshold
                if itr > 100 and len(likelyValues) > 3 \
                    and abs(perplexity_from_like(likelyValues[-1], docLens.sum()) - perplexity_from_like(likelyValues[-2], docLens.sum())) < 1.0:

    return \
        ModelState(K, topicMean, sigT, vocab, vocabPrior, A, dtype, MODEL_NAME), \
        QueryState(means, expMeans, varcs, docLens), \
        (np.array(boundIters), np.array(boundValues), np.array(likelyValues))
    def testMapOnRealData(self):
        dtype = np.float64  # DTYPE

        data = DataSet.from_files(words_file=AclWordPath,
        with open(AclDictPath, "rb") as f:
            dic = pkl.load(f)


        trainData, testData = data.doc_completion_split()

        for pseudoNegCount in (5, 10, 25, 50, 100):

            # Initialise the model
            K = TopicCount
            model = rtm.newModelAtRandom(trainData,
                                         pseudoNegCount=data.doc_count *
            queryState = rtm.newQueryState(trainData, model)
            trainPlan = rtm.newTrainPlan(iterations=50,

            # Train the model, and the immediately save the result to a file for subsequent inspection
            model, topics, (bndItrs, bndVals,
                            bndLikes) = rtm.train(trainData, model, queryState,
            #        with open(newModelFileFromModel(model), "wb") as f:
            #            pkl.dump ((model, query, (bndItrs, bndVals, bndLikes)), f)

            # Plot the evolution of the bound during training.
            fig, ax1 = plt.subplots()
            ax1.plot(bndItrs, bndVals, 'b-')
            ax1.set_ylabel('Bound', color='b')

            ax2 = ax1.twinx()
            ax2.plot(bndItrs, bndLikes, 'r-')
            ax2.set_ylabel('Likelihood', color='r')


            # Print out the most likely topic words
            # scale = np.reciprocal(1 + np.squeeze(np.array(data.words.sum(axis=0))))
            vocab = rtm.wordDists(model)
            topWordCount = 10
            kTopWordInds = [
                self.topWordInds(vocab[k, :], topWordCount) for k in range(K)

            like = rtm.log_likelihood(trainData, model, topics)
            perp = perplexity_from_like(like, trainData.word_count)

            # print ("Prior %s" % (str(model.topicPrior)))
            print("Pseudo Neg-Count: %d " % pseudoNegCount)
            print("\tTrain Perplexity: %f\n\n" % perp)

            # for k in range(model.K):
            #     print ("\nTopic %d\n=============================" % k)
            #     print ("\n".join("%-20s\t%0.4f" % (dic[kTopWordInds[k][c]], vocab[k][kTopWordInds[k][c]]) for c in range(topWordCount)))

            min_probs = rtm.min_link_probs(model, topics, testData.links)
            link_probs = rtm.link_probs(model, topics, min_probs)
                map = mean_average_prec(testData.links, link_probs)
                print("Unexpected error")

            print("\tThe Mean-Average-Precision is %.3f" % map)
def train (data, modelState, queryState, trainPlan):
    Infers the topic distributions in general, and specifically for
    each individual datapoint.
    data - the dataset of words, features and links of which only words and
           features are used in this model
    modelState - the actual CTM model
    queryState - the query results - essentially all the "local" variables
                 matched to the given observations
    trainPlan  - how to execute the training process (e.g. iterations,
                 log-interval etc.)
    A new model object with the updated model (note parameters are
    updated in place, so make a defensive copy if you want itr)
    A new query object with the update query parameters
    W, X = data.words, data.feats
    D, _ = W.shape
    # Unpack the the structs, for ease of access and efficiency
    iterations, epsilon, logFrequency, fastButInaccurate, debug = trainPlan.iterations, trainPlan.epsilon, trainPlan.logFrequency, trainPlan.fastButInaccurate, trainPlan.debug
    means, expMeans, varcs, docLens = queryState.means, queryState.expMeans, queryState.varcs, queryState.docLens
    F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, Ab, dtype = modelState.F, modelState.P, modelState.K, modelState.A, modelState.R_A, modelState.fv, modelState.Y, modelState.R_Y, modelState.lfv, modelState.V, modelState.sigT, modelState.vocab, modelState.vocabPrior, modelState.Ab, modelState.dtype
    # Book-keeping for logs
    boundIters  = np.zeros(shape=(iterations // logFrequency,))
    boundValues = np.zeros(shape=(iterations // logFrequency,))
    boundLikes = np.zeros(shape=(iterations // logFrequency,))
    bvIdx = 0
    debugFn = _debug_with_bound if debug else _debug_with_nothing
    _debug_with_bound.old_bound = 0
    # For efficient inference, we need a separate covariance for every unique
    # document length. For products to execute quickly, the doc-term matrix
    # therefore needs to be ordered in ascending terms of document length
    originalDocLens = docLens
    sortIdx = np.argsort(docLens, kind=STABLE_SORT_ALG) # sort needs to be stable in order to be reversible
    W = W[sortIdx,:] # deep sorted copy
    X = X[sortIdx,:]
    means, varcs = means[sortIdx,:], varcs[sortIdx,:]

    docLens = originalDocLens[sortIdx]
    lens, inds = np.unique(docLens, return_index=True)
    inds = np.append(inds, [W.shape[0]])
    # Initialize some working variables
    R = W.copy()
    aI_P = 1./lfv  * ssp.eye(P, dtype=dtype)
    print("Creating posterior covariance of A, this will take some time...")
    XTX = X.T.dot(X)
    R_A = XTX
    R_A = R_A.todense()      # dense inverse typically as fast or faster than sparse inverse
    R_A.flat[::F+1] += 1./fv # and the result is usually dense in any case
    R_A = la.inv(R_A)
    print("Covariance matrix calculated, launching inference")

    diff_m_xa = (means-X.dot(A.T))
    means_cov_with_x_a = diff_m_xa.T.dot(diff_m_xa)

    expMeans = np.zeros((BatchSize, K), dtype=dtype)
    R = np.zeros((BatchSize, K), dtype=dtype)
    S = np.zeros((BatchSize, K), dtype=dtype)
    vocabScale = np.ones(vocab.shape, dtype=dtype)
    # Iterate over parameters
    batchIter = 0
    for itr in range(iterations):
        # We start with the M-Step, so the parameters are consistent with our
        # initialisation of the RVs when we do the E-Step

        # Update the covariance of the prior
        diff_a_yv = (A-Y.dot(V))
        sigT  = 1./lfv * (Y.dot(Y.T))
        sigT += 1./fv * diff_a_yv.dot(diff_a_yv.T)
        sigT += means_cov_with_x_a
        sigT.flat[::K+1] += varcs.sum(axis=0)

        # As small numbers lead to instable inverse estimates, we use the
        # fact that for a scalar a, (a .* X)^-1 = 1/a * X^-1 and use these
        # scales whenever we use the inverse of the unscaled covariance
        sigScale  = 1. / (P+D+F)
        isigScale = 1. / sigScale

        isigT = la.inv(sigT)
        debugFn (itr, sigT, "sigT", W, X, XTX, F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, means, varcs, Ab, docLens)
        # Update the vocabulary
        vocab *= vocabScale
        vocab += vocabPrior
        vocab = normalizerows_ip(vocab)
        debugFn (itr, vocab, "vocab", W, X, XTX, F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, means, varcs, Ab, docLens)
        # Finally update the parameter V
        V = la.inv(sigScale * R_Y + Y.T.dot(isigT).dot(Y)).dot(Y.T.dot(isigT).dot(A))
        debugFn (itr, V, "V", W, X, XTX, F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, means, varcs, Ab, docLens)
        # And now this is the E-Step
        # Update the distribution on the latent space
        R_Y_base = aI_P + 1/fv * V.dot(V.T)
        R_Y = la.inv(R_Y_base)
        debugFn (itr, R_Y, "R_Y", W, X, XTX, F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, means, varcs, Ab, docLens)
        Y = 1./fv * A.dot(V.T).dot(R_Y)
        debugFn (itr, Y, "Y", W, X, XTX, F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, means, varcs, Ab, docLens)
        # Update the mapping from the features to topics
        A = (1./fv * Y.dot(V) + (X.T.dot(means)).T).dot(R_A)
        debugFn (itr, A, "A", W, X, XTX, F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, means, varcs, Ab, docLens)
        # Update the Variances
        varcs = 1./((docLens * (K-1.)/K)[:,np.newaxis] + isigScale * isigT.flat[::K+1])
        debugFn (itr, varcs, "varcs", W, X, XTX, F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, means, varcs, Ab, docLens)

        # Faster version?
        vocabScale[:,:] = 0
        means_cov_with_x_a[:,:] = 0
        for lenIdx in range(len(lens)):
            nd         = lens[lenIdx]
            start, end = inds[lenIdx], inds[lenIdx + 1]
            lhs        = la.inv(isigT + sigScale * nd * Ab) * sigScale

            for d in range(start, end, BatchSize):
                end_d = min(d + BatchSize, end)
                span  = end_d - d

                expMeans[:span,:] = np.exp(means[d:end_d,:] - means[d:end_d,:].max(axis=1)[:span,np.newaxis], out=expMeans[:span,:])
                R = sparseScalarQuotientOfDot(W[d:end_d,:], expMeans[d:end_d,:], vocab)
                S[:span,:] = expMeans[:span, :] * R.dot(vocab.T)

                # Convert expMeans to a softmax(means)
                expMeans[:span,:] /= expMeans[:span,:].sum(axis=1)[:span,np.newaxis]

                mu   = X[d:end_d,:].dot(A.T)
                rhs  = mu.dot(isigT) * isigScale
                rhs += S[:span,:]
                rhs += docLens[d:end_d,np.newaxis] * means[d:end_d,:].dot(Ab)
                rhs -= docLens[d:end_d,np.newaxis] * expMeans[:span,:] # here expMeans is actually softmax(means)

                means[d:end_d,:] = rhs.dot(lhs) # huh?! Left and right refer to eqn for a single mean: once we're talking a DxK matrix it gets swapped

                expMeans[:span,:] = np.exp(means[d:end_d,:] - means[d:end_d,:].max(axis=1)[:span,np.newaxis], out=expMeans[:span,:])
                R = sparseScalarQuotientOfDot(W[d:end_d,:], expMeans[:span,:], vocab, out=R)

                stepSize = (Tau + batchIter) ** -Kappa
                batchIter += 1

                # Do a gradient update of the vocab
                vocabScale += (R.T.dot(expMeans[:span,:])).T
                # vocabScale *= vocab
                # normalizerows_ip(vocabScale)
                # # vocabScale += vocabPrior
                # vocabScale *= stepSize
                # vocab *= (1 - stepSize)
                # vocab += vocabScale

                diff = (means[d:end_d,:] - mu)
                means_cov_with_x_a += diff.T.dot(diff)

#       print("Vec-Means: %f, %f, %f, %f" % (means.min(), means.mean(), means.std(), means.max()))
        debugFn (itr, means, "means", W, X, XTX, F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, means, varcs, Ab, docLens)
        if logFrequency > 0 and itr % logFrequency == 0:
            modelState = ModelState(F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT * sigScale, vocab, vocabPrior, Ab, dtype, MODEL_NAME)
            queryState = QueryState(means, expMeans, varcs, docLens)

            boundValues[bvIdx] = var_bound(DataSet(W, feats=X), modelState, queryState, XTX)
            boundLikes[bvIdx]  = log_likelihood(DataSet(W, feats=X), modelState, queryState)
            boundIters[bvIdx]  = itr
            perp = perplexity_from_like(boundLikes[bvIdx], docLens.sum())
            print (time.strftime('%X') + " : Iteration %d: Perplexity %4.0f bound %f" % (itr, perp, boundValues[bvIdx]))
            if bvIdx > 0 and  boundValues[bvIdx - 1] > boundValues[bvIdx]:
                printStderr ("ERROR: bound degradation: %f > %f" % (boundValues[bvIdx - 1], boundValues[bvIdx]))
#           print ("Means: min=%f, avg=%f, max=%f\n\n" % (means.min(), means.mean(), means.max()))

            # Check to see if the improvement in the likelihood has fallen below the threshold
            if bvIdx > 1 and boundIters[bvIdx] > 20:
                lastPerp = perplexity_from_like(boundLikes[bvIdx - 1], docLens.sum())
                if lastPerp - perp < 1:
                    boundIters, boundValues, likelyValues = clamp (boundIters, boundValues, boundLikes, bvIdx)
            bvIdx += 1
    revert_sort = np.argsort(sortIdx, kind=STABLE_SORT_ALG)
    means       = means[revert_sort,:]
    varcs       = varcs[revert_sort,:]
    docLens     = docLens[revert_sort]
    return \
        ModelState(F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT * sigScale, vocab, vocabPrior, Ab, dtype, MODEL_NAME), \
        QueryState(means, expMeans, varcs, docLens), \
        (boundIters, boundValues, boundLikes)
def query(data, modelState, queryState, queryPlan):
    Given a _trained_ model, attempts to predict the topics for each of
    the inputs.
    data - the dataset of words, features and links of which only words and
           features are used in this model
    modelState - the _trained_ model
    queryState - the query state generated for the query dataset
    queryPlan  - used in this case as we need to tighten up the approx
    The model state and query state, in that order. The model state is
    unchanged, the query is.
    W, X = data.words, data.feats
    D, _ = W.shape

    # Unpack the the structs, for ease of access and efficiency
    iterations, epsilon, logFrequency, fastButInaccurate, debug = queryPlan.iterations, queryPlan.epsilon, queryPlan.logFrequency, queryPlan.fastButInaccurate, queryPlan.debug
    means, expMeans, varcs, n = queryState.means, queryState.expMeans, queryState.varcs, queryState.docLens
    F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, Ab, dtype = modelState.F, modelState.P, modelState.K, modelState.A, modelState.R_A, modelState.fv, modelState.Y, modelState.R_Y, modelState.lfv, modelState.V, modelState.sigT, modelState.vocab, modelState.vocabPrior, modelState.Ab, modelState.dtype

    # TODO Get ride of this via a command-line param
    iterations = max(iterations, 100)

    # Debugging
    debugFn = _debug_with_bound if debug else _debug_with_nothing
    _debug_with_bound.old_bound = 0

    # Necessary values
    isigT = la.inv(sigT)

    lastPerp = 1E+300 if dtype is np.float64 else 1E+30
    for itr in range(iterations):
        # Counts of topic assignments
        expMeans = np.exp(means - means.max(axis=1)[:, np.newaxis],
        R = sparseScalarQuotientOfDot(W, expMeans, vocab)
        S = expMeans * R.dot(vocab.T)

        # the variance
        varcs[:] = 1. / ((n *
                          (K - 1.) / K)[:, np.newaxis] + isigT.flat[::K + 1])
        debugFn(itr, varcs, "query-varcs", W, X, None, F, P, K, A, R_A, fv, Y,
                R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, means, varcs, Ab,

        # Update the Means
        rhs = X.dot(A.T).dot(isigT)
        rhs += S
        rhs += n[:, np.newaxis] * means.dot(Ab)
        rhs -= n[:, np.newaxis] * rowwise_softmax(means, out=means)

        # Long version
        inverses = dict()
        for d in range(D):
            if not n[d] in inverses:
                inverses[n[d]] = la.inv(isigT + n[d] * Ab)
            lhs = inverses[n[d]]
            means[d, :] = lhs.dot(rhs[d, :])
        debugFn(itr, means, "query-means", W, X, None, F, P, K, A, R_A, fv, Y,
                R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, means, varcs, Ab,

        like = log_likelihood(data, modelState,
                              QueryState(means, expMeans, varcs, n))
        perp = perplexity_from_like(like, data.word_count)
        if itr > 20 and lastPerp - perp < 1:
        lastPerp = perp

    return modelState, queryState  # query vars altered in-place
def train(data, modelState, queryState, trainPlan):
    Infers the topic distributions in general, and specifically for
    each individual datapoint.
    data - the dataset of words, features and links of which only words and
           features are used in this model
    modelState - the actual CTM model
    queryState - the query results - essentially all the "local" variables
                 matched to the given observations
    trainPlan  - how to execute the training process (e.g. iterations,
                 log-interval etc.)
    A new model object with the updated model (note parameters are
    updated in place, so make a defensive copy if you want itr)
    A new query object with the update query parameters
    W, X = data.words, data.feats
    D, _ = W.shape

    # Unpack the the structs, for ease of access and efficiency
    iterations, epsilon, logFrequency, fastButInaccurate, debug = trainPlan.iterations, trainPlan.epsilon, trainPlan.logFrequency, trainPlan.fastButInaccurate, trainPlan.debug
    means, expMeans, varcs, docLens = queryState.means, queryState.expMeans, queryState.varcs, queryState.docLens
    F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, Ab, dtype = modelState.F, modelState.P, modelState.K, modelState.A, modelState.R_A, modelState.fv, modelState.Y, modelState.R_Y, modelState.lfv, modelState.V, modelState.sigT, modelState.vocab, modelState.vocabPrior, modelState.Ab, modelState.dtype

    # Book-keeping for logs
    boundIters, boundValues, boundLikes = [], [], []
    debugFn = _debug_with_bound if debug else _debug_with_nothing
    _debug_with_bound.old_bound = 0

    # For efficient inference, we need a separate covariance for every unique
    # document length. For products to execute quickly, the doc-term matrix
    # therefore needs to be ordered in ascending terms of document length
    originalDocLens = docLens
    sortIdx = np.argsort(docLens, kind=STABLE_SORT_ALG
                         )  # sort needs to be stable in order to be reversible
    W = W[sortIdx, :]  # deep sorted copy
    X = X[sortIdx, :]
    means, varcs = means[sortIdx, :], varcs[sortIdx, :]

    docLens = originalDocLens[sortIdx]

    lens, inds = np.unique(docLens, return_index=True)
    inds = np.append(inds, [W.shape[0]])

    # Initialize some working variables
    R = W.copy()

    aI_P = 1. / lfv * ssp.eye(P, dtype=dtype)

    print("Creating posterior covariance of A, this will take some time...")
    XTX = X.T.dot(X)
    R_A = XTX
    leastSquares = lambda feats, targets: la.lstsq(
        feats, targets, lapack_driver="gelsy")[0].T
    if ssp.issparse(
            R_A):  # dense inverse typically as fast or faster than sparse
        R_A = to_dense_array(
            R_A)  # inverse and the result is usually dense in any case
        leastSquares = lambda feats, targets: np.array(
            [ssp.linalg.lsqr(feats, targets[:, k])[0] for k in range(K)])
    R_A.flat[::F + 1] += 1. / fv
    R_A = la.inv(R_A)
    print("Covariance matrix calculated, launching inference")

    priorSigt_diag = np.ndarray(shape=(K, ), dtype=dtype)

    # Iterate over parameters
    for itr in range(iterations):
        A = leastSquares(X, means)
        diff_a_yv = (A - Y.dot(V))

        for _ in range(10):  #(50 if itr == 0 else 1):
            # Update the covariance of the prior
            diff_m_xa = (means - X.dot(A.T))

            sigT = 1. / lfv * (Y.dot(Y.T))
            sigT += 1. / fv * diff_a_yv.dot(diff_a_yv.T)
            sigT += diff_m_xa.T.dot(diff_m_xa)
            sigT.flat[::K + 1] += varcs.sum(axis=0)

            # As small numbers lead to instable inverse estimates, we use the
            # fact that for a scalar a, (a .* X)^-1 = 1/a * X^-1 and use these
            # scales whenever we use the inverse of the unscaled covariance
            sigScale = 1. / (P + D + F)
            isigScale = 1. / sigScale

            isigT = la.inv(sigT)
            debugFn(itr, sigT, "sigT", W, X, XTX, F, P, K, A, R_A, fv, Y, R_Y,
                    lfv, V, sigT, vocab, vocabPrior, dtype, means, varcs, Ab,

            # Update the vocabulary
            vocab *= (
            ).T  # Awkward order to maintain sparsity (R is sparse, expMeans is dense)
            vocab += vocabPrior
            vocab = normalizerows_ip(vocab)

            # Reset the means to their original form, and log effect of vocab update
            R = sparseScalarQuotientOfDot(W, expMeans, vocab, out=R)
            S = expMeans * R.dot(vocab.T)
            debugFn(itr, vocab, "vocab", W, X, XTX, F, P, K, A, R_A, fv, Y,
                    R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, means, varcs,
                    Ab, docLens)

            # Update the Variances
            varcs = 1. / ((docLens * (K - 1.) / K)[:, np.newaxis] +
                          isigScale * isigT.flat[::K + 1])
            debugFn(itr, varcs, "varcs", W, X, XTX, F, P, K, A, R_A, fv, Y,
                    R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, means, varcs,
                    Ab, docLens)

            # Update the Means
            rhs = X.dot(A.T).dot(isigT) * isigScale
            rhs += S
            rhs += docLens[:, np.newaxis] * means.dot(Ab)
            rhs -= docLens[:, np.newaxis] * rowwise_softmax(means, out=means)

            # Faster version?
            for lenIdx in range(len(lens)):
                nd = lens[lenIdx]
                start, end = inds[lenIdx], inds[lenIdx + 1]
                lhs = la.inv(isigT + sigScale * nd * Ab) * sigScale

                means[start:end, :] = rhs[start:end, :].dot(
                )  # huh?! Left and right refer to eqn for a single mean: once we're talking a DxK matrix it gets swapped

    #       print("Vec-Means: %f, %f, %f, %f" % (means.min(), means.mean(), means.std(), means.max()))
            debugFn(itr, means, "means", W, X, XTX, F, P, K, A, R_A, fv, Y,
                    R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, means, varcs,
                    Ab, docLens)

            expMeans = np.exp(means - means.max(axis=1)[:, np.newaxis],

        # for _ in range(150):
        #     # Finally update the parameter V
        #     V = la.inv(sigScale * R_Y + Y.T.dot(isigT).dot(Y)).dot(Y.T.dot(isigT).dot(A))
        #     debugFn(itr, V, "V", W, X, XTX, F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, means,
        #             varcs, Ab, docLens)
        #     # Update the distribution on the latent space
        #     R_Y_base = aI_P + 1 / fv * V.dot(V.T)
        #     R_Y = la.inv(R_Y_base)
        #     debugFn(itr, R_Y, "R_Y", W, X, XTX, F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, dtype,
        #             means, varcs, Ab, docLens)
        #     Y = 1. / fv * A.dot(V.T).dot(R_Y)
        #     debugFn(itr, Y, "Y", W, X, XTX, F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, means,
        #             varcs, Ab, docLens)
        #     # Update the mapping from the features to topics
        #     A = (1. / fv * Y.dot(V) + (X.T.dot(means)).T).dot(R_A)
        #     debugFn(itr, A, "A", W, X, XTX, F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, means,
        #             varcs, Ab, docLens)

        if logFrequency > 0 and itr % logFrequency == 0:
            modelState = ModelState(F, P, K, A, R_A, fv, Y, R_Y, lfv, V,
                                    sigT * sigScale, vocab, vocabPrior, Ab,
                                    dtype, MODEL_NAME)
            queryState = QueryState(means, expMeans, varcs, docLens)

                var_bound(DataSet(W, feats=X), modelState, queryState, XTX))
                log_likelihood(DataSet(W, feats=X), modelState, queryState))
            perp = perplexity_from_like(boundLikes[-1], docLens.sum())
                time.strftime('%X') +
                " : Iteration %d: Perplexity %4.0f bound %f" %
                (itr, perp, boundValues[-1]))
            if len(boundIters) >= 2 and boundValues[-2] > boundValues[-1]:
                printStderr("ERROR: bound degradation: %f > %f" %
                            (boundValues[-2], boundValues[-1]))

#           print ("Means: min=%f, avg=%f, max=%f\n\n" % (means.min(), means.mean(), means.max()))

# Check to see if the improvement in the likelihood has fallen below the threshold
            if len(boundIters) > 2 and boundIters[-1] > 20:
                lastPerp = perplexity_from_like(boundLikes[-2], docLens.sum())
                if lastPerp - perp < 1:

    revert_sort = np.argsort(sortIdx, kind=STABLE_SORT_ALG)
    means = means[revert_sort, :]
    varcs = varcs[revert_sort, :]
    docLens = docLens[revert_sort]

    return \
        ModelState(F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT * sigScale, vocab, vocabPrior, Ab, dtype, MODEL_NAME), \
        QueryState(means, expMeans, varcs, docLens), \
        (boundIters, boundValues, boundLikes)
def train(data, modelState, queryState, trainPlan, query=False):
    Infers the topic distributions in general, and specifically for
    each individual datapoint.

    data - the dataset of words, features and links of which only words are used in this model
    modelState - the actual LDA model. In a training run (query = False) this
                 will be mutated in place, and then returned.
    queryState - the query results - essentially all the "local" variables
                 matched to the given observations. This will be mutated in-place
                 and then returned.
    trainPlan  - how to execute the training process (e.g. iterations,
                 log-interval etc.)
    query      -

    The updated model object (note parameters are updated in place, so make a
    defensive copy if you want it)
    The query object with the update query parameters
    iterations, epsilon, logFrequency, fastButInaccurate, debug = \
        trainPlan.iterations, trainPlan.epsilon, trainPlan.logFrequency, trainPlan.fastButInaccurate, trainPlan.debug
    W_list, docLens, q_n_dk, q_n_kt, q_n_k, z_dnk = \
        queryState.W_list, queryState.docLens, queryState.n_dk, queryState.n_kt, queryState.n_k, queryState.z_dnk
    K, topicPrior_, vocabPrior, m_n_dk, m_n_kt, m_n_k = \
        modelState.K, modelState.topicPrior, modelState.vocabPrior, modelState.n_dk, modelState.n_kt, modelState.n_k
    topicPrior = topicPrior_.mean()

    D_train = 0 if m_n_dk is None else m_n_dk.shape[0]
    D_query = q_n_dk.shape[0]
    W = data.words
    T = W.shape[1]

    # Quick sanity check
    if np.any(docLens < 1):
        raise ValueError(
            "Input document-term matrix contains at least one document with no words"

    # Book-keeping for logs
    logPoints = 1 if logFrequency == 0 else iterations // logFrequency
    boundIters = []
    boundValues = []
    likelyValues = []

    # Early stopping check
    finishedTraining = False

    # Add the model counts (essentially the learnt model parameters) to those for
    # the query, assuming the model has been trained previously
    if m_n_dk is not None:
        np.add(q_n_kt, m_n_kt, out=q_n_kt)  # q_n_kt += m_n_kt
        np.add(q_n_k, m_n_k, out=q_n_k)  # q_n_k  += m_n_k

#     print ("Topic prior : " + str(topicPrior))

# Select the training iterations function appropriate for the dtype
    if debug: print("Starting Training")
    do_iterations = compiled.iterate_f32 \
                    if modelState.dtype == np.float32 \
                    else compiled.iterate_f64

    # Iterate in segments, pausing to take measures of the bound / likelihood
    segIters = logFrequency
    remainder = iterations - segIters * (logPoints - 1)
    for segment in range(logPoints - 1):
        do_iterations (segIters, D_query, D_train, K, T, \
                       W_list, docLens, \
                       q_n_dk, q_n_kt, q_n_k, z_dnk,\
                       topicPrior, vocabPrior)

        # Measure and record the improvement to the bound and log-likely
        boundIters.append(segment * segIters)
            var_bound_intermediate(data, modelState, queryState, q_n_kt,
            log_likely_intermediate(data, modelState, queryState, q_n_kt,

        # Check to see if the improvement in the bound has fallen below the threshold
        perp = perplexity_from_like(likelyValues[-1], W.sum())
        print("Iteration %d : Train Perp = %4.0f  Bound = %.3f" %
              (segment * segIters, perp, boundValues[-1]))

        if len(boundIters) > 2 and (boundIters[-1] > 50):
            lastPerp = perplexity_from_like(likelyValues[-2], W.sum())
            if lastPerp - perp < 1:
                finishedTraining = True
                print("Converged, existing early")

    # Final scheduled batch of iterations if we haven't already converged.
    if not finishedTraining:
        do_iterations (remainder, D_query, D_train, K, T, \
                   W_list, docLens, \
                   q_n_dk, q_n_kt, q_n_k, z_dnk,\
                   topicPrior, vocabPrior)

        boundIters.append(iterations - 1)
            var_bound_intermediate(data, modelState, queryState, q_n_kt,
            log_likely_intermediate(data, modelState, queryState, q_n_kt,

    # Now return the results
    if query:  # Model is unchanged, query is changed
        if m_n_dk is not None:
            np.subtract(q_n_kt, m_n_kt, out=q_n_kt)  # q_n_kt -= m_n_kt
            np.subtract(q_n_k, m_n_k, out=q_n_k)  # q_n_k  -= m_n_k
    else:  # train # Model is changed (or flat-out created). Query is changed
        if m_n_dk is not None:  # Amend existing
            m_n_dk = np.vstack((m_n_dk, q_n_dk))
            m_n_kt[:, :] = q_n_kt
            m_n_k[:] = q_n_k
        else:  # Create from scratch
            m_n_dk = q_n_dk.copy()
            m_n_kt = q_n_kt.copy()
            m_n_k = q_n_k.copy()

    return ModelState(K, topicPrior, vocabPrior, m_n_dk, m_n_kt, m_n_k, modelState.dtype, modelState.name), \
           QueryState(W_list, docLens, q_n_dk, q_n_kt, q_n_k, z_dnk), \
           (boundIters, boundValues, likelyValues)
def train(data, modelState, queryState, trainPlan):
    Infers the topic distributions in general, and specifically for
    each individual datapoint.

    data - the dataset of words, features and links of which only words are used in this model
    modelState - the actual LDA model. In a training run (query = False) this
                 will be mutated in place, and then returned.
    queryState - the query results - essentially all the "local" variables
                 matched to the given observations. This will be mutated in-place
                 and then returned.
    trainPlan  - how to execute the training process (e.g. iterations,
                 log-interval etc.)
    query      -

    The updated model object (note parameters are updated in place, so make a
    defensive copy if you want it)
    The query object with the update query parameters
    iterations, epsilon, logFrequency, fastButInaccurate, debug, batchSize, rate_retardation, forgetting_rate = \
        trainPlan.iterations, trainPlan.epsilon, trainPlan.logFrequency, trainPlan.fastButInaccurate, trainPlan.debug, \
        trainPlan.batchSize, trainPlan.rate_retardation, trainPlan.forgetting_rate
    W_list, docLens, topicDists = \
        queryState.W_list, queryState.docLens, queryState.topicDists
    K, topicPrior, vocabPrior, wordDists, dtype = \
        modelState.K, modelState.topicPrior, modelState.vocabPrior, modelState.wordDists, modelState.dtype

    W = data.words
    D, T = W.shape

    # Quick sanity check
    if np.any(docLens < 1):
        raise ValueError(
            "Input document-term matrix contains at least one document with no words"

    # Book-keeping for logs
    logPoints = 1 if logFrequency == 0 else iterations // logFrequency
    boundIters = np.zeros(shape=(logPoints, ))
    boundValues = np.zeros(shape=(logPoints, ))
    likelyValues = np.zeros(shape=(logPoints, ))
    bvIdx = 0

    # Instead of storing the full topic assignments for every individual word, we
    # re-estimate from scratch. I.e for the memberships z which is DxNxT in dimension,
    # we only store a 1xNxT = NxT part.
    z_dnk = np.empty((docLens.max(), K), dtype=dtype, order='F')

    # Select the training iterations function appropriate for the dtype
    current_micro_time = lambda: int(time.time())
    do_iterations = compiled.iterate_f32 \
                    if modelState.dtype == np.float32 \
                    else compiled.iterate_f64
    #    do_iterations = iterate # pure Python

    # Iterate in segments, pausing to take measures of the bound / likelihood
    segIters = logFrequency
    remainder = iterations - segIters * (logPoints - 1)
    totalItrs = 0
    for segment in range(logPoints - 1):
        start = current_micro_time()
        totalItrs += do_iterations (segIters, \
                 batchSize, segment * segIters, rate_retardation, forgetting_rate, \
                 D, K, T, \
                 W_list, docLens, \
                 topicPrior, vocabPrior, \
                 z_dnk, topicDists, wordDists)

        duration = current_micro_time() - start

        boundIters[bvIdx] = segment * segIters
        boundValues[bvIdx] = var_bound(data, modelState, queryState)
        likelyValues[bvIdx] = log_likelihood(data, modelState, queryState)
        perp = perplexity_from_like(likelyValues[bvIdx], W.sum())
        bvIdx += 1

        if converged(boundIters, boundValues, bvIdx, epsilon, minIters=20):
            boundIters, boundValues, likelyValues = clamp(
                boundIters, boundValues, likelyValues, bvIdx)
            return ModelState(K, topicPrior, vocabPrior, wordDists, modelState.dtype, modelState.name), \
                QueryState(W_list, docLens, topicDists), \
                (boundIters, boundValues, likelyValues)

            "Segment %d/%d Total Iterations %d Duration %d Perplexity %4.0f Bound %10.2f Likelihood %10.2f"
            % (segment, logPoints, totalItrs, duration, perp,
               boundValues[bvIdx - 1], likelyValues[bvIdx - 1]))

    # Final batch of iterations.
    do_iterations (remainder, D, K, T, \
                 W_list, docLens, \
                 topicPrior, vocabPrior, \
                 z_dnk, topicDists, wordDists)

    boundIters[bvIdx] = iterations - 1
    boundValues[bvIdx] = var_bound(data, modelState, queryState)
    likelyValues[bvIdx] = log_likelihood(data, modelState, queryState)

    return ModelState(K, topicPrior, vocabPrior, wordDists, modelState.dtype, modelState.name), \
           QueryState(W_list, docLens, topicDists), \
           (boundIters, boundValues, likelyValues)
def query(data, modelState, queryState, queryPlan):
    Given a _trained_ model, attempts to predict the topics for each of
    the inputs.
    data - the dataset of words, features and links of which only words and
           features are used in this model
    modelState - the _trained_ model
    queryState - the query state generated for the query dataset
    queryPlan  - used in this case as we need to tighten up the approx
    The model state and query state, in that order. The model state is
    unchanged, the query is.
    iterations, epsilon, logFrequency, fastButInaccurate, debug = queryPlan.iterations, queryPlan.epsilon, queryPlan.logFrequency, queryPlan.fastButInaccurate, queryPlan.debug
    means, expMeans, varcs, lxi, s, n = queryState.means, queryState.expMeans, queryState.varcs, queryState.lxi, queryState.s, queryState.docLens
    F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, dtype = modelState.F, modelState.P, modelState.K, modelState.A, modelState.R_A, modelState.fv, modelState.Y, modelState.R_Y, modelState.lfv, modelState.V, modelState.sigT, modelState.vocab, modelState.vocabPrior, modelState.dtype

    # Necessary temp variables (notably the count of topic to word assignments
    # per topic per doc)
    isigT = la.inv(sigT)
    W,X = data.words, data.feats
    # Enable logging or not. If enabled, we need the inner product of the feat matrix
    if debug:
        XTX = X.T.dot(X)
        debugFn = _debug_with_bound
        XTX = None
        debugFn = _debug_with_nothing
    # Iterate over parameters
    lastPerp = 1E+300 if dtype is np.float64 else 1E+30
    for itr in range(iterations):
        # Estimate Z_dvk
        expMeans = np.exp(means - means.max(axis=1)[:,np.newaxis], out=expMeans)
        R = sparseScalarQuotientOfDot(W, expMeans, vocab)
        S = expMeans * R.dot(vocab.T)
        # Update the Means
        vMat   = (2  * s[:,np.newaxis] * lxi - 0.5) * n[:,np.newaxis] + S
        rhsMat = vMat + X.dot(A.T).dot(isigT) # TODO Verify this
        lhsMat = np.reciprocal(np.diag(isigT)[np.newaxis,:] + n[:,np.newaxis] * 2 * lxi)  # inverse of D diagonal matrices...
        means = lhsMat * rhsMat # as LHS is a diagonal matrix for all d, it's equivalent
                                # to doing a hadamard product for all d
        debugFn (itr, means, "query-means", W, X, XTX, F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, means, varcs, lxi, s, n)
        # Update the Variances
        varcs = 1./(2 * n[:,np.newaxis] * lxi + isigT.flat[::K+1])
        debugFn (itr, varcs, "query-varcs", W, X, XTX, F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, means, varcs, lxi, s, n)
        # Update the approximation parameters
        lxi = ctm.negJakkolaOfDerivedXi(means, varcs, s)
        debugFn (itr, lxi, "query-lxi", W, X, XTX, F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, means, varcs, lxi, s, n)
        # s can sometimes grow unboundedly
        # Follow Bouchard's suggested approach of fixing it at zero
#         s = (np.sum(lxi * means, axis=1) + 0.25 * K - 0.5) / np.sum(lxi, axis=1)
#         debugFn (itr, s, "s", W, X, XTX, F, P, K, A, R_A, fv, Y, R_Y, lfv, V, sigT, vocab, vocabPrior, dtype, means, varcs, lxi, s, n)

        like = log_likelihood(data, modelState, QueryState(means, expMeans, varcs, lxi, s, n))
        perp = perplexity_from_like(like, data.word_count)
        if itr > 20 and lastPerp - perp < 1:
        lastPerp = perp

    return modelState, QueryState (means, expMeans, varcs, lxi, s, n)
def _old_train(data, model, query, plan, updateVocab=True):
    Infers the topic distributions in general, and specifically for
    each individual datapoint,

    data - the training data, we just use the DxT document-term matrix
    model - the initial model configuration. This is MUTATED IN-PLACE
    qyery - the query results - essentially all the "local" variables
            matched to the given observations. Also MUTATED IN-PLACE
    plan  - how to execute the training process (e.g. iterations,
            log-interval etc.)

    The updated model object (note parameters are updated in place, so make a
    defensive copy if you want it)
    The query object with the update query parameters
    iterations, epsilon, logFrequency, fastButInaccurate, debug, batchSize = \
        plan.iterations, plan.epsilon, plan.logFrequency, plan.fastButInaccurate, plan.debug, plan.batchSize
    docLens, topicMeans = \
        query.docLens, query.topicDists
    K, topicPrior, vocabPrior, wordDists ,dtype = \
        model.K, model.topicPrior, model.vocabPrior, model.wordDists, model.dtype

    # Quick sanity check
    if np.any(docLens < 1):
        raise ValueError(
            "Input document-term matrix contains at least one document with no words"
    assert model.dtype == np.float64, "Only implemented for 64-bit floats"

    # Prepare the data for inference
    topicMeans = _convertDirichletParamToMeans(docLens, topicMeans, topicPrior)

    W = data.words
    D, T = W.shape

    iters, bnds, likes = [], [], []

    # A few parameters for handling adaptive step-sizes in SGD
    grad = 0
    grad_inner = 0
    grad_rate = 1
    log_likely = 0  # complete dataset likelihood for gradient adjustments
    stepSize = np.array([1.] * K, dtype=model.dtype)

    # Instead of storing the full topic assignments for every individual word, we
    # re-estimate from scratch. I.e for the memberships z which is DxNxT in dimension,
    # we only store a 1xNxT = NxT part.
    diWordDistSums = np.empty((K, ), dtype=dtype)
    diWordDists = np.empty(wordDists.shape, dtype=dtype)
    wordUpdates = wordDists.copy() if batchSize > 0 else None
    batchProcessCount = 0

    # Amend the name if batchSize == 0 implying we're using SGD
    modelName = "lda/svbp/%s" % _sgd_desc(plan) \
                if batchSize > 0 else model.name

    for itr in range(iterations):
        diWordDistSums[:] = wordDists.sum(axis=1)
        fns.digamma(diWordDistSums, out=diWordDistSums)
        fns.digamma(wordDists, out=diWordDists)

        if updateVocab:
            # Perform inference, updating the vocab
            if batchSize == 0:
                wordDists[:, :] = vocabPrior
                wordUpdates[:, :] = 0

            for d in range(D):
                batchProcessCount += 1
                #if debug and d % 100 == 0: printAndFlushNoNewLine(".")
                wordIdx, z = _update_topics_at_d(d, data, docLens, topicMeans,
                                                 topicPrior, diWordDists,
                wordDists[:, wordIdx] += W[d, :].data[np.newaxis, :] * z

                if plan.rate_algor == RateAlgorAmaria:
                    log_likely += 0
                elif plan.rate_algor == RateAlgorVariance:
                    g = wordDists.mean(axis=0) + vocabPrior
                    grad *= (1 - grad_rate)
                    grad += grad_rate * wordDists
                    grad += grad_rate * vocabPrior
                    gg += 0
                elif plan.rate_algor != RateAlgorTimeKappa:
                    raise ValueError("Unknown rate algorithm " +

                if batchSize > 0 and batchProcessCount == batchSize:
                    batch_index = (
                        itr * D + d
                    ) / batchSize  #TODO  Will not be right if batchSize is not a multiple of D
                    stepSize = _step_sizes(stepSize, batch_index, g, gg,
                                           log_likely, plan)
                    wordDists *= (1 - stepSize)
                    wordDists += stepSize * vocabPrior

                    stepSize *= float(D) / batchSize
                    wordUpdates *= stepSize
                    wordDists += wordUpdates

                    diWordDistSums[:] = wordDists.sum(axis=1)
                    fns.digamma(diWordDistSums, out=diWordDistSums)
                    fns.digamma(wordDists, out=diWordDists)

                    wordUpdates[:, :] = 0
                    batchProcessCount = 0
                    log_likely = 0

                    if debug:
                        bnds.append(_var_bound_internal(data, model, query))
                            _log_likelihood_internal(data, model, query))

                        perp = perplexity_from_like(likes[-1], W.sum())
                            "Iteration %d, after %d docs: Train Perp = %4.0f  Bound = %.3f"
                            % (itr, batchSize, perp, bnds[-1]))

            # Log bound and the determine if we can stop early
            if itr % logFrequency == 0 or debug:
                bnds.append(_var_bound_internal(data, model, query))
                likes.append(_log_likelihood_internal(data, model, query))

                perp = perplexity_from_like(likes[-1], W.sum())
                print("Iteration %d : Train Perp = %4.0f  Bound = %.3f" %
                      (itr, perp, bnds[-1]))

                if len(iters) > 2 and (iters[-1] > 20 or
                                       (iters[-1] > 2 and batchSize > 0)):
                    lastPerp = perplexity_from_like(likes[-2], W.sum())
                    if lastPerp - perp < 1:
                        print("Converged, existing early")

            # Update hyperparameters (do this after bound, to make sure bound
            # calculation is internally consistent)
            if HyperUpdateEnabled and itr > 0 and itr % HyperParamUpdateInterval == 0:
                if debug: print("Topic Prior was " + str(topicPrior))
                _updateTopicHyperParamsFromMeans(model, query)
                if debug: print("Topic Prior is now " + str(topicPrior))
            for d in range(D):
                _ = _update_topics_at_d(d, data, docLens, topicMeans,
                                        topicPrior, diWordDists,

    topicMeans = _convertMeansToDirichletParam(docLens, topicMeans, topicPrior)

    return ModelState(K, topicPrior, vocabPrior, wordDists, True, dtype, modelName), \
           QueryState(docLens, topicMeans, True), \
           (np.array(iters, dtype=np.int32), np.array(bnds), np.array(likes))
def train(data, model, query, plan, updateVocab=True):
    Infers the topic distributions in general, and specifically for
    each individual datapoint,

    data - the training data, we just use the DxT document-term matrix
    model - the initial model configuration. This is MUTATED IN-PLACE
    qyery - the query results - essentially all the "local" variables
            matched to the given observations. Also MUTATED IN-PLACE
    plan  - how to execute the training process (e.g. iterations,
            log-interval etc.)

    The updated model object (note parameters are updated in place, so make a
    defensive copy if you want it)
    The query object with the update query parameters

    iterations, epsilon, logFrequency, fastButInaccurate, debug, burnIn, thinning = \
        plan.iterations, plan.epsilon, plan.logFrequency, plan.fastButInaccurate, plan.debug, plan.burnIn, plan.thinning
    docLens, topicDists = \
        query.docLens, query.topicDists
    K, topicPrior, vocabPrior, wordDists, dtype = \
        model.K, model.topicPrior, model.vocabPrior, model.wordDists, model.dtype

    W = data.words
    D, T = W.shape

    # Quick sanity check
    if np.any(docLens < 1):
        raise ValueError(
            "Input document-term matrix contains at least one document with no words"
    assert dtype == np.float64, "Only implemented for 64-bit floats"

    iters, bnds, likes = [], [], []

    sampleCount = 0
    wordDistSamples = np.zeros((K, T), dtype=np.float64)
    topicDistSamples = np.zeros((D, K), dtype=np.float64)

    for itr in range(plan.iterations + plan.burnIn):
        topicDists = sample_memberships(W, topicPrior, wordDists, topicDists)
        wordDists = sample_dirichlet(W, vocabPrior, topicDists, wordDists)

        if is_sampling_iteration(itr, plan):
            wordDistSamples += wordDists
            topicDistSamples += topicDists
            sampleCount += 1

        if debug or (logFrequency > 0 and itr % logFrequency == 0):
            m = ModelState(K, topicPrior, vocabPrior, wordDists, True, dtype,
            q = QueryState(query.docLens, topicDists, True)

            bnds.append(var_bound(data, m, q))
            likes.append(log_likelihood(data, m, q))

            perp = perplexity_from_like(likes[-1], W.sum())
            print("Iteration %d : Train Perp = %4.0f  Bound = %.3f" %
                  (itr, perp, bnds[-1]))

            # if len(iters) > 2 and iters[-1] > 50:
            #     lastPerp = perplexity_from_like(likes[-2], W.sum())
            #     if lastPerp - perp < 1:
            #         break;

    return ModelState(K, topicPrior, vocabPrior, wordDists, True, dtype, model.name), \
           QueryState(query.docLens, topicDists, True), \
           (np.array(iters, dtype=np.int32), np.array(bnds), np.array(likes))