def test_event_to_disjoint_union_five(): # This test case can be visualized as follows: # fig, ax = plt.subplots() # ax.add_patch(Rectangle((2, 5), 6, 3, fill=False)) # ax.add_patch(Rectangle((3, 7), 1, 4, fill=False)) # ax.add_patch(Rectangle((3.5, 2), 2, 7, fill=False)) # ax.add_patch(Rectangle((4.5, 1), 2, 5, fill=False)) # ax.add_patch(Rectangle((5, 7), 2, 3, fill=False)) # # // ((2 < X < 8) & (5 < Y < 8)) # // ((3 < X < 4) & (8 <= Y < 11)) # // ((3.5 < X < 5.5) & (2 < Y <= 5)) # // ((5.5 <= X < 6.5) & (1 < Y <= 5)) # // ((4.5 < X < 5.5) & (1 < Y <= 2)) # // ((5 < X < 7) & (8 <= Y < 10)) # # fig, ax = plt.subplots() # ax.add_patch(Rectangle((2, 5), 6, 3, fill=False)) # ax.add_patch(Rectangle((3, 8), 1, 3, fill=False)) # ax.add_patch(Rectangle((3.5, 2), 2, 3, fill=False)) # ax.add_patch(Rectangle((5.5, 1), 1, 4, fill=False)) # ax.add_patch(Rectangle((4.5, 1), 1, 1, fill=False)) # ax.add_patch(Rectangle((5, 8), 2, 2, fill=False)) X = Id('X') Y = Id('Y') E1 = ((2 < X) < 8) & ((5 < Y) < 8) E2 = ((3 < X) < 4) & ((7 < Y) < 11) E3 = ((3.5 < X) < 5.5) & ((2 < Y) < 7) E4 = ((4.5 < X) < 6.5) & ((1 < Y) < 6) E5 = ((5 < X) < 7) & ((7 < Y) < 10) event = E1 | E2 | E3 | E4 | E5 dnf_to_disjoint_union(event)
def test_dnf_non_disjoint_clauses(): X = Id('X') Y = Id('Y') Z = Id('Z') event = (X > 0) | (Y < 0) overlaps = dnf_non_disjoint_clauses(event) assert overlaps == {1: [0]} event = (X > 0) | ((X < 0) & (Y < 0)) overlaps = dnf_non_disjoint_clauses(event) assert not overlaps event = ((X > 0) & (Z < 0)) | ((X < 0) & (Y < 0)) | ((X > 1)) overlaps = dnf_non_disjoint_clauses(event) assert overlaps == {2: [0]} event = ((X > 0) & (Z < 0)) | ((X < 0) & (Y < 0)) | ((X > 1) & (Z > 1)) overlaps = dnf_non_disjoint_clauses(event) assert not overlaps event = ((X**2 < 9)) | (1 < X) overlaps = dnf_non_disjoint_clauses(event) assert overlaps == {1: [0]} event = ((X**2 < 9) & (0 < X < 1)) | (1 < X) overlaps = dnf_non_disjoint_clauses(event) assert not overlaps
def test_mutual_information_four_clusters(memo): X = Id('X') Y = Id('Y') spe \ = 0.25*(X >> norm(loc=0, scale=0.5) & Y >> norm(loc=0, scale=0.5)) \ | 0.25*(X >> norm(loc=5, scale=0.5) & Y >> norm(loc=0, scale=0.5)) \ | 0.25*(X >> norm(loc=0, scale=0.5) & Y >> norm(loc=5, scale=0.5)) \ | 0.25*(X >> norm(loc=5, scale=0.5) & Y >> norm(loc=5, scale=0.5)) \ A = X > 2 B = Y > 2 samples = spe.sample(100, prng) mi = spe.mutual_information(A, B, memo=memo) assert allclose(mi, 0) check_mi_properties(spe, A, B, memo) event = ((X>2) & (Y<2) | ((X<2) & (Y>2))) spe_condition = spe.condition(event) samples = spe_condition.sample(100, prng) assert all(event.evaluate(sample) for sample in samples) mi = spe_condition.mutual_information(X > 2, Y > 2) assert allclose(mi, log(2)) check_mi_properties(spe, (X>1) | (Y<1), (Y>2), memo) check_mi_properties(spe, (X>1) | (Y<1), (X>1.5) & (Y>2), memo)
def test_sum_simplify_nested_sum_2(): X = Id('X') W = Id('W') children = [ SumSPE([ (X >> norm(loc=0, scale=1)) & (W >> norm(loc=0, scale=2)), (X >> norm(loc=0, scale=2)) & (W >> norm(loc=0, scale=1))], [log(0.9), log(0.1)]), (X >> norm(loc=0, scale=4)) & (W >> norm(loc=0, scale=10)), SumSPE([ (X >> norm(loc=0, scale=1)) & (W >> norm(loc=0, scale=2)), (X >> norm(loc=0, scale=2)) & (W >> norm(loc=0, scale=1)), (X >> norm(loc=0, scale=8)) & (W >> norm(loc=0, scale=3)),], [log(0.4), log(0.3), log(0.3)]), ] spe = SumSPE(children, [log(0.4), log(0.4), log(0.2)]) assert spe.size() == 19 assert spe.children == ( children[0].children[0], # 2 leaves children[0].children[1], # 2 leaves children[1], # 2 leaf children[2].children[0], # 2 leaves children[2].children[1], # 2 leaves children[2].children[2], # 2 leaves ) assert allclose(spe.weights[0], log(0.4) + log(0.9)) assert allclose(spe.weights[1], log(0.4) + log(0.1)) assert allclose(spe.weights[2], log(0.4)) assert allclose(spe.weights[3], log(0.2) + log(0.4)) assert allclose(spe.weights[4], log(0.2) + log(0.3)) assert allclose(spe.weights[5], log(0.2) + log(0.3))
def test_transform_real_leaf_logprob(): X = Id('X') Y = Id('Y') Z = Id('Z') spe = (X >> norm(loc=0, scale=1)) with pytest.raises(AssertionError): spe.transform(Z, Y**2) with pytest.raises(AssertionError): spe.transform(X, X**2) spe = spe.transform(Z, X**2) assert spe.env == {X:X, Z:X**2} assert spe.get_symbols() == {X, Z} assert spe.logprob(Z < 1) == spe.logprob(X**2 < 1) assert spe.logprob((Z < 1) | ((X + 1) < 3)) \ == spe.logprob((X**2 < 1) | ((X+1) < 3)) spe = spe.transform(Y, 2*Z) assert spe.env == {X:X, Z:X**2, Y:2*Z} assert spe.logprob(Y**(1,3) < 10) \ == spe.logprob((2*Z)**(1,3) < 10) \ == spe.logprob((2*(X**2))**(1,3) < 10) \ W = Id('W') spe = spe.transform(W, X > 1) assert allclose(spe.logprob(W), spe.logprob(X > 1))
def test_event_to_disjoint_union_nominal(): X = Id('X') Y = Id('Y') event = (X << {'1'}) | (X << {'1', '2'}) assert dnf_to_disjoint_union(event) == X << {'1', '2'} event = (X << {'1'}) | ~(Y << {'1'}) assert dnf_to_disjoint_union(event) == EventOr([ (X << {'1'}), ~(Y << {'1'}) & ~(X << {'1'}) ])
def test_transform_real_leaf_sample(): X = Id('X') Z = Id('Z') Y = Id('Y') spe = (X >> poisson(loc=-1, mu=1)) spe = spe.transform(Z, X+1) spe = spe.transform(Y, Z-1) samples = spe.sample(100) assert any(s[X] == -1 for s in samples) assert all(0 <= s[Z] for s in samples) assert all(s[Y] == s[X] for s in samples) assert all(spe.sample_func(lambda X,Y,Z: X-Y+Z==Z, 100)) assert all(set(s) == {X,Y} for s in spe.sample_subset([X, Y], 100))
def test_event_to_disjiont_union_numerical(): X = Id('X') Y = Id('Y') Z = Id('Z') for event in [ (X > 0) | (X < 3), (X > 0) | (Y < 3), ((X > 0) & (Y < 1)) | ((X < 1) & (Y < 3)) | (Z < 0), ((X > 0) & (Y < 1)) | ((X < 1) & (Y < 3)) | (Z < 0) | ~(X << {1, 3}), ]: event_dnf = dnf_to_disjoint_union(event) assert not dnf_non_disjoint_clauses(event_dnf)
def test_sum_simplify_product_complex(): A1 = Id('A') >> norm(loc=0, scale=1) A0 = Id('A') >> norm(loc=0, scale=2) B = Id('B') >> norm(loc=0, scale=1) B1 = Id('B') >> norm(loc=0, scale=2) B0 = Id('B') >> norm(loc=0, scale=3) C = Id('C') >> norm(loc=0, scale=1) C1 = Id('C') >> norm(loc=0, scale=2) D = Id('D') >> norm(loc=0, scale=1) spe = SumSPE([ ProductSPE([A1, B, C, D]), ProductSPE([A0, B1, C, D]), ProductSPE([A0, B0, C1, D]), ], [log(0.4), log(0.4), log(0.2)]) spe_simplified = spe_simplify_sum(spe) assert isinstance(spe_simplified, ProductSPE) assert isinstance(spe_simplified.children[0], SumSPE) assert spe_simplified.children[1] == D ssc0 = spe_simplified.children[0] assert isinstance(ssc0.children[1], ProductSPE) assert ssc0.children[1].children == (A0, B0, C1) assert isinstance(ssc0.children[0], ProductSPE) assert ssc0.children[0].children[1] == C ssc0c0 = ssc0.children[0].children[0] assert isinstance(ssc0c0, SumSPE) assert isinstance(ssc0c0.children[0], ProductSPE) assert isinstance(ssc0c0.children[1], ProductSPE) assert ssc0c0.children[0].children == (A1, B) assert ssc0c0.children[1].children == (A0, B1)
def test_product_condition_or_probabilithy_zero(): X = Id('X') Y = Id('Y') spe = ProductSPE([X >> norm(loc=0, scale=1), Y >> gamma(a=1)]) # Condition on event which has probability zero. event = (X > 2) & (X < 2) with pytest.raises(ValueError): spe.condition(event) assert spe.logprob(event) == -float('inf') # Condition on event which has probability zero. event = (Y < 0) | (Y < -1) with pytest.raises(ValueError): spe.condition(event) assert spe.logprob(event) == -float('inf') # Condition on an event where one clause has probability # zero, yielding a single product. spe_condition = spe.condition((Y < 0) | ((Log(X) >= 0) & (1 <= Y))) assert isinstance(spe_condition, ProductSPE) assert spe_condition.children[0].symbol == X assert spe_condition.children[0].conditioned assert spe_condition.children[0].support == Interval(1, oo) assert spe_condition.children[1].symbol == Y assert spe_condition.children[1].conditioned assert spe_condition.children[0].support == Interval(1, oo) # We have (X < 2) & ~(1 < exp(|3X**2|) is empty. # Thus Y remains unconditioned, # and X is partitioned into (-oo, 0) U (0, oo) with equal weight. event = (Exp(abs(3 * X**2)) > 1) | ((Log(Y) < 0.5) & (X < 2)) spe_condition = spe.condition(event) # # The most concise representation of spe_condition is: # (Product (Sum [.5 .5] X|X<0 X|X>0) Y) assert isinstance(spe_condition, ProductSPE) assert isinstance(spe_condition.children[0], SumSPE) assert spe_condition.children[0].weights == (-log(2), -log(2)) assert spe_condition.children[0].children[0].conditioned assert spe_condition.children[0].children[1].conditioned assert spe_condition.children[0].children[0].support \ in [Interval.Ropen(-oo, 0), Interval.Lopen(0, oo)] assert spe_condition.children[0].children[1].support \ in [Interval.Ropen(-oo, 0), Interval.Lopen(0, oo)] assert spe_condition.children[0].children[0].support \ != spe_condition.children[0].children[1].support assert spe_condition.children[1].symbol == Y assert not spe_condition.children[1].conditioned
def test_substitute_transitive(case): (expr, env, expr_prime) = case if len(env) == 1: [(s0, s1)] = env.items() s2 = Id('s2') env_prime = {s0: s2, s2: s1} assert expr.substitute(env_prime) == expr_prime
def test_sum_normal_gamma(): X = Id('X') weights = [log(Fraction(2, 3)), log(Fraction(1, 3))] spe = SumSPE([ X >> norm(loc=0, scale=1), X >> gamma(loc=0, a=1), ], weights) assert spe.logprob(X > 0) == logsumexp([ spe.weights[0] + spe.children[0].logprob(X > 0), spe.weights[1] + spe.children[1].logprob(X > 0), ]) assert spe.logprob(X < 0) == log(Fraction(2, 3)) + log(Fraction(1, 2)) samples = spe.sample(100, prng=numpy.random.RandomState(1)) assert all(s[X] for s in samples) spe.sample_func(lambda X: abs(X**3), 100) with pytest.raises(ValueError): spe.sample_func(lambda Y: abs(X**3), 100) spe_condition = spe.condition(X < 0) assert isinstance(spe_condition, ContinuousLeaf) assert spe_condition.conditioned assert spe_condition.logprob(X < 0) == 0 samples = spe_condition.sample(100) assert all(s[X] < 0 for s in samples) assert spe.logprob(X < 0) == logsumexp([ spe.weights[0] + spe.children[0].logprob(X < 0), spe.weights[1] + spe.children[1].logprob(X < 0), ])
def test_transform_sum(): X = Id('X') Z = Id('Z') Y = Id('Y') spe \ = 0.3*(X >> norm(loc=0, scale=1)) \ | 0.7*(X >> choice({'0': 0.4, '1': 0.6})) with pytest.raises(Exception): # Cannot transform Nominal variate. spe.transform(Z, X**2) spe \ = 0.3*(X >> norm(loc=0, scale=1)) \ | 0.7*(X >> poisson(mu=2)) spe = spe.transform(Z, X**2) assert spe.logprob(Z < 1) == spe.logprob(X**2 < 1) assert spe.children[0].env == spe.children[1].env spe = spe.transform(Y, Z/2) assert spe.children[0].env \ == spe.children[1].env \ == {X:X, Z:X**2, Y:Z/2}
def test_product_disjoint_union_numerical(): X = Id('X') Y = Id('Y') Z = Id('Z') spe = ProductSPE([ X >> norm(loc=0, scale=1), Y >> norm(loc=0, scale=2), Z >> norm(loc=0, scale=2), ]) for event in [ (1 / X < 4) | (X > 7), (2 * X - 3 > 0) | (Log(Y) < 3), ((X > 0) & (Y < 1)) | ((X < 1) & (Y < 3)) | ~(X << {1, 2}), ((X > 0) & (Y < 1)) | ((X < 1) & (Y < 3)) | (Z < 0), ((X > 0) & (Y < 1)) | ((X < 1) & (Y < 3)) | (Z < 0) | ~(X << {1, 3}), ]: clauses = dnf_to_disjoint_union(event) logps = [spe.logprob(s) for s in clauses.subexprs] assert allclose(logsumexp(logps), spe.logprob(event))
def test_product_distribution_normal_gamma_basic(): X1 = Id('X1') X2 = Id('X2') X3 = Id('X3') X4 = Id('X4') children = [ ProductSPE([ X1 >> norm(loc=0, scale=1), X4 >> norm(loc=10, scale=1), ]), X2 >> gamma(loc=0, a=1), X3 >> norm(loc=2, scale=3) ] spe = ProductSPE(children) assert spe.children == ( children[0].children[0], children[0].children[1], children[1], children[2], ) assert spe.get_symbols() == frozenset([X1, X2, X3, X4]) assert spe.size() == 5 samples = spe.sample(2) assert len(samples) == 2 for sample in samples: assert len(sample) == 4 assert all([X in sample for X in (X1, X2, X3, X4)]) samples = spe.sample_subset((X1, X2), 10) assert len(samples) == 10 for sample in samples: assert len(sample) == 2 assert X1 in sample assert X2 in sample samples = spe.sample_func(lambda X1, X2, X3: (X1, (X2**2, X3)), 1) assert len(samples) == 1 assert len(samples[0]) == 2 assert len(samples[0][1]) == 2 with pytest.raises(ValueError): spe.sample_func(lambda X1, X5: X1 + X4, 1)
def test_nominal_distribution(): X = Id('X') spe = X >> choice({ 'a': Fraction(1, 5), 'b': Fraction(1, 5), 'c': Fraction(3, 5), }) assert allclose(spe.logprob(X << {'a'}), log(Fraction(1, 5))) assert allclose(spe.logprob(X << {'b'}), log(Fraction(1, 5))) assert allclose(spe.logprob(X << {'a', 'c'}), log(Fraction(4, 5))) assert allclose(spe.logprob((X << {'a'}) & ~(X << {'b'})), log(Fraction(1, 5))) assert allclose(spe.logprob((X << {'a', 'b'}) & ~(X << {'b'})), log(Fraction(1, 5))) assert spe.logprob((X << {'d'})) == -float('inf') assert spe.logprob((X << ())) == -float('inf') samples = spe.sample(100) assert all(s[X] in spe.support for s in samples) samples = spe.sample_subset([X], 100) assert all(len(s) == 1 and s[X] in spe.support for s in samples) with pytest.raises(Exception): spe.sample_subset(['f'], 100) predicate = lambda X: (X in {'a', 'b'}) or X in {'c'} samples = spe.sample_func(predicate, 100) assert all(samples) predicate = lambda X: (not (X in {'a', 'b'})) and (not (X in {'c'})) samples = spe.sample_func(predicate, 100) assert not any(samples) func = lambda X: 1 if X in {'a'} else None samples = spe.sample_func(func, 100, prng=numpy.random.RandomState(1)) assert sum(1 for s in samples if s == 1) > 12 assert sum(1 for s in samples if s is None) > 70 with pytest.raises(ValueError): spe.sample_func(lambda Y: Y, 100) spe_condition = spe.condition(X << {'a', 'b'}) assert spe_condition.support == FiniteNominal('a', 'b', 'c') assert allclose(spe_condition.logprob(X << {'a'}), -log(2)) assert allclose(spe_condition.logprob(X << {'b'}), -log(2)) assert spe_condition.logprob(X << {'c'}) == -float('inf') assert isinf_neg(spe_condition.logprob(X**2 << {1})) with pytest.raises(ValueError): spe.condition(X << {'python'}) assert spe.condition(~(X << {'python'})) == spe
def test_product_inclusion_exclusion_basic(): X = Id('X') Y = Id('Y') spe = ProductSPE([X >> norm(loc=0, scale=1), Y >> gamma(a=1)]) a = spe.logprob(X > 0.1) b = spe.logprob(Y < 0.5) c = spe.logprob((X > 0.1) & (Y < 0.5)) d = spe.logprob((X > 0.1) | (Y < 0.5)) e = spe.logprob((X > 0.1) | ((Y < 0.5) & ~(X > 0.1))) f = spe.logprob(~(X > 0.1)) g = spe.logprob((Y < 0.5) & ~(X > 0.1)) assert allclose(a, spe.children[0].logprob(X > 0.1)) assert allclose(b, spe.children[1].logprob(Y < 0.5)) # Pr[A and B] = Pr[A] * Pr[B] assert allclose(c, a + b) # Pr[A or B] = Pr[A] + Pr[B] - Pr[AB] assert allclose(d, logdiffexp(logsumexp([a, b]), c)) # Pr[A or B] = Pr[A] + Pr[B & ~A] assert allclose(e, d) # Pr[A and B] = Pr[A] * Pr[B] assert allclose(g, b + f) # Pr[A or (B & ~A)] = Pr[A] + Pr[B & ~A] assert allclose(e, logsumexp([a, b + f])) # (A => B) => Pr[A or B] = Pr[B] # i.e.,k (X > 1) => (X > 0). assert allclose(spe.logprob((X > 0) | (X > 1)), spe.logprob(X > 0)) # Positive probability event. # Pr[A] = 1 - Pr[~A] event = ((0 < X) < 0.5) | ((Y < 0) & (1 < X)) assert allclose(spe.logprob(event), logdiffexp(0, spe.logprob(~event))) # Probability zero event. event = ((0 < X) < 0.5) & ((Y < 0) | (1 < X)) assert isinf_neg(spe.logprob(event)) assert allclose(spe.logprob(~event), 0)
def test_sum_normal_gamma_exposed(): X = Id('X') W = Id('W') weights = W >> choice({ '0': Fraction(2, 3), '1': Fraction(1, 3), }) children = { '0': X >> norm(loc=0, scale=1), '1': X >> gamma(loc=0, a=1), } spe = ExposedSumSPE(children, weights) assert spe.logprob(W << {'0'}) == log(Fraction(2, 3)) assert spe.logprob(W << {'1'}) == log(Fraction(1, 3)) assert allclose(spe.logprob((W << {'0'}) | (W << {'1'})), 0) assert spe.logprob((W << {'0'}) & (W << {'1'})) == -float('inf') assert allclose(spe.logprob((W << {'0', '1'}) & (X < 1)), spe.logprob(X < 1)) assert allclose(spe.logprob((W << {'0'}) & (X < 1)), spe.weights[0] + spe.children[0].logprob(X < 1)) spe_condition = spe.condition((W << {'1'}) | (W << {'0'})) assert isinstance(spe_condition, SumSPE) assert len(spe_condition.weights) == 2 assert \ allclose(spe_condition.weights[0], log(Fraction(2,3))) \ and allclose(spe_condition.weights[0], log(Fraction(2,3))) \ or \ allclose(spe_condition.weights[1], log(Fraction(2,3))) \ and allclose(spe_condition.weights[0], log(Fraction(2,3)) ) spe_condition = spe.condition((W << {'1'})) assert isinstance(spe_condition, ProductSPE) assert isinstance(spe_condition.children[0], NominalLeaf) assert isinstance(spe_condition.children[1], ContinuousLeaf) assert spe_condition.logprob(X < 5) == spe.children[1].logprob(X < 5)
def test_product_disjoint_union_nominal(): N = Id('N') P = Id('P') nationality = N >> choice({'India': 0.5, 'USA': 0.5}) perfect = P >> choice({'Imperfect': 0.99, 'Perfect': 0.01}) student = nationality & perfect condition_1 = (N << {'India'}) & (P << {'Imperfect'}) condition_2 = (N << {'India'}) & (P << {'Perfect'}) condition_3 = (N << {'USA'}) & (P << {'Imperfect'}) condition_4 = (N << {'USA'}) & (P << {'Perfect'}) event_1 = condition_1 event_2 = condition_2 & ~condition_1 event_3 = condition_3 & ~condition_2 & ~condition_1 event_4 = condition_4 & ~condition_3 & ~condition_2 & ~condition_1 assert allclose(student.prob(event_1), 0.5 * 0.99) assert allclose(student.prob(event_2), 0.5 * 0.01) assert allclose(student.prob(event_3), 0.5 * 0.99) assert allclose(student.prob(event_4), 0.5 * 0.01)
def test_sum_simplify_product_collapse(): A1 = Id('A') >> norm(loc=0, scale=1) A0 = Id('A') >> norm(loc=0, scale=1) B = Id('B') >> norm(loc=0, scale=1) B1 = Id('B') >> norm(loc=0, scale=1) B0 = Id('B') >> norm(loc=0, scale=1) C = Id('C') >> norm(loc=0, scale=1) C1 = Id('C') >> norm(loc=0, scale=1) D = Id('D') >> norm(loc=0, scale=1) spe = SumSPE([ ProductSPE([A1, B, C, D]), ProductSPE([A0, B1, C, D]), ProductSPE([A0, B0, C1, D]), ], [log(0.4), log(0.4), log(0.2)]) assert spe_simplify_sum(spe) == ProductSPE([A1, B, C, D])
def test_sum_normal_nominal(): X = Id('X') children = [ X >> norm(loc=0, scale=1), X >> choice({ 'low': Fraction(3, 10), 'high': Fraction(7, 10) }), ] weights = [log(Fraction(4, 7)), log(Fraction(3, 7))] spe = SumSPE(children, weights) assert allclose(spe.logprob(X < 0), log(Fraction(4, 7)) + log(Fraction(1, 2))) assert allclose(spe.logprob(X << {'low'}), log(Fraction(3, 7)) + log(Fraction(3, 10))) # The semantics of ~(X<<{'low'}) are (X << String and X != 'low') assert allclose(spe.logprob(~(X << {'low'})), spe.logprob((X << {'high'}))) assert allclose( spe.logprob((X << FiniteNominal(b=True)) & ~(X << {'low'})), spe.logprob((X << FiniteNominal(b=True)) & (X << {'high'}))) assert isinf_neg(spe.logprob((X < 0) & (X << {'low'}))) assert allclose(spe.logprob((X < 0) | (X << {'low'})), logsumexp([spe.logprob(X < 0), spe.logprob(X << {'low'})])) assert isinf_neg(spe.logprob(X << {'a'})) assert allclose(spe.logprob(~(X << {'a'})), spe.logprob(X << {'low', 'high'})) assert allclose(spe.logprob(X**2 < 9), log(Fraction(4, 7)) + spe.children[0].logprob(X**2 < 9)) spe_condition = spe.condition(X**2 < 9) assert isinstance(spe_condition, ContinuousLeaf) assert spe_condition.support == Interval.open(-3, 3) spe_condition = spe.condition((X**2 < 9) | X << {'low'}) assert isinstance(spe_condition, SumSPE) assert spe_condition.children[0].support == Interval.open(-3, 3) assert spe_condition.children[1].support == FiniteNominal('low', 'high') assert isinf_neg(spe_condition.children[1].logprob(X << {'high'})) assert spe_condition == spe.condition((X**2 < 9) | ~(X << {'high'})) assert allclose(spe.logprob((X < oo) | ~(X << {'1'})), 0)
def test_condition_non_contiguous(): X = Id('X') spe = X >> poisson(mu=5) # FiniteSet. for c in [{0, 2, 3}, {-1, 0, 2, 3}, {-1, 0, 2, 3, 'z'}]: spe_condition = spe.condition((X << c)) assert isinstance(spe_condition, SumSPE) assert allclose(0, spe_condition.children[0].logprob(X << {0})) assert allclose(0, spe_condition.children[1].logprob(X << {2, 3})) # FiniteSet or Interval. spe_condition = spe.condition((X << {-1, 'x', 0, 2, 3}) | (X > 7)) assert isinstance(spe_condition, SumSPE) assert len(spe_condition.children) == 3 assert allclose(0, spe_condition.children[0].logprob(X << {0})) assert allclose(0, spe_condition.children[1].logprob(X << {2, 3})) assert allclose(0, spe_condition.children[2].logprob(X > 7))
def test_sum_simplify_nested_sum_1(): X = Id('X') children = [ SumSPE( [X >> norm(loc=0, scale=1), X >> norm(loc=0, scale=2)], [log(0.4), log(0.6)]), X >> gamma(loc=0, a=1), ] spe = SumSPE(children, [log(0.7), log(0.3)]) assert spe.size() == 4 assert spe.children == ( children[0].children[0], children[0].children[1], children[1] ) assert allclose(spe.weights[0], log(0.7) + log(0.4)) assert allclose(spe.weights[1], log(0.7) + log(0.6)) assert allclose(spe.weights[2], log(0.3))
def test_randint(): X = Id('X') spe = X >> randint(low=0, high=5) assert spe.xl == 0 assert spe.xu == 4 assert spe.logprob(X < 5) == spe.logprob(X <= 4) == 0 # i.e., X is not in [0, 3] spe_condition = spe.condition(~((X + 1) << {1, 4})) assert isinstance(spe_condition, SumSPE) xl = spe_condition.children[0].xl idx0 = 0 if xl == 1 else 1 idx1 = 1 if xl == 1 else 0 assert spe_condition.children[idx0].xl == 1 assert spe_condition.children[idx0].xu == 2 assert spe_condition.children[idx1].xl == 4 assert spe_condition.children[idx1].xu == 4 assert allclose(spe_condition.children[idx0].logprob(X << {1, 2}), 0) assert allclose(spe_condition.children[idx1].logprob(X << {4}), 0)
def test_sum_simplify_leaf(): Xd0 = Id('X') >> norm(loc=0, scale=1) Xd1 = Id('X') >> norm(loc=0, scale=2) Xd2 = Id('X') >> norm(loc=0, scale=3) spe = SumSPE([Xd0, Xd1, Xd2], [log(0.5), log(0.1), log(.4)]) assert spe.size() == 4 assert spe_simplify_sum(spe) == spe Xd0 = Id('X') >> norm(loc=0, scale=1) Xd1 = Id('X') >> norm(loc=0, scale=1) Xd2 = Id('X') >> norm(loc=0, scale=1) spe = SumSPE([Xd0, Xd1, Xd2], [log(0.5), log(0.1), log(.4)]) assert spe_simplify_sum(spe) == Xd0 Xd3 = Id('X') >> norm(loc=0, scale=2) spe = SumSPE([Xd0, Xd3, Xd1, Xd3], [log(0.5), log(0.1), log(.3), log(.1)]) spe_simplified = spe_simplify_sum(spe) assert len(spe_simplified.children) == 2 assert spe_simplified.children[0] == Xd0 assert spe_simplified.children[1] == Xd3 assert allclose(spe_simplified.weights[0], log(0.8)) assert allclose(spe_simplified.weights[1], log(0.2))
def test_transform_product(): X = Id('X') Y = Id('Y') W = Id('W') Z = Id('Z') V = Id('V') spe \ = (X >> norm(loc=0, scale=1)) \ & (Y >> poisson(mu=10)) with pytest.raises(Exception): # Cannot use symbols from different transforms. spe.transform(W, (X > 0) | (Y << {'0'})) spe = spe.transform(W, (X**2 - 3*X)**(1,10)) spe = spe.transform(Z, (W > 0) | (X**3 < 1)) spe = spe.transform(V, Y/10) assert allclose( spe.logprob(W>1), spe.logprob((X**2 - 3*X)**(1,10) > 1)) with pytest.raises(Exception): spe.tarnsform(Id('R'), (V>1) | (W < 0))
def test_product_condition_basic(): X = Id('X') Y = Id('Y') spe = ProductSPE([X >> norm(loc=0, scale=1), Y >> gamma(a=1)]) # Condition on (X > 0) and ((X > 0) | (Y < 0)) # where the second clause reduces to first as Y < 0 # has probability zero. for event in [(X > 0), (X > 0) | (Y < 0)]: dX = spe.condition(event) assert isinstance(dX, ProductSPE) assert dX.children[0].symbol == Id('X') assert dX.children[0].conditioned assert dX.children[0].support == Interval.open(0, oo) assert dX.children[1].symbol == Id('Y') assert not dX.children[1].conditioned assert dX.children[1].Fl == 0 assert dX.children[1].Fu == 1 # Condition on (Y < 0.5) dY = spe.condition(Y < 0.5) assert isinstance(dY, ProductSPE) assert dY.children[0].symbol == Id('X') assert not dY.children[0].conditioned assert dY.children[1].symbol == Id('Y') assert dY.children[1].conditioned assert dY.children[1].support == Interval.Ropen(0, 0.5) # Condition on (X > 0) & (Y < 0.5) dXY_and = spe.condition((X > 0) & (Y < 0.5)) assert isinstance(dXY_and, ProductSPE) assert dXY_and.children[0].symbol == Id('X') assert dXY_and.children[0].conditioned assert dXY_and.children[0].support == Interval.open(0, oo) assert dXY_and.children[1].symbol == Id('Y') assert dXY_and.children[1].conditioned assert dXY_and.children[1].support == Interval.Ropen(0, 0.5) # Condition on (X > 0) | (Y < 0.5) event = (X > 0) | (Y < 0.5) dXY_or = spe.condition((X > 0) | (Y < 0.5)) assert isinstance(dXY_or, SumSPE) assert all(isinstance(d, ProductSPE) for d in dXY_or.children) assert allclose(dXY_or.logprob(X > 0), dXY_or.weights[0]) samples = dXY_or.sample(100, prng=numpy.random.RandomState(1)) assert all(event.evaluate(sample) for sample in samples) # Condition on a disjoint union with one term in second clause. dXY_disjoint_one = spe.condition((X > 0) & (Y < 0.5) | (X <= 0)) assert isinstance(dXY_disjoint_one, SumSPE) component_0 = dXY_disjoint_one.children[0] assert component_0.children[0].symbol == Id('X') assert component_0.children[0].conditioned assert component_0.children[0].support == Interval.open(0, oo) assert component_0.children[1].symbol == Id('Y') assert component_0.children[1].conditioned assert component_0.children[1].support == Interval.Ropen(0, 0.5) component_1 = dXY_disjoint_one.children[1] assert component_1.children[0].symbol == Id('X') assert component_1.children[0].conditioned assert component_1.children[0].support == Interval(-oo, 0) assert component_1.children[1].symbol == Id('Y') assert not component_1.children[1].conditioned # Condition on a disjoint union with two terms in each clause dXY_disjoint_two = spe.condition((X > 0) & (Y < 0.5) | ((X <= 0) & ~(Y < 3))) assert isinstance(dXY_disjoint_two, SumSPE) component_0 = dXY_disjoint_two.children[0] assert component_0.children[0].symbol == Id('X') assert component_0.children[0].conditioned assert component_0.children[0].support == Interval.open(0, oo) assert component_0.children[1].symbol == Id('Y') assert component_0.children[1].conditioned assert component_0.children[1].support == Interval.Ropen(0, 0.5) component_1 = dXY_disjoint_two.children[1] assert component_1.children[0].symbol == Id('X') assert component_1.children[0].conditioned assert component_1.children[0].support == Interval(-oo, 0) assert component_1.children[1].symbol == Id('Y') assert component_1.children[1].conditioned assert component_1.children[1].support == Interval(3, oo) # Some various conditioning. spe.condition((X > 0) & (Y < 0.5) | ((X <= 1) | ~(Y < 3))) spe.condition((X > 0) & (Y < 0.5) | ((X <= 1) & (Y < 3)))
from sppl.distributions import discrete from sppl.distributions import norm from sppl.distributions import poisson from sppl.distributions import rv_discrete from sppl.distributions import uniformd from sppl.math_util import allclose from sppl.sets import FiniteNominal from sppl.sets import Interval from sppl.sets import inf as oo from sppl.spe import ContinuousLeaf from sppl.spe import DiscreteLeaf from sppl.spe import NominalLeaf from sppl.spe import SumSPE from sppl.transforms import Id X = Id('X') def test_simple_parse_real(): assert isinstance(.3 * bernoulli(p=.1), DistributionMix) a = .3 * bernoulli(p=.1) | .5 * norm() | .2 * poisson(mu=7) spe = a(X) assert isinstance(spe, SumSPE) assert allclose(spe.weights, [log(.3), log(.5), log(.2)]) assert isinstance(spe.children[0], DiscreteLeaf) assert isinstance(spe.children[1], ContinuousLeaf) assert isinstance(spe.children[2], DiscreteLeaf) assert spe.children[0].support == Interval(0, 1) assert spe.children[1].support == Interval(-oo, oo) assert spe.children[2].support == Interval(0, oo)
from math import log import numpy from sppl.distributions import bernoulli from sppl.distributions import choice from sppl.distributions import norm from sppl.spe import ProductSPE from sppl.spe import SumSPE from sppl.spe import spe_cache_duplicate_subtrees from sppl.transforms import Id rng = numpy.random.RandomState(1) W = Id('W') Y = Id('Y') X = [Id('X[0]'), Id('X[1]')] Z = [Id('Z[0]'), Id('Z[1]')] def test_cache_simple_leaf(): spe = .5 * (W >> norm(loc=0, scale=1)) | .5 * (W >> norm(loc=0, scale=1)) assert spe.children[0] is not spe.children[1] spe_cached = spe_cache_duplicate_subtrees(spe, {}) assert spe_cached.children[0] is spe_cached.children[1] def test_cache_simple_sum_of_product(): spe \ = 0.3 * ((W >> norm(loc=0, scale=1)) & (Y >> norm(loc=0, scale=1))) \
condition_2 = (N << {'India'}) & (P << {'Perfect'}) condition_3 = (N << {'USA'}) & (P << {'Imperfect'}) condition_4 = (N << {'USA'}) & (P << {'Perfect'}) event_1 = condition_1 event_2 = condition_2 & ~condition_1 event_3 = condition_3 & ~condition_2 & ~condition_1 event_4 = condition_4 & ~condition_3 & ~condition_2 & ~condition_1 assert allclose(student.prob(event_1), 0.5 * 0.99) assert allclose(student.prob(event_2), 0.5 * 0.01) assert allclose(student.prob(event_3), 0.5 * 0.99) assert allclose(student.prob(event_4), 0.5 * 0.01) A = Id('A') B = Id('B') C = Id('C') D = Id('D') spe_abcd \ = norm(loc=0, scale=1)(A) \ & norm(loc=0, scale=1)(B) \ & norm(loc=0, scale=1)(C) \ & norm(loc=0, scale=1)(D) def test_product_condition_simplify_a(): spe = spe_abcd.condition((A > 1) | (A < -1)) assert isinstance(spe, ProductSPE) assert spe_abcd.children[1] in spe.children assert spe_abcd.children[2] in spe.children