def test_sum_to_1(): check_vector_transform(tr.sum_to_1, Simplex(2)) check_vector_transform(tr.sum_to_1, Simplex(4)) check_jacobian_det( tr.sum_to_1, Vector(Unit, 2), aet.dvector, np.array([0, 0]), lambda x: x[:-1] )
def test_stickbreaking(): check_vector_transform(tr.stick_breaking, Simplex(2)) check_vector_transform(tr.stick_breaking, Simplex(4)) check_transform(tr.stick_breaking, MultiSimplex(3, 2), constructor=at.dmatrix, test=np.zeros((2, 2)))
def test_stickbreaking(): with pytest.warns( DeprecationWarning, match="The argument `eps` is deprecated and will not be used." ): tr.StickBreaking(eps=1e-9) check_vector_transform(tr.stick_breaking, Simplex(2)) check_vector_transform(tr.stick_breaking, Simplex(4)) check_transform( tr.stick_breaking, MultiSimplex(3, 2), constructor=aet.dmatrix, test=np.zeros((2, 2)) )
def test_dirichlet(self): for n in [2, 3]: pymc3_random(Dirichlet, {'a': Vector(Rplus, n)}, valuedomain=Simplex(n), size=100, ref_rand=lambda a=None, size=None: st.dirichlet.rvs( a, size=size))
def test_multinomial(self): for n in [2, 3]: pymc3_random_discrete(Multinomial, {'p': Simplex(n), 'n' : Nat}, valuedomain=Vector(Nat, n), size=100, ref_rand=lambda n=None, p=None, size=None: \ nr.multinomial(n, p, size=size))
def checks_categorical_random(self, s): pymc3_random_discrete(Categorical, {'p':Simplex(s)}, ref_rand=lambda size=None, p=None: nr.choice(np.arange(p.shape[0]), p=p, size=size))