Exemplo n.º 1
0
    def _run_pure_online(self, d, k, mode, nframes):
        #++++++++++++++++++++++++++++++++++++++++
        # Approximate the models with online EM
        #++++++++++++++++++++++++++++++++++++++++
        ogm     = GM(d, k, mode)
        ogmm    = OnGMM(ogm, 'kmean')
        init_data   = self.data[0:nframes / 20, :]
        ogmm.init(init_data)

        # Forgetting param
        ku		= 0.005
        t0		= 200
        lamb	= 1 - 1/(N.arange(-1, nframes-1) * ku + t0)
        nu0		= 0.2
        nu		= N.zeros((len(lamb), 1))
        nu[0]	= nu0
        for i in range(1, len(lamb)):
            nu[i]	= 1./(1 + lamb[i] / nu[i-1])

        # object version of online EM
        for t in range(nframes):
            # the assert are here to check we do not create copies
            # unvoluntary for parameters
            assert ogmm.pw is ogmm.cw
            assert ogmm.pmu is ogmm.cmu
            assert ogmm.pva is ogmm.cva
            ogmm.compute_sufficient_statistics_frame(self.data[t], nu[t])
            ogmm.update_em_frame()

        ogmm.gm.set_param(ogmm.cw, ogmm.cmu, ogmm.cva)

        return ogmm.gm
Exemplo n.º 2
0
    def _check(self, d, k, mode, nframes, emiter):
        #++++++++++++++++++++++++++++++++++++++++
        # Approximate the models with online EM
        #++++++++++++++++++++++++++++++++++++++++
        # Learn the model with Online EM
        ogm         = GM(d, k, mode)
        ogmm        = OnGMM(ogm, 'kmean')
        init_data   = self.data
        ogmm.init(init_data, niter = KM_ITER)

        # Check that online kmean init is the same than kmean offline init
        ogm0    = copy.copy(ogm)
        assert_array_equal(ogm0.w, self.gm0.w)
        assert_array_equal(ogm0.mu, self.gm0.mu)
        assert_array_equal(ogm0.va, self.gm0.va)

        # Forgetting param
        lamb	= N.ones((nframes, 1))
        lamb[0] = 0
        nu0		= 1.0
        nu		= N.zeros((len(lamb), 1))
        nu[0]	= nu0
        for i in range(1, len(lamb)):
            nu[i]	= 1./(1 + lamb[i] / nu[i-1])

        # object version of online EM: the p* arguments are updated only at each 
        # epoch, which is equivalent to on full EM iteration on the 
        # classic EM algorithm
        ogmm.pw    = ogmm.cw.copy()
        ogmm.pmu   = ogmm.cmu.copy()
        ogmm.pva   = ogmm.cva.copy()
        for e in range(emiter):
            for t in range(nframes):
                ogmm.compute_sufficient_statistics_frame(self.data[t], nu[t])
                ogmm.update_em_frame()

            # Change pw args only a each epoch 
            ogmm.pw  = ogmm.cw.copy()
            ogmm.pmu = ogmm.cmu.copy()
            ogmm.pva = ogmm.cva.copy()

        # For equivalence between off and on, we allow a margin of error,
        # because of round-off errors.
        print " Checking precision of equivalence with offline EM trainer "
        maxtestprec = 18
        try :
            for i in range(maxtestprec):
                    assert_array_almost_equal(self.gm.w, ogmm.pw, decimal = i)
                    assert_array_almost_equal(self.gm.mu, ogmm.pmu, decimal = i)
                    assert_array_almost_equal(self.gm.va, ogmm.pva, decimal = i)
            print "\t !! Precision up to %d decimals !! " % i
        except AssertionError:
            if i < AR_AS_PREC:
                print """\t !!NOT OK: Precision up to %d decimals only, 
                    outside the allowed range (%d) !! """ % (i, AR_AS_PREC)
                raise AssertionError
            else:
                print "\t !!OK: Precision up to %d decimals !! " % i