Exemple #1
0
    def test_pymc3(self):
        # test objective logl against pymc3

        # don't run this test if pymc3 is not installed
        try:
            import pymc3 as pm
        except ImportError:
            return

        logl = self.objective.logl()

        from refnx.analysis import pymc3_model
        from refnx.analysis.objective import _to_pymc3_distribution

        mod = pymc3_model(self.objective)
        with mod:
            pymc_logl = mod.logp({
                "p0": self.p[0].value,
                "p1": self.p[1].value
            })

        assert_allclose(logl, pymc_logl)

        # now check some of the distributions
        with pm.Model():
            p = Parameter(1, bounds=(1, 10))
            d = _to_pymc3_distribution("a", p)
            assert_almost_equal(d.distribution.logp(2).eval(), p.logp(2))
            assert_(np.isneginf(d.distribution.logp(-1).eval()))

            q = Parameter(1, bounds=PDF(stats.uniform(1, 9)))
            d = _to_pymc3_distribution("b", q)
            assert_almost_equal(d.distribution.logp(2).eval(), q.logp(2))
            assert_(np.isneginf(d.distribution.logp(-1).eval()))

            p = Parameter(1, bounds=PDF(stats.uniform))
            d = _to_pymc3_distribution("c", p)
            assert_almost_equal(d.distribution.logp(0.5).eval(), p.logp(0.5))

            p = Parameter(1, bounds=PDF(stats.norm))
            d = _to_pymc3_distribution("d", p)
            assert_almost_equal(d.distribution.logp(2).eval(), p.logp(2))

            p = Parameter(1, bounds=PDF(stats.norm(1, 10)))
            d = _to_pymc3_distribution("e", p)
            assert_almost_equal(d.distribution.logp(2).eval(), p.logp(2))
Exemple #2
0
    def test_pymc3_sample(self):
        # test sampling with pymc3
        try:
            import pymc3 as pm
            from refnx.analysis import pymc3_model
        except (ModuleNotFoundError, ImportError, AttributeError):
            # can't run test if pymc3/theano not installed
            return

        with pymc3_model(self.objective):
            s = pm.NUTS()
            pm.sample(
                200,
                tune=100,
                step=s,
                discard_tuned_samples=True,
                compute_convergence_checks=False,
                random_seed=1,
            )