def generate_synth_data(V=10, D=10, T=5, N=100, alpha_beta=10., alpha_theta=10., plot=False, train_frac=0.5): # true_lda = StandardLDA(T, V, alpha_beta=alpha_beta, alpha_theta=alpha_theta) true_lda = LogisticNormalCorrelatedLDA(T, V, alpha_beta=alpha_beta) print("Sigma: ", true_lda.Sigma) # true_lda = StickbreakingCorrelatedLDA(T, V, alpha_beta=alpha_beta) data = np.zeros((D,V),dtype=int) for d in xrange(D): doc = true_lda.generate(N=N, keep=True) data[d,:] = doc.w if plot: plt.figure() plt.imshow(data, interpolation="none") plt.xlabel("Vocabulary") plt.ylabel("Documents") plt.colorbar() plt.show() # Split each document into two train_data = np.zeros_like(data) test_data = np.zeros_like(data) for d,w in enumerate(data): # Get vector where i is repeated w[i] times wcnt = ibincount(w) # Subsample wcnt train_inds = np.random.rand(wcnt.size) < train_frac train_data[d] = np.bincount(wcnt[train_inds], minlength=V) test_data[d] = np.bincount(wcnt[~train_inds], minlength=V) assert np.allclose(train_data[d] + test_data[d], w) return true_lda, train_data, test_data
std_results = \ train(std_model, thetas=true_lda.thetas if init_to_true else None) std_collapsed_model = StandardLDA(T,V,alpha_beta,alpha_theta) std_collapsed_model.beta = true_lda.beta if init_to_true else std_collapsed_model.beta std_collapsed_results = \ train(std_collapsed_model, method='resample_model_collapsed', thetas=true_lda.thetas if init_to_true else None) sb_model = StickbreakingCorrelatedLDA(T, V, alpha_beta) sb_model.beta = true_lda.beta if init_to_true else sb_model.beta sb_results = \ train(sb_model, thetas=true_lda.thetas if init_to_true else None) ln_model = LogisticNormalCorrelatedLDA(T, V, alpha_beta) ln_model.beta = true_lda.beta if init_to_true else ln_model.beta ln_results = \ train(ln_model, thetas=true_lda.thetas if init_to_true else None) all_results = [sb_results, ln_results, std_results, std_collapsed_results] all_labels = ["SB Corr. LDA", "LN Corr. LDA", "Std. LDA", "Collapsed LDA"] # all_results = [std_results, std_collapsed_results] # all_labels = ["Std. LDA", "Collapsed LDA"] # all_results = [ln_results] # all_labels = ["LN Corr. LDA"] plt.figure() # Plot log likelihood vs iteration plt.subplot(121) for ind, (results, label) in enumerate(zip(all_results, all_labels)):