コード例 #1
0
def test_solver_21__ci_():
    # 1 <= log(x**3 - 3*x + 3) < 5
    # Can only be solved by numerical approximation of roots.
    # https://www.wolframalpha.com/input/?i=1+%3C%3D+log%28x**3+-+3x+%2B+3%29+%3C+5
    solution = Union(
        Interval(
            -1.777221448430427630375448631016427343692,
            0.09418455242255462832154474245589911789464),
        Interval.Ropen(
            1.683036896007873002053903888560528225797,
            5.448658707897512189124586716091172798465))

    expr = Log(Y**3 - 3*Y + 3)
    event = ((1 <= expr) & (expr < 5))
    answer = event.solve()
    assert isinstance(answer, Union)
    assert len(answer.args) == 2
    first = answer.args[0] if answer.args[0].a < 0 else answer.args[1]
    second = answer.args[0] if answer.args[0].a > 0 else answer.args[1]
    # Check first interval.
    assert not first.left_open
    assert not first.right_open
    assert allclose(float(first.a), float(solution.args[0].a))
    assert allclose(float(first.b), float(solution.args[0].b))
    # Check second interval.
    assert not second.left_open
    assert second.right_open
    assert allclose(float(second.a), float(solution.args[1].a))
    assert allclose(float(second.b), float(solution.args[1].b))
コード例 #2
0
def test_parse_2_closed():
    # (log(x) <= 2) & (x >= exp(2))
    expr = (Log(X) >= 2) & (X <= sympy.exp(2))
    event = EventAnd([
        EventInterval(Log(Y), Interval(2, oo)),
        EventInterval(Y, Interval(-oo, sympy.exp(2)))
    ])
    assert expr == event
コード例 #3
0
def test_solver_13():
    # 2*sqrt(|x|**2) - 3 > 10
    solution = Union(
        Interval.open(-oo, -Rat(13, 2)),
        Interval.open(Rat(13, 2), oo))
    event = (2*Sqrt(abs(Y)**2) - 3) > 10
    answer = event.solve()
    assert answer == solution
コード例 #4
0
def test_solver_14():
    # x**2 > 10
    solution = Union(
        Interval.open(-oo, -sympy.sqrt(10)),
        Interval.open(sympy.sqrt(10), oo))
    event = Y**2 > 10
    answer = event.solve()
    assert answer == solution
コード例 #5
0
def test_parse_2_open():
    # log(x) < 2 & (x < exp(2))
    expr = (Log(X) > 2) & (X < sympy.exp(2))
    event = EventAnd([
        EventInterval(Log(Y), Interval.open(2, oo)),
        EventInterval(Y, Interval.open(-oo, sympy.exp(2)))
    ])
    assert expr == event
コード例 #6
0
def test_solver_20():
    # log(x**2 - 3) < 5
    solution = Union(
        Interval.open(-sympy.sqrt(3 + sympy.exp(5)), -sympy.sqrt(3)),
        Interval.open(sympy.sqrt(3), sympy.sqrt(3 + sympy.exp(5))))
    event = Log(Y**2 - 3) < 5
    answer = event.solve()
    assert answer == solution
コード例 #7
0
def test_solver_6():
    # (x**2 - 2*x) > 10
    solution =  Union(
        Interval.open(-oo, 1 - sympy.sqrt(11)),
        Interval.open(1 + sympy.sqrt(11), oo))
    event = (Y**2 - 2*Y) > 10
    answer = event.solve()
    assert answer == solution
コード例 #8
0
def test_parse_5_lopen():
    # (2*x + 10 < 4) & (x + 10 >= 3)
    expr = ((2*X + 10) <= 4) & (X + 10 > 3)
    event = EventAnd([
        EventInterval(Poly(Y, [10, 2]), Interval(-oo, 4)),
        EventInterval(Poly(Y, [10, 1]), Interval.open(3, oo)),
    ])
    assert expr == event
コード例 #9
0
def test_parse_27_piecewise_many():
    assert (Y < 0)*(Y**2) + (0 <= Y)*Y**((1, 2)) == Piecewise(
        [
            Poly(Y, [0, 0, 1]),
            Radical(Y, 2)],
        [
            EventInterval(Y, Interval.open(-oo, 0)),
            EventInterval(Y, Interval(0, oo))
        ])
