def test_dirichlet_samples(): alphas = np.array([[100, 1, 1], [1, 100, 1], [1, 1, 100]], dtype=config.floatX) res = get_test_value(dirichlet(alphas)) assert np.all(np.diag(res) >= res) res = get_test_value(dirichlet(alphas, size=2)) assert res.shape == (2, 3, 3) assert all(np.all(np.diag(r) >= r) for r in res) for i in range(alphas.shape[0]): res = get_test_value(dirichlet(alphas[i])) assert np.all(res[i] > np.delete(res, [i])) res = get_test_value(dirichlet(alphas[i], size=2)) assert res.shape == (2, 3) assert all(np.all(r[i] > np.delete(r, [i])) for r in res) rng_state = np.random.RandomState( np.random.MT19937(np.random.SeedSequence(1234))) alphas = np.array([[1000, 1, 1], [1, 1000, 1], [1, 1, 1000]], dtype=config.floatX) assert dirichlet.rng_fn(rng_state, alphas, None).shape == alphas.shape assert dirichlet.rng_fn(rng_state, alphas, size=10).shape == (10, ) + alphas.shape assert (dirichlet.rng_fn(rng_state, alphas, size=(10, 2)).shape == (10, 2) + alphas.shape)
def test_dirichlet_infer_shape(): M_aet = iscalar("M") M_aet.tag.test_value = 3 test_params = [ ([aet.ones((M_aet, ))], None), ([aet.ones((M_aet, ))], (M_aet + 1, )), ([aet.ones((M_aet, ))], (2, M_aet)), ([aet.ones((M_aet, M_aet + 1))], None), ([aet.ones((M_aet, M_aet + 1))], (M_aet + 2, )), ([aet.ones((M_aet, M_aet + 1))], (2, M_aet + 2, M_aet + 3)), ] for args, size in test_params: rv = dirichlet(*args, size=size) rv_shape = tuple(dirichlet._infer_shape(size or (), args, None)) assert tuple(get_test_value(rv_shape)) == tuple( get_test_value(rv).shape)
def test_dirichlet_infer_shape(M, size): rv = dirichlet(M, size=size) rv_shape = list(dirichlet._infer_shape(size or (), [M], None)) all_args = (M, ) + size fn_inputs = [ i for i in graph_inputs([a for a in all_args if isinstance(a, Variable)]) if not isinstance(i, (Constant, SharedVariable)) ] aesara_fn = function(fn_inputs, [at.as_tensor(o) for o in rv_shape + [rv]], mode=py_mode) *rv_shape_val, rv_val = aesara_fn(*[ i.tag.test_value for i in fn_inputs if not isinstance(i, (SharedVariable, Constant)) ]) assert tuple(rv_shape_val) == tuple(rv_val.shape)
def test_dirichlet_ShapeFeature(): """Make sure `RandomVariable.infer_shape` works with `ShapeFeature`.""" M_at = iscalar("M") M_at.tag.test_value = 2 N_at = iscalar("N") N_at.tag.test_value = 3 d_rv = dirichlet(at.ones((M_at, N_at)), name="Gamma") fg = FunctionGraph( outputs=[d_rv], clone=False, features=[ShapeFeature()], ) s1, s2 = fg.shape_feature.shape_of[d_rv] assert M_at in graph_inputs([s1]) assert N_at in graph_inputs([s2])
def test_dirichlet_ShapeFeature(): """Make sure `RandomVariable.infer_shape` works with `ShapeFeature`.""" M_tt = iscalar("M") M_tt.tag.test_value = 2 N_tt = iscalar("N") N_tt.tag.test_value = 3 d_rv = dirichlet(aet.ones((M_tt, N_tt)), name="Gamma") fg = FunctionGraph( [i for i in graph_inputs([d_rv]) if not isinstance(i, Constant)], [d_rv], clone=False, features=[ShapeFeature()], ) s1, s2 = fg.shape_feature.shape_of[d_rv] assert M_tt in graph_inputs([s1]) assert N_tt in graph_inputs([s2])