Пример #1
0
def test_get_jaxified_logp():
    with pm.Model() as m:
        x = pm.Flat("x")
        y = pm.Flat("y")
        pm.Potential("pot", at.log(at.exp(x) + at.exp(y)))

    jax_fn = get_jaxified_logp(m)
    # This would underflow if not optimized
    assert not np.isinf(jax_fn((np.array(5000.0), np.array(5000.0))))
Пример #2
0
def mv_prior_simple():
    n = 3
    noise = 0.1
    X = np.linspace(0, 1, n)[:, None]

    K = pm.gp.cov.ExpQuad(1, 1)(X).eval()
    L = np.linalg.cholesky(K)
    K_noise = K + noise * np.eye(n)
    obs = floatX_array([-0.1, 0.5, 1.1])

    # Posterior mean
    L_noise = np.linalg.cholesky(K_noise)
    alpha = np.linalg.solve(L_noise.T, np.linalg.solve(L_noise, obs))
    mu_post = np.dot(K.T, alpha)

    # Posterior standard deviation
    v = np.linalg.solve(L_noise, K)
    std_post = (K - np.dot(v.T, v)).diagonal()**0.5

    with pm.Model() as model:
        x = pm.Flat("x", size=n)
        x_obs = pm.MvNormal("x_obs", observed=obs, mu=x, cov=noise * np.eye(n))

    return model.compute_initial_point(), model, (K, L, mu_post, std_post,
                                                  noise)
Пример #3
0
def test_point_logps_potential():
    with pm.Model() as model:
        x = pm.Flat("x", initval=1)
        y = pm.Potential("y", x * 2)

    logps = model.point_logps()
    assert np.isclose(logps["y"], 2)
Пример #4
0
 def test_untransformed_initial_point(self):
     with pm.Model() as pmodel:
         pm.Flat("A", initval="moment")
         pm.HalfFlat("B", initval="moment")
     fn = make_initial_point_fn(model=pmodel,
                                jitter_rvs={},
                                return_transformed=False)
     iv = fn(0)
     assert iv["A"] == 0
     assert iv["B"] == 1
     pass
Пример #5
0
def simple_normal(bounded_prior=False):
    """Simple normal for testing MLE / MAP; probes issue #2482."""
    x0 = 10.0
    sd = 1.0
    a, b = (9, 12)  # bounds for uniform RV, need non-symmetric to reproduce issue

    with pm.Model(rng_seeder=2482) as model:
        if bounded_prior:
            mu_i = pm.Uniform("mu_i", a, b)
        else:
            mu_i = pm.Flat("mu_i")
        pm.Normal("X_obs", mu=mu_i, sigma=sd, observed=x0)

    return model.initial_point, model, None
Пример #6
0
 def test_respects_overrides(self):
     with pm.Model() as pmodel:
         A = pm.Flat("A", initval="moment")
         B = pm.HalfFlat("B", initval=4)
         C = pm.Normal("C", mu=A + B, initval="moment")
     fn = make_initial_point_fn(
         model=pmodel,
         jitter_rvs={},
         return_transformed=True,
         overrides={
             A: at.as_tensor(2, dtype=int),
             B: 3,
             C: 5,
         },
     )
     iv = fn(0)
     assert iv["A"] == 2
     assert np.isclose(iv["B_log__"], np.log(3))
     assert iv["C"] == 5
Пример #7
0
    def test_string_overrides_work(self):
        with pm.Model() as pmodel:
            A = pm.Flat("A", initval=10)
            B = pm.HalfFlat("B", initval=10)
            C = pm.HalfFlat("C", initval=10)

        fn = make_initial_point_fn(
            model=pmodel,
            jitter_rvs={},
            return_transformed=True,
            overrides={
                "A": 1,
                "B": 1,
                "C_log__": 0,
            },
        )
        iv = fn(0)
        assert iv["A"] == 1
        assert np.isclose(iv["B_log__"], 0)
        assert iv["C_log__"] == 0
Пример #8
0
def test_set_initval():
    # Make sure the dependencies between variables are maintained when
    # generating initial values
    rng = np.random.RandomState(392)

    with pm.Model(rng_seeder=rng) as model:
        eta = pm.Uniform("eta", 1.0, 2.0, size=(1, 1))
        mu = pm.Normal("mu", sd=eta, initval=[[100]])
        alpha = pm.HalfNormal("alpha", initval=100)
        value = pm.NegativeBinomial("value", mu=mu, alpha=alpha)

    assert np.array_equal(model.initial_values[mu], np.array([[100.0]]))
    np.testing.assert_array_equal(model.initial_values[alpha], np.array(100))
    assert model.initial_values[value] is None

    # `Flat` cannot be sampled, so let's make sure that doesn't break initial
    # value computations
    with pm.Model() as model:
        x = pm.Flat("x")
        y = pm.Normal("y", x, 1)

    assert y in model.initial_values
Пример #9
0
 def test_adds_jitter(self):
     with pm.Model() as pmodel:
         A = pm.Flat("A", initval="moment")
         B = pm.HalfFlat("B", initval="moment")
         C = pm.Normal("C", mu=A + B, initval="moment")
     fn = make_initial_point_fn(model=pmodel,
                                jitter_rvs={B},
                                return_transformed=True)
     iv = fn(0)
     # Moment of the Flat is 0
     assert iv["A"] == 0
     # Moment of the HalfFlat is 1, but HalfFlat is log-transformed by default
     # so the transformed initial value with jitter will be zero plus a jitter between [-1, 1].
     b_transformed = iv["B_log__"]
     b_untransformed = transform_back(B, b_transformed)
     assert b_transformed != 0
     assert -1 < b_transformed < 1
     # C is centered on 0 + untransformed initval of B
     assert np.isclose(
         iv["C"], np.array(0 + b_untransformed, dtype=aesara.config.floatX))
     # Test jitter respects seeding.
     assert fn(0) == fn(0)
     assert fn(0) != fn(1)
Пример #10
0
                                     getParameters=True,
                                     getEnergy=True)


class Step(object):
    def __init__(self, var):
        self.var = var.name

    def step(self, point):
        new = point.copy()
        #new[self.var] = 10 + np.random.rand() # Normal samples
        state = point['state']
        sigma = point['sigma']
        new[self.var] = propagate(simulation, state, temperature, sigma,
                                  epsilon)

        return new


with pymc.Model() as model:
    sigma = pymc.Uniform("sigma", 0.535, 0.565)
    state = pymc.Flat('state')

    step1 = pymc.step_methods.NUTS(vars=[sigma])
    step2 = Step(state)  # not sure how to limit this to one variable

    trace = pymc.sample(10, [step1, step2])

pymc.traceplot(trace[:])
show()