def test_string_overrides_work(self): with pm.Model() as pmodel: A = pm.Flat("A", initval=10) B = pm.HalfFlat("B", initval=10) C = pm.HalfFlat("C", initval=10) fn = make_initial_point_fn( model=pmodel, jitter_rvs={}, return_transformed=True, overrides={ "A": 1, "B": 1, "C_log__": 0, }, ) iv = fn(0) assert iv["A"] == 1 assert np.isclose(iv["B_log__"], 0) assert iv["C_log__"] == 0
def test_untransformed_initial_point(self): with pm.Model() as pmodel: pm.Flat("A", initval="moment") pm.HalfFlat("B", initval="moment") fn = make_initial_point_fn(model=pmodel, jitter_rvs={}, return_transformed=False) iv = fn(0) assert iv["A"] == 0 assert iv["B"] == 1 pass
def test_model_not_drawable_prior(self): data = np.random.poisson(lam=10, size=200) model = pm.Model() with model: mu = pm.HalfFlat("sigma") pm.Poisson("foo", mu=mu, observed=data) idata = pm.sample(tune=1000) with model: with pytest.raises(NotImplementedError) as excinfo: pm.sample_prior_predictive(50) assert "Cannot sample" in str(excinfo.value) samples = pm.sample_posterior_predictive(idata, 40, return_inferencedata=False) assert samples["foo"].shape == (40, 200)
def test_respects_overrides(self): with pm.Model() as pmodel: A = pm.Flat("A", initval="moment") B = pm.HalfFlat("B", initval=4) C = pm.Normal("C", mu=A + B, initval="moment") fn = make_initial_point_fn( model=pmodel, jitter_rvs={}, return_transformed=True, overrides={ A: at.as_tensor(2, dtype=int), B: 3, C: 5, }, ) iv = fn(0) assert iv["A"] == 2 assert np.isclose(iv["B_log__"], np.log(3)) assert iv["C"] == 5
def test_adds_jitter(self): with pm.Model() as pmodel: A = pm.Flat("A", initval="moment") B = pm.HalfFlat("B", initval="moment") C = pm.Normal("C", mu=A + B, initval="moment") fn = make_initial_point_fn(model=pmodel, jitter_rvs={B}, return_transformed=True) iv = fn(0) # Moment of the Flat is 0 assert iv["A"] == 0 # Moment of the HalfFlat is 1, but HalfFlat is log-transformed by default # so the transformed initial value with jitter will be zero plus a jitter between [-1, 1]. b_transformed = iv["B_log__"] b_untransformed = transform_back(B, b_transformed) assert b_transformed != 0 assert -1 < b_transformed < 1 # C is centered on 0 + untransformed initval of B assert np.isclose( iv["C"], np.array(0 + b_untransformed, dtype=aesara.config.floatX)) # Test jitter respects seeding. assert fn(0) == fn(0) assert fn(0) != fn(1)