Exemple #1
0
def test_reduce_rational_inequalities_real_relational():
    assert reduce_rational_inequalities([], x) == False
    assert reduce_rational_inequalities(
        [[(x**2 + 3*x + 2)/(x**2 - 16) >= 0]], x, relational=False) == \
        Union(Interval.open(-oo, -4), Interval(-2, -1), Interval.open(4, oo))

    assert reduce_rational_inequalities(
        [[((-2*x - 10)*(3 - x))/((x**2 + 5)*(x - 2)**2) < 0]], x,
        relational=False) == \
        Union(Interval.open(-5, 2), Interval.open(2, 3))

    assert reduce_rational_inequalities([[(x + 1)/(x - 5) <= 0]], x,
        relational=False) == \
        Interval.Ropen(-1, 5)

    assert reduce_rational_inequalities([[(x**2 + 4*x + 3)/(x - 1) > 0]], x,
        relational=False) == \
        Union(Interval.open(-3, -1), Interval.open(1, oo))

    assert reduce_rational_inequalities([[(x**2 - 16)/(x - 1)**2 < 0]], x,
        relational=False) == \
        Union(Interval.open(-4, 1), Interval.open(1, 4))

    assert reduce_rational_inequalities([[(3*x + 1)/(x + 4) >= 1]], x,
        relational=False) == \
        Union(Interval.open(-oo, -4), Interval.Ropen(Rational(3, 2), oo))

    assert reduce_rational_inequalities([[(x - 8)/x <= 3 - x]], x,
        relational=False) == \
        Union(Interval.Lopen(-oo, -2), Interval.Lopen(0, 4))

    # issue sympy/sympy#10237
    assert reduce_rational_inequalities(
        [[x < oo, x >= 0, -oo < x]], x, relational=False) == Interval(0, oo)
Exemple #2
0
def test_GammaProcess_numeric():
    t, d, x, y = symbols('t d x y', positive=True)
    X = GammaProcess("X", 1, 2)
    assert X.state_space == Interval(0, oo)
    assert X.index_set == Interval(0, oo)
    assert X.lamda == 1
    assert X.gamma == 2

    raises(ValueError, lambda: GammaProcess("X", -1, 2))
    raises(ValueError, lambda: GammaProcess("X", 0, -2))
    raises(ValueError, lambda: GammaProcess("X", -1, -2))

    # all are independent because of non-overlapping intervals
    assert P((X(t) > 4) & (X(d) > 3) & (X(x) > 2) & (X(y) > 1), Contains(t,
        Interval.Lopen(0, 1)) & Contains(d, Interval.Lopen(1, 2)) & Contains(x,
        Interval.Lopen(2, 3)) & Contains(y, Interval.Lopen(3, 4))).simplify() == \
                                                            120*exp(-10)

    # Check working with Not and Or
    assert P(
        Not((X(t) < 5) & (X(d) > 3)),
        Contains(t, Interval.Ropen(2, 4)) & Contains(d, Interval.Lopen(
            7, 8))).simplify() == -4 * exp(-3) + 472 * exp(-8) / 3 + 1
    assert P((X(t) > 2) | (X(t) < 4), Contains(t, Interval.Ropen(1, 4))).simplify() == \
                                            -643*exp(-4)/15 + 109*exp(-2)/15 + 1

    assert E(X(t)) == 2 * t  # E(X(t)) == gamma*t/l
    assert E(X(2) + x * E(X(5))) == 10 * x + 4
