示例#1
0
    def _run_pure_online(self, d, k, mode, nframes):
        # ++++++++++++++++++++++++++++++++++++++++
        # Approximate the models with online EM
        # ++++++++++++++++++++++++++++++++++++++++
        ogm = GM(d, k, mode)
        ogmm = OnGMM(ogm, "kmean")
        init_data = self.data[0 : nframes / 20, :]
        ogmm.init(init_data)

        # Forgetting param
        ku = 0.005
        t0 = 200
        lamb = 1 - 1 / (N.arange(-1, nframes - 1) * ku + t0)
        nu0 = 0.2
        nu = N.zeros((len(lamb), 1))
        nu[0] = nu0
        for i in range(1, len(lamb)):
            nu[i] = 1.0 / (1 + lamb[i] / nu[i - 1])

        # object version of online EM
        for t in range(nframes):
            # the assert are here to check we do not create copies
            # unvoluntary for parameters
            assert ogmm.pw is ogmm.cw
            assert ogmm.pmu is ogmm.cmu
            assert ogmm.pva is ogmm.cva
            ogmm.compute_sufficient_statistics_frame(self.data[t], nu[t])
            ogmm.update_em_frame()

        ogmm.gm.set_param(ogmm.cw, ogmm.cmu, ogmm.cva)

        return ogmm.gm
示例#2
0
    def _check(self, d, k, mode, nframes, emiter):
        # ++++++++++++++++++++++++++++++++++++++++
        # Approximate the models with online EM
        # ++++++++++++++++++++++++++++++++++++++++
        # Learn the model with Online EM
        ogm = GM(d, k, mode)
        ogmm = OnGMM(ogm, "kmean")
        init_data = self.data
        ogmm.init(init_data, niter=KM_ITER)

        # Check that online kmean init is the same than kmean offline init
        ogm0 = copy.copy(ogm)
        assert_array_equal(ogm0.w, self.gm0.w)
        assert_array_equal(ogm0.mu, self.gm0.mu)
        assert_array_equal(ogm0.va, self.gm0.va)

        # Forgetting param
        lamb = N.ones((nframes, 1))
        lamb[0] = 0
        nu0 = 1.0
        nu = N.zeros((len(lamb), 1))
        nu[0] = nu0
        for i in range(1, len(lamb)):
            nu[i] = 1.0 / (1 + lamb[i] / nu[i - 1])

        # object version of online EM: the p* arguments are updated only at each
        # epoch, which is equivalent to on full EM iteration on the
        # classic EM algorithm
        ogmm.pw = ogmm.cw.copy()
        ogmm.pmu = ogmm.cmu.copy()
        ogmm.pva = ogmm.cva.copy()
        for e in range(emiter):
            for t in range(nframes):
                ogmm.compute_sufficient_statistics_frame(self.data[t], nu[t])
                ogmm.update_em_frame()

            # Change pw args only a each epoch
            ogmm.pw = ogmm.cw.copy()
            ogmm.pmu = ogmm.cmu.copy()
            ogmm.pva = ogmm.cva.copy()

        # For equivalence between off and on, we allow a margin of error,
        # because of round-off errors.
        print " Checking precision of equivalence with offline EM trainer "
        maxtestprec = 18
        try:
            for i in range(maxtestprec):
                assert_array_almost_equal(self.gm.w, ogmm.pw, decimal=i)
                assert_array_almost_equal(self.gm.mu, ogmm.pmu, decimal=i)
                assert_array_almost_equal(self.gm.va, ogmm.pva, decimal=i)
            print "\t !! Precision up to %d decimals !! " % i
        except AssertionError:
            if i < AR_AS_PREC:
                print """\t !!NOT OK: Precision up to %d decimals only, 
                    outside the allowed range (%d) !! """ % (
                    i,
                    AR_AS_PREC,
                )
                raise AssertionError
            else:
                print "\t !!OK: Precision up to %d decimals !! " % i