Example #1
0
    def test_mcmc(self):
        self.mcfitter.sample(steps=50, nthin=1, verbose=False)

        assert_equal(self.mcfitter.nvary, 2)

        # smoke test for corner plot
        self.mcfitter.objective.corner()

        # we're not doing Parallel Tempering here.
        assert_(self.mcfitter._ntemps == -1)
        assert_(isinstance(self.mcfitter.sampler, emcee.EnsembleSampler))

        # should be able to multithread
        mcfitter = CurveFitter(self.objective, nwalkers=50)
        res = mcfitter.sample(steps=33, nthin=2, verbose=False, pool=2)

        # check that the autocorrelation function at least runs
        acfs = mcfitter.acf(nburn=10)
        assert_equal(acfs.shape[-1], mcfitter.nvary)

        # check chain shape
        assert_equal(mcfitter.chain.shape, (33, 50, 2))
        # assert_equal(mcfitter._lastpos, mcfitter.chain[:, -1, :])
        assert_equal(res[0].chain.shape, (33, 50))

        # if the number of parameters changes there should be an Exception
        # raised
        from pytest import raises
        with raises(RuntimeError):
            self.p[0].vary = False
            self.mcfitter.sample(1)

        # can fix by making the sampler again
        self.mcfitter.make_sampler()
        self.mcfitter.sample(1)
Example #2
0
    def test_mcmc(self):
        self.mcfitter.sample(steps=50, nthin=1, verbose=False)

        # we're not doing Parallel Tempering here.
        assert_(self.mcfitter._ntemps == -1)
        assert_(isinstance(self.mcfitter.sampler, emcee.EnsembleSampler))

        # should be able to multithread
        mcfitter = CurveFitter(self.objective, nwalkers=50)
        res = mcfitter.sample(steps=33, nthin=2, verbose=False, pool=2)

        # check that the autocorrelation function at least runs
        acfs = mcfitter.acf(nburn=10)
        assert_equal(acfs.shape[-1], mcfitter.nvary)

        # check chain shape
        assert_equal(mcfitter.chain.shape, (33, 50, 2))
        # assert_equal(mcfitter._lastpos, mcfitter.chain[:, -1, :])
        assert_equal(res[0].chain.shape, (33, 50))