Exemple #3
0
def test_WienerProcess():
    X = WienerProcess("X")
    assert X.state_space == S.Reals
    assert X.index_set == Interval(0, oo)

    t, d, x, y = symbols('t d x y', positive=True)
    assert isinstance(X(t), RandomIndexedSymbol)
    assert X.distribution(t) == NormalDistribution(0, sqrt(t))
    with warns_deprecated_sympy():
        X.distribution(X(t))
    raises(ValueError, lambda: PoissonProcess("X", -1))
    raises(NotImplementedError, lambda: X[t])
    raises(IndexError, lambda: X(-2))

    assert X.joint_distribution(X(2), X(3)) == JointDistributionHandmade(
        Lambda((X(2), X(3)),
               sqrt(6) * exp(-X(2)**2 / 4) * exp(-X(3)**2 / 6) / (12 * pi)))
    assert X.joint_distribution(4, 6) == JointDistributionHandmade(
        Lambda((X(4), X(6)),
               sqrt(6) * exp(-X(4)**2 / 8) * exp(-X(6)**2 / 12) / (24 * pi)))

    assert P(X(t) < 3).simplify() == erf(3 * sqrt(2) /
                                         (2 * sqrt(t))) / 2 + S(1) / 2
    assert P(X(t) > 2, Contains(t, Interval.Lopen(3, 7))).simplify() == S(1)/2 -\
                erf(sqrt(2)/2)/2

    # Equivalent to P(X(1)>1)**4
    assert P((X(t) > 4) & (X(d) > 3) & (X(x) > 2) & (X(y) > 1),
        Contains(t, Interval.Lopen(0, 1)) & Contains(d, Interval.Lopen(1, 2))
        & Contains(x, Interval.Lopen(2, 3)) & Contains(y, Interval.Lopen(3, 4))).simplify() ==\
        (1 - erf(sqrt(2)/2))*(1 - erf(sqrt(2)))*(1 - erf(3*sqrt(2)/2))*(1 - erf(2*sqrt(2)))/16

    # Contains an overlapping interval so, return Probability
    assert P((X(t) < 2) & (X(d) > 3),
             Contains(t, Interval.Lopen(0, 2))
             & Contains(d, Interval.Ropen(2, 4))) == Probability(
                 (X(d) > 3) & (X(t) < 2),
                 Contains(d, Interval.Ropen(2, 4))
                 & Contains(t, Interval.Lopen(0, 2)))

    assert str(P(Not((X(t) < 5) & (X(d) > 3)), Contains(t, Interval.Ropen(2, 4)) &
        Contains(d, Interval.Lopen(7, 8))).simplify()) == \
                '-(1 - erf(3*sqrt(2)/2))*(2 - erfc(5/2))/4 + 1'
    # Distribution has mean 0 at each timestamp
    assert E(X(t)) == 0
    assert E(
        x * (X(t) + X(d)) * (X(t)**2 + X(d)**2),
        Contains(t, Interval.Lopen(0, 1))
        & Contains(d, Interval.Ropen(1, 2))) == Expectation(
            x * (X(d) + X(t)) * (X(d)**2 + X(t)**2),
            Contains(d, Interval.Ropen(1, 2))
            & Contains(t, Interval.Lopen(0, 1)))
    assert E(X(t) + x * E(X(3))) == 0

    #test issue 20078
    assert (2 * X(t) + 3 * X(t)).simplify() == 5 * X(t)
    assert (2 * X(t) - 3 * X(t)).simplify() == -X(t)
    assert (2 * (0.25 * X(t))).simplify() == 0.5 * X(t)
    assert (2 * X(t) * 0.25 * X(t)).simplify() == 0.5 * X(t)**2
    assert (X(t)**2 + X(t)**3).simplify() == (X(t) + 1) * X(t)**2
Exemple #4
0
def test_GammaProcess_symbolic():
    t, d, x, y, g, l = symbols('t d x y g l', positive=True)
    X = GammaProcess("X", l, g)

    raises(NotImplementedError, lambda: X[t])
    raises(IndexError, lambda: X(-1))
    assert isinstance(X(t), RandomIndexedSymbol)
    assert X.state_space == Interval(0, oo)
    assert X.distribution(t) == GammaDistribution(g * t, 1 / l)
    assert X.joint_distribution(5, X(3)) == JointDistributionHandmade(
        Lambda(
            (X(5), X(3)),
            l**(8 * g) * exp(-l * X(3)) * exp(-l * X(5)) * X(3)**(3 * g - 1) *
            X(5)**(5 * g - 1) / (gamma(3 * g) * gamma(5 * g))))
    # property of the gamma process at any given timestamp
    assert E(X(t)) == g * t / l
    assert variance(X(t)).simplify() == g * t / l**2

    # Equivalent to E(2*X(1)) + E(X(1)**2) + E(X(1)**3), where E(X(1)) == g/l
    assert E(X(t)**2 + X(d)*2 + X(y)**3, Contains(t, Interval.Lopen(0, 1))
        & Contains(d, Interval.Lopen(1, 2)) & Contains(y, Interval.Ropen(3, 4))) == \
            2*g/l + (g**2 + g)/l**2 + (g**3 + 3*g**2 + 2*g)/l**3

    assert P(X(t) > 3, Contains(t, Interval.Lopen(3, 4))).simplify() == \
                                1 - lowergamma(g, 3*l)/gamma(g) # equivalent to P(X(1)>3)

    #test issue 20078
    assert (2 * X(t) + 3 * X(t)).simplify() == 5 * X(t)
    assert (2 * X(t) - 3 * X(t)).simplify() == -X(t)
    assert (2 * (0.25 * X(t))).simplify() == 0.5 * X(t)
    assert (2 * X(t) * 0.25 * X(t)).simplify() == 0.5 * X(t)**2
    assert (X(t)**2 + X(t)**3).simplify() == (X(t) + 1) * X(t)**2
