def test_save_load(self):
        """Test loading/saving LdaVowpalWabbit model."""
        if not self.vw_path:  # for python 2.6
            return
        lda = LdaVowpalWabbit(
            self.vw_path, corpus=self.corpus, passes=10, chunksize=256,
            id2word=self.dictionary, cleanup_files=True, alpha=0.1,
            eta=0.1, num_topics=len(TOPIC_WORDS), random_seed=1
        )

        with tempfile.NamedTemporaryFile() as fhandle:
            lda.save(fhandle.name)
            lda2 = LdaVowpalWabbit.load(fhandle.name)

            # ensure public fields are saved/loaded correctly
            saved_fields = [
                lda.alpha, lda.chunksize, lda.cleanup_files,
                lda.decay, lda.eta, lda.gamma_threshold,
                lda.id2word, lda.num_terms, lda.num_topics,
                lda.passes, lda.random_seed, lda.vw_path
            ]
            loaded_fields = [
                lda2.alpha, lda2.chunksize, lda2.cleanup_files,
                lda2.decay, lda2.eta, lda2.gamma_threshold,
                lda2.id2word, lda2.num_terms, lda2.num_topics,
                lda2.passes, lda2.random_seed, lda2.vw_path
            ]
            self.assertEqual(saved_fields, loaded_fields)

            # ensure topic matrices are saved/loaded correctly
            saved_topics = lda.show_topics(num_topics=5, num_words=10)
            loaded_topics = lda2.show_topics(num_topics=5, num_words=10)
            self.assertEqual(loaded_topics, saved_topics)
    def test_perplexity(self):
        """Test LdaVowpalWabbit perplexity is within expected range."""
        if not self.vw_path:  # for python 2.6
            return
        lda = LdaVowpalWabbit(
            self.vw_path, corpus=self.corpus, passes=10, chunksize=256,
            id2word=self.dictionary, cleanup_files=True, alpha=0.1,
            eta=0.1, num_topics=len(TOPIC_WORDS), random_seed=1)

        # varies, but should be between -1 and -5
        result = lda.log_perplexity(self.corpus)
        self.assertTrue(result < -1)
        self.assertTrue(result > -5)
    def test_perplexity(self):
        """Test LdaVowpalWabbit perplexity is within expected range."""
        if not self.vw_path:  # for python 2.6
            return
        lda = LdaVowpalWabbit(
            self.vw_path, corpus=self.corpus, passes=10, chunksize=256,
            id2word=self.dictionary, cleanup_files=True, alpha=0.1,
            eta=0.1, num_topics=len(TOPIC_WORDS), random_seed=1)

        # varies, but should be between -1 and -5
        result = lda.log_perplexity(self.corpus)
        self.assertTrue(result < -1)
        self.assertTrue(result > -5)
