Esempio n. 1
0
def test_shapes(backend, moves, nwalkers=32, ndim=3, nsteps=10, seed=1234):
    # Set up the random number generator.
    np.random.seed(seed)

    with backend() as be:
        # Initialize the ensemble, moves and sampler.
        coords = np.random.randn(nwalkers, ndim)
        sampler = EnsembleSampler(nwalkers,
                                  ndim,
                                  normal_log_prob,
                                  moves=moves,
                                  backend=be)

        # Run the sampler.
        sampler.run_mcmc(coords, nsteps)

        chain = sampler.get_chain()
        assert len(chain) == nsteps, "wrong number of steps"

        tau = sampler.get_autocorr_time(quiet=True)
        assert tau.shape == (ndim, )

        # Check the shapes.
        with pytest.warns(DeprecationWarning):
            assert sampler.chain.shape == (
                nwalkers,
                nsteps,
                ndim,
            ), "incorrect coordinate dimensions"
        with pytest.warns(DeprecationWarning):
            assert sampler.lnprobability.shape == (
                nwalkers,
                nsteps,
            ), "incorrect probability dimensions"
        assert sampler.get_chain().shape == (
            nsteps,
            nwalkers,
            ndim,
        ), "incorrect coordinate dimensions"
        assert sampler.get_log_prob().shape == (
            nsteps,
            nwalkers,
        ), "incorrect probability dimensions"

        assert sampler.acceptance_fraction.shape == (
            nwalkers, ), "incorrect acceptance fraction dimensions"

        # Check the shape of the flattened coords.
        assert sampler.get_chain(flat=True).shape == (
            nsteps * nwalkers,
            ndim,
        ), "incorrect coordinate dimensions"
        assert sampler.get_log_prob(flat=True).shape == (
            nsteps * nwalkers, ), "incorrect probability dimensions"
Esempio n. 2
0
def test_sampler_generator():
    nwalkers = 32
    ndim = 3
    nsteps = 5
    np.random.seed(456)
    coords = np.random.randn(nwalkers, ndim)
    seed1 = np.random.default_rng(1)
    sampler1 = EnsembleSampler(nwalkers, ndim, normal_log_prob, seed=seed1)
    sampler1.run_mcmc(coords, nsteps)
    seed2 = np.random.default_rng(1)
    sampler2 = EnsembleSampler(nwalkers, ndim, normal_log_prob, seed=seed2)
    sampler2.run_mcmc(coords, nsteps)
    np.testing.assert_allclose(sampler1.get_chain(), sampler2.get_chain())
    np.testing.assert_allclose(sampler1.get_log_prob(),
                               sampler2.get_log_prob())
Esempio n. 3
0
def test_shapes(backend, moves, nwalkers=32, ndim=3, nsteps=10, seed=1234):
    # Set up the random number generator.
    np.random.seed(seed)

    with backend() as be:
        # Initialize the ensemble, moves and sampler.
        coords = np.random.randn(nwalkers, ndim)
        sampler = EnsembleSampler(nwalkers, ndim, normal_log_prob,
                                  moves=moves, backend=be)

        # Run the sampler.
        sampler.run_mcmc(coords, nsteps)
        chain = sampler.get_chain()
        assert len(chain) == nsteps, "wrong number of steps"

        tau = sampler.get_autocorr_time(quiet=True)
        assert tau.shape == (ndim,)

        # Check the shapes.
        assert sampler.chain.shape == (nwalkers, nsteps, ndim), \
            "incorrect coordinate dimensions"
        assert sampler.get_chain().shape == (nsteps, nwalkers, ndim), \
            "incorrect coordinate dimensions"
        assert sampler.lnprobability.shape == (nsteps, nwalkers), \
            "incorrect probability dimensions"

        assert sampler.acceptance_fraction.shape == (nwalkers,), \
            "incorrect acceptance fraction dimensions"

        # Check the shape of the flattened coords.
        assert sampler.get_chain(flat=True).shape == \
            (nsteps * nwalkers, ndim), "incorrect coordinate dimensions"
        assert sampler.get_log_prob(flat=True).shape == \
            (nsteps*nwalkers,), "incorrect probability dimensions"