コード例 #10
0
def test_parse_6():
    # (x**2 - 2*x) > 10
    expr = (X**2 - 2*X) > 10
    event = EventInterval(Poly(Y, [0, -2, 1]), Interval.open(10, oo))
    assert expr == event

    # (exp(x)**2 - 2*exp(x)) > 10
    expr = (Exp(X)**2 - 2*Exp(X)) > 10
    event = EventInterval(Poly(Exp(X), [0, -2, 1]), Interval.open(10, oo))
    assert expr == event
コード例 #11
0
def test_solver_23_reciprocal_lte():
    for c in [1, 3]:
        # Positive
        # 1 / X < 10
        solution = Interval.Ropen(-oo, 0) | Interval.Lopen(Rat(c, 10), oo)
        event = (c / Y) < 10
        assert event.solve() == solution
        # 1 / X <= 10
        solution = Interval.Ropen(-oo, 0) | Interval(Rat(c, 10), oo)
        event = (c / Y) <= 10
        assert event.solve() == solution
        # 1 / X <= sqrt(2)
        solution = Interval.Ropen(-oo, 0) | Interval(c / sympy.sqrt(2), oo)
        event = (c / Y) <= sympy.sqrt(2)
        assert event.solve() == solution
        # Negative.
        # 1 / X < -10
        solution = Interval.open(-Rat(c, 10), 0)
        event = (c / Y) < -10
        assert event.solve() == solution
        # 1 / X <= -10
        solution = Interval.Ropen(-Rat(c, 10), 0)
        event = (c / Y) <= -10
        assert event.solve() == solution
        # 1 / X <= -sqrt(2)
        solution = Interval.Ropen(-c / sympy.sqrt(2), 0)
        event = (c / Y) <= -sympy.sqrt(2)
        assert event.solve() == solution
コード例 #12
0
def test_solver_18():
    # 3*(x**(1/7))**4 - 3*(x**(1/7))**2 <= 9
    solution = Interval(0, (Rat(1, 2) + sympy.sqrt(13)/2)**(Rat(7, 2)))

    Z = Y**(Rat(1, 7))
    expr = 3*Z**4 - 3*Z**2
    event = (expr <= 9)
    answer = event.solve()
    assert answer == solution

    interval = (~event).solve()
    assert interval == Interval.open(solution.right, oo)
コード例 #13
0
def test_Interval_in():
    with pytest.raises(Exception):
        Interval(3, 1)
    assert 1 in Interval(0, 1)
    assert 1 not in Interval.Ropen(0, 1)
    assert 1 in Interval.Lopen(0, 1)
    assert 0 in Interval(0, 1)
    assert 0 not in Interval.Lopen(0, 1)
    assert 0 in Interval.Ropen(0, 1)
    assert inf not in Interval(-inf, inf)
    assert -inf not in Interval(-inf, 0)
    assert 10 in Interval(-inf, inf)
コード例 #14
0
def test_simple_parse_real():
    assert isinstance(.3 * bernoulli(p=.1), DistributionMix)
    a = .3 * bernoulli(p=.1) | .5 * norm() | .2 * poisson(mu=7)
    spe = a(X)
    assert isinstance(spe, SumSPE)
    assert allclose(spe.weights, [log(.3), log(.5), log(.2)])
    assert isinstance(spe.children[0], DiscreteLeaf)
    assert isinstance(spe.children[1], ContinuousLeaf)
    assert isinstance(spe.children[2], DiscreteLeaf)
    assert spe.children[0].support == Interval(0, 1)
    assert spe.children[1].support == Interval(-oo, oo)
    assert spe.children[2].support == Interval(0, oo)
コード例 #15
0
def test_solver_27_piecewise_many():
    expr = (Y < 0)*(Y**2) + (0 <= Y)*Y**(Rat(1, 2))
    event = expr << {3}
    assert sorted(event.solve()) == [-sympy.sqrt(3), 9]
    event = 0 < expr
    assert event.solve() == Union(
        Interval.open(-oo, 0),
        Interval.open(0, oo))

    # TODO: Consider banning the restriction of a function
    # to a segment outside of its domain.
    expr = (Y < 0)*Y**(Rat(1, 2))
    assert (expr < 1).solve() is EmptySet
コード例 #16
0
def test_solver_24_negative_power_Rat():
    # Case 1.
    event = Y**Rat(-1, 3) < 6
    assert event.solve() == Interval.Lopen(Rat(1, 216), oo)
    # Case 2.
    event = (-1 < Y**Rat(-1, 3)) < 6
    assert event.solve() == Interval.Lopen(Rat(1, 216), oo)
    # Case 3.
    event = 5 <= Y**Rat(-1, 3)
    assert event.solve() == Interval.Lopen(0, Rat(1, 125))
    # Case 4.
    event = (5 <= Y**Rat(-1, 3)) < 6
    assert event.solve() == Interval.Lopen(Rat(1, 216), Rat(1, 125))