Exemple #5
0
def test_stationary_points():
    x, y = symbols('x y')

    assert stationary_points(sin(x), x, Interval(-pi/2, pi/2)
        ) == {-pi/2, pi/2}
    assert  stationary_points(sin(x), x, Interval.Ropen(0, pi/4)
        ) == EmptySet()
    assert stationary_points(tan(x), x,
        ) == EmptySet()
    assert stationary_points(sin(x)*cos(x), x, Interval(0, pi)
        ) == {pi/4, pi*Rational(3, 4)}
    assert stationary_points(sec(x), x, Interval(0, pi)
        ) == {0, pi}
    assert stationary_points((x+3)*(x-2), x
        ) == FiniteSet(Rational(-1, 2))
    assert stationary_points((x + 3)/(x - 2), x, Interval(-5, 5)
        ) == EmptySet()
    assert stationary_points((x**2+3)/(x-2), x
        ) == {2 - sqrt(7), 2 + sqrt(7)}
    assert stationary_points((x**2+3)/(x-2), x, Interval(0, 5)
        ) == {2 + sqrt(7)}
    assert stationary_points(x**4 + x**3 - 5*x**2, x, S.Reals
        ) == FiniteSet(-2, 0, Rational(5, 4))
    assert stationary_points(exp(x), x
        ) == EmptySet()
    assert stationary_points(log(x) - x, x, S.Reals
        ) == {1}
    assert stationary_points(cos(x), x, Union(Interval(0, 5), Interval(-6, -3))
        ) == {0, -pi, pi}
    assert stationary_points(y, x, S.Reals
        ) == S.Reals
    assert stationary_points(y, x, S.EmptySet) == S.EmptySet
Exemple #6
0
def test_is_convex():
    assert is_convex(1/x, x, domain=Interval.open(0, oo)) == True
    assert is_convex(1/x, x, domain=Interval(-oo, 0)) == False
    assert is_convex(x**2, x, domain=Interval(0, oo)) == True
    assert is_convex(1/x**3, x, domain=Interval.Lopen(0, oo)) == True
    assert is_convex(-1/x**3, x, domain=Interval.Ropen(-oo, 0)) == True
    assert is_convex(log(x), x) == False
    raises(NotImplementedError, lambda: is_convex(log(x), x, a))
Exemple #7
0
def test_trig_inequalities():
    # all the inequalities are solved in a periodic interval.
    assert isolve(sin(x) < S.Half, x, relational=False) == \
        Union(Interval(0, pi/6, False, True), Interval.open(pi*Rational(5, 6), 2*pi))
    assert isolve(sin(x) > S.Half, x, relational=False) == \
        Interval(pi/6, pi*Rational(5, 6), True, True)
    assert isolve(cos(x) < S.Zero, x, relational=False) == \
        Interval(pi/2, pi*Rational(3, 2), True, True)
    assert isolve(cos(x) >= S.Zero, x, relational=False) == \
        Union(Interval(0, pi/2), Interval.Ropen(pi*Rational(3, 2), 2*pi))

    assert isolve(tan(x) < S.One, x, relational=False) == \
        Union(Interval.Ropen(0, pi/4), Interval.open(pi/2, pi))

    assert isolve(sin(x) <= S.Zero, x, relational=False) == \
        Union(FiniteSet(S.Zero), Interval.Ropen(pi, 2*pi))

    assert isolve(sin(x) <= S.One, x, relational=False) == S.Reals
    assert isolve(cos(x) < S(-2), x, relational=False) == S.EmptySet
    assert isolve(sin(x) >= S.NegativeOne, x, relational=False) == S.Reals
    assert isolve(cos(x) > S.One, x, relational=False) == S.EmptySet
