Пример #1
0
    def testAverage(self):
        clst1 = DPCluster(0.5, self.mu1, self.sig)
        clst3 = DPCluster(0.5, self.mu1, self.sig)
        clst5 = DPCluster(0.5, self.mu1, self.sig)
        clst7 = DPCluster(0.5, self.mu1, self.sig)
        clst2 = DPCluster(0.5, self.mu2, self.sig)
        clst4 = DPCluster(0.5, self.mu2, self.sig)
        clst6 = DPCluster(0.5, self.mu2, self.sig)
        clst8 = DPCluster(0.5, self.mu2, self.sig)

        lt = {}
        for i in range(8):
            lt[i] = i
        mix = OrderedDPMixture(
            [clst1, clst2, clst3, clst4, clst5, clst6, clst7, clst8],
            lt,
            niter=4)
        avg = mix.average()

        assert len(avg.clusters) == 2
        assert all(avg.mus[0] == self.mu1)
        assert all(avg.mus[1] == self.mu2)
        assert all(avg.sigmas[0] == self.sig)
        assert all(avg.sigmas[1] == self.sig)
        assert avg.pis[0] == 0.5, 'pis should be 0.5 but is %f' % avg.pis()[0]
        assert avg.pis[1] == 0.5, 'pis should be 0.5 but is %f' % avg.pis()[0]
Пример #2
0
    def testLast(self):
        clst1 = DPCluster(0.5, self.mu1, self.sig)
        clst3 = DPCluster(0.5, self.mu1 + 3, self.sig)
        clst5 = DPCluster(0.5, self.mu1 + 5, self.sig)
        clst7 = DPCluster(0.5, self.mu1 + 7, self.sig)
        clst2 = DPCluster(0.5, self.mu2 + 2, self.sig)
        clst4 = DPCluster(0.5, self.mu2 + 4, self.sig)
        clst6 = DPCluster(0.5, self.mu2 + 6, self.sig)
        clst8 = DPCluster(0.5, self.mu2 + 8, self.sig)

        lt = {}
        for i in range(8):
            lt[i] = i
        mix = OrderedDPMixture(
            [clst1, clst2, clst3, clst4, clst5, clst6, clst7, clst8],
            lt,
            niter=4)

        new_r = mix.last()
        assert len(new_r.clusters) == 2
        assert all(new_r.clusters[0].mu == clst7.mu)
        assert all(new_r.clusters[1].mu == clst8.mu)

        new_r = mix.last(2)
        assert len(new_r.clusters) == 4
        assert all(new_r.clusters[0].mu == clst5.mu)
        assert all(new_r.clusters[1].mu == clst6.mu)
        assert all(new_r.clusters[2].mu == clst7.mu)
        assert all(new_r.clusters[3].mu == clst8.mu)

        try:
            new_r = mix.last(10)
        except ValueError:
            pass
Пример #3
0
    def setUp(self):
        self.mu1 = array([0, 0, 0])
        self.sig = eye(3)
        self.mu2 = array([5, 5, 5])

        self.clust1 = DPCluster(.5 / 3, self.mu1, self.sig)
        self.clust2 = DPCluster(.5 / 3, self.mu2, self.sig)
        self.clusters = [
            self.clust1, self.clust2, self.clust1, self.clust2, self.clust1,
            self.clust2
        ]
        self.lookup = {0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 5}
        self.mix = OrderedDPMixture(self.clusters,
                                    self.lookup,
                                    niter=3,
                                    identified=True)