def test_or_and(): with pytest.raises(ValueError): (0.3*(X >> norm()) | 0.7*(Y >> gamma(a=1))) & (Z >> norm()) a = (0.3*(X >> norm()) | 0.7*(X >> gamma(a=1))) & (Z >> norm()) assert isinstance(a, ProductSPE) assert isinstance(a.children[0], SumSPE) assert isinstance(a.children[1], ContinuousLeaf)
def test_cache_simple_sum_of_product(): spe \ = 0.3 * ((W >> norm(loc=0, scale=1)) & (Y >> norm(loc=0, scale=1))) \ | 0.7 * ((W >> norm(loc=0, scale=1)) & (Y >> norm(loc=0, scale=2))) spe_cached = spe_cache_duplicate_subtrees(spe, {}) assert spe_cached.children[0].children[0] is spe_cached.children[ 1].children[0]
def test_logpdf_lexicographic_both(): spe = .75*(X >> norm() & Y >> atomic(loc=0) & Z >> discrete({1:.2, 2:.8})) \ | .25*(X >> discrete({1:.5, 2:.5}) & Y >> norm() & Z >> atomic(loc=2)) # Lexicographic, Mix assignment = {X:1, Y:0, Z:2} assert allclose( spe.logpdf(assignment), logsumexp([ log(.75) + norm().dist.logpdf(1) + log(1) + log(.8), log(.25) + log(.5) + norm().dist.logpdf(0) + log(1)])) assert isinstance(spe.constrain(assignment), SumSPE)
def test_product_leaf(): with pytest.raises(TypeError): 0.3*(X >> gamma(a=1)) & (X >> norm()) with pytest.raises(TypeError): (X >> norm()) & 0.3*(X >> gamma(a=1)) with pytest.raises(ValueError): (X >> norm()) & (X >> gamma(a=1)) y = (X >> norm()) & (Y >> gamma(a=1)) & (Z >> norm()) assert isinstance(y, ProductSPE) assert len(y.children) == 3 assert y.get_symbols() == frozenset([X, Y, Z])
def test_mutual_information_four_clusters(memo): X = Id('X') Y = Id('Y') spe \ = 0.25*(X >> norm(loc=0, scale=0.5) & Y >> norm(loc=0, scale=0.5)) \ | 0.25*(X >> norm(loc=5, scale=0.5) & Y >> norm(loc=0, scale=0.5)) \ | 0.25*(X >> norm(loc=0, scale=0.5) & Y >> norm(loc=5, scale=0.5)) \ | 0.25*(X >> norm(loc=5, scale=0.5) & Y >> norm(loc=5, scale=0.5)) \ A = X > 2 B = Y > 2 samples = spe.sample(100, prng) mi = spe.mutual_information(A, B, memo=memo) assert allclose(mi, 0) check_mi_properties(spe, A, B, memo) event = ((X>2) & (Y<2) | ((X<2) & (Y>2))) spe_condition = spe.condition(event) samples = spe_condition.sample(100, prng) assert all(event.evaluate(sample) for sample in samples) mi = spe_condition.mutual_information(X > 2, Y > 2) assert allclose(mi, log(2)) check_mi_properties(spe, (X>1) | (Y<1), (Y>2), memo) check_mi_properties(spe, (X>1) | (Y<1), (X>1.5) & (Y>2), memo)
def test_sum_of_sums(): w \ = 0.3*(0.4*(X >> norm()) | 0.6*(X >> norm())) \ | 0.7*(0.1*(X >> norm()) | 0.9*(X >> norm())) assert isinstance(w, SumSPE) assert len(w.children) == 4 assert allclose(float(w.weights[0]), log(0.3) + log(0.4)) assert allclose(float(w.weights[1]), log(0.3) + log(0.6)) assert allclose(float(w.weights[2]), log(0.7) + log(0.1)) assert allclose(float(w.weights[3]), log(0.7) + log(0.9)) w \ = 0.3*(0.4*(X >> norm()) | 0.6*(X >> norm())) \ | 0.2*(0.1*(X >> norm()) | 0.9*(X >> norm())) assert isinstance(w, PartialSumSPE) assert allclose(float(w.weights[0]), 0.3) assert allclose(float(w.weights[1]), 0.2) a = w | 0.5*(X >> gamma(a=1)) assert isinstance(a, SumSPE) assert len(a.children) == 5 assert allclose(float(a.weights[0]), log(0.3) + log(0.4)) assert allclose(float(a.weights[1]), log(0.3) + log(0.6)) assert allclose(float(a.weights[2]), log(0.2) + log(0.1)) assert allclose(float(a.weights[3]), log(0.2) + log(0.9)) assert allclose(float(a.weights[4]), log(0.5)) # Wrong symbol. with pytest.raises(ValueError): z = w | 0.4*(Y >> gamma(a=1))
def test_sum_simplify_product_complex(): A1 = Id('A') >> norm(loc=0, scale=1) A0 = Id('A') >> norm(loc=0, scale=2) B = Id('B') >> norm(loc=0, scale=1) B1 = Id('B') >> norm(loc=0, scale=2) B0 = Id('B') >> norm(loc=0, scale=3) C = Id('C') >> norm(loc=0, scale=1) C1 = Id('C') >> norm(loc=0, scale=2) D = Id('D') >> norm(loc=0, scale=1) spe = SumSPE([ ProductSPE([A1, B, C, D]), ProductSPE([A0, B1, C, D]), ProductSPE([A0, B0, C1, D]), ], [log(0.4), log(0.4), log(0.2)]) spe_simplified = spe_simplify_sum(spe) assert isinstance(spe_simplified, ProductSPE) assert isinstance(spe_simplified.children[0], SumSPE) assert spe_simplified.children[1] == D ssc0 = spe_simplified.children[0] assert isinstance(ssc0.children[1], ProductSPE) assert ssc0.children[1].children == (A0, B0, C1) assert isinstance(ssc0.children[0], ProductSPE) assert ssc0.children[0].children[1] == C ssc0c0 = ssc0.children[0].children[0] assert isinstance(ssc0c0, SumSPE) assert isinstance(ssc0c0.children[0], ProductSPE) assert isinstance(ssc0c0.children[1], ProductSPE) assert ssc0c0.children[0].children == (A1, B) assert ssc0c0.children[1].children == (A0, B1)
def test_transform_real_leaf_logprob(): X = Id('X') Y = Id('Y') Z = Id('Z') spe = (X >> norm(loc=0, scale=1)) with pytest.raises(AssertionError): spe.transform(Z, Y**2) with pytest.raises(AssertionError): spe.transform(X, X**2) spe = spe.transform(Z, X**2) assert spe.env == {X:X, Z:X**2} assert spe.get_symbols() == {X, Z} assert spe.logprob(Z < 1) == spe.logprob(X**2 < 1) assert spe.logprob((Z < 1) | ((X + 1) < 3)) \ == spe.logprob((X**2 < 1) | ((X+1) < 3)) spe = spe.transform(Y, 2*Z) assert spe.env == {X:X, Z:X**2, Y:2*Z} assert spe.logprob(Y**(1,3) < 10) \ == spe.logprob((2*Z)**(1,3) < 10) \ == spe.logprob((2*(X**2))**(1,3) < 10) \ W = Id('W') spe = spe.transform(W, X > 1) assert allclose(spe.logprob(W), spe.logprob(X > 1))
def test_serialize_env(transform): spe = (X >> norm()).transform(Y, transform) metadata = spe_to_dict(spe) spe_json_encoded = json.dumps(metadata) spe_json_decoded = json.loads(spe_json_encoded) spe2 = spe_from_dict(spe_json_decoded) assert spe2 == spe
def test_sum_normal_gamma(): X = Id('X') weights = [log(Fraction(2, 3)), log(Fraction(1, 3))] spe = SumSPE([ X >> norm(loc=0, scale=1), X >> gamma(loc=0, a=1), ], weights) assert spe.logprob(X > 0) == logsumexp([ spe.weights[0] + spe.children[0].logprob(X > 0), spe.weights[1] + spe.children[1].logprob(X > 0), ]) assert spe.logprob(X < 0) == log(Fraction(2, 3)) + log(Fraction(1, 2)) samples = spe.sample(100, prng=numpy.random.RandomState(1)) assert all(s[X] for s in samples) spe.sample_func(lambda X: abs(X**3), 100) with pytest.raises(ValueError): spe.sample_func(lambda Y: abs(X**3), 100) spe_condition = spe.condition(X < 0) assert isinstance(spe_condition, ContinuousLeaf) assert spe_condition.conditioned assert spe_condition.logprob(X < 0) == 0 samples = spe_condition.sample(100) assert all(s[X] < 0 for s in samples) assert spe.logprob(X < 0) == logsumexp([ spe.weights[0] + spe.children[0].logprob(X < 0), spe.weights[1] + spe.children[1].logprob(X < 0), ])
def test_logpdf_mixture_nominal(): spe = SumSPE([X >> norm(), X >> choice({'a':.1, 'b':.9})], [log(.4), log(.6)]) assert allclose( spe.logpdf({X: .5}), log(.4) + spe.children[0].logpdf({X: .5})) assert allclose( spe.logpdf({X: 'a'}), log(.6) + spe.children[1].logpdf({X: 'a'}))
def test_logpdf_mixture_real_continuous_continuous(): spe = X >> (.3*norm() | .7*gamma(a=1)) assert allclose( spe.logpdf({X: .5}), logsumexp([ log(.3) + spe.children[0].logpdf({X: 0.5}), log(.7) + spe.children[1].logpdf({X: 0.5}), ]))
def test_sum_simplify_nested_sum_1(): X = Id('X') children = [ SumSPE( [X >> norm(loc=0, scale=1), X >> norm(loc=0, scale=2)], [log(0.4), log(0.6)]), X >> gamma(loc=0, a=1), ] spe = SumSPE(children, [log(0.7), log(0.3)]) assert spe.size() == 4 assert spe.children == ( children[0].children[0], children[0].children[1], children[1] ) assert allclose(spe.weights[0], log(0.7) + log(0.4)) assert allclose(spe.weights[1], log(0.7) + log(0.6)) assert allclose(spe.weights[2], log(0.3))
def test_logpdf_mixture_real_continuous_discrete(): spe = X >> (.3*norm() | .7*poisson(mu=1)) assert allclose( spe.logpdf(X << {.5}), logsumexp([ log(.3) + spe.children[0].logpdf({X: 0.5}), log(.7) + spe.children[1].logpdf({X: 0.5}), ])) assert False, 'Invalid base measure addition'
def test_if_else_transform(): model = Sequence( Sample(X, norm(loc=0, scale=1)), IfElse(X > 0, Transform(Z, X**2), Otherwise, Transform(Z, X))).interpret() assert model.children[0].env == {X: X, Z: X**2} assert model.children[1].env == {X: X, Z: X} assert allclose(model.children[0].logprob(Z > 0), 0) assert allclose(model.children[1].logprob(Z > 0), -float('inf')) assert allclose(model.logprob(Z > 0), -log(2))
def test_product_disjoint_union_numerical(): X = Id('X') Y = Id('Y') Z = Id('Z') spe = ProductSPE([ X >> norm(loc=0, scale=1), Y >> norm(loc=0, scale=2), Z >> norm(loc=0, scale=2), ]) for event in [ (1 / X < 4) | (X > 7), (2 * X - 3 > 0) | (Log(Y) < 3), ((X > 0) & (Y < 1)) | ((X < 1) & (Y < 3)) | ~(X << {1, 2}), ((X > 0) & (Y < 1)) | ((X < 1) & (Y < 3)) | (Z < 0), ((X > 0) & (Y < 1)) | ((X < 1) & (Y < 3)) | (Z < 0) | ~(X << {1, 3}), ]: clauses = dnf_to_disjoint_union(event) logps = [spe.logprob(s) for s in clauses.subexprs] assert allclose(logsumexp(logps), spe.logprob(event))
def test_transform_sum(): X = Id('X') Z = Id('Z') Y = Id('Y') spe \ = 0.3*(X >> norm(loc=0, scale=1)) \ | 0.7*(X >> choice({'0': 0.4, '1': 0.6})) with pytest.raises(Exception): # Cannot transform Nominal variate. spe.transform(Z, X**2) spe \ = 0.3*(X >> norm(loc=0, scale=1)) \ | 0.7*(X >> poisson(mu=2)) spe = spe.transform(Z, X**2) assert spe.logprob(Z < 1) == spe.logprob(X**2 < 1) assert spe.children[0].env == spe.children[1].env spe = spe.transform(Y, Z/2) assert spe.children[0].env \ == spe.children[1].env \ == {X:X, Z:X**2, Y:Z/2}
def test_product_distribution_normal_gamma_basic(): X1 = Id('X1') X2 = Id('X2') X3 = Id('X3') X4 = Id('X4') children = [ ProductSPE([ X1 >> norm(loc=0, scale=1), X4 >> norm(loc=10, scale=1), ]), X2 >> gamma(loc=0, a=1), X3 >> norm(loc=2, scale=3) ] spe = ProductSPE(children) assert spe.children == ( children[0].children[0], children[0].children[1], children[1], children[2], ) assert spe.get_symbols() == frozenset([X1, X2, X3, X4]) assert spe.size() == 5 samples = spe.sample(2) assert len(samples) == 2 for sample in samples: assert len(sample) == 4 assert all([X in sample for X in (X1, X2, X3, X4)]) samples = spe.sample_subset((X1, X2), 10) assert len(samples) == 10 for sample in samples: assert len(sample) == 2 assert X1 in sample assert X2 in sample samples = spe.sample_func(lambda X1, X2, X3: (X1, (X2**2, X3)), 1) assert len(samples) == 1 assert len(samples[0]) == 2 assert len(samples[0][1]) == 2 with pytest.raises(ValueError): spe.sample_func(lambda X1, X5: X1 + X4, 1)
def test_simple_parse_real(): assert isinstance(.3 * bernoulli(p=.1), DistributionMix) a = .3 * bernoulli(p=.1) | .5 * norm() | .2 * poisson(mu=7) spe = a(X) assert isinstance(spe, SumSPE) assert allclose(spe.weights, [log(.3), log(.5), log(.2)]) assert isinstance(spe.children[0], DiscreteLeaf) assert isinstance(spe.children[1], ContinuousLeaf) assert isinstance(spe.children[2], DiscreteLeaf) assert spe.children[0].support == Interval(0, 1) assert spe.children[1].support == Interval(-oo, oo) assert spe.children[2].support == Interval(0, oo)
def test_sum_simplify_product_collapse(): A1 = Id('A') >> norm(loc=0, scale=1) A0 = Id('A') >> norm(loc=0, scale=1) B = Id('B') >> norm(loc=0, scale=1) B1 = Id('B') >> norm(loc=0, scale=1) B0 = Id('B') >> norm(loc=0, scale=1) C = Id('C') >> norm(loc=0, scale=1) C1 = Id('C') >> norm(loc=0, scale=1) D = Id('D') >> norm(loc=0, scale=1) spe = SumSPE([ ProductSPE([A1, B, C, D]), ProductSPE([A0, B1, C, D]), ProductSPE([A0, B0, C1, D]), ], [log(0.4), log(0.4), log(0.2)]) assert spe_simplify_sum(spe) == ProductSPE([A1, B, C, D])
def test_sum_normal_nominal(): X = Id('X') children = [ X >> norm(loc=0, scale=1), X >> choice({ 'low': Fraction(3, 10), 'high': Fraction(7, 10) }), ] weights = [log(Fraction(4, 7)), log(Fraction(3, 7))] spe = SumSPE(children, weights) assert allclose(spe.logprob(X < 0), log(Fraction(4, 7)) + log(Fraction(1, 2))) assert allclose(spe.logprob(X << {'low'}), log(Fraction(3, 7)) + log(Fraction(3, 10))) # The semantics of ~(X<<{'low'}) are (X << String and X != 'low') assert allclose(spe.logprob(~(X << {'low'})), spe.logprob((X << {'high'}))) assert allclose( spe.logprob((X << FiniteNominal(b=True)) & ~(X << {'low'})), spe.logprob((X << FiniteNominal(b=True)) & (X << {'high'}))) assert isinf_neg(spe.logprob((X < 0) & (X << {'low'}))) assert allclose(spe.logprob((X < 0) | (X << {'low'})), logsumexp([spe.logprob(X < 0), spe.logprob(X << {'low'})])) assert isinf_neg(spe.logprob(X << {'a'})) assert allclose(spe.logprob(~(X << {'a'})), spe.logprob(X << {'low', 'high'})) assert allclose(spe.logprob(X**2 < 9), log(Fraction(4, 7)) + spe.children[0].logprob(X**2 < 9)) spe_condition = spe.condition(X**2 < 9) assert isinstance(spe_condition, ContinuousLeaf) assert spe_condition.support == Interval.open(-3, 3) spe_condition = spe.condition((X**2 < 9) | X << {'low'}) assert isinstance(spe_condition, SumSPE) assert spe_condition.children[0].support == Interval.open(-3, 3) assert spe_condition.children[1].support == FiniteNominal('low', 'high') assert isinf_neg(spe_condition.children[1].logprob(X << {'high'})) assert spe_condition == spe.condition((X**2 < 9) | ~(X << {'high'})) assert allclose(spe.logprob((X < oo) | ~(X << {'1'})), 0)
def test_product_condition_or_probabilithy_zero(): X = Id('X') Y = Id('Y') spe = ProductSPE([X >> norm(loc=0, scale=1), Y >> gamma(a=1)]) # Condition on event which has probability zero. event = (X > 2) & (X < 2) with pytest.raises(ValueError): spe.condition(event) assert spe.logprob(event) == -float('inf') # Condition on event which has probability zero. event = (Y < 0) | (Y < -1) with pytest.raises(ValueError): spe.condition(event) assert spe.logprob(event) == -float('inf') # Condition on an event where one clause has probability # zero, yielding a single product. spe_condition = spe.condition((Y < 0) | ((Log(X) >= 0) & (1 <= Y))) assert isinstance(spe_condition, ProductSPE) assert spe_condition.children[0].symbol == X assert spe_condition.children[0].conditioned assert spe_condition.children[0].support == Interval(1, oo) assert spe_condition.children[1].symbol == Y assert spe_condition.children[1].conditioned assert spe_condition.children[0].support == Interval(1, oo) # We have (X < 2) & ~(1 < exp(|3X**2|) is empty. # Thus Y remains unconditioned, # and X is partitioned into (-oo, 0) U (0, oo) with equal weight. event = (Exp(abs(3 * X**2)) > 1) | ((Log(Y) < 0.5) & (X < 2)) spe_condition = spe.condition(event) # # The most concise representation of spe_condition is: # (Product (Sum [.5 .5] X|X<0 X|X>0) Y) assert isinstance(spe_condition, ProductSPE) assert isinstance(spe_condition.children[0], SumSPE) assert spe_condition.children[0].weights == (-log(2), -log(2)) assert spe_condition.children[0].children[0].conditioned assert spe_condition.children[0].children[1].conditioned assert spe_condition.children[0].children[0].support \ in [Interval.Ropen(-oo, 0), Interval.Lopen(0, oo)] assert spe_condition.children[0].children[1].support \ in [Interval.Ropen(-oo, 0), Interval.Lopen(0, oo)] assert spe_condition.children[0].children[0].support \ != spe_condition.children[0].children[1].support assert spe_condition.children[1].symbol == Y assert not spe_condition.children[1].conditioned
def test_transform_product(): X = Id('X') Y = Id('Y') W = Id('W') Z = Id('Z') V = Id('V') spe \ = (X >> norm(loc=0, scale=1)) \ & (Y >> poisson(mu=10)) with pytest.raises(Exception): # Cannot use symbols from different transforms. spe.transform(W, (X > 0) | (Y << {'0'})) spe = spe.transform(W, (X**2 - 3*X)**(1,10)) spe = spe.transform(Z, (W > 0) | (X**3 < 1)) spe = spe.transform(V, Y/10) assert allclose( spe.logprob(W>1), spe.logprob((X**2 - 3*X)**(1,10) > 1)) with pytest.raises(Exception): spe.tarnsform(Id('R'), (V>1) | (W < 0))
def test_logpdf_lexicographic_either(): spe = .75*(X >> norm() & Y >> atomic(loc=0) & Z >> discrete({1:.1, 2:.9})) \ | .25*(X >> atomic(loc=0) & Y >> norm() & Z >> norm()) # Lexicographic, Branch 1 assignment = {X:0, Y:0, Z:2} assert allclose( spe.logpdf(assignment), log(.75) + norm().dist.logpdf(0) + log(1) + log(.9)) assert isinstance(spe.constrain(assignment), ProductSPE) # Lexicographic, Branch 2 assignment = {X:0, Y:0, Z:0} assert allclose( spe.logpdf(assignment), log(.25) + log(1) + norm().dist.logpdf(0) + norm().dist.logpdf(0)) assert isinstance(spe.constrain(assignment), ProductSPE)
def test_sum_normal_gamma_exposed(): X = Id('X') W = Id('W') weights = W >> choice({ '0': Fraction(2, 3), '1': Fraction(1, 3), }) children = { '0': X >> norm(loc=0, scale=1), '1': X >> gamma(loc=0, a=1), } spe = ExposedSumSPE(children, weights) assert spe.logprob(W << {'0'}) == log(Fraction(2, 3)) assert spe.logprob(W << {'1'}) == log(Fraction(1, 3)) assert allclose(spe.logprob((W << {'0'}) | (W << {'1'})), 0) assert spe.logprob((W << {'0'}) & (W << {'1'})) == -float('inf') assert allclose(spe.logprob((W << {'0', '1'}) & (X < 1)), spe.logprob(X < 1)) assert allclose(spe.logprob((W << {'0'}) & (X < 1)), spe.weights[0] + spe.children[0].logprob(X < 1)) spe_condition = spe.condition((W << {'1'}) | (W << {'0'})) assert isinstance(spe_condition, SumSPE) assert len(spe_condition.weights) == 2 assert \ allclose(spe_condition.weights[0], log(Fraction(2,3))) \ and allclose(spe_condition.weights[0], log(Fraction(2,3))) \ or \ allclose(spe_condition.weights[1], log(Fraction(2,3))) \ and allclose(spe_condition.weights[0], log(Fraction(2,3)) ) spe_condition = spe.condition((W << {'1'})) assert isinstance(spe_condition, ProductSPE) assert isinstance(spe_condition.children[0], NominalLeaf) assert isinstance(spe_condition.children[1], ContinuousLeaf) assert spe_condition.logprob(X < 5) == spe.children[1].logprob(X < 5)
def test_product_inclusion_exclusion_basic(): X = Id('X') Y = Id('Y') spe = ProductSPE([X >> norm(loc=0, scale=1), Y >> gamma(a=1)]) a = spe.logprob(X > 0.1) b = spe.logprob(Y < 0.5) c = spe.logprob((X > 0.1) & (Y < 0.5)) d = spe.logprob((X > 0.1) | (Y < 0.5)) e = spe.logprob((X > 0.1) | ((Y < 0.5) & ~(X > 0.1))) f = spe.logprob(~(X > 0.1)) g = spe.logprob((Y < 0.5) & ~(X > 0.1)) assert allclose(a, spe.children[0].logprob(X > 0.1)) assert allclose(b, spe.children[1].logprob(Y < 0.5)) # Pr[A and B] = Pr[A] * Pr[B] assert allclose(c, a + b) # Pr[A or B] = Pr[A] + Pr[B] - Pr[AB] assert allclose(d, logdiffexp(logsumexp([a, b]), c)) # Pr[A or B] = Pr[A] + Pr[B & ~A] assert allclose(e, d) # Pr[A and B] = Pr[A] * Pr[B] assert allclose(g, b + f) # Pr[A or (B & ~A)] = Pr[A] + Pr[B & ~A] assert allclose(e, logsumexp([a, b + f])) # (A => B) => Pr[A or B] = Pr[B] # i.e.,k (X > 1) => (X > 0). assert allclose(spe.logprob((X > 0) | (X > 1)), spe.logprob(X > 0)) # Positive probability event. # Pr[A] = 1 - Pr[~A] event = ((0 < X) < 0.5) | ((Y < 0) & (1 < X)) assert allclose(spe.logprob(event), logdiffexp(0, spe.logprob(~event))) # Probability zero event. event = ((0 < X) < 0.5) & ((Y < 0) | (1 < X)) assert isinf_neg(spe.logprob(event)) assert allclose(spe.logprob(~event), 0)
def test_sum_simplify_leaf(): Xd0 = Id('X') >> norm(loc=0, scale=1) Xd1 = Id('X') >> norm(loc=0, scale=2) Xd2 = Id('X') >> norm(loc=0, scale=3) spe = SumSPE([Xd0, Xd1, Xd2], [log(0.5), log(0.1), log(.4)]) assert spe.size() == 4 assert spe_simplify_sum(spe) == spe Xd0 = Id('X') >> norm(loc=0, scale=1) Xd1 = Id('X') >> norm(loc=0, scale=1) Xd2 = Id('X') >> norm(loc=0, scale=1) spe = SumSPE([Xd0, Xd1, Xd2], [log(0.5), log(0.1), log(.4)]) assert spe_simplify_sum(spe) == Xd0 Xd3 = Id('X') >> norm(loc=0, scale=2) spe = SumSPE([Xd0, Xd3, Xd1, Xd3], [log(0.5), log(0.1), log(.3), log(.1)]) spe_simplified = spe_simplify_sum(spe) assert len(spe_simplified.children) == 2 assert spe_simplified.children[0] == Xd0 assert spe_simplified.children[1] == Xd3 assert allclose(spe_simplified.weights[0], log(0.8)) assert allclose(spe_simplified.weights[1], log(0.2))
event_2 = condition_2 & ~condition_1 event_3 = condition_3 & ~condition_2 & ~condition_1 event_4 = condition_4 & ~condition_3 & ~condition_2 & ~condition_1 assert allclose(student.prob(event_1), 0.5 * 0.99) assert allclose(student.prob(event_2), 0.5 * 0.01) assert allclose(student.prob(event_3), 0.5 * 0.99) assert allclose(student.prob(event_4), 0.5 * 0.01) A = Id('A') B = Id('B') C = Id('C') D = Id('D') spe_abcd \ = norm(loc=0, scale=1)(A) \ & norm(loc=0, scale=1)(B) \ & norm(loc=0, scale=1)(C) \ & norm(loc=0, scale=1)(D) def test_product_condition_simplify_a(): spe = spe_abcd.condition((A > 1) | (A < -1)) assert isinstance(spe, ProductSPE) assert spe_abcd.children[1] in spe.children assert spe_abcd.children[2] in spe.children assert spe_abcd.children[3] in spe.children idx_sum = [i for i, c in enumerate(spe.children) if isinstance(c, SumSPE)] assert len(idx_sum) == 1 assert allclose(spe.children[idx_sum[0]].weights[0], -log(2)) assert allclose(spe.children[idx_sum[0]].weights[1], -log(2))
def test_product_condition_basic(): X = Id('X') Y = Id('Y') spe = ProductSPE([X >> norm(loc=0, scale=1), Y >> gamma(a=1)]) # Condition on (X > 0) and ((X > 0) | (Y < 0)) # where the second clause reduces to first as Y < 0 # has probability zero. for event in [(X > 0), (X > 0) | (Y < 0)]: dX = spe.condition(event) assert isinstance(dX, ProductSPE) assert dX.children[0].symbol == Id('X') assert dX.children[0].conditioned assert dX.children[0].support == Interval.open(0, oo) assert dX.children[1].symbol == Id('Y') assert not dX.children[1].conditioned assert dX.children[1].Fl == 0 assert dX.children[1].Fu == 1 # Condition on (Y < 0.5) dY = spe.condition(Y < 0.5) assert isinstance(dY, ProductSPE) assert dY.children[0].symbol == Id('X') assert not dY.children[0].conditioned assert dY.children[1].symbol == Id('Y') assert dY.children[1].conditioned assert dY.children[1].support == Interval.Ropen(0, 0.5) # Condition on (X > 0) & (Y < 0.5) dXY_and = spe.condition((X > 0) & (Y < 0.5)) assert isinstance(dXY_and, ProductSPE) assert dXY_and.children[0].symbol == Id('X') assert dXY_and.children[0].conditioned assert dXY_and.children[0].support == Interval.open(0, oo) assert dXY_and.children[1].symbol == Id('Y') assert dXY_and.children[1].conditioned assert dXY_and.children[1].support == Interval.Ropen(0, 0.5) # Condition on (X > 0) | (Y < 0.5) event = (X > 0) | (Y < 0.5) dXY_or = spe.condition((X > 0) | (Y < 0.5)) assert isinstance(dXY_or, SumSPE) assert all(isinstance(d, ProductSPE) for d in dXY_or.children) assert allclose(dXY_or.logprob(X > 0), dXY_or.weights[0]) samples = dXY_or.sample(100, prng=numpy.random.RandomState(1)) assert all(event.evaluate(sample) for sample in samples) # Condition on a disjoint union with one term in second clause. dXY_disjoint_one = spe.condition((X > 0) & (Y < 0.5) | (X <= 0)) assert isinstance(dXY_disjoint_one, SumSPE) component_0 = dXY_disjoint_one.children[0] assert component_0.children[0].symbol == Id('X') assert component_0.children[0].conditioned assert component_0.children[0].support == Interval.open(0, oo) assert component_0.children[1].symbol == Id('Y') assert component_0.children[1].conditioned assert component_0.children[1].support == Interval.Ropen(0, 0.5) component_1 = dXY_disjoint_one.children[1] assert component_1.children[0].symbol == Id('X') assert component_1.children[0].conditioned assert component_1.children[0].support == Interval(-oo, 0) assert component_1.children[1].symbol == Id('Y') assert not component_1.children[1].conditioned # Condition on a disjoint union with two terms in each clause dXY_disjoint_two = spe.condition((X > 0) & (Y < 0.5) | ((X <= 0) & ~(Y < 3))) assert isinstance(dXY_disjoint_two, SumSPE) component_0 = dXY_disjoint_two.children[0] assert component_0.children[0].symbol == Id('X') assert component_0.children[0].conditioned assert component_0.children[0].support == Interval.open(0, oo) assert component_0.children[1].symbol == Id('Y') assert component_0.children[1].conditioned assert component_0.children[1].support == Interval.Ropen(0, 0.5) component_1 = dXY_disjoint_two.children[1] assert component_1.children[0].symbol == Id('X') assert component_1.children[0].conditioned assert component_1.children[0].support == Interval(-oo, 0) assert component_1.children[1].symbol == Id('Y') assert component_1.children[1].conditioned assert component_1.children[1].support == Interval(3, oo) # Some various conditioning. spe.condition((X > 0) & (Y < 0.5) | ((X <= 1) | ~(Y < 3))) spe.condition((X > 0) & (Y < 0.5) | ((X <= 1) & (Y < 3)))
def test_cache_simple_leaf(): spe = .5 * (W >> norm(loc=0, scale=1)) | .5 * (W >> norm(loc=0, scale=1)) assert spe.children[0] is not spe.children[1] spe_cached = spe_cache_duplicate_subtrees(spe, {}) assert spe_cached.children[0] is spe_cached.children[1]