Exemple #8
0
def test_PoissonProcess():
    X = PoissonProcess("X", 3)
    assert X.state_space == S.Naturals0
    assert X.index_set == Interval(0, oo)
    assert X.lamda == 3

    t, d, x, y = symbols('t d x y', positive=True)
    assert isinstance(X(t), RandomIndexedSymbol)
    assert X.distribution(t) == PoissonDistribution(3 * t)
    raises(ValueError, lambda: PoissonProcess("X", -1))
    raises(NotImplementedError, lambda: X[t])
    raises(IndexError, lambda: X(-5))

    assert X.joint_distribution(X(2), X(3)) == JointDistributionHandmade(
        Lambda((X(2), X(3)), 6**X(2) * 9**X(3) * exp(-15) /
               (factorial(X(2)) * factorial(X(3)))))

    assert X.joint_distribution(4, 6) == JointDistributionHandmade(
        Lambda((X(4), X(6)), 12**X(4) * 18**X(6) * exp(-30) /
               (factorial(X(4)) * factorial(X(6)))))

    assert P(X(t) < 1) == exp(-3 * t)
    assert P(Eq(X(t), 0),
             Contains(t, Interval.Lopen(3, 5))) == exp(-6)  # exp(-2*lamda)
    res = P(Eq(X(t), 1), Contains(t, Interval.Lopen(3, 4)))
    assert res == 3 * exp(-3)

    # Equivalent to P(Eq(X(t), 1))**4 because of non-overlapping intervals
    assert P(
        Eq(X(t), 1) & Eq(X(d), 1) & Eq(X(x), 1) & Eq(X(y), 1),
        Contains(t, Interval.Lopen(0, 1))
        & Contains(d, Interval.Lopen(1, 2)) & Contains(x, Interval.Lopen(2, 3))
        & Contains(y, Interval.Lopen(3, 4))) == res**4

    # Return Probability because of overlapping intervals
    assert P(Eq(X(t), 2) & Eq(X(d), 3), Contains(t, Interval.Lopen(0, 2))
    & Contains(d, Interval.Ropen(2, 4))) == \
                Probability(Eq(X(d), 3) & Eq(X(t), 2), Contains(t, Interval.Lopen(0, 2))
                & Contains(d, Interval.Ropen(2, 4)))

    raises(ValueError, lambda: P(
        Eq(X(t), 2) & Eq(X(d), 3),
        Contains(t, Interval.Lopen(0, 4)) & Contains(d, Interval.Lopen(3, oo)))
           )  # no bound on d
    assert P(Eq(X(3), 2)) == 81 * exp(-9) / 2
    assert P(Eq(X(t), 2), Contains(t, Interval.Lopen(0,
                                                     5))) == 225 * exp(-15) / 2

    # Check that probability works correctly by adding it to 1
    res1 = P(X(t) <= 3, Contains(t, Interval.Lopen(0, 5)))
    res2 = P(X(t) > 3, Contains(t, Interval.Lopen(0, 5)))
    assert res1 == 691 * exp(-15)
    assert (res1 + res2).simplify() == 1

    # Check Not and  Or
    assert P(Not(Eq(X(t), 2) & (X(d) > 3)), Contains(t, Interval.Ropen(2, 4)) & \
            Contains(d, Interval.Lopen(7, 8))).simplify() == -18*exp(-6) + 234*exp(-9) + 1
    assert P(Eq(X(t), 2) | Ne(X(t), 4),
             Contains(t, Interval.Ropen(2, 4))) == 1 - 36 * exp(-6)
    raises(ValueError, lambda: P(X(t) > 2, X(t) + X(d)))
    assert E(
        X(t)) == 3 * t  # property of the distribution at a given timestamp
    assert E(
        X(t)**2 + X(d) * 2 + X(y)**3,
        Contains(t, Interval.Lopen(0, 1))
        & Contains(d, Interval.Lopen(1, 2))
        & Contains(y, Interval.Ropen(3, 4))) == 75
    assert E(X(t)**2, Contains(t, Interval.Lopen(0, 1))) == 12
    assert E(x*(X(t) + X(d))*(X(t)**2+X(d)**2), Contains(t, Interval.Lopen(0, 1))
    & Contains(d, Interval.Ropen(1, 2))) == \
            Expectation(x*(X(d) + X(t))*(X(d)**2 + X(t)**2), Contains(t, Interval.Lopen(0, 1))
            & Contains(d, Interval.Ropen(1, 2)))

    # Value Error because of infinite time bound
    raises(ValueError, lambda: E(X(t)**3, Contains(t, Interval.Lopen(1, oo))))

    # Equivalent to E(X(t)**2) - E(X(d)**2) == E(X(1)**2) - E(X(1)**2) == 0
    assert E((X(t) + X(d)) * (X(t) - X(d)),
             Contains(t, Interval.Lopen(0, 1))
             & Contains(d, Interval.Lopen(1, 2))) == 0
    assert E(X(2) + x * E(X(5))) == 15 * x + 6
    assert E(x * X(1) + y) == 3 * x + y
    assert P(Eq(X(1), 2) & Eq(X(t), 3),
             Contains(t, Interval.Lopen(1, 2))) == 81 * exp(-6) / 4
    Y = PoissonProcess("Y", 6)
    Z = X + Y
    assert Z.lamda == X.lamda + Y.lamda == 9
    raises(ValueError,
           lambda: X + 5)  # should be added be only PoissonProcess instance
    N, M = Z.split(4, 5)
    assert N.lamda == 4
    assert M.lamda == 5
    raises(ValueError, lambda: Z.split(3, 2))  # 2+3 != 9

    raises(
        ValueError, lambda: P(Eq(X(t), 0),
                              Contains(t, Interval.Lopen(1, 3)) & Eq(X(1), 0)))
    # check if it handles queries with two random variables in one args
    res1 = P(Eq(N(3), N(5)))
    assert res1 == P(Eq(N(t), 0), Contains(t, Interval(3, 5)))
    res2 = P(N(3) > N(1))
    assert res2 == P((N(t) > 0), Contains(t, Interval(1, 3)))
    assert P(N(3) < N(1)) == 0  # condition is not possible
    res3 = P(N(3) <= N(1))  # holds only for Eq(N(3), N(1))
    assert res3 == P(Eq(N(t), 0), Contains(t, Interval(1, 3)))

    # tests from https://www.probabilitycourse.com/chapter11/11_1_2_basic_concepts_of_the_poisson_process.php
    X = PoissonProcess('X', 10)  # 11.1
    assert P(Eq(X(S(1) / 3), 3)
             & Eq(X(1), 10)) == exp(-10) * Rational(8000000000, 11160261)
    assert P(Eq(X(1), 1), Eq(X(S(1) / 3), 3)) == 0
    assert P(Eq(X(1), 10), Eq(X(S(1) / 3), 3)) == P(Eq(X(S(2) / 3), 7))

    X = PoissonProcess('X', 2)  # 11.2
    assert P(X(S(1) / 2) < 1) == exp(-1)
    assert P(X(3) < 1, Eq(X(1), 0)) == exp(-4)
    assert P(Eq(X(4), 3), Eq(X(2), 3)) == exp(-4)

    X = PoissonProcess('X', 3)
    assert P(Eq(X(2), 5) & Eq(X(1), 2)) == Rational(81, 4) * exp(-6)

    # check few properties
    assert P(
        X(2) <= 3,
        X(1) >= 1) == 3 * P(Eq(X(1), 0)) + 2 * P(Eq(X(1), 1)) + P(Eq(X(1), 2))
    assert P(X(2) <= 3, X(1) > 1) == 2 * P(Eq(X(1), 0)) + 1 * P(Eq(X(1), 1))
    assert P(Eq(X(2), 5) & Eq(X(1), 2)) == P(Eq(X(1), 3)) * P(Eq(X(1), 2))
    assert P(Eq(X(3), 4), Eq(X(1), 3)) == P(Eq(X(2), 1))

    #test issue 20078
    assert (2 * X(t) + 3 * X(t)).simplify() == 5 * X(t)
    assert (2 * X(t) - 3 * X(t)).simplify() == -X(t)
    assert (2 * (0.25 * X(t))).simplify() == 0.5 * X(t)
    assert (2 * X(t) * 0.25 * X(t)).simplify() == 0.5 * X(t)**2
    assert (X(t)**2 + X(t)**3).simplify() == (X(t) + 1) * X(t)**2