コード例 #17
0
def test_parse_26_piecewise_one_expr_compound_event():
    assert (Y**2)*((Y < 0) | (0 < Y)) == Piecewise(
        [Poly(Y, [0, 0, 1])],
        [EventOr([
            EventInterval(Y, Interval.open(-oo, 0)),
            EventInterval(Y, Interval.open(0, oo)),
            ])])

    assert (Y**2)*(~((3 < Y) <= 4)) == Piecewise(
        [Poly(Y, [0, 0, 1])],
        [EventOr([
            EventInterval(Y, Interval(-oo, 3)),
            EventInterval(Y, Interval.open(4, oo)),
            ])])
コード例 #18
0
def test_sum_normal_nominal():
    X = Id('X')
    children = [
        X >> norm(loc=0, scale=1),
        X >> choice({
            'low': Fraction(3, 10),
            'high': Fraction(7, 10)
        }),
    ]
    weights = [log(Fraction(4, 7)), log(Fraction(3, 7))]

    spe = SumSPE(children, weights)

    assert allclose(spe.logprob(X < 0),
                    log(Fraction(4, 7)) + log(Fraction(1, 2)))

    assert allclose(spe.logprob(X << {'low'}),
                    log(Fraction(3, 7)) + log(Fraction(3, 10)))

    # The semantics of ~(X<<{'low'}) are (X << String and X != 'low')
    assert allclose(spe.logprob(~(X << {'low'})), spe.logprob((X << {'high'})))
    assert allclose(
        spe.logprob((X << FiniteNominal(b=True)) & ~(X << {'low'})),
        spe.logprob((X << FiniteNominal(b=True)) & (X << {'high'})))

    assert isinf_neg(spe.logprob((X < 0) & (X << {'low'})))

    assert allclose(spe.logprob((X < 0) | (X << {'low'})),
                    logsumexp([spe.logprob(X < 0),
                               spe.logprob(X << {'low'})]))

    assert isinf_neg(spe.logprob(X << {'a'}))
    assert allclose(spe.logprob(~(X << {'a'})),
                    spe.logprob(X << {'low', 'high'}))

    assert allclose(spe.logprob(X**2 < 9),
                    log(Fraction(4, 7)) + spe.children[0].logprob(X**2 < 9))

    spe_condition = spe.condition(X**2 < 9)
    assert isinstance(spe_condition, ContinuousLeaf)
    assert spe_condition.support == Interval.open(-3, 3)

    spe_condition = spe.condition((X**2 < 9) | X << {'low'})
    assert isinstance(spe_condition, SumSPE)
    assert spe_condition.children[0].support == Interval.open(-3, 3)
    assert spe_condition.children[1].support == FiniteNominal('low', 'high')
    assert isinf_neg(spe_condition.children[1].logprob(X << {'high'}))

    assert spe_condition == spe.condition((X**2 < 9) | ~(X << {'high'}))
    assert allclose(spe.logprob((X < oo) | ~(X << {'1'})), 0)
コード例 #19
0
ファイル: test_indian_gpa.py プロジェクト: probcomp/sppl
def test_condition():
    model = model_no_latents()
    GPA = Id('GPA')
    model_condition = model.condition(GPA << {4} | GPA << {10})
    assert len(model_condition.children) == 2
    assert model_condition.children[0].support == Interval.Ropen(4, 5)
    assert model_condition.children[1].support == Interval.Ropen(10, 11)

    model_condition = model.condition((0 < GPA < 4))
    assert len(model_condition.children) == 2
    assert model_condition.children[0].support \
        == model_condition.children[1].support
    assert allclose(model_condition.children[0].logprob(GPA < 1),
                    model_condition.children[1].logprob(GPA < 1))
コード例 #20
0
def test_event_containment_union():
    assert (X << (Interval(0, 1) | Interval(2, 3))) \
        == (((0 <= X) <= 1) | ((2 <= X) <= 3))
    assert (X << (FiniteReal(0, 1) | Interval(2, 3))) \
        == ((X << {0, 1}) | ((2 <= X) <= 3))
    assert (X << FiniteNominal('a', b=True)) \
        == EventFiniteNominal(X, FiniteNominal('a', b=True))
    assert X << EmptySet == EventFiniteReal(X, EmptySet)
    # Ordering is not guaranteed.
    a = X << (Interval(0,1) | (FiniteReal(1.5) | FiniteNominal('a')))
    assert len(a.subexprs) == 3
    assert EventInterval(X, Interval(0,1)) in a.subexprs
    assert EventFiniteReal(X, FiniteReal(1.5)) in a.subexprs
    assert EventFiniteNominal(X, FiniteNominal('a')) in a.subexprs
