def test_sum_simplify_nested_sum_2(): X = Id('X') W = Id('W') children = [ SumSPE([ (X >> norm(loc=0, scale=1)) & (W >> norm(loc=0, scale=2)), (X >> norm(loc=0, scale=2)) & (W >> norm(loc=0, scale=1))], [log(0.9), log(0.1)]), (X >> norm(loc=0, scale=4)) & (W >> norm(loc=0, scale=10)), SumSPE([ (X >> norm(loc=0, scale=1)) & (W >> norm(loc=0, scale=2)), (X >> norm(loc=0, scale=2)) & (W >> norm(loc=0, scale=1)), (X >> norm(loc=0, scale=8)) & (W >> norm(loc=0, scale=3)),], [log(0.4), log(0.3), log(0.3)]), ] spe = SumSPE(children, [log(0.4), log(0.4), log(0.2)]) assert spe.size() == 19 assert spe.children == ( children[0].children[0], # 2 leaves children[0].children[1], # 2 leaves children[1], # 2 leaf children[2].children[0], # 2 leaves children[2].children[1], # 2 leaves children[2].children[2], # 2 leaves ) assert allclose(spe.weights[0], log(0.4) + log(0.9)) assert allclose(spe.weights[1], log(0.4) + log(0.1)) assert allclose(spe.weights[2], log(0.4)) assert allclose(spe.weights[3], log(0.2) + log(0.4)) assert allclose(spe.weights[4], log(0.2) + log(0.3)) assert allclose(spe.weights[5], log(0.2) + log(0.3))
def test_logpdf_mixture_nominal(): spe = SumSPE([X >> norm(), X >> choice({'a':.1, 'b':.9})], [log(.4), log(.6)]) assert allclose( spe.logpdf({X: .5}), log(.4) + spe.children[0].logpdf({X: .5})) assert allclose( spe.logpdf({X: 'a'}), log(.6) + spe.children[1].logpdf({X: 'a'}))
def test_sum_simplify_product_complex(): A1 = Id('A') >> norm(loc=0, scale=1) A0 = Id('A') >> norm(loc=0, scale=2) B = Id('B') >> norm(loc=0, scale=1) B1 = Id('B') >> norm(loc=0, scale=2) B0 = Id('B') >> norm(loc=0, scale=3) C = Id('C') >> norm(loc=0, scale=1) C1 = Id('C') >> norm(loc=0, scale=2) D = Id('D') >> norm(loc=0, scale=1) spe = SumSPE([ ProductSPE([A1, B, C, D]), ProductSPE([A0, B1, C, D]), ProductSPE([A0, B0, C1, D]), ], [log(0.4), log(0.4), log(0.2)]) spe_simplified = spe_simplify_sum(spe) assert isinstance(spe_simplified, ProductSPE) assert isinstance(spe_simplified.children[0], SumSPE) assert spe_simplified.children[1] == D ssc0 = spe_simplified.children[0] assert isinstance(ssc0.children[1], ProductSPE) assert ssc0.children[1].children == (A0, B0, C1) assert isinstance(ssc0.children[0], ProductSPE) assert ssc0.children[0].children[1] == C ssc0c0 = ssc0.children[0].children[0] assert isinstance(ssc0c0, SumSPE) assert isinstance(ssc0c0.children[0], ProductSPE) assert isinstance(ssc0c0.children[1], ProductSPE) assert ssc0c0.children[0].children == (A1, B) assert ssc0c0.children[1].children == (A0, B1)
def test_sum_simplify_nested_sum_1(): X = Id('X') children = [ SumSPE( [X >> norm(loc=0, scale=1), X >> norm(loc=0, scale=2)], [log(0.4), log(0.6)]), X >> gamma(loc=0, a=1), ] spe = SumSPE(children, [log(0.7), log(0.3)]) assert spe.size() == 4 assert spe.children == ( children[0].children[0], children[0].children[1], children[1] ) assert allclose(spe.weights[0], log(0.7) + log(0.4)) assert allclose(spe.weights[1], log(0.7) + log(0.6)) assert allclose(spe.weights[2], log(0.3))
def test_sum_simplify_product_collapse(): A1 = Id('A') >> norm(loc=0, scale=1) A0 = Id('A') >> norm(loc=0, scale=1) B = Id('B') >> norm(loc=0, scale=1) B1 = Id('B') >> norm(loc=0, scale=1) B0 = Id('B') >> norm(loc=0, scale=1) C = Id('C') >> norm(loc=0, scale=1) C1 = Id('C') >> norm(loc=0, scale=1) D = Id('D') >> norm(loc=0, scale=1) spe = SumSPE([ ProductSPE([A1, B, C, D]), ProductSPE([A0, B1, C, D]), ProductSPE([A0, B0, C1, D]), ], [log(0.4), log(0.4), log(0.2)]) assert spe_simplify_sum(spe) == ProductSPE([A1, B, C, D])
def test_sum_simplify_leaf(): Xd0 = Id('X') >> norm(loc=0, scale=1) Xd1 = Id('X') >> norm(loc=0, scale=2) Xd2 = Id('X') >> norm(loc=0, scale=3) spe = SumSPE([Xd0, Xd1, Xd2], [log(0.5), log(0.1), log(.4)]) assert spe.size() == 4 assert spe_simplify_sum(spe) == spe Xd0 = Id('X') >> norm(loc=0, scale=1) Xd1 = Id('X') >> norm(loc=0, scale=1) Xd2 = Id('X') >> norm(loc=0, scale=1) spe = SumSPE([Xd0, Xd1, Xd2], [log(0.5), log(0.1), log(.4)]) assert spe_simplify_sum(spe) == Xd0 Xd3 = Id('X') >> norm(loc=0, scale=2) spe = SumSPE([Xd0, Xd3, Xd1, Xd3], [log(0.5), log(0.1), log(.3), log(.1)]) spe_simplified = spe_simplify_sum(spe) assert len(spe_simplified.children) == 2 assert spe_simplified.children[0] == Xd0 assert spe_simplified.children[1] == Xd3 assert allclose(spe_simplified.weights[0], log(0.8)) assert allclose(spe_simplified.weights[1], log(0.2))
def test_sum_normal_gamma(): X = Id('X') weights = [log(Fraction(2, 3)), log(Fraction(1, 3))] spe = SumSPE([ X >> norm(loc=0, scale=1), X >> gamma(loc=0, a=1), ], weights) assert spe.logprob(X > 0) == logsumexp([ spe.weights[0] + spe.children[0].logprob(X > 0), spe.weights[1] + spe.children[1].logprob(X > 0), ]) assert spe.logprob(X < 0) == log(Fraction(2, 3)) + log(Fraction(1, 2)) samples = spe.sample(100, prng=numpy.random.RandomState(1)) assert all(s[X] for s in samples) spe.sample_func(lambda X: abs(X**3), 100) with pytest.raises(ValueError): spe.sample_func(lambda Y: abs(X**3), 100) spe_condition = spe.condition(X < 0) assert isinstance(spe_condition, ContinuousLeaf) assert spe_condition.conditioned assert spe_condition.logprob(X < 0) == 0 samples = spe_condition.sample(100) assert all(s[X] < 0 for s in samples) assert spe.logprob(X < 0) == logsumexp([ spe.weights[0] + spe.children[0].logprob(X < 0), spe.weights[1] + spe.children[1].logprob(X < 0), ])
def test_sum_normal_nominal(): X = Id('X') children = [ X >> norm(loc=0, scale=1), X >> choice({ 'low': Fraction(3, 10), 'high': Fraction(7, 10) }), ] weights = [log(Fraction(4, 7)), log(Fraction(3, 7))] spe = SumSPE(children, weights) assert allclose(spe.logprob(X < 0), log(Fraction(4, 7)) + log(Fraction(1, 2))) assert allclose(spe.logprob(X << {'low'}), log(Fraction(3, 7)) + log(Fraction(3, 10))) # The semantics of ~(X<<{'low'}) are (X << String and X != 'low') assert allclose(spe.logprob(~(X << {'low'})), spe.logprob((X << {'high'}))) assert allclose( spe.logprob((X << FiniteNominal(b=True)) & ~(X << {'low'})), spe.logprob((X << FiniteNominal(b=True)) & (X << {'high'}))) assert isinf_neg(spe.logprob((X < 0) & (X << {'low'}))) assert allclose(spe.logprob((X < 0) | (X << {'low'})), logsumexp([spe.logprob(X < 0), spe.logprob(X << {'low'})])) assert isinf_neg(spe.logprob(X << {'a'})) assert allclose(spe.logprob(~(X << {'a'})), spe.logprob(X << {'low', 'high'})) assert allclose(spe.logprob(X**2 < 9), log(Fraction(4, 7)) + spe.children[0].logprob(X**2 < 9)) spe_condition = spe.condition(X**2 < 9) assert isinstance(spe_condition, ContinuousLeaf) assert spe_condition.support == Interval.open(-3, 3) spe_condition = spe.condition((X**2 < 9) | X << {'low'}) assert isinstance(spe_condition, SumSPE) assert spe_condition.children[0].support == Interval.open(-3, 3) assert spe_condition.children[1].support == FiniteNominal('low', 'high') assert isinf_neg(spe_condition.children[1].logprob(X << {'high'})) assert spe_condition == spe.condition((X**2 < 9) | ~(X << {'high'})) assert allclose(spe.logprob((X < oo) | ~(X << {'1'})), 0)
def test_cache_complex_sum_of_product(): # Test case adapted from the SPE generated by # test_repeat.make_model_repeat(n=2) duplicate_subtrees = [None, None] for i in range(2): duplicate_subtrees[i] = SumSPE([ ProductSPE([(X[0] >> bernoulli(p=.1)), SumSPE([(Z[0] >> bernoulli(p=.5)) & (Y >> choice({ '0': .1, '1': .9 })), (Z[0] >> bernoulli(p=.1)) & (Y >> choice({ '0': .9, '1': .1 }))], weights=[log(.730), log(.270)])]), ProductSPE([ Z[0] >> bernoulli(p=.1), Y >> choice({ '0': .9, '1': .1 }), X[0] >> bernoulli(p=.5), ]), ], weights=[log(.925), log(.075)]) assert duplicate_subtrees[0] == duplicate_subtrees[1] assert duplicate_subtrees[0] is not duplicate_subtrees[1] left_subtree = ProductSPE([ X[1] >> bernoulli(p=.5), SumSPE([ ProductSPE([ duplicate_subtrees[0], Z[1] >> bernoulli(p=.5), ]), ProductSPE([ Z[1] >> bernoulli(p=.7), SumSPE([ Y >> choice({ '0': .3, '1': .7 }) & X[0] >> bernoulli(p=.1) & Z[0] >> bernoulli(p=.1), Y >> choice({ '0': .7, '1': .3 }) & X[0] >> bernoulli(p=.5) & Z[0] >> bernoulli(p=.5), ], weights=[log(.9), log(.1)]) ]) ], weights=[log(.783), log(.217)]) ]) right_subtree = ProductSPE([ Z[1] >> bernoulli(p=.8), X[1] >> bernoulli(p=.1), duplicate_subtrees[1] ]) spe = .92 * left_subtree | .08 * right_subtree spe_cached = spe_cache_duplicate_subtrees(spe, {}) assert spe_cached.children[0].children[1].children[0].children[ 0] is duplicate_subtrees[0] assert spe_cached.children[1].children[2] is duplicate_subtrees[0]