示例#1
0
    def test_valid_run(self):

        def isvalid(probs):
            return probs.sum() <= 1.00001 and probs.sum() >= 0.99999 and \
                (probs > 0).all() and \
                (probs < 1).all()

        annots = self.create_annots(test.SMALL_DEL_FILE)
        estimator = LDAEstimator(annots, 200, .001, .002, .003, 100, 50, 5, 0)
        
        gamma = np.arange(5)
        prob_items = estimator.prob_items(gamma)
        prob_items_tag = estimator.prob_items_given_tag(0, gamma)
        prob_items_user = estimator.prob_items_given_user(0, gamma)
        prob_items_user_tag = estimator.prob_items_given_user_tag(0, 0, gamma)
        
        self.assertTrue(isvalid(prob_items))
        self.assertTrue(isvalid(prob_items_tag))
        self.assertTrue(isvalid(prob_items_user))
        self.assertTrue(isvalid(prob_items_user_tag))
        
        self.assertTrue(estimator.chain_likelihood().all())
        
        self.assertTrue((estimator._get_user_topic_prb() >= 0).all())
        self.assertTrue((estimator._get_topic_document_prb() >= 0).all())
        self.assertTrue((estimator._get_topic_term_prb() >= 0).all())
        
        self.assertTrue((estimator._get_user_topic_prb() <= 1).all())
        self.assertTrue((estimator._get_topic_document_prb() <= 1).all())
        self.assertTrue((estimator._get_topic_term_prb() <= 1).all())
        
        self.assertTrue((estimator._get_user_topic_prb()).any())
        self.assertTrue((estimator._get_topic_document_prb()).any())
        self.assertTrue((estimator._get_topic_term_prb()).any())
示例#2
0
    def test_gibbs_sample_with_same_sample_seed(self):
        annots = self.create_annots(test.DELICIOUS_FILE)
        
        #Last two parameters -> sample_every=1, seed=0
        estimator_seed_one_a = LDAEstimator(annots, 10, .5, .5, .5, 5, 2, 1, 1)
        estimator_seed_one_b = LDAEstimator(annots, 10, .5, .5, .5, 5, 2, 1, 1)
        
        ut_1a = estimator_seed_one_a._get_user_topic_prb()
        td_1a = estimator_seed_one_a._get_topic_document_prb()
        tt_1a = estimator_seed_one_a._get_topic_term_prb()

        ut_1b = estimator_seed_one_b._get_user_topic_prb()
        td_1b = estimator_seed_one_b._get_topic_document_prb()
        tt_1b = estimator_seed_one_b._get_topic_term_prb()
        
        self.assertFalse((ut_1a - ut_1b).any())
        self.assertFalse((td_1a - td_1b).any())
        self.assertFalse((tt_1a - tt_1b).any())
示例#3
0
    def test_gibbs_sample_with_sample_user_every(self):
        annots = self.create_annots(test.DELICIOUS_FILE)
        
        #Last two parameters -> sample_every=1, seed=1
        estimator_seed_one_a = LDAEstimator(annots, 10, .5, .5, .5, 5, 2, 1, 1)
        
        #Last two parameters -> sample_every=3, seed=1
        estimator_seed_one_b = LDAEstimator(annots, 10, .5, .5, .5, 5, 2, 3, 1)
        
        ut_1a = estimator_seed_one_a._get_user_topic_prb()
        td_1a = estimator_seed_one_a._get_topic_document_prb()
        tt_1a = estimator_seed_one_a._get_topic_term_prb()

        ut_1b = estimator_seed_one_b._get_user_topic_prb()
        td_1b = estimator_seed_one_b._get_topic_document_prb()
        tt_1b = estimator_seed_one_b._get_topic_term_prb()
        
        #If sum is diff 0 at least one different cell in matrices
        self.assertTrue(np.sum(ut_1a - ut_1b) != 0)
        self.assertTrue(np.sum(td_1a - td_1b) != 0)
        self.assertTrue(np.sum(tt_1a - tt_1b) != 0)
示例#4
0
    def test_gibbs_sample(self):
        
        #Runs everything on a large dataset

        annots = self.create_annots(test.DELICIOUS_FILE)
        estimator = LDAEstimator(annots, 10, .5, .5, .5, 5, 2, 1, 0)
        
        self.assertEqual(estimator.get_iter(), 4)
        
        ut = estimator._get_user_topic_prb()
        td = estimator._get_topic_document_prb()
        tt = estimator._get_topic_term_prb()
        
        self.assertTrue(ut.any())
        self.assertTrue(td.any())
        self.assertTrue(tt.any())
        
        self.assertTrue((ut >= 0).all())
        self.assertTrue((td >= 0).all())
        self.assertTrue((tt >= 0).all())

        self.assertTrue((ut <= 1).all())
        self.assertTrue((td <= 1).all())
        self.assertTrue((tt <= 1).all())