Пример #4
0
    def test_topic_coherence(self):
        """Test LdaVowpalWabbit topic coherence."""
        if not self.vw_path:  # for python 2.6
            return
        corpus, dictionary = get_corpus()
        lda = LdaVowpalWabbit(self.vw_path,
                              corpus=corpus,
                              passes=10,
                              chunksize=256,
                              id2word=dictionary,
                              cleanup_files=True,
                              alpha=0.1,
                              eta=0.1,
                              num_topics=len(TOPIC_WORDS),
                              random_seed=1)
        lda.print_topics(5, 10)

        # map words in known topic to an ID
        topic_map = {}
        for i, words in enumerate(TOPIC_WORDS):
            topic_map[frozenset(words)] = i

        n_coherent = 0
        for topic_id in range(lda.num_topics):
            topic = lda.show_topic(topic_id, topn=20)

            # get all words from LDA topic
            topic_words = [w[1] for w in topic]

            # get list of original topics that each word actually belongs to
            ids = []
            for word in topic_words:
                for src_topic_words, src_topic_id in six.iteritems(topic_map):
                    if word in src_topic_words:
                        ids.append(src_topic_id)

            # count the number of times each original topic appears
            counts = defaultdict(int)
            for found_topic_id in ids:
                counts[found_topic_id] += 1

            # if at least 6/10 words assigned to same topic, consider it coherent
            max_count = 0
            for count in six.itervalues(counts):
                max_count = max(max_count, count)

            if max_count >= 6:
                n_coherent += 1

        # not 100% deterministic, but should always get 3+ coherent topics
        self.assertTrue(n_coherent >= 3)
    def test_topic_coherence(self):
        """Test LdaVowpalWabbit topic coherence."""
        if not self.vw_path: # for python 2.6
            return
        corpus, dictionary = get_corpus()
        lda = LdaVowpalWabbit(self.vw_path,
                              corpus=corpus,
                              passes=10,
                              chunksize=256,
                              id2word=dictionary,
                              cleanup_files=True,
                              alpha=0.1,
                              eta=0.1,
                              num_topics=len(TOPIC_WORDS),
                              random_seed=1)
        lda.print_topics(5, 10)

        # map words in known topic to an ID
        topic_map = {}
        for i, words in enumerate(TOPIC_WORDS):
            topic_map[frozenset(words)] = i

        n_coherent = 0
        for topic_id in range(lda.num_topics):
            topic = lda.show_topic(topic_id, topn=20)

            # get all words from LDA topic
            topic_words = [w[1] for w in topic]

            # get list of original topics that each word actually belongs to
            ids = []
            for word in topic_words:
                for src_topic_words, src_topic_id in six.iteritems(topic_map):
                    if word in src_topic_words:
                        ids.append(src_topic_id)

            # count the number of times each original topic appears
            counts = defaultdict(int)
            for found_topic_id in ids:
                counts[found_topic_id] += 1

            # if at least 6/10 words assigned to same topic, consider it coherent
            max_count = 0
            for count in six.itervalues(counts):
                max_count = max(max_count, count)

            if max_count >= 6:
                n_coherent += 1

        # not 100% deterministic, but should always get 3+ coherent topics
        self.assertTrue(n_coherent >= 3)
    def test_model_update(self):
        """Test updating existing LdaVowpalWabbit model."""
        if not self.vw_path:  # for python 2.6
            return
        lda = LdaVowpalWabbit(
            self.vw_path, corpus=[self.corpus[0]], passes=10, chunksize=256,
            id2word=self.dictionary, cleanup_files=True, alpha=0.1,
            eta=0.1, num_topics=len(TOPIC_WORDS), random_seed=1
        )

        lda.update(self.corpus[1:])
        result = lda.log_perplexity(self.corpus)
        self.assertTrue(result < -1)
        self.assertTrue(result > -5)
    def test_model_update(self):
        """Test updating existing LdaVowpalWabbit model."""
        if not self.vw_path:  # for python 2.6
            return
        lda = LdaVowpalWabbit(
            self.vw_path, corpus=[self.corpus[0]], passes=10, chunksize=256,
            id2word=self.dictionary, cleanup_files=True, alpha=0.1,
            eta=0.1, num_topics=len(TOPIC_WORDS), random_seed=1
        )

        lda.update(self.corpus[1:])
        result = lda.log_perplexity(self.corpus)
        self.assertTrue(result < -1)
        self.assertTrue(result > -5)
    def test_save_load(self):
        """Test loading/saving LdaVowpalWabbit model."""
        if not self.vw_path:  # for python 2.6
            return
        lda = LdaVowpalWabbit(
            self.vw_path, corpus=self.corpus, passes=10, chunksize=256,
            id2word=self.dictionary, cleanup_files=True, alpha=0.1,
            eta=0.1, num_topics=len(TOPIC_WORDS), random_seed=1
        )

        with tempfile.NamedTemporaryFile() as fhandle:
            lda.save(fhandle.name)
            lda2 = LdaVowpalWabbit.load(fhandle.name)

            # ensure public fields are saved/loaded correctly
            saved_fields = [
                lda.alpha, lda.chunksize, lda.cleanup_files,
                lda.decay, lda.eta, lda.gamma_threshold,
                lda.id2word, lda.num_terms, lda.num_topics,
                lda.passes, lda.random_seed, lda.vw_path
            ]
            loaded_fields = [
                lda2.alpha, lda2.chunksize, lda2.cleanup_files,
                lda2.decay, lda2.eta, lda2.gamma_threshold,
                lda2.id2word, lda2.num_terms, lda2.num_topics,
                lda2.passes, lda2.random_seed, lda2.vw_path
            ]
            self.assertEqual(saved_fields, loaded_fields)

            # ensure topic matrices are saved/loaded correctly
            saved_topics = lda.show_topics(num_topics=5, num_words=10)
            loaded_topics = lda2.show_topics(num_topics=5, num_words=10)
            self.assertEqual(loaded_topics, saved_topics)
 def testvwmodel2ldamodel(self):
     """Test copying of VWModel to LdaModel"""
     if not self.vw_path:
         return
     tm1 = LdaVowpalWabbit(vw_path=self.vw_path, corpus=self.corpus, num_topics=2, id2word=self.dictionary)
     tm2 = ldavowpalwabbit.vwmodel2ldamodel(tm1)
     for document in self.corpus:
         element1_1, element1_2 = tm1[document][0]
         element2_1, element2_2 = tm2[document][0]
         self.assertAlmostEqual(element1_1, element2_1)
         self.assertAlmostEqual(element1_2, element2_2, 5)
         logging.debug('%d %d', element1_1, element2_1)
         logging.debug('%d %d', element1_2, element2_2)
Пример #10
0
 def testvwmodel2ldamodel(self):
     """Test copying of VWModel to LdaModel"""
     if not self.vw_path:
         return
     tm1 = LdaVowpalWabbit(vw_path=self.vw_path,
                           corpus=self.corpus,
                           num_topics=2,
                           id2word=self.dictionary)
     tm2 = ldavowpalwabbit.vwmodel2ldamodel(tm1)
     for document in self.corpus:
         self.assertEqual(tm1[document][0], tm2[document][0])
         self.assertEqual(tm1[document][1], tm2[document][1])
         logging.debug('%d %d', tm1[document][0], tm2[document][0])
         logging.debug('%d %d', tm1[document][1], tm2[document][1])