コード例 #21
0
def test_parse_18():
    # 3*(x**(1/7))**4 - 3*(x**(1/7))**2 <= 9
    Z = X**Rat(1, 7)
    expr = 3*Z**4 - 3*Z**2
    expr_prime = Poly(Radical(Y, 7), [0, 0, -3, 0, 3])
    assert expr == expr_prime

    event = EventInterval(expr_prime, Interval(-oo, 9))
    assert (expr <= 9) == event

    event_not = EventInterval(expr_prime, Interval.open(9, oo))
    assert ~(expr <= 9) == event_not

    expr = (3*Abs(Z))**4 - (3*Abs(Z))**2
    expr_prime = Poly(Poly(Abs(Z), [0, 3]), [0, 0, -1, 0, 3])
コード例 #22
0
def test_solver_finite_symbolic():
    # Transform can never be symbolic.
    event = Y << {'a', 'b'}
    assert event.solve() == FiniteNominal('a', 'b')
    # Complement the Identity.
    event = ~(Y << {'a', 'b'})
    assert event.solve() == FiniteNominal('a', 'b', b=True)
    # Transform can never be symbolic.
    event = Y**2 << {'a', 'b'}
    assert event.solve() is EmptySet
    # Complement the Identity.
    event = ~(Y**2 << {'a', 'b'})
    assert event.solve() == FiniteNominal(b=True)
    # Solve Identity mixed.
    event = Y << {9, 'a', '7'}
    assert event.solve() == Union(
        FiniteReal(9),
        FiniteNominal('a', '7'))
    # Solve Transform mixed.
    event = Y**2 << {9, 'a', 'b'}
    assert event.solve() == FiniteReal(-3, 3)
    # Solve a disjunction.
    event = (Y << {'a', 'b'}) | (Y << {'c'})
    assert event.solve() == FiniteNominal('a', 'b', 'c')
    # Solve a conjunction with intersection.
    event = (Y << {'a', 'b'}) & (Y << {'b', 'c'})
    assert event.solve() == FiniteNominal('b')
    # Solve a conjunction with no intersection.
    event = (Y << {'a', 'b'}) & (Y << {'c'})
    assert event.solve() is EmptySet
    # Solve a disjunction with complement.
    event = (Y << {'a', 'b'}) & ~(Y << {'c'})
    assert event.solve() == FiniteNominal('a', 'b')
    # Solve a disjunction with complement.
    event = (Y << {'a', 'b'}) | ~(Y << {'c'})
    assert event.solve() == FiniteNominal('c', b=True)
    # Union of interval and symbolic.
    event = (Y**2 <= 9) | (Y << {'a'})
    assert event.solve() == Interval(-3, 3) | FiniteNominal('a')
    # Union of interval and not symbolic.
    event = (Y**2 <= 9) | ~(Y << {'a'})
    assert event.solve() == Interval(-3, 3) | FiniteNominal('a', b=True)
    # Intersection of interval and symbolic.
    event = (Y**2 <= 9) & (Y << {'a'})
    assert event.solve() is EmptySet
    # Intersection of interval and not symbolic.
    event = (Y**2 <= 9) & ~(Y << {'a'})
    assert event.solve() == EmptySet
コード例 #23
0
def test_solver_10():
    # Sympy hangs on this test.
    # exp(sqrt(log(x))) > -5
    solution = Interval(1, oo)
    event = Exp(Sqrt(Log(Y))) > -5
    answer = event.solve()
    assert answer == solution
コード例 #24
0
def test_event_containment_real():
    assert (X << Interval(0, 10)) == EventInterval(X, Interval(0, 10))
    for values in [FiniteReal(0, 10), [0, 10], {0, 10}]:
        assert (X << values) == EventFiniteReal(X, FiniteReal(0, 10))
    # with pytest.raises(ValueError):
    #     X << {1, None}
    assert X << {1, 2} == EventFiniteReal(X, {1, 2})
    assert ~(X << {1, 2}) == EventOr([
        EventInterval(X, Interval.Ropen(-oo, 1)),
        EventInterval(X, Interval.open(1, 2)),
        EventInterval(X, Interval.Lopen(2, oo)),
    ])
    # https://github.com/probcomp/sum-product-dsl/issues/22
    # and of EventBasic does not yet perform simplifications.
    assert ~(~(X << {1, 2})) == \
        ((1 <= X) & ((X <= 1) | (2 <= X)) & (X <= 2))