Exemple #9
0
def test_normalize_theta_set():

    # Interval
    assert normalize_theta_set(Interval(pi, 2*pi)) == \
        Union(FiniteSet(0), Interval.Ropen(pi, 2*pi))
    assert normalize_theta_set(Interval(9 * pi / 2,
                                        5 * pi)) == Interval(pi / 2, pi)
    assert normalize_theta_set(Interval(-3 * pi / 2,
                                        pi / 2)) == Interval.Ropen(0, 2 * pi)
    assert normalize_theta_set(Interval.open(-3*pi/2, pi/2)) == \
        Union(Interval.Ropen(0, pi/2), Interval.open(pi/2, 2*pi))
    assert normalize_theta_set(Interval.open(-7*pi/2, -3*pi/2)) == \
        Union(Interval.Ropen(0, pi/2), Interval.open(pi/2, 2*pi))
    assert normalize_theta_set(Interval(-pi/2, pi/2)) == \
        Union(Interval(0, pi/2), Interval.Ropen(3*pi/2, 2*pi))
    assert normalize_theta_set(Interval.open(-pi/2, pi/2)) == \
        Union(Interval.Ropen(0, pi/2), Interval.open(3*pi/2, 2*pi))
    assert normalize_theta_set(Interval(-4 * pi,
                                        3 * pi)) == Interval.Ropen(0, 2 * pi)
    assert normalize_theta_set(Interval(-3 * pi / 2, -pi / 2)) == Interval(
        pi / 2, 3 * pi / 2)
    assert normalize_theta_set(Interval.open(0, 2 * pi)) == Interval.open(
        0, 2 * pi)
    assert normalize_theta_set(Interval.Ropen(-pi/2, pi/2)) == \
        Union(Interval.Ropen(0, pi/2), Interval.Ropen(3*pi/2, 2*pi))
    assert normalize_theta_set(Interval.Lopen(-pi/2, pi/2)) == \
        Union(Interval(0, pi/2), Interval.open(3*pi/2, 2*pi))
    assert normalize_theta_set(Interval(-pi/2, pi/2)) == \
        Union(Interval(0, pi/2), Interval.Ropen(3*pi/2, 2*pi))
    assert normalize_theta_set(Interval.open(4 * pi,
                                             9 * pi / 2)) == Interval.open(
                                                 0, pi / 2)
    assert normalize_theta_set(Interval.Lopen(4 * pi,
                                              9 * pi / 2)) == Interval.Lopen(
                                                  0, pi / 2)
    assert normalize_theta_set(Interval.Ropen(4 * pi,
                                              9 * pi / 2)) == Interval.Ropen(
                                                  0, pi / 2)
    assert normalize_theta_set(Interval.open(3*pi, 5*pi)) == \
        Union(Interval.Ropen(0, pi), Interval.open(pi, 2*pi))

    # FiniteSet
    assert normalize_theta_set(FiniteSet(0, pi, 3 * pi)) == FiniteSet(0, pi)
    assert normalize_theta_set(FiniteSet(0, pi / 2, pi,
                                         2 * pi)) == FiniteSet(0, pi / 2, pi)
    assert normalize_theta_set(FiniteSet(0, -pi / 2, -pi,
                                         -2 * pi)) == FiniteSet(
                                             0, pi, 3 * pi / 2)
    assert normalize_theta_set(FiniteSet(-3*pi/2, pi/2)) == \
        FiniteSet(pi/2)
    assert normalize_theta_set(FiniteSet(2 * pi)) == FiniteSet(0)

    # Unions
    assert normalize_theta_set(Union(Interval(0, pi/3), Interval(pi/2, pi))) == \
        Union(Interval(0, pi/3), Interval(pi/2, pi))
    assert normalize_theta_set(Union(Interval(0, pi), Interval(2*pi, 7*pi/3))) == \
        Interval(0, pi)

    # ValueError for non-real sets
    raises(ValueError, lambda: normalize_theta_set(S.Complexes))