Esempio n. 4
0
def test_errors(backend, nwalkers=32, ndim=3, nsteps=5, seed=1234):
    # Set up the random number generator.
    np.random.seed(seed)

    with backend() as be:
        # Initialize the ensemble, proposal, and sampler.
        coords = np.random.randn(nwalkers, ndim)
        sampler = EnsembleSampler(nwalkers, ndim, normal_log_prob, backend=be)

        # Test for not running.
        with pytest.raises(AttributeError):
            sampler.get_chain()
        with pytest.raises(AttributeError):
            sampler.get_log_prob()

        # What about not storing the chain.
        sampler.run_mcmc(coords, nsteps, store=False)
        with pytest.raises(AttributeError):
            sampler.get_chain()

        # Now what about if we try to continue using the sampler with an
        # ensemble of a different shape.
        sampler.run_mcmc(coords, nsteps, store=False)

        coords2 = np.random.randn(nwalkers, ndim + 1)
        with pytest.raises(ValueError):
            list(sampler.run_mcmc(coords2, nsteps))

        # Ensure that a warning is logged if the inital coords don't allow
        # the chain to explore all of parameter space, and that one is not
        # if we explicitly disable it, or the initial coords can.
        with pytest.warns(RuntimeWarning) as recorded_warnings:
            sampler.run_mcmc(np.ones((nwalkers, ndim)), nsteps)
            assert len(recorded_warnings) == 1
        with pytest.warns(None) as recorded_warnings:
            sampler.run_mcmc(
                np.ones((nwalkers, ndim)),
                nsteps,
                skip_initial_state_check=True,
            )
            sampler.run_mcmc(np.random.randn(nwalkers, ndim), nsteps)
            assert len(recorded_warnings) == 0
Esempio n. 5
0
class Sampler:
    """
    wrapper of emcee.EnsembleSampler. 
    """
    def __init__(self, lnpost, p0, nwalkers=120, blobs_dtype=float):
        """
        init
        """

        self.lnpost = lnpost
        blobs_dtype = blobs_dtype  # Note: Here dtype must be specified, otherwise an error happens. #[("lnlike",float),]
        self.sampler = EnsembleSampler(
            nwalkers, p0.shape[1], lnpost, blobs_dtype=blobs_dtype
        )  # NOTE: dtype must be list of tuple (not tuple of tuple)
        self.p0 = p0
        self.p_last = p0
        self.ndim = p0.shape[1]

    def reset_sampler(self):
        self.sampler.reset()

    def sample(self, n_sample, burnin=False, use_pool=False):
        """
        execute mcmc for given iteration steps.
        """
        desc = "burnin" if burnin else "sample"

        with Pool() as pool:
            self.sampler.pool = pool if use_pool else None
            iteration = tqdm(self.sampler.sample(self.p_last,
                                                 iterations=n_sample),
                             total=n_sample,
                             desc=desc)
            for _ret in iteration:
                self.p_last = _ret.coords  # if uses_emcee3 else _ret[0]  # for emcee2
                lnposts = _ret.log_prob  # if uses_emcee3 else _ret[1]  # for emcee2
                iteration.set_postfix(lnpost_min=np.min(lnposts),
                                      lnpost_max=np.max(lnposts),
                                      lnpost_mean=np.mean(lnposts))
            if burnin:
                self.reset_sampler()

    def get_chain(self, **kwargs):
        return self.sampler.get_chain(**kwargs)

    def get_log_prob(self, **kwargs):
        return self.sampler.get_log_prob(**kwargs)

    def get_blobs(self, **kwargs):
        return self.sampler.get_blobs(**kwargs)

    def get_last_sample(self, **kwargs):
        return self.sampler.get_last_sample(**kwargs)

    def _save(self, fname_base):
        np.save(fname_base + "_chain.npy", self.get_chain())
        np.save(fname_base + "_lnprob.npy", self.get_log_prob())
        np.save(fname_base + "_lnlike.npy", self.get_blobs())

    def save(self, fname_base):
        '''
        Save MCMC results into "<fname_base>_chain/lnprob/lnlike.npy".
        If fname_base is like "your_directory/your_prefix", create "your_directory" before saving.
        '''
        dirname = os.path.dirname(fname_base)
        if dirname == "":
            self._save(fname_base)
        else:
            if not os.path.isdir(dirname): os.mkdir(dirname)
            self._save(fname_base)

    def save_pickle(self, fname_base, overwrite=False):
        fname = fname_base + '_.gz'
        if os.path.exists(fname):
            if overwrite:
                warn(f"{fname} exsits already. It will be overwritten.")
            else:
                raise RuntimeError(
                    f"{fname} exsits already. If you want to overwrite it, set \"overwrite=True\"."
                )
        data = pickle.dumps(self)
        with gzip.open(fname, mode='wb') as fp:
            fp.write(data)