def test_mutual_information_four_clusters(memo): X = Id('X') Y = Id('Y') spe \ = 0.25*(X >> norm(loc=0, scale=0.5) & Y >> norm(loc=0, scale=0.5)) \ | 0.25*(X >> norm(loc=5, scale=0.5) & Y >> norm(loc=0, scale=0.5)) \ | 0.25*(X >> norm(loc=0, scale=0.5) & Y >> norm(loc=5, scale=0.5)) \ | 0.25*(X >> norm(loc=5, scale=0.5) & Y >> norm(loc=5, scale=0.5)) \ A = X > 2 B = Y > 2 samples = spe.sample(100, prng) mi = spe.mutual_information(A, B, memo=memo) assert allclose(mi, 0) check_mi_properties(spe, A, B, memo) event = ((X>2) & (Y<2) | ((X<2) & (Y>2))) spe_condition = spe.condition(event) samples = spe_condition.sample(100, prng) assert all(event.evaluate(sample) for sample in samples) mi = spe_condition.mutual_information(X > 2, Y > 2) assert allclose(mi, log(2)) check_mi_properties(spe, (X>1) | (Y<1), (Y>2), memo) check_mi_properties(spe, (X>1) | (Y<1), (X>1.5) & (Y>2), memo)
def test_solver_21__ci_(): # 1 <= log(x**3 - 3*x + 3) < 5 # Can only be solved by numerical approximation of roots. # https://www.wolframalpha.com/input/?i=1+%3C%3D+log%28x**3+-+3x+%2B+3%29+%3C+5 solution = Union( Interval( -1.777221448430427630375448631016427343692, 0.09418455242255462832154474245589911789464), Interval.Ropen( 1.683036896007873002053903888560528225797, 5.448658707897512189124586716091172798465)) expr = Log(Y**3 - 3*Y + 3) event = ((1 <= expr) & (expr < 5)) answer = event.solve() assert isinstance(answer, Union) assert len(answer.args) == 2 first = answer.args[0] if answer.args[0].a < 0 else answer.args[1] second = answer.args[0] if answer.args[0].a > 0 else answer.args[1] # Check first interval. assert not first.left_open assert not first.right_open assert allclose(float(first.a), float(solution.args[0].a)) assert allclose(float(first.b), float(solution.args[0].b)) # Check second interval. assert not second.left_open assert second.right_open assert allclose(float(second.a), float(solution.args[1].a)) assert allclose(float(second.b), float(solution.args[1].b))
def test_solve_poly_equality_quadratic_one(): roots = solve_poly_equality(expr_quadratic, 1) # SymPy is not smart enough to simplify irrational roots symbolically # so check numerical equality of the symbolic roots. assert len(roots) == 2 assert any(allclose(float(x), float(xe1_quad0)) for x in roots) assert any(allclose(float(x), float(xe1_quad1)) for x in roots)
def test_logpdf_mixture_nominal(): spe = SumSPE([X >> norm(), X >> choice({'a':.1, 'b':.9})], [log(.4), log(.6)]) assert allclose( spe.logpdf({X: .5}), log(.4) + spe.children[0].logpdf({X: .5})) assert allclose( spe.logpdf({X: 'a'}), log(.6) + spe.children[1].logpdf({X: 'a'}))
def test_solve_poly_equality_cubic_irrat_one(): # This expression is too slow to solve symbolically. # The 5s timeout will trigger a numerical approximation. roots = solve_poly_equality(expr_cubic_irrat, 1) assert len(roots) == 3 assert any(allclose(float(x), xe1_cubic_irrat0) for x in roots) assert any(allclose(float(x), xe1_cubic_irrat1) for x in roots) assert any(allclose(float(x), xe1_cubic_irrat2) for x in roots)
def test_solve_poly_equality_cubic_irrat_zero(): roots = solve_poly_equality(expr_cubic_irrat, 0) # Confirm that roots contains symbolic elements (no timeout). assert -Rational(10, 7) in roots # SymPy is not smart enough to simplify irrational roots symbolically # so check numerical equality of the symbolic roots. assert any(allclose(float(x), float(SymSqrt(2) / 10)) for x in roots) assert any(allclose(float(x), float(SymSqrt(5))) for x in roots)
def test_if_else_transform(): model = Sequence( Sample(X, norm(loc=0, scale=1)), IfElse(X > 0, Transform(Z, X**2), Otherwise, Transform(Z, X))).interpret() assert model.children[0].env == {X: X, Z: X**2} assert model.children[1].env == {X: X, Z: X} assert allclose(model.children[0].logprob(Z > 0), 0) assert allclose(model.children[1].logprob(Z > 0), -float('inf')) assert allclose(model.logprob(Z > 0), -log(2))
def test_simple_model_enumerate(): command_switch = Sequence( Sample(Y, randint(low=0, high=4)), Switch(Y, enumerate(range(0, 4)), lambda i, j: Sample(X, bernoulli(p=1 / (i + j + 1))))) model = command_switch.interpret() assert allclose(model.prob(Y << {0} & (X << {1})), .25 * 1 / (0 + 0 + 1)) assert allclose(model.prob(Y << {1} & (X << {1})), .25 * 1 / (1 + 1 + 1)) assert allclose(model.prob(Y << {2} & (X << {1})), .25 * 1 / (2 + 2 + 1)) assert allclose(model.prob(Y << {3} & (X << {1})), .25 * 1 / (3 + 3 + 1))
def test_condition_nominal(): command = Sequence(Sample(Y, choice({ 'a': .1, 'b': .1, 'c': .8 })), Condition(Y << {'a', 'b'})) model = command.interpret() assert allclose(model.prob(Y << {'a'}), .5) assert allclose(model.prob(Y << {'b'}), .5) assert allclose(model.prob(Y << {'c'}), 0)
def test_sum_leaf(): # Cannot sum leaves without weights. with pytest.raises(TypeError): (X >> norm()) | (X >> gamma(a=1)) # Cannot sum a leaf with a partial sum. with pytest.raises(TypeError): 0.3*(X >> norm()) | (X >> gamma(a=1)) # Cannot sum a leaf with a partial sum. with pytest.raises(TypeError): (X >> norm()) | 0.3*(X >> gamma(a=1)) # Wrong symbol. with pytest.raises(ValueError): 0.4*(X >> norm()) | 0.6*(Y >> gamma(a=1)) # Sum exceeds one. with pytest.raises(ValueError): 0.4*(X >> norm()) | 0.7*(Y >> gamma(a=1)) y = 0.4*(X >> norm()) | 0.3*(X >> gamma(a=1)) assert isinstance(y, PartialSumSPE) assert len(y.weights) == 2 assert allclose(float(y.weights[0]), 0.4) assert allclose(float(y.weights[1]), 0.3) y = 0.4*(X >> norm()) | 0.6*(X >> gamma(a=1)) assert isinstance(y, SumSPE) assert len(y.weights) == 2 assert allclose(float(y.weights[0]), log(0.4)) assert allclose(float(y.weights[1]), log(0.6)) # Sum exceeds one. with pytest.raises(TypeError): y | 0.7 * (X >> norm()) y = 0.4*(X >> norm()) | 0.3*(X >> gamma(a=1)) | 0.1*(X >> norm()) assert isinstance(y, PartialSumSPE) assert len(y.weights) == 3 assert allclose(float(y.weights[0]), 0.4) assert allclose(float(y.weights[1]), 0.3) assert allclose(float(y.weights[2]), 0.1) y = 0.4*(X >> norm()) | 0.3*(X >> gamma(a=1)) | 0.3*(X >> norm()) assert isinstance(y, SumSPE) assert len(y.weights) == 3 assert allclose(float(y.weights[0]), log(0.4)) assert allclose(float(y.weights[1]), log(0.3)) assert allclose(float(y.weights[2]), log(0.3)) with pytest.raises(TypeError): (0.3)*(0.3*(X >> norm())) with pytest.raises(TypeError): (0.3*(X >> norm())) * (0.3) with pytest.raises(TypeError): 0.3*(0.3*(X >> norm()) | 0.5*(X >> norm())) w = 0.3*(0.4*(X >> norm()) | 0.6*(X >> norm())) assert isinstance(w, PartialSumSPE)
def test_product_condition_simplify_a(): spe = spe_abcd.condition((A > 1) | (A < -1)) assert isinstance(spe, ProductSPE) assert spe_abcd.children[1] in spe.children assert spe_abcd.children[2] in spe.children assert spe_abcd.children[3] in spe.children idx_sum = [i for i, c in enumerate(spe.children) if isinstance(c, SumSPE)] assert len(idx_sum) == 1 assert allclose(spe.children[idx_sum[0]].weights[0], -log(2)) assert allclose(spe.children[idx_sum[0]].weights[1], -log(2)) assert spe.children[idx_sum[0]].children[0].conditioned assert spe.children[idx_sum[0]].children[1].conditioned
def check_mi_properties(spe, A, B, memo): miAB = spe.mutual_information(A, B, memo=memo) miAA = spe.mutual_information(A, A, memo=memo) miBB = spe.mutual_information(B, B, memo=memo) eA = entropy(spe, A, memo=memo) eB = entropy(spe, B, memo=memo) eAB = entropyc(spe, A, B, memo=memo) eBA = entropyc(spe, B, A, memo=memo) assert allclose(miAA, eA) assert allclose(miBB, eB) assert allclose(miAB, eA - eAB) assert allclose(miAB, eB - eBA)
def test_render_sppl(): model = get_model() sppl_code = render_sppl(model) compiler = SPPL_Compiler(sppl_code.getvalue()) namespace = compiler.execute_module() (X, Y) = (namespace.X, namespace.Y) for i in range(5): assert allclose(model.logprob(Y << {'0'}), [ model.logprob(Y << {str(i)}), namespace.model.logprob(Y << {str(i)}) ]) for i in range(4): assert allclose(model.logprob(X << {i}), namespace.model.logprob(X << {i}))
def test_product_condition_simplify_ab(): spe = spe_abcd.condition((A > 1) | (B < 0)) assert isinstance(spe, ProductSPE) assert spe_abcd.children[2] in spe.children assert spe_abcd.children[2] in spe.children idx_sum = [i for i, c in enumerate(spe.children) if isinstance(c, SumSPE)] assert len(idx_sum) == 1 spe_sum = spe.children[idx_sum[0]] assert isinstance(spe_sum.children[0], ProductSPE) assert isinstance(spe_sum.children[1], ProductSPE) lp0 = spe_abcd.logprob(A > 1) lp1 = spe_abcd.logprob((B < 0) & ~(A > 1)) weights = lognorm([lp0, lp1]) assert allclose(spe_sum.weights[0], weights[0]) assert allclose(spe_sum.weights[1], weights[1])
def test_logpdf_lexicographic_either(): spe = .75*(X >> norm() & Y >> atomic(loc=0) & Z >> discrete({1:.1, 2:.9})) \ | .25*(X >> atomic(loc=0) & Y >> norm() & Z >> norm()) # Lexicographic, Branch 1 assignment = {X:0, Y:0, Z:2} assert allclose( spe.logpdf(assignment), log(.75) + norm().dist.logpdf(0) + log(1) + log(.9)) assert isinstance(spe.constrain(assignment), ProductSPE) # Lexicographic, Branch 2 assignment = {X:0, Y:0, Z:0} assert allclose( spe.logpdf(assignment), log(.25) + log(1) + norm().dist.logpdf(0) + norm().dist.logpdf(0)) assert isinstance(spe.constrain(assignment), ProductSPE)
def test_condition_non_contiguous(): X = Id('X') spe = X >> poisson(mu=5) # FiniteSet. for c in [{0, 2, 3}, {-1, 0, 2, 3}, {-1, 0, 2, 3, 'z'}]: spe_condition = spe.condition((X << c)) assert isinstance(spe_condition, SumSPE) assert allclose(0, spe_condition.children[0].logprob(X << {0})) assert allclose(0, spe_condition.children[1].logprob(X << {2, 3})) # FiniteSet or Interval. spe_condition = spe.condition((X << {-1, 'x', 0, 2, 3}) | (X > 7)) assert isinstance(spe_condition, SumSPE) assert len(spe_condition.children) == 3 assert allclose(0, spe_condition.children[0].logprob(X << {0})) assert allclose(0, spe_condition.children[1].logprob(X << {2, 3})) assert allclose(0, spe_condition.children[2].logprob(X > 7))
def test_transform_real_leaf_logprob(): X = Id('X') Y = Id('Y') Z = Id('Z') spe = (X >> norm(loc=0, scale=1)) with pytest.raises(AssertionError): spe.transform(Z, Y**2) with pytest.raises(AssertionError): spe.transform(X, X**2) spe = spe.transform(Z, X**2) assert spe.env == {X:X, Z:X**2} assert spe.get_symbols() == {X, Z} assert spe.logprob(Z < 1) == spe.logprob(X**2 < 1) assert spe.logprob((Z < 1) | ((X + 1) < 3)) \ == spe.logprob((X**2 < 1) | ((X+1) < 3)) spe = spe.transform(Y, 2*Z) assert spe.env == {X:X, Z:X**2, Y:2*Z} assert spe.logprob(Y**(1,3) < 10) \ == spe.logprob((2*Z)**(1,3) < 10) \ == spe.logprob((2*(X**2))**(1,3) < 10) \ W = Id('W') spe = spe.transform(W, X > 1) assert allclose(spe.logprob(W), spe.logprob(X > 1))
def test_simple_model_lte(): command_switch = Sequence( Sample(Y, beta(a=2, b=3)), Switch(Y, binspace(0, 1, 5), lambda i: Sample(X, bernoulli(p=i.right)))) model_switch = command_switch.interpret() command_ifelse = Sequence( Sample(Y, beta(a=2, b=3)), IfElse( Y <= 0, Sample(X, bernoulli(p=0)), Y <= 0.25, Sample(X, bernoulli(p=.25)), Y <= 0.50, Sample(X, bernoulli(p=.50)), Y <= 0.75, Sample(X, bernoulli(p=.75)), Y <= 1, Sample(X, bernoulli(p=1)), )) model_ifelse = command_ifelse.interpret() grid = [float(x) for x in linspace(0, 1, 5)] for model in [model_switch, model_ifelse]: symbols = model.get_symbols() assert symbols == {X, Y} assert allclose( model.logprob(X << {1}), logsumexp([ model.logprob((il < Y) <= ih) + log(ih) for il, ih in zip(grid[:-1], grid[1:]) ]))
def test_simple_model_eq(): command_switch = Sequence( Sample(Y, randint(low=0, high=4)), Switch(Y, range(0, 4), lambda i: Sample(X, bernoulli(p=1 / (i + 1))))) model_switch = command_switch.interpret() command_ifelse = Sequence( Sample(Y, randint(low=0, high=4)), IfElse( Y << {0}, Sample(X, bernoulli(p=1 / (0 + 1))), Y << {1}, Sample(X, bernoulli(p=1 / (1 + 1))), Y << {2}, Sample(X, bernoulli(p=1 / (2 + 1))), Y << {3}, Sample(X, bernoulli(p=1 / (3 + 1))), )) model_ifelse = command_ifelse.interpret() for model in [model_switch, model_ifelse]: symbols = model.get_symbols() assert symbols == {X, Y} assert allclose(model.logprob(X << {1}), logsumexp([-log(4) - log(i + 1) for i in range(4)]))
def test_logpdf_mixture_real_continuous_continuous(): spe = X >> (.3*norm() | .7*gamma(a=1)) assert allclose( spe.logpdf({X: .5}), logsumexp([ log(.3) + spe.children[0].logpdf({X: 0.5}), log(.7) + spe.children[1].logpdf({X: 0.5}), ]))
def test_logpdf_mixture_real_continuous_discrete(): spe = X >> (.3*norm() | .7*poisson(mu=1)) assert allclose( spe.logpdf(X << {.5}), logsumexp([ log(.3) + spe.children[0].logpdf({X: 0.5}), log(.7) + spe.children[1].logpdf({X: 0.5}), ])) assert False, 'Invalid base measure addition'
def test_randint(): X = Id('X') spe = X >> randint(low=0, high=5) assert spe.xl == 0 assert spe.xu == 4 assert spe.logprob(X < 5) == spe.logprob(X <= 4) == 0 # i.e., X is not in [0, 3] spe_condition = spe.condition(~((X + 1) << {1, 4})) assert isinstance(spe_condition, SumSPE) xl = spe_condition.children[0].xl idx0 = 0 if xl == 1 else 1 idx1 = 1 if xl == 1 else 0 assert spe_condition.children[idx0].xl == 1 assert spe_condition.children[idx0].xu == 2 assert spe_condition.children[idx1].xl == 4 assert spe_condition.children[idx1].xu == 4 assert allclose(spe_condition.children[idx0].logprob(X << {1, 2}), 0) assert allclose(spe_condition.children[idx1].logprob(X << {4}), 0)
def test_sum_simplify_nested_sum_1(): X = Id('X') children = [ SumSPE( [X >> norm(loc=0, scale=1), X >> norm(loc=0, scale=2)], [log(0.4), log(0.6)]), X >> gamma(loc=0, a=1), ] spe = SumSPE(children, [log(0.7), log(0.3)]) assert spe.size() == 4 assert spe.children == ( children[0].children[0], children[0].children[1], children[1] ) assert allclose(spe.weights[0], log(0.7) + log(0.4)) assert allclose(spe.weights[1], log(0.7) + log(0.6)) assert allclose(spe.weights[2], log(0.3))
def test_simple_parse_nominal(): assert isinstance(.7 * choice({'a': .1, 'b': .9}), DistributionMix) a = .3 * bernoulli(p=.1) | .7 * choice({'a': .1, 'b': .9}) spe = a(X) assert isinstance(spe, SumSPE) assert allclose(spe.weights, [log(.3), log(.7)]) assert isinstance(spe.children[0], DiscreteLeaf) assert isinstance(spe.children[1], NominalLeaf) assert spe.children[0].support == Interval(0, 1) assert spe.children[1].support == FiniteNominal('a', 'b')
def test_logpdf_lexicographic_both(): spe = .75*(X >> norm() & Y >> atomic(loc=0) & Z >> discrete({1:.2, 2:.8})) \ | .25*(X >> discrete({1:.5, 2:.5}) & Y >> norm() & Z >> atomic(loc=2)) # Lexicographic, Mix assignment = {X:1, Y:0, Z:2} assert allclose( spe.logpdf(assignment), logsumexp([ log(.75) + norm().dist.logpdf(1) + log(1) + log(.8), log(.25) + log(.5) + norm().dist.logpdf(0) + log(1)])) assert isinstance(spe.constrain(assignment), SumSPE)
def test_simple_parse_real(): assert isinstance(.3 * bernoulli(p=.1), DistributionMix) a = .3 * bernoulli(p=.1) | .5 * norm() | .2 * poisson(mu=7) spe = a(X) assert isinstance(spe, SumSPE) assert allclose(spe.weights, [log(.3), log(.5), log(.2)]) assert isinstance(spe.children[0], DiscreteLeaf) assert isinstance(spe.children[1], ContinuousLeaf) assert isinstance(spe.children[2], DiscreteLeaf) assert spe.children[0].support == Interval(0, 1) assert spe.children[1].support == Interval(-oo, oo) assert spe.children[2].support == Interval(0, oo)
def test_complex_model(): # Slow for larger number of repetitions # https://github.com/probcomp/sum-product-dsl/issues/43 command = Sequence( Sample(Y, choice({'0': .2, '1': .2, '2': .2, '3': .2, '4': .2})), For(0, 3, lambda i: Sequence( Sample(Z[i], bernoulli(p=0.1)), IfElse( Y << {str(i)} | Z[i] << {0}, Sample(X[i], bernoulli(p=1/(i+1))), Otherwise, Sample(X[i], bernoulli(p=0.1)))))) model = command.interpret() assert allclose(model.prob(Y << {'0'}), 0.2)
def test_complex_model_reorder(): command = Sequence( Sample(Y, choice({'0': .2, '1': .2, '2': .2, '3': .2, '4': .2})), For(0, 3, lambda i: Sample(Z[i], bernoulli(p=0.1))), For(0, 3, lambda i: IfElse( Y << {str(i)}, Sample(X[i], bernoulli(p=1/(i+1))), Z[i] << {0}, Sample(X[i], bernoulli(p=1/(i+1))), Otherwise, Sample(X[i], bernoulli(p=0.1))))) model = command.interpret() assert(allclose(model.prob(Y << {'0'}), 0.2))
def test_ifelse_zero_conditions(): command = Sequence( Sample(Y, randint(low=0, high=3)), IfElse( Y << {-1}, Transform(X, Y**(-1)), Y << {0}, Sample(X, bernoulli(p=1)), Y << {1}, Transform(X, Y), Y << {2}, Transform(X, Y**2), Y << {3}, Transform(X, Y**3), )) model = command.interpret() assert len(model.children) == 3 assert len(model.weights) == 3 assert allclose(model.weights[0], model.logprob(Y << {0})) assert allclose(model.weights[1], model.logprob(Y << {1})) assert allclose(model.weights[2], model.logprob(Y << {2}))
def test_sum_normal_nominal(): X = Id('X') children = [ X >> norm(loc=0, scale=1), X >> choice({ 'low': Fraction(3, 10), 'high': Fraction(7, 10) }), ] weights = [log(Fraction(4, 7)), log(Fraction(3, 7))] spe = SumSPE(children, weights) assert allclose(spe.logprob(X < 0), log(Fraction(4, 7)) + log(Fraction(1, 2))) assert allclose(spe.logprob(X << {'low'}), log(Fraction(3, 7)) + log(Fraction(3, 10))) # The semantics of ~(X<<{'low'}) are (X << String and X != 'low') assert allclose(spe.logprob(~(X << {'low'})), spe.logprob((X << {'high'}))) assert allclose( spe.logprob((X << FiniteNominal(b=True)) & ~(X << {'low'})), spe.logprob((X << FiniteNominal(b=True)) & (X << {'high'}))) assert isinf_neg(spe.logprob((X < 0) & (X << {'low'}))) assert allclose(spe.logprob((X < 0) | (X << {'low'})), logsumexp([spe.logprob(X < 0), spe.logprob(X << {'low'})])) assert isinf_neg(spe.logprob(X << {'a'})) assert allclose(spe.logprob(~(X << {'a'})), spe.logprob(X << {'low', 'high'})) assert allclose(spe.logprob(X**2 < 9), log(Fraction(4, 7)) + spe.children[0].logprob(X**2 < 9)) spe_condition = spe.condition(X**2 < 9) assert isinstance(spe_condition, ContinuousLeaf) assert spe_condition.support == Interval.open(-3, 3) spe_condition = spe.condition((X**2 < 9) | X << {'low'}) assert isinstance(spe_condition, SumSPE) assert spe_condition.children[0].support == Interval.open(-3, 3) assert spe_condition.children[1].support == FiniteNominal('low', 'high') assert isinf_neg(spe_condition.children[1].logprob(X << {'high'})) assert spe_condition == spe.condition((X**2 < 9) | ~(X << {'high'})) assert allclose(spe.logprob((X < oo) | ~(X << {'1'})), 0)