def test_save_load(self): """Test loading/saving LdaVowpalWabbit model.""" if not self.vw_path: # for python 2.6 return lda = LdaVowpalWabbit( self.vw_path, corpus=self.corpus, passes=10, chunksize=256, id2word=self.dictionary, cleanup_files=True, alpha=0.1, eta=0.1, num_topics=len(TOPIC_WORDS), random_seed=1 ) with tempfile.NamedTemporaryFile() as fhandle: lda.save(fhandle.name) lda2 = LdaVowpalWabbit.load(fhandle.name) # ensure public fields are saved/loaded correctly saved_fields = [ lda.alpha, lda.chunksize, lda.cleanup_files, lda.decay, lda.eta, lda.gamma_threshold, lda.id2word, lda.num_terms, lda.num_topics, lda.passes, lda.random_seed, lda.vw_path ] loaded_fields = [ lda2.alpha, lda2.chunksize, lda2.cleanup_files, lda2.decay, lda2.eta, lda2.gamma_threshold, lda2.id2word, lda2.num_terms, lda2.num_topics, lda2.passes, lda2.random_seed, lda2.vw_path ] self.assertEqual(saved_fields, loaded_fields) # ensure topic matrices are saved/loaded correctly saved_topics = lda.show_topics(num_topics=5, num_words=10) loaded_topics = lda2.show_topics(num_topics=5, num_words=10) self.assertEqual(loaded_topics, saved_topics)
def test_perplexity(self): """Test LdaVowpalWabbit perplexity is within expected range.""" if not self.vw_path: # for python 2.6 return lda = LdaVowpalWabbit( self.vw_path, corpus=self.corpus, passes=10, chunksize=256, id2word=self.dictionary, cleanup_files=True, alpha=0.1, eta=0.1, num_topics=len(TOPIC_WORDS), random_seed=1) # varies, but should be between -1 and -5 result = lda.log_perplexity(self.corpus) self.assertTrue(result < -1) self.assertTrue(result > -5)
def test_topic_coherence(self): """Test LdaVowpalWabbit topic coherence.""" if not self.vw_path: # for python 2.6 return corpus, dictionary = get_corpus() lda = LdaVowpalWabbit(self.vw_path, corpus=corpus, passes=10, chunksize=256, id2word=dictionary, cleanup_files=True, alpha=0.1, eta=0.1, num_topics=len(TOPIC_WORDS), random_seed=1) lda.print_topics(5, 10) # map words in known topic to an ID topic_map = {} for i, words in enumerate(TOPIC_WORDS): topic_map[frozenset(words)] = i n_coherent = 0 for topic_id in range(lda.num_topics): topic = lda.show_topic(topic_id, topn=20) # get all words from LDA topic topic_words = [w[1] for w in topic] # get list of original topics that each word actually belongs to ids = [] for word in topic_words: for src_topic_words, src_topic_id in six.iteritems(topic_map): if word in src_topic_words: ids.append(src_topic_id) # count the number of times each original topic appears counts = defaultdict(int) for found_topic_id in ids: counts[found_topic_id] += 1 # if at least 6/10 words assigned to same topic, consider it coherent max_count = 0 for count in six.itervalues(counts): max_count = max(max_count, count) if max_count >= 6: n_coherent += 1 # not 100% deterministic, but should always get 3+ coherent topics self.assertTrue(n_coherent >= 3)
def test_model_update(self): """Test updating existing LdaVowpalWabbit model.""" if not self.vw_path: # for python 2.6 return lda = LdaVowpalWabbit( self.vw_path, corpus=[self.corpus[0]], passes=10, chunksize=256, id2word=self.dictionary, cleanup_files=True, alpha=0.1, eta=0.1, num_topics=len(TOPIC_WORDS), random_seed=1 ) lda.update(self.corpus[1:]) result = lda.log_perplexity(self.corpus) self.assertTrue(result < -1) self.assertTrue(result > -5)
def testvwmodel2ldamodel(self): """Test copying of VWModel to LdaModel""" if not self.vw_path: return tm1 = LdaVowpalWabbit(vw_path=self.vw_path, corpus=self.corpus, num_topics=2, id2word=self.dictionary) tm2 = ldavowpalwabbit.vwmodel2ldamodel(tm1) for document in self.corpus: element1_1, element1_2 = tm1[document][0] element2_1, element2_2 = tm2[document][0] self.assertAlmostEqual(element1_1, element2_1) self.assertAlmostEqual(element1_2, element2_2, 5) logging.debug('%d %d', element1_1, element2_1) logging.debug('%d %d', element1_2, element2_2)
def testvwmodel2ldamodel(self): """Test copying of VWModel to LdaModel""" if not self.vw_path: return tm1 = LdaVowpalWabbit(vw_path=self.vw_path, corpus=self.corpus, num_topics=2, id2word=self.dictionary) tm2 = ldavowpalwabbit.vwmodel2ldamodel(tm1) for document in self.corpus: self.assertEqual(tm1[document][0], tm2[document][0]) self.assertEqual(tm1[document][1], tm2[document][1]) logging.debug('%d %d', tm1[document][0], tm2[document][0]) logging.debug('%d %d', tm1[document][1], tm2[document][1])