コード例 #25
0
def test_FiniteNominal_and():
    assert FN('a', 'b') & EmptySet is EmptySet
    assert FN('a', 'b') & FN('c') is EmptySet
    assert FN('a', 'b', 'c') & FN('a') == FN('a')
    assert FN('a', 'b', 'c') & FN(b=True) == FN('a', 'b', 'c')
    assert FN('a', 'b', 'c') & FN('a') == FN('a')
    assert FN('a', 'b', 'c', b=True) & FN('a') is EmptySet
    assert FN('a', 'b', 'c', b=True) & FN('d', 'a', 'b') == FN('d')
    assert FN('a', 'b', 'c', b=True) & FN('d') == FN('d')
    assert FN('a', 'b', 'c') & FN('a', b=True) == FN('b', 'c')
    assert FN('a', 'b', 'c') & FN('d', 'a', 'b', b=True) == FN('c')
    assert FN('a', 'b', 'c') & FN('d', b=True) == FN('a', 'b', 'c')
    assert FN('a', 'b', 'c', b=True) & FN('d', b=True) == FN('a',
                                                             'b',
                                                             'c',
                                                             'd',
                                                             b=True)
    assert FN('a', 'b', 'c', b=True) & FN('a', b=True) == FN('a',
                                                             'b',
                                                             'c',
                                                             b=True)
    assert FN(b=True) & FN(b=True) == FN(b=True)
    # FiniteReal
    assert FN('a') & FR(1) is EmptySet
    assert FN('a', b=True) & FR(1) is EmptySet
    # Interval
    assert FN('a') & Interval(0, 1) is EmptySet
コード例 #26
0
def test_parse_12():
    # 2*sqrt(|x|) - 3 > 10
    expr = 2*Sqrt(Abs(X)) - 3
    expr_prime = Poly(Sqrt(Abs(Y)), [-3, 2])
    assert expr == expr_prime

    event = EventInterval(expr, Interval.open(10, oo))
    assert (expr > 10) == event
コード例 #27
0
def test_parse_15():
    # ((x**4)**(1/7)) < 9
    expr = ((X**4))**Rat(1, 7)
    expr_prime = Radical(Pow(Y, 4), 7)
    assert expr == expr_prime

    event = EventInterval(expr_prime, Interval.open(-oo, 9))
    assert (expr < 9) == event
コード例 #28
0
def test_solver_19():
    # 3*(x**(1/7))**4 - 3*(x**(1/7))**2 <= 9
    #   or || 3*(x**(1/7))**4 - 3*(x**(1/7))**2 > 11
    solution = Union(
        Interval(0, (Rat(1, 2) + sympy.sqrt(13)/2)**(Rat(7, 2))),
        Interval.open((Rat(1,2) + sympy.sqrt(141)/6)**(Rat(7, 2)), oo))

    Z = Y**(Rat(1, 7))
    expr = 3*Z**4 - 3*Z**2
    event = (expr <= 9) | (expr > 11)
    answer = event.solve()
    assert answer == solution

    interval = (~event).solve()
    assert interval == Interval.Lopen(
        solution.args[0].right,
        solution.args[1].left)
コード例 #29
0
def test_parse_16():
    # (x**(1/7))**4 < 9
    for expr in [((X**Rat(1,7)))**4, (X**(1,7))**4]:
        expr_prime = Pow(Radical(Y, 7), 4)
        assert expr == expr_prime

    event = EventInterval(expr, Interval.open(-oo, 9))
    assert (expr < 9) == event
コード例 #30
0
def test_solver_23_reciprocal_range():
    solution = Interval.Ropen(-1, -Rat(1, 3))
    event = ((-3 < 1/Y) <= -1)
    assert event.solve() == solution

    solution = Interval.open(0, Rat(1, 3))
    event = ((-3 < 1/(2*Y-1)) < -1)
    assert event.solve() == solution

    solution = Interval.open(-1 / sympy.sqrt(3), 1 / sympy.sqrt(3))
    event = ((-3 < 1/(2*(abs(Y)**2)-1)) <= -1)
    assert event.solve() == solution

    solution = Union(
        Interval.open(-1 / sympy.sqrt(3), 0),
        Interval.open(0, 1 / sympy.sqrt(3)))
    event = ((-3 < 1/(2*(abs(Y)**2)-1)) < -1)
    assert event.solve() == solution