Exemple #10
0
def test_normalize_theta_set():
    # Interval
    assert normalize_theta_set(Interval(pi, 2*pi)) == \
        Union(FiniteSet(0), Interval.Ropen(pi, 2*pi))
    assert normalize_theta_set(Interval(pi * Rational(9, 2),
                                        5 * pi)) == Interval(pi / 2, pi)
    assert normalize_theta_set(Interval(pi * Rational(-3, 2),
                                        pi / 2)) == Interval.Ropen(0, 2 * pi)
    assert normalize_theta_set(Interval.open(pi*Rational(-3, 2), pi/2)) == \
        Union(Interval.Ropen(0, pi/2), Interval.open(pi/2, 2*pi))
    assert normalize_theta_set(Interval.open(pi*Rational(-7, 2), pi*Rational(-3, 2))) == \
        Union(Interval.Ropen(0, pi/2), Interval.open(pi/2, 2*pi))
    assert normalize_theta_set(Interval(-pi/2, pi/2)) == \
        Union(Interval(0, pi/2), Interval.Ropen(pi*Rational(3, 2), 2*pi))
    assert normalize_theta_set(Interval.open(-pi/2, pi/2)) == \
        Union(Interval.Ropen(0, pi/2), Interval.open(pi*Rational(3, 2), 2*pi))
    assert normalize_theta_set(Interval(-4 * pi,
                                        3 * pi)) == Interval.Ropen(0, 2 * pi)
    assert normalize_theta_set(Interval(pi * Rational(-3, 2),
                                        -pi / 2)) == Interval(
                                            pi / 2, pi * Rational(3, 2))
    assert normalize_theta_set(Interval.open(0, 2 * pi)) == Interval.open(
        0, 2 * pi)
    assert normalize_theta_set(Interval.Ropen(-pi/2, pi/2)) == \
        Union(Interval.Ropen(0, pi/2), Interval.Ropen(pi*Rational(3, 2), 2*pi))
    assert normalize_theta_set(Interval.Lopen(-pi/2, pi/2)) == \
        Union(Interval(0, pi/2), Interval.open(pi*Rational(3, 2), 2*pi))
    assert normalize_theta_set(Interval(-pi/2, pi/2)) == \
        Union(Interval(0, pi/2), Interval.Ropen(pi*Rational(3, 2), 2*pi))
    assert normalize_theta_set(Interval.open(
        4 * pi, pi * Rational(9, 2))) == Interval.open(0, pi / 2)
    assert normalize_theta_set(Interval.Lopen(
        4 * pi, pi * Rational(9, 2))) == Interval.Lopen(0, pi / 2)
    assert normalize_theta_set(Interval.Ropen(
        4 * pi, pi * Rational(9, 2))) == Interval.Ropen(0, pi / 2)
    assert normalize_theta_set(Interval.open(3*pi, 5*pi)) == \
        Union(Interval.Ropen(0, pi), Interval.open(pi, 2*pi))

    # FiniteSet
    assert normalize_theta_set(FiniteSet(0, pi, 3 * pi)) == FiniteSet(0, pi)
    assert normalize_theta_set(FiniteSet(0, pi / 2, pi,
                                         2 * pi)) == FiniteSet(0, pi / 2, pi)
    assert normalize_theta_set(FiniteSet(0, -pi / 2, -pi,
                                         -2 * pi)) == FiniteSet(
                                             0, pi, pi * Rational(3, 2))
    assert normalize_theta_set(FiniteSet(pi*Rational(-3, 2), pi/2)) == \
        FiniteSet(pi/2)
    assert normalize_theta_set(FiniteSet(2 * pi)) == FiniteSet(0)

    # Unions
    assert normalize_theta_set(Union(Interval(0, pi/3), Interval(pi/2, pi))) == \
        Union(Interval(0, pi/3), Interval(pi/2, pi))
    assert normalize_theta_set(Union(Interval(0, pi), Interval(2*pi, pi*Rational(7, 3)))) == \
        Interval(0, pi)

    # ValueError for non-real sets
    raises(ValueError, lambda: normalize_theta_set(S.Complexes))

    # NotImplementedError for subset of reals
    raises(NotImplementedError, lambda: normalize_theta_set(Interval(0, 1)))

    # NotImplementedError without pi as coefficient
    raises(NotImplementedError,
           lambda: normalize_theta_set(Interval(1, 2 * pi)))
    raises(NotImplementedError,
           lambda: normalize_theta_set(Interval(2 * pi, 10)))
    raises(NotImplementedError,
           lambda: normalize_theta_set(FiniteSet(0, 3, 3 * pi)))
