def test_solver_21__ci_(): # 1 <= log(x**3 - 3*x + 3) < 5 # Can only be solved by numerical approximation of roots. # https://www.wolframalpha.com/input/?i=1+%3C%3D+log%28x**3+-+3x+%2B+3%29+%3C+5 solution = Union( Interval( -1.777221448430427630375448631016427343692, 0.09418455242255462832154474245589911789464), Interval.Ropen( 1.683036896007873002053903888560528225797, 5.448658707897512189124586716091172798465)) expr = Log(Y**3 - 3*Y + 3) event = ((1 <= expr) & (expr < 5)) answer = event.solve() assert isinstance(answer, Union) assert len(answer.args) == 2 first = answer.args[0] if answer.args[0].a < 0 else answer.args[1] second = answer.args[0] if answer.args[0].a > 0 else answer.args[1] # Check first interval. assert not first.left_open assert not first.right_open assert allclose(float(first.a), float(solution.args[0].a)) assert allclose(float(first.b), float(solution.args[0].b)) # Check second interval. assert not second.left_open assert second.right_open assert allclose(float(second.a), float(solution.args[1].a)) assert allclose(float(second.b), float(solution.args[1].b))
def test_parse_2_closed(): # (log(x) <= 2) & (x >= exp(2)) expr = (Log(X) >= 2) & (X <= sympy.exp(2)) event = EventAnd([ EventInterval(Log(Y), Interval(2, oo)), EventInterval(Y, Interval(-oo, sympy.exp(2))) ]) assert expr == event
def test_solver_13(): # 2*sqrt(|x|**2) - 3 > 10 solution = Union( Interval.open(-oo, -Rat(13, 2)), Interval.open(Rat(13, 2), oo)) event = (2*Sqrt(abs(Y)**2) - 3) > 10 answer = event.solve() assert answer == solution
def test_solver_14(): # x**2 > 10 solution = Union( Interval.open(-oo, -sympy.sqrt(10)), Interval.open(sympy.sqrt(10), oo)) event = Y**2 > 10 answer = event.solve() assert answer == solution
def test_parse_2_open(): # log(x) < 2 & (x < exp(2)) expr = (Log(X) > 2) & (X < sympy.exp(2)) event = EventAnd([ EventInterval(Log(Y), Interval.open(2, oo)), EventInterval(Y, Interval.open(-oo, sympy.exp(2))) ]) assert expr == event
def test_solver_20(): # log(x**2 - 3) < 5 solution = Union( Interval.open(-sympy.sqrt(3 + sympy.exp(5)), -sympy.sqrt(3)), Interval.open(sympy.sqrt(3), sympy.sqrt(3 + sympy.exp(5)))) event = Log(Y**2 - 3) < 5 answer = event.solve() assert answer == solution
def test_solver_6(): # (x**2 - 2*x) > 10 solution = Union( Interval.open(-oo, 1 - sympy.sqrt(11)), Interval.open(1 + sympy.sqrt(11), oo)) event = (Y**2 - 2*Y) > 10 answer = event.solve() assert answer == solution
def test_parse_5_lopen(): # (2*x + 10 < 4) & (x + 10 >= 3) expr = ((2*X + 10) <= 4) & (X + 10 > 3) event = EventAnd([ EventInterval(Poly(Y, [10, 2]), Interval(-oo, 4)), EventInterval(Poly(Y, [10, 1]), Interval.open(3, oo)), ]) assert expr == event
def test_parse_27_piecewise_many(): assert (Y < 0)*(Y**2) + (0 <= Y)*Y**((1, 2)) == Piecewise( [ Poly(Y, [0, 0, 1]), Radical(Y, 2)], [ EventInterval(Y, Interval.open(-oo, 0)), EventInterval(Y, Interval(0, oo)) ])
def test_parse_6(): # (x**2 - 2*x) > 10 expr = (X**2 - 2*X) > 10 event = EventInterval(Poly(Y, [0, -2, 1]), Interval.open(10, oo)) assert expr == event # (exp(x)**2 - 2*exp(x)) > 10 expr = (Exp(X)**2 - 2*Exp(X)) > 10 event = EventInterval(Poly(Exp(X), [0, -2, 1]), Interval.open(10, oo)) assert expr == event
def test_solver_23_reciprocal_lte(): for c in [1, 3]: # Positive # 1 / X < 10 solution = Interval.Ropen(-oo, 0) | Interval.Lopen(Rat(c, 10), oo) event = (c / Y) < 10 assert event.solve() == solution # 1 / X <= 10 solution = Interval.Ropen(-oo, 0) | Interval(Rat(c, 10), oo) event = (c / Y) <= 10 assert event.solve() == solution # 1 / X <= sqrt(2) solution = Interval.Ropen(-oo, 0) | Interval(c / sympy.sqrt(2), oo) event = (c / Y) <= sympy.sqrt(2) assert event.solve() == solution # Negative. # 1 / X < -10 solution = Interval.open(-Rat(c, 10), 0) event = (c / Y) < -10 assert event.solve() == solution # 1 / X <= -10 solution = Interval.Ropen(-Rat(c, 10), 0) event = (c / Y) <= -10 assert event.solve() == solution # 1 / X <= -sqrt(2) solution = Interval.Ropen(-c / sympy.sqrt(2), 0) event = (c / Y) <= -sympy.sqrt(2) assert event.solve() == solution
def test_solver_18(): # 3*(x**(1/7))**4 - 3*(x**(1/7))**2 <= 9 solution = Interval(0, (Rat(1, 2) + sympy.sqrt(13)/2)**(Rat(7, 2))) Z = Y**(Rat(1, 7)) expr = 3*Z**4 - 3*Z**2 event = (expr <= 9) answer = event.solve() assert answer == solution interval = (~event).solve() assert interval == Interval.open(solution.right, oo)
def test_Interval_in(): with pytest.raises(Exception): Interval(3, 1) assert 1 in Interval(0, 1) assert 1 not in Interval.Ropen(0, 1) assert 1 in Interval.Lopen(0, 1) assert 0 in Interval(0, 1) assert 0 not in Interval.Lopen(0, 1) assert 0 in Interval.Ropen(0, 1) assert inf not in Interval(-inf, inf) assert -inf not in Interval(-inf, 0) assert 10 in Interval(-inf, inf)
def test_simple_parse_real(): assert isinstance(.3 * bernoulli(p=.1), DistributionMix) a = .3 * bernoulli(p=.1) | .5 * norm() | .2 * poisson(mu=7) spe = a(X) assert isinstance(spe, SumSPE) assert allclose(spe.weights, [log(.3), log(.5), log(.2)]) assert isinstance(spe.children[0], DiscreteLeaf) assert isinstance(spe.children[1], ContinuousLeaf) assert isinstance(spe.children[2], DiscreteLeaf) assert spe.children[0].support == Interval(0, 1) assert spe.children[1].support == Interval(-oo, oo) assert spe.children[2].support == Interval(0, oo)
def test_solver_27_piecewise_many(): expr = (Y < 0)*(Y**2) + (0 <= Y)*Y**(Rat(1, 2)) event = expr << {3} assert sorted(event.solve()) == [-sympy.sqrt(3), 9] event = 0 < expr assert event.solve() == Union( Interval.open(-oo, 0), Interval.open(0, oo)) # TODO: Consider banning the restriction of a function # to a segment outside of its domain. expr = (Y < 0)*Y**(Rat(1, 2)) assert (expr < 1).solve() is EmptySet
def test_solver_24_negative_power_Rat(): # Case 1. event = Y**Rat(-1, 3) < 6 assert event.solve() == Interval.Lopen(Rat(1, 216), oo) # Case 2. event = (-1 < Y**Rat(-1, 3)) < 6 assert event.solve() == Interval.Lopen(Rat(1, 216), oo) # Case 3. event = 5 <= Y**Rat(-1, 3) assert event.solve() == Interval.Lopen(0, Rat(1, 125)) # Case 4. event = (5 <= Y**Rat(-1, 3)) < 6 assert event.solve() == Interval.Lopen(Rat(1, 216), Rat(1, 125))
def test_parse_26_piecewise_one_expr_compound_event(): assert (Y**2)*((Y < 0) | (0 < Y)) == Piecewise( [Poly(Y, [0, 0, 1])], [EventOr([ EventInterval(Y, Interval.open(-oo, 0)), EventInterval(Y, Interval.open(0, oo)), ])]) assert (Y**2)*(~((3 < Y) <= 4)) == Piecewise( [Poly(Y, [0, 0, 1])], [EventOr([ EventInterval(Y, Interval(-oo, 3)), EventInterval(Y, Interval.open(4, oo)), ])])
def test_sum_normal_nominal(): X = Id('X') children = [ X >> norm(loc=0, scale=1), X >> choice({ 'low': Fraction(3, 10), 'high': Fraction(7, 10) }), ] weights = [log(Fraction(4, 7)), log(Fraction(3, 7))] spe = SumSPE(children, weights) assert allclose(spe.logprob(X < 0), log(Fraction(4, 7)) + log(Fraction(1, 2))) assert allclose(spe.logprob(X << {'low'}), log(Fraction(3, 7)) + log(Fraction(3, 10))) # The semantics of ~(X<<{'low'}) are (X << String and X != 'low') assert allclose(spe.logprob(~(X << {'low'})), spe.logprob((X << {'high'}))) assert allclose( spe.logprob((X << FiniteNominal(b=True)) & ~(X << {'low'})), spe.logprob((X << FiniteNominal(b=True)) & (X << {'high'}))) assert isinf_neg(spe.logprob((X < 0) & (X << {'low'}))) assert allclose(spe.logprob((X < 0) | (X << {'low'})), logsumexp([spe.logprob(X < 0), spe.logprob(X << {'low'})])) assert isinf_neg(spe.logprob(X << {'a'})) assert allclose(spe.logprob(~(X << {'a'})), spe.logprob(X << {'low', 'high'})) assert allclose(spe.logprob(X**2 < 9), log(Fraction(4, 7)) + spe.children[0].logprob(X**2 < 9)) spe_condition = spe.condition(X**2 < 9) assert isinstance(spe_condition, ContinuousLeaf) assert spe_condition.support == Interval.open(-3, 3) spe_condition = spe.condition((X**2 < 9) | X << {'low'}) assert isinstance(spe_condition, SumSPE) assert spe_condition.children[0].support == Interval.open(-3, 3) assert spe_condition.children[1].support == FiniteNominal('low', 'high') assert isinf_neg(spe_condition.children[1].logprob(X << {'high'})) assert spe_condition == spe.condition((X**2 < 9) | ~(X << {'high'})) assert allclose(spe.logprob((X < oo) | ~(X << {'1'})), 0)
def test_condition(): model = model_no_latents() GPA = Id('GPA') model_condition = model.condition(GPA << {4} | GPA << {10}) assert len(model_condition.children) == 2 assert model_condition.children[0].support == Interval.Ropen(4, 5) assert model_condition.children[1].support == Interval.Ropen(10, 11) model_condition = model.condition((0 < GPA < 4)) assert len(model_condition.children) == 2 assert model_condition.children[0].support \ == model_condition.children[1].support assert allclose(model_condition.children[0].logprob(GPA < 1), model_condition.children[1].logprob(GPA < 1))
def test_event_containment_union(): assert (X << (Interval(0, 1) | Interval(2, 3))) \ == (((0 <= X) <= 1) | ((2 <= X) <= 3)) assert (X << (FiniteReal(0, 1) | Interval(2, 3))) \ == ((X << {0, 1}) | ((2 <= X) <= 3)) assert (X << FiniteNominal('a', b=True)) \ == EventFiniteNominal(X, FiniteNominal('a', b=True)) assert X << EmptySet == EventFiniteReal(X, EmptySet) # Ordering is not guaranteed. a = X << (Interval(0,1) | (FiniteReal(1.5) | FiniteNominal('a'))) assert len(a.subexprs) == 3 assert EventInterval(X, Interval(0,1)) in a.subexprs assert EventFiniteReal(X, FiniteReal(1.5)) in a.subexprs assert EventFiniteNominal(X, FiniteNominal('a')) in a.subexprs
def test_parse_18(): # 3*(x**(1/7))**4 - 3*(x**(1/7))**2 <= 9 Z = X**Rat(1, 7) expr = 3*Z**4 - 3*Z**2 expr_prime = Poly(Radical(Y, 7), [0, 0, -3, 0, 3]) assert expr == expr_prime event = EventInterval(expr_prime, Interval(-oo, 9)) assert (expr <= 9) == event event_not = EventInterval(expr_prime, Interval.open(9, oo)) assert ~(expr <= 9) == event_not expr = (3*Abs(Z))**4 - (3*Abs(Z))**2 expr_prime = Poly(Poly(Abs(Z), [0, 3]), [0, 0, -1, 0, 3])
def test_solver_finite_symbolic(): # Transform can never be symbolic. event = Y << {'a', 'b'} assert event.solve() == FiniteNominal('a', 'b') # Complement the Identity. event = ~(Y << {'a', 'b'}) assert event.solve() == FiniteNominal('a', 'b', b=True) # Transform can never be symbolic. event = Y**2 << {'a', 'b'} assert event.solve() is EmptySet # Complement the Identity. event = ~(Y**2 << {'a', 'b'}) assert event.solve() == FiniteNominal(b=True) # Solve Identity mixed. event = Y << {9, 'a', '7'} assert event.solve() == Union( FiniteReal(9), FiniteNominal('a', '7')) # Solve Transform mixed. event = Y**2 << {9, 'a', 'b'} assert event.solve() == FiniteReal(-3, 3) # Solve a disjunction. event = (Y << {'a', 'b'}) | (Y << {'c'}) assert event.solve() == FiniteNominal('a', 'b', 'c') # Solve a conjunction with intersection. event = (Y << {'a', 'b'}) & (Y << {'b', 'c'}) assert event.solve() == FiniteNominal('b') # Solve a conjunction with no intersection. event = (Y << {'a', 'b'}) & (Y << {'c'}) assert event.solve() is EmptySet # Solve a disjunction with complement. event = (Y << {'a', 'b'}) & ~(Y << {'c'}) assert event.solve() == FiniteNominal('a', 'b') # Solve a disjunction with complement. event = (Y << {'a', 'b'}) | ~(Y << {'c'}) assert event.solve() == FiniteNominal('c', b=True) # Union of interval and symbolic. event = (Y**2 <= 9) | (Y << {'a'}) assert event.solve() == Interval(-3, 3) | FiniteNominal('a') # Union of interval and not symbolic. event = (Y**2 <= 9) | ~(Y << {'a'}) assert event.solve() == Interval(-3, 3) | FiniteNominal('a', b=True) # Intersection of interval and symbolic. event = (Y**2 <= 9) & (Y << {'a'}) assert event.solve() is EmptySet # Intersection of interval and not symbolic. event = (Y**2 <= 9) & ~(Y << {'a'}) assert event.solve() == EmptySet
def test_solver_10(): # Sympy hangs on this test. # exp(sqrt(log(x))) > -5 solution = Interval(1, oo) event = Exp(Sqrt(Log(Y))) > -5 answer = event.solve() assert answer == solution
def test_event_containment_real(): assert (X << Interval(0, 10)) == EventInterval(X, Interval(0, 10)) for values in [FiniteReal(0, 10), [0, 10], {0, 10}]: assert (X << values) == EventFiniteReal(X, FiniteReal(0, 10)) # with pytest.raises(ValueError): # X << {1, None} assert X << {1, 2} == EventFiniteReal(X, {1, 2}) assert ~(X << {1, 2}) == EventOr([ EventInterval(X, Interval.Ropen(-oo, 1)), EventInterval(X, Interval.open(1, 2)), EventInterval(X, Interval.Lopen(2, oo)), ]) # https://github.com/probcomp/sum-product-dsl/issues/22 # and of EventBasic does not yet perform simplifications. assert ~(~(X << {1, 2})) == \ ((1 <= X) & ((X <= 1) | (2 <= X)) & (X <= 2))
def test_FiniteNominal_and(): assert FN('a', 'b') & EmptySet is EmptySet assert FN('a', 'b') & FN('c') is EmptySet assert FN('a', 'b', 'c') & FN('a') == FN('a') assert FN('a', 'b', 'c') & FN(b=True) == FN('a', 'b', 'c') assert FN('a', 'b', 'c') & FN('a') == FN('a') assert FN('a', 'b', 'c', b=True) & FN('a') is EmptySet assert FN('a', 'b', 'c', b=True) & FN('d', 'a', 'b') == FN('d') assert FN('a', 'b', 'c', b=True) & FN('d') == FN('d') assert FN('a', 'b', 'c') & FN('a', b=True) == FN('b', 'c') assert FN('a', 'b', 'c') & FN('d', 'a', 'b', b=True) == FN('c') assert FN('a', 'b', 'c') & FN('d', b=True) == FN('a', 'b', 'c') assert FN('a', 'b', 'c', b=True) & FN('d', b=True) == FN('a', 'b', 'c', 'd', b=True) assert FN('a', 'b', 'c', b=True) & FN('a', b=True) == FN('a', 'b', 'c', b=True) assert FN(b=True) & FN(b=True) == FN(b=True) # FiniteReal assert FN('a') & FR(1) is EmptySet assert FN('a', b=True) & FR(1) is EmptySet # Interval assert FN('a') & Interval(0, 1) is EmptySet
def test_parse_12(): # 2*sqrt(|x|) - 3 > 10 expr = 2*Sqrt(Abs(X)) - 3 expr_prime = Poly(Sqrt(Abs(Y)), [-3, 2]) assert expr == expr_prime event = EventInterval(expr, Interval.open(10, oo)) assert (expr > 10) == event
def test_parse_15(): # ((x**4)**(1/7)) < 9 expr = ((X**4))**Rat(1, 7) expr_prime = Radical(Pow(Y, 4), 7) assert expr == expr_prime event = EventInterval(expr_prime, Interval.open(-oo, 9)) assert (expr < 9) == event
def test_solver_19(): # 3*(x**(1/7))**4 - 3*(x**(1/7))**2 <= 9 # or || 3*(x**(1/7))**4 - 3*(x**(1/7))**2 > 11 solution = Union( Interval(0, (Rat(1, 2) + sympy.sqrt(13)/2)**(Rat(7, 2))), Interval.open((Rat(1,2) + sympy.sqrt(141)/6)**(Rat(7, 2)), oo)) Z = Y**(Rat(1, 7)) expr = 3*Z**4 - 3*Z**2 event = (expr <= 9) | (expr > 11) answer = event.solve() assert answer == solution interval = (~event).solve() assert interval == Interval.Lopen( solution.args[0].right, solution.args[1].left)
def test_parse_16(): # (x**(1/7))**4 < 9 for expr in [((X**Rat(1,7)))**4, (X**(1,7))**4]: expr_prime = Pow(Radical(Y, 7), 4) assert expr == expr_prime event = EventInterval(expr, Interval.open(-oo, 9)) assert (expr < 9) == event
def test_solver_23_reciprocal_range(): solution = Interval.Ropen(-1, -Rat(1, 3)) event = ((-3 < 1/Y) <= -1) assert event.solve() == solution solution = Interval.open(0, Rat(1, 3)) event = ((-3 < 1/(2*Y-1)) < -1) assert event.solve() == solution solution = Interval.open(-1 / sympy.sqrt(3), 1 / sympy.sqrt(3)) event = ((-3 < 1/(2*(abs(Y)**2)-1)) <= -1) assert event.solve() == solution solution = Union( Interval.open(-1 / sympy.sqrt(3), 0), Interval.open(0, 1 / sympy.sqrt(3))) event = ((-3 < 1/(2*(abs(Y)**2)-1)) < -1) assert event.solve() == solution