Exemple #11
0
def solve_univariate_inequality(expr,
                                gen,
                                relational=True,
                                domain=S.Reals,
                                continuous=False):
    """Solves a real univariate inequality.

    Parameters
    ==========

    expr : Relational
        The target inequality
    gen : Symbol
        The variable for which the inequality is solved
    relational : bool
        A Relational type output is expected or not
    domain : Set
        The domain over which the equation is solved
    continuous: bool
        True if expr is known to be continuous over the given domain
        (and so continuous_domain() doesn't need to be called on it)

    Raises
    ======

    NotImplementedError
        The solution of the inequality cannot be determined due to limitation
        in :func:`sympy.solvers.solveset.solvify`.

    Notes
    =====

    Currently, we cannot solve all the inequalities due to limitations in
    :func:`sympy.solvers.solveset.solvify`. Also, the solution returned for trigonometric inequalities
    are restricted in its periodic interval.

    See Also
    ========

    sympy.solvers.solveset.solvify: solver returning solveset solutions with solve's output API

    Examples
    ========

    >>> from sympy import solve_univariate_inequality, Symbol, sin, Interval, S
    >>> x = Symbol('x')

    >>> solve_univariate_inequality(x**2 >= 4, x)
    ((2 <= x) & (x < oo)) | ((-oo < x) & (x <= -2))

    >>> solve_univariate_inequality(x**2 >= 4, x, relational=False)
    Union(Interval(-oo, -2), Interval(2, oo))

    >>> domain = Interval(0, S.Infinity)
    >>> solve_univariate_inequality(x**2 >= 4, x, False, domain)
    Interval(2, oo)

    >>> solve_univariate_inequality(sin(x) > 0, x, relational=False)
    Interval.open(0, pi)

    """
    from sympy.solvers.solvers import denoms

    if domain.is_subset(S.Reals) is False:
        raise NotImplementedError(
            filldedent('''
        Inequalities in the complex domain are
        not supported. Try the real domain by
        setting domain=S.Reals'''))
    elif domain is not S.Reals:
        rv = solve_univariate_inequality(
            expr, gen, relational=False,
            continuous=continuous).intersection(domain)
        if relational:
            rv = rv.as_relational(gen)
        return rv
    else:
        pass  # continue with attempt to solve in Real domain

    # This keeps the function independent of the assumptions about `gen`.
    # `solveset` makes sure this function is called only when the domain is
    # real.
    _gen = gen
    _domain = domain
    if gen.is_extended_real is False:
        rv = S.EmptySet
        return rv if not relational else rv.as_relational(_gen)
    elif gen.is_extended_real is None:
        gen = Dummy('gen', extended_real=True)
        try:
            expr = expr.xreplace({_gen: gen})
        except TypeError:
            raise TypeError(
                filldedent('''
                When gen is real, the relational has a complex part
                which leads to an invalid comparison like I < 0.
                '''))

    rv = None

    if expr is S.true:
        rv = domain

    elif expr is S.false:
        rv = S.EmptySet

    else:
        e = expr.lhs - expr.rhs
        period = periodicity(e, gen)
        if period == S.Zero:
            e = expand_mul(e)
            const = expr.func(e, 0)
            if const is S.true:
                rv = domain
            elif const is S.false:
                rv = S.EmptySet
        elif period is not None:
            frange = function_range(e, gen, domain)

            rel = expr.rel_op
            if rel in ('<', '<='):
                if expr.func(frange.sup, 0):
                    rv = domain
                elif not expr.func(frange.inf, 0):
                    rv = S.EmptySet

            elif rel in ('>', '>='):
                if expr.func(frange.inf, 0):
                    rv = domain
                elif not expr.func(frange.sup, 0):
                    rv = S.EmptySet

            inf, sup = domain.inf, domain.sup
            if sup - inf is S.Infinity:
                domain = Interval(0, period, False, True).intersect(_domain)
                _domain = domain

        if rv is None:
            n, d = e.as_numer_denom()
            try:
                if gen not in n.free_symbols and len(e.free_symbols) > 1:
                    raise ValueError
                # this might raise ValueError on its own
                # or it might give None...
                solns = solvify(e, gen, domain)
                if solns is None:
                    # in which case we raise ValueError
                    raise ValueError
            except (ValueError, NotImplementedError):
                # replace gen with generic x since it's
                # univariate anyway
                raise NotImplementedError(
                    filldedent('''
                    The inequality, %s, cannot be solved using
                    solve_univariate_inequality.
                    ''' % expr.subs(gen, Symbol('x'))))

            expanded_e = expand_mul(e)

            def valid(x):
                # this is used to see if gen=x satisfies the
                # relational by substituting it into the
                # expanded form and testing against 0, e.g.
                # if expr = x*(x + 1) < 2 then e = x*(x + 1) - 2
                # and expanded_e = x**2 + x - 2; the test is
                # whether a given value of x satisfies
                # x**2 + x - 2 < 0
                #
                # expanded_e, expr and gen used from enclosing scope
                v = expanded_e.subs(gen, expand_mul(x))
                try:
                    r = expr.func(v, 0)
                except TypeError:
                    r = S.false
                if r in (S.true, S.false):
                    return r
                if v.is_extended_real is False:
                    return S.false
                else:
                    v = v.n(2)
                    if v.is_comparable:
                        return expr.func(v, 0)
                    # not comparable or couldn't be evaluated
                    raise NotImplementedError(
                        'relationship did not evaluate: %s' % r)

            singularities = []
            for d in denoms(expr, gen):
                singularities.extend(solvify(d, gen, domain))
            if not continuous:
                domain = continuous_domain(expanded_e, gen, domain)

            include_x = '=' in expr.rel_op and expr.rel_op != '!='

            try:
                discontinuities = set(domain.boundary -
                                      FiniteSet(domain.inf, domain.sup))
                # remove points that are not between inf and sup of domain
                critical_points = FiniteSet(
                    *(solns + singularities +
                      list(discontinuities))).intersection(
                          Interval(domain.inf, domain.sup, domain.inf
                                   not in domain, domain.sup not in domain))
                if all(r.is_number for r in critical_points):
                    reals = _nsort(critical_points, separated=True)[0]
                else:
                    sifted = sift(critical_points,
                                  lambda x: x.is_extended_real)
                    if sifted[None]:
                        # there were some roots that weren't known
                        # to be real
                        raise NotImplementedError
                    try:
                        reals = sifted[True]
                        if len(reals) > 1:
                            reals = list(sorted(reals))
                    except TypeError:
                        raise NotImplementedError
            except NotImplementedError:
                raise NotImplementedError(
                    'sorting of these roots is not supported')

            # If expr contains imaginary coefficients, only take real
            # values of x for which the imaginary part is 0
            make_real = S.Reals
            if im(expanded_e) != S.Zero:
                check = True
                im_sol = FiniteSet()
                try:
                    a = solveset(im(expanded_e), gen, domain)
                    if not isinstance(a, Interval):
                        for z in a:
                            if z not in singularities and valid(
                                    z) and z.is_extended_real:
                                im_sol += FiniteSet(z)
                    else:
                        start, end = a.inf, a.sup
                        for z in _nsort(critical_points + FiniteSet(end)):
                            valid_start = valid(start)
                            if start != end:
                                valid_z = valid(z)
                                pt = _pt(start, z)
                                if pt not in singularities and pt.is_extended_real and valid(
                                        pt):
                                    if valid_start and valid_z:
                                        im_sol += Interval(start, z)
                                    elif valid_start:
                                        im_sol += Interval.Ropen(start, z)
                                    elif valid_z:
                                        im_sol += Interval.Lopen(start, z)
                                    else:
                                        im_sol += Interval.open(start, z)
                            start = z
                        for s in singularities:
                            im_sol -= FiniteSet(s)
                except (TypeError):
                    im_sol = S.Reals
                    check = False

                if im_sol is S.EmptySet:
                    raise ValueError(
                        filldedent('''
                        %s contains imaginary parts which cannot be
                        made 0 for any value of %s satisfying the
                        inequality, leading to relations like I < 0.
                        ''' % (expr.subs(gen, _gen), _gen)))

                make_real = make_real.intersect(im_sol)

            sol_sets = [S.EmptySet]

            start = domain.inf
            if start in domain and valid(start) and start.is_finite:
                sol_sets.append(FiniteSet(start))

            for x in reals:
                end = x

                if valid(_pt(start, end)):
                    sol_sets.append(Interval(start, end, True, True))

                if x in singularities:
                    singularities.remove(x)
                else:
                    if x in discontinuities:
                        discontinuities.remove(x)
                        _valid = valid(x)
                    else:  # it's a solution
                        _valid = include_x
                    if _valid:
                        sol_sets.append(FiniteSet(x))

                start = end

            end = domain.sup
            if end in domain and valid(end) and end.is_finite:
                sol_sets.append(FiniteSet(end))

            if valid(_pt(start, end)):
                sol_sets.append(Interval.open(start, end))

            if im(expanded_e) != S.Zero and check:
                rv = (make_real).intersect(_domain)
            else:
                rv = Intersection((Union(*sol_sets)), make_real,
                                  _domain).subs(gen, _gen)

    return rv if not relational else rv.as_relational(_gen)