 def is_right_unbounded(self):
     """Return ``True`` if the right endpoint is positive infinity. """
     return self.right is S.Infinity or self.right == Float("+inf")
def test_core_numbers():
    for c in (Integer(2), Rational(2, 3), Float("1.2")):
def sympify(a, locals=None, convert_xor=True, strict=False, rational=False,
    """Converts an arbitrary expression to a type that can be used inside SymPy.

    For example, it will convert Python ints into instance of sympy.Rational,
    floats into instances of sympy.Float, etc. It is also able to coerce symbolic
    expressions which inherit from Basic. This can be useful in cooperation
    with SAGE.

    It currently accepts as arguments:
       - any object defined in sympy
       - standard numeric python types: int, long, float, Decimal
       - strings (like "0.09" or "2e-19")
       - booleans, including ``None`` (will leave ``None`` unchanged)
       - lists, sets or tuples containing any of the above

    .. warning::
        Note that this function uses ``eval``, and thus shouldn't be used on
        unsanitized input.

    If the argument is already a type that SymPy understands, it will do
    nothing but return that value. This can be used at the beginning of a
    function to ensure you are working with the correct type.

    >>> from sympy import sympify

    >>> sympify(2).is_integer
    >>> sympify(2).is_real

    >>> sympify(2.0).is_real
    >>> sympify("2.0").is_real
    >>> sympify("2e-45").is_real

    If the expression could not be converted, a SympifyError is raised.

    >>> sympify("x***2")
    Traceback (most recent call last):
    SympifyError: SympifyError: "could not parse u'x***2'"


    The sympification happens with access to everything that is loaded
    by ``from sympy import *``; anything used in a string that is not
    defined by that import will be converted to a symbol. In the following,
    the ``bitcount`` function is treated as a symbol and the ``O`` is
    interpreted as the Order object (used with series) and it raises
    an error when used improperly:

    >>> s = 'bitcount(42)'
    >>> sympify(s)
    >>> sympify("O(x)")
    >>> sympify("O + 1")
    Traceback (most recent call last):
    TypeError: unbound method...

    In order to have ``bitcount`` be recognized it can be imported into a
    namespace dictionary and passed as locals:

    >>> from sympy.core.compatibility import exec_
    >>> ns = {}
    >>> exec_('from sympy.core.evalf import bitcount', ns)
    >>> sympify(s, locals=ns)

    In order to have the ``O`` interpreted as a Symbol, identify it as such
    in the namespace dictionary. This can be done in a variety of ways; all
    three of the following are possibilities:

    >>> from sympy import Symbol
    >>> ns["O"] = Symbol("O")  # method 1
    >>> exec_('from sympy.abc import O', ns)  # method 2
    >>> ns.update(dict(O=Symbol("O")))  # method 3
    >>> sympify("O + 1", locals=ns)
    O + 1

    If you want *all* single-letter and Greek-letter variables to be symbols
    then you can use the clashing-symbols dictionaries that have been defined
    there as private variables: _clash1 (single-letter variables), _clash2
    (the multi-letter Greek names) or _clash (both single and multi-letter
    names that are defined in abc).

    >>> from sympy.abc import _clash1
    >>> _clash1
    {'C': C, 'E': E, 'I': I, 'N': N, 'O': O, 'Q': Q, 'S': S}
    >>> sympify('I & Q', _clash1)
    I & Q


    If the option ``strict`` is set to ``True``, only the types for which an
    explicit conversion has been defined are converted. In the other
    cases, a SympifyError is raised.

    >>> print(sympify(None))
    >>> sympify(None, strict=True)
    Traceback (most recent call last):
    SympifyError: SympifyError: None


    If the option ``evaluate`` is set to ``False``, then arithmetic and
    operators will be converted into their SymPy equivalents and the
    ``evaluate=False`` option will be added. Nested ``Add`` or ``Mul`` will
    be denested first. This is done via an AST transformation that replaces
    operators with their SymPy equivalents, so if an operand redefines any
    of those operations, the redefined operators will not be used.

    >>> sympify('2**2 / 3 + 5')
    >>> sympify('2**2 / 3 + 5', evaluate=False)
    2**2/3 + 5


    To extend ``sympify`` to convert custom objects (not derived from ``Basic``),
    just define a ``_sympy_`` method to your class. You can do that even to
    classes that you do not own by subclassing or adding the method at runtime.

    >>> from sympy import Matrix
    >>> class MyList1(object):
    ...     def __iter__(self):
    ...         yield 1
    ...         yield 2
    ...         return
    ...     def __getitem__(self, i): return list(self)[i]
    ...     def _sympy_(self): return Matrix(self)
    >>> sympify(MyList1())

    If you do not have control over the class definition you could also use the
    ``converter`` global dictionary. The key is the class and the value is a
    function that takes a single argument and returns the desired SymPy
    object, e.g. ``converter[MyList] = lambda x: Matrix(x)``.

    >>> class MyList2(object):   # XXX Do not do this if you control the class!
    ...     def __iter__(self):  #     Use _sympy_!
    ...         yield 1
    ...         yield 2
    ...         return
    ...     def __getitem__(self, i): return list(self)[i]
    >>> from sympy.core.sympify import converter
    >>> converter[MyList2] = lambda x: Matrix(x)
    >>> sympify(MyList2())


    Sometimes autosimplification during sympification results in expressions
    that are very different in structure than what was entered. Until such
    autosimplification is no longer done, the ``kernS`` function might be of
    some use. In the example below you can see how an expression reduces to
    -1 by autosimplification, but does not do so when ``kernS`` is used.

    >>> from sympy.core.sympify import kernS
    >>> from sympy.abc import x
    >>> -2*(-(-x + 1/x)/(x*(x - 1/x)**2) - 1/(x*(x - 1/x))) - 1
    >>> s = '-2*(-(-x + 1/x)/(x*(x - 1/x)**2) - 1/(x*(x - 1/x))) - 1'
    >>> sympify(s)
    >>> kernS(s)
    -2*(-(-x + 1/x)/(x*(x - 1/x)**2) - 1/(x*(x - 1/x))) - 1

    if evaluate is None:
        if global_evaluate[0] is False:
            evaluate = global_evaluate[0]
            evaluate = True
        if a in sympy_classes:
            return a
    except TypeError: # Type of a is unhashable
        cls = a.__class__
    except AttributeError:  # a is probably an old-style class object
        cls = type(a)
    if cls in sympy_classes:
        return a
    if cls is type(None):
        if strict:
            raise SympifyError(a)
            return a

    #Support for basic numpy datatypes
    if type(a).__module__ == 'numpy':
        import numpy as np
        if np.isscalar(a):
            if not isinstance(a, np.floating):
                return sympify(np.asscalar(a))
                    from sympy.core.numbers import Float
                    prec = np.finfo(a).nmant
                    a = str(list(np.reshape(np.asarray(a),
                                            (1, np.size(a)))[0]))[1:-1]
                    return Float(a, precision=prec)
                except NotImplementedError:
                    raise SympifyError('Translation for numpy float : %s '
                                       'is not implemented' % a)

        return converter[cls](a)
    except KeyError:
        for superclass in getmro(cls):
                return converter[superclass](a)
            except KeyError:

    if isinstance(a, CantSympify):
        raise SympifyError(a)

        return a._sympy_()
    except AttributeError:

    if not isinstance(a, string_types):
        for coerce in (float, int):
                return sympify(coerce(a))
            except (TypeError, ValueError, AttributeError, SympifyError):

    if strict:
        raise SympifyError(a)

        from ..tensor.array import Array
        return Array(a.flat, a.shape)  # works with e.g. NumPy arrays
    except AttributeError:

    if iterable(a):
            return type(a)([sympify(x, locals=locals, convert_xor=convert_xor,
                rational=rational) for x in a])
        except TypeError:
            # Not all iterables are rebuildable with their type.
    if isinstance(a, dict):
            return type(a)([sympify(x, locals=locals, convert_xor=convert_xor,
                rational=rational) for x in a.items()])
        except TypeError:
            # Not all iterables are rebuildable with their type.

    # At this point we were given an arbitrary expression
    # which does not inherit from Basic and doesn't implement
    # _sympy_ (which is a canonical and robust way to convert
    # anything to SymPy expression).
    # As a last chance, we try to take "a"'s normal form via unicode()
    # and try to parse it. If it fails, then we have no luck and
    # return an exception
        from .compatibility import unicode
        a = unicode(a)
    except Exception as exc:
        raise SympifyError(a, exc)

    from sympy.parsing.sympy_parser import (parse_expr, TokenError,
    from sympy.parsing.sympy_parser import convert_xor as t_convert_xor
    from sympy.parsing.sympy_parser import rationalize as t_rationalize

    transformations = standard_transformations

    if rational:
        transformations += (t_rationalize,)
    if convert_xor:
        transformations += (t_convert_xor,)

        a = a.replace('\n', '')
        expr = parse_expr(a, local_dict=locals, transformations=transformations, evaluate=evaluate)
    except (TokenError, SyntaxError) as exc:
        raise SympifyError('could not parse %r' % a, exc)

    return expr
 def is_left_unbounded(self):
     """Return ``True`` if the left endpoint is negative infinity. """
     return self.left is S.NegativeInfinity or self.left == Float("-inf")
def _real_to_rational(expr, tolerance=None, rational_conversion='base10'):
    Replace all reals in expr with rationals.

    >>> from sympy import Rational
    >>> from sympy.simplify.simplify import _real_to_rational
    >>> from sympy.abc import x

    >>> _real_to_rational(.76 + .1*x**.5)
    sqrt(x)/10 + 19/25

    If rational_conversion='base10', this uses the base-10 string. If
    rational_conversion='exact', the exact, base-2 representation is used.

    >>> _real_to_rational(0.333333333333333, rational_conversion='exact')
    >>> _real_to_rational(0.333333333333333)

    expr = _sympify(expr)
    inf = Float('inf')
    p = expr
    reps = {}
    reduce_num = None
    if tolerance is not None and tolerance < 1:
        reduce_num = ceiling(1/tolerance)
    for fl in p.atoms(Float):
        key = fl
        if reduce_num is not None:
            r = Rational(fl).limit_denominator(reduce_num)
        elif (tolerance is not None and tolerance >= 1 and
                fl.is_Integer is False):
            r = Rational(tolerance*round(fl/tolerance)
            if rational_conversion == 'exact':
                r = Rational(fl)
                reps[key] = r
            elif rational_conversion != 'base10':
                raise ValueError("rational_conversion must be 'base10' or 'exact'")

            r = nsimplify(fl, rational=False)
            # e.g. log(3).n() -> log(3) instead of a Rational
            if fl and not r:
                r = Rational(fl)
            elif not r.is_Rational:
                if fl == inf or fl == -inf:
                    r = S.ComplexInfinity
                elif fl < 0:
                    fl = -fl
                    d = Pow(10, int((mpmath.log(fl)/mpmath.log(10))))
                    r = -Rational(str(fl/d))*d
                elif fl > 0:
                    d = Pow(10, int((mpmath.log(fl)/mpmath.log(10))))
                    r = Rational(str(fl/d))*d
                    r = Integer(0)
        reps[key] = r
    return p.subs(reps, simultaneous=True)
 def to_sympy(self, element):
     """Convert ``element`` to SymPy number. """
     return Float(element.real, self.dps) + I*Float(element.imag, self.dps)
 def to_sympy(self, element):
     """Convert ``element`` to SymPy number. """
     return Float(element, self.dps)
def nsimplify(expr, constants=(), tolerance=None, full=False, rational=None,
    Find a simple representation for a number or, if there are free symbols or
    if rational=True, then replace Floats with their Rational equivalents. If
    no change is made and rational is not False then Floats will at least be
    converted to Rationals.

    For numerical expressions, a simple formula that numerically matches the
    given numerical expression is sought (and the input should be possible
    to evalf to a precision of at least 30 digits).

    Optionally, a list of (rationally independent) constants to
    include in the formula may be given.

    A lower tolerance may be set to find less exact matches. If no tolerance
    is given then the least precise value will set the tolerance (e.g. Floats
    default to 15 digits of precision, so would be tolerance=10**-15).

    With full=True, a more extensive search is performed
    (this is useful to find simpler numbers when the tolerance
    is set low).

    When converting to rational, if rational_conversion='base10' (the default), then
    convert floats to rationals using their base-10 (string) representation.
    When rational_conversion='exact' it uses the exact, base-2 representation.


    >>> from sympy import nsimplify, sqrt, GoldenRatio, exp, I, exp, pi
    >>> nsimplify(4/(1+sqrt(5)), [GoldenRatio])
    -2 + 2*GoldenRatio
    >>> nsimplify((1/(exp(3*pi*I/5)+1)))
    1/2 - I*sqrt(sqrt(5)/10 + 1/4)
    >>> nsimplify(I**I, [pi])
    >>> nsimplify(pi, tolerance=0.01)

    >>> nsimplify(0.333333333333333, rational=True, rational_conversion='exact')
    >>> nsimplify(0.333333333333333, rational=True)

    See Also

        return sympify(as_int(expr))
    except (TypeError, ValueError):
    expr = sympify(expr).xreplace({
        Float('inf'): S.Infinity,
        Float('-inf'): S.NegativeInfinity,
    if expr is S.Infinity or expr is S.NegativeInfinity:
        return expr
    if rational or expr.free_symbols:
        return _real_to_rational(expr, tolerance, rational_conversion)

    # SymPy's default tolerance for Rationals is 15; other numbers may have
    # lower tolerances set, so use them to pick the largest tolerance if None
    # was given
    if tolerance is None:
        tolerance = 10**-min([15] +
             for n in expr.atoms(Float)])
    # XXX should prec be set independent of tolerance or should it be computed
    # from tolerance?
    prec = 30
    bprec = int(prec*3.33)

    constants_dict = {}
    for constant in constants:
        constant = sympify(constant)
        v = constant.evalf(prec)
        if not v.is_Float:
            raise ValueError("constants must be real-valued")
        constants_dict[str(constant)] = v._to_mpmath(bprec)

    exprval = expr.evalf(prec, chop=True)
    re, im = exprval.as_real_imag()

    # safety check to make sure that this evaluated to a number
    if not (re.is_Number and im.is_Number):
        return expr

    def nsimplify_real(x):
        orig = mpmath.mp.dps
        xv = x._to_mpmath(bprec)
            # We'll be happy with low precision if a simple fraction
            if not (tolerance or full):
                mpmath.mp.dps = 15
                rat = mpmath.pslq([xv, 1])
                if rat is not None:
                    return Rational(-int(rat[1]), int(rat[0]))
            mpmath.mp.dps = prec
            newexpr = mpmath.identify(xv, constants=constants_dict,
                tol=tolerance, full=full)
            if not newexpr:
                raise ValueError
            if full:
                newexpr = newexpr[0]
            expr = sympify(newexpr)
            if x and not expr:  # don't let x become 0
                raise ValueError
            if expr.is_finite is False and not xv in [mpmath.inf, mpmath.ninf]:
                raise ValueError
            return expr
            # even though there are returns above, this is executed
            # before leaving
            mpmath.mp.dps = orig
        if re:
            re = nsimplify_real(re)
        if im:
            im = nsimplify_real(im)
    except ValueError:
        if rational is None:
            return _real_to_rational(expr, rational_conversion=rational_conversion)
        return expr

    rv = re + im*S.ImaginaryUnit
    # if there was a change or rational is explicitly not wanted
    # return the value, else return the Rational representation
    if rv != expr or rational is False:
        return rv
    return _real_to_rational(expr, rational_conversion=rational_conversion)
def hypsum(expr, n, start, prec):
    Sum a rapidly convergent infinite hypergeometric series with
    given general term, e.g. e = hypsum(1/factorial(n), n). The
    quotient between successive terms must be a quotient of integer
    from sympy import Float, hypersimp, lambdify

    if prec == float('inf'):
        raise NotImplementedError('does not support inf prec')

    if start:
        expr = expr.subs(n, n + start)
    hs = hypersimp(expr, n)
    if hs is None:
        raise NotImplementedError("a hypergeometric series is required")
    num, den = hs.as_numer_denom()

    func1 = lambdify(n, num)
    func2 = lambdify(n, den)

    h, g, p = check_convergence(num, den, n)

    if h < 0:
        raise ValueError("Sum diverges like (n!)^%i" % (-h))

    term = expr.subs(n, 0)
    if not term.is_Rational:
        raise NotImplementedError("Non rational term functionality is not implemented.")

    # Direct summation if geometric or faster
    if h > 0 or (h == 0 and abs(g) > 1):
        term = (MPZ(term.p) << prec) // term.q
        s = term
        k = 1
        while abs(term) > 5:
            term *= MPZ(func1(k - 1))
            term //= MPZ(func2(k - 1))
            s += term
            k += 1
        return from_man_exp(s, -prec)
        alt = g < 0
        if abs(g) < 1:
            raise ValueError("Sum diverges like (%i)^n" % abs(1/g))
        if p < 1 or (p == 1 and not alt):
            raise ValueError("Sum diverges like n^%i" % (-p))
        # We have polynomial convergence: use Richardson extrapolation
        vold = None
        ndig = prec_to_dps(prec)
        while True:
            # Need to use at least quad precision because a lot of cancellation
            # might occur in the extrapolation process; we check the answer to
            # make sure that the desired precision has been reached, too.
            prec2 = 4*prec
            term0 = (MPZ(term.p) << prec2) // term.q

            def summand(k, _term=[term0]):
                if k:
                    k = int(k)
                    _term[0] *= MPZ(func1(k - 1))
                    _term[0] //= MPZ(func2(k - 1))
                return make_mpf(from_man_exp(_term[0], -prec2))

            with workprec(prec):
                v = nsum(summand, [0, mpmath_inf], method='richardson')
            vf = Float(v, ndig)
            if vold is not None and vold == vf:
            prec += prec  # double precision each time
            vold = vf

        return v._mpf_
def feq(a, b):
    """Test if two floating point values are 'equal'."""
    t_float = Float("1.0E-10")
    return -t_float < a - b < t_float
def test_BernoulliProcess():

    B = BernoulliProcess("B", p=0.6, success=1, failure=0)
    assert B.state_space == FiniteSet(0, 1)
    assert B.index_set == S.Naturals0
    assert B.success == 1
    assert B.failure == 0

    X = BernoulliProcess("X", p=Rational(1, 3), success='H', failure='T')
    assert X.state_space == FiniteSet('H', 'T')
    H, T = symbols("H,T")
    assert E(X[1] + X[2] * X[3]
             ) == H**2 / 9 + 4 * H * T / 9 + H / 3 + 4 * T**2 / 9 + 2 * T / 3

    t, x = symbols('t, x', positive=True, integer=True)
    assert isinstance(B[t], RandomIndexedSymbol)

           lambda: BernoulliProcess("X", p=1.1, success=1, failure=0))
    raises(NotImplementedError, lambda: B(t))

    raises(IndexError, lambda: B[-3])
    assert B.joint_distribution(B[3], B[9]) == JointDistributionHandmade(
            (B[3], B[9]),
            Piecewise((0.6, Eq(B[3], 1)), (0.4, Eq(B[3], 0)),
                      (0, True)) * Piecewise((0.6, Eq(B[9], 1)),
                                             (0.4, Eq(B[9], 0)), (0, True))))

    assert B.joint_distribution(2, B[4]) == JointDistributionHandmade(
            (B[2], B[4]),
            Piecewise((0.6, Eq(B[2], 1)), (0.4, Eq(B[2], 0)),
                      (0, True)) * Piecewise((0.6, Eq(B[4], 1)),
                                             (0.4, Eq(B[4], 0)), (0, True))))

    # Test for the sum distribution of Bernoulli Process RVs
    Y = B[1] + B[2] + B[3]
    assert P(Eq(Y, 0)).round(2) == Float(0.06, 1)
    assert P(Eq(Y, 2)).round(2) == Float(0.43, 2)
    assert P(Eq(Y, 4)).round(2) == 0
    assert P(Gt(Y, 1)).round(2) == Float(0.65, 2)
    # Test for independency of each Random Indexed variable
    assert P(Eq(B[1], 0) & Eq(B[2], 1) & Eq(B[3], 0)
             & Eq(B[4], 1)).round(2) == Float(0.06, 1)

    assert E(2 * B[1] + B[2]).round(2) == Float(1.80, 3)
    assert E(2 * B[1] + B[2] + 5).round(2) == Float(6.80, 3)
    assert E(B[2] * B[4] + B[10]).round(2) == Float(0.96, 2)
    assert E(B[2] > 0, Eq(B[1], 1) & Eq(B[2], 1)).round(2) == Float(0.60, 2)
    assert E(B[1]) == 0.6
    assert P(B[1] > 0).round(2) == Float(0.60, 2)
    assert P(B[1] < 1).round(2) == Float(0.40, 2)
    assert P(B[1] > 0, B[2] <= 1).round(2) == Float(0.60, 2)
    assert P(B[12] * B[5] > 0).round(2) == Float(0.36, 2)
    assert P(B[12] * B[5] > 0, B[4] < 1).round(2) == Float(0.36, 2)
    assert P(Eq(B[2], 1), B[2] > 0) == 1
    assert P(Eq(B[5], 3)) == 0
    assert P(Eq(B[1], 1), B[1] < 0) == 0
    assert P(B[2] > 0, Eq(B[2], 1)) == 1
    assert P(B[2] < 0, Eq(B[2], 1)) == 0
    assert P(B[2] > 0, B[2] == 7) == 0
    assert P(B[5] > 0, B[5]) == BernoulliDistribution(0.6, 0, 1)
    raises(ValueError, lambda: P(3))
    raises(ValueError, lambda: P(B[3] > 0, 3))

    # test issue 19456
    expr = Sum(B[t], (t, 0, 4))
    expr2 = Sum(B[t], (t, 1, 3))
    expr3 = Sum(B[t]**2, (t, 1, 3))
    assert expr.doit() == B[0] + B[1] + B[2] + B[3] + B[4]
    assert expr2.doit() == Y
    assert expr3.doit() == B[1]**2 + B[2]**2 + B[3]**2
    assert B[2 * t].free_symbols == {B[2 * t], t}
    assert B[4].free_symbols == {B[4]}
    assert B[x * t].free_symbols == {B[x * t], x, t}

    #test issue 20078
    assert (2 * B[t] + 3 * B[t]).simplify() == 5 * B[t]
    assert (2 * B[t] - 3 * B[t]).simplify() == -B[t]
    assert (2 * (0.25 * B[t])).simplify() == 0.5 * B[t]
    assert (2 * B[t] * 0.25 * B[t]).simplify() == 0.5 * B[t]**2
    assert (B[t]**2 + B[t]**3).simplify() == (B[t] + 1) * B[t]**2
def test_ContinuousMarkovChain():
    T1 = Matrix([[S(-2), S(2), S.Zero], [S.Zero, S.NegativeOne, S.One],
                 [Rational(3, 2), Rational(3, 2),
    C1 = ContinuousMarkovChain('C', [0, 1, 2], T1)
    assert C1.limiting_distribution() == ImmutableMatrix(
        [[Rational(3, 19), Rational(12, 19),
          Rational(4, 19)]])

    T2 = Matrix([[-S.One, S.One, S.Zero], [S.One, -S.One, S.Zero],
                 [S.Zero, S.One, -S.One]])
    C2 = ContinuousMarkovChain('C', [0, 1, 2], T2)
    A, t = C2.generator_matrix, symbols('t', positive=True)
    assert C2.transition_probabilities(A)(t) == Matrix(
        [[S.Half + exp(-2 * t) / 2, S.Half - exp(-2 * t) / 2, 0],
         [S.Half - exp(-2 * t) / 2, S.Half + exp(-2 * t) / 2, 0],
             S.Half - exp(-t) + exp(-2 * t) / 2, S.Half - exp(-2 * t) / 2,
    with ignore_warnings(
            UserWarning):  ### TODO: Restore tests once warnings are removed
        assert P(Eq(C2(1), 1), Eq(C2(0), 1),
                 evaluate=False) == Probability(Eq(C2(1), 1), Eq(C2(0), 1))
    assert P(Eq(C2(1), 1), Eq(C2(0), 1)) == exp(-2) / 2 + S.Half
    assert P(
        Eq(C2(1), 0) & Eq(C2(2), 1) & Eq(C2(3), 1),
        Eq(P(Eq(C2(1), 0)),
           S.Half)) == (Rational(1, 4) - exp(-2) / 4) * (exp(-2) / 2 + S.Half)
    assert P(
        Not(Eq(C2(1), 0) & Eq(C2(2), 1) & Eq(C2(3), 2)) |
        (Eq(C2(1), 0) & Eq(C2(2), 1) & Eq(C2(3), 2)),
        Eq(P(Eq(C2(1), 0)), Rational(1, 4))
        & Eq(P(Eq(C2(1), 1)), Rational(1, 4))) is S.One
    assert E(C2(Rational(3, 2)),
             Eq(C2(0), 2)) == -exp(-3) / 2 + 2 * exp(Rational(-3, 2)) + S.Half
    assert variance(C2(Rational(3, 2)), Eq(
        1)) == ((S.Half - exp(-3) / 2)**2 * (exp(-3) / 2 + S.Half) +
                (Rational(-1, 2) - exp(-3) / 2)**2 * (S.Half - exp(-3) / 2))
    raises(KeyError, lambda: P(Eq(C2(1), 0), Eq(P(Eq(C2(1), 1)), S.Half)))
    assert P(Eq(C2(1), 0), Eq(P(Eq(C2(5), 1)),
                              S.Half)) == Probability(Eq(C2(1), 0))
    TS1 = MatrixSymbol('G', 3, 3)
    CS1 = ContinuousMarkovChain('C', [0, 1, 2], TS1)
    A = CS1.generator_matrix
    assert CS1.transition_probabilities(A)(t) == exp(t * A)

    C3 = ContinuousMarkovChain(
        'C', [Symbol('0'), Symbol('1'), Symbol('2')], T2)
    assert P(Eq(C3(1), 1), Eq(C3(0), 1)) == exp(-2) / 2 + S.Half
    assert P(Eq(C3(1), Symbol('1')), Eq(C3(0),
                                        Symbol('1'))) == exp(-2) / 2 + S.Half

    #test probability queries
    G = Matrix([[-S(1), Rational(1, 10),
                 Rational(9, 10)], [Rational(2, 5), -S(1),
                                    Rational(3, 5)],
                [Rational(1, 2), Rational(1, 2), -S(1)]])
    C = ContinuousMarkovChain('C', state_space=[0, 1, 2], gen_mat=G)
    assert P(Eq(C(7.385), C(3.19)), Eq(C(0.862),
                                       0)).round(5) == Float(0.35469, 5)
    assert P(Gt(C(98.715), C(19.807)), Eq(C(11.314),
                                          2)).round(5) == Float(0.32452, 5)
    assert P(Le(C(5.9), C(10.112)), Eq(C(4), 1)).round(6) == Float(0.675214, 6)
    assert Float(P(Eq(C(7.32), C(2.91)), Eq(C(2.63), 1)),
                 14) == Float(1 - P(Ne(C(7.32), C(2.91)), Eq(C(2.63), 1)), 14)
    assert Float(P(Gt(C(3.36), C(1.101)), Eq(C(0.8), 2)),
                 14) == Float(1 - P(Le(C(3.36), C(1.101)), Eq(C(0.8), 2)), 14)
    assert Float(P(Lt(C(4.9), C(2.79)), Eq(C(1.61), 0)),
                 14) == Float(1 - P(Ge(C(4.9), C(2.79)), Eq(C(1.61), 0)), 14)
    assert P(Eq(C(5.243), C(10.912)), Eq(C(2.174),
                                         1)) == P(Eq(C(10.912), C(5.243)),
                                                  Eq(C(2.174), 1))
    assert P(Gt(C(2.344), C(9.9)), Eq(C(1.102),
                                      1)) == P(Lt(C(9.9), C(2.344)),
                                               Eq(C(1.102), 1))
    assert P(Ge(C(7.87), C(1.008)), Eq(C(0.153),
                                       1)) == P(Le(C(1.008), C(7.87)),
                                                Eq(C(0.153), 1))

    #test symbolic queries
    a, b, c, d = symbols('a b c d')
    query = P(Eq(C(a), b), Eq(C(c), d))
    assert query.subs({
        a: 3.65,
        b: 2,
        c: 1.78,
        d: 1
    }).evalf().round(10) == P(Eq(C(3.65), 2), Eq(C(1.78), 1)).round(10)
    query_gt = P(Gt(C(a), b), Eq(C(c), d))
    query_le = P(Le(C(a), b), Eq(C(c), d))
    assert query_gt.subs({
        a: 13.2,
        b: 0,
        c: 3.29,
        d: 2
    }).evalf() + query_le.subs({
        a: 13.2,
        b: 0,
        c: 3.29,
        d: 2
    }).evalf() == 1
    query_ge = P(Ge(C(a), b), Eq(C(c), d))
    query_lt = P(Lt(C(a), b), Eq(C(c), d))
    assert query_ge.subs({
        a: 7.43,
        b: 1,
        c: 1.45,
        d: 0
    }).evalf() + query_lt.subs({
        a: 7.43,
        b: 1,
        c: 1.45,
        d: 0
    }).evalf() == 1

    #test issue 20078
    assert (2 * C(1) + 3 * C(1)).simplify() == 5 * C(1)
    assert (2 * C(1) - 3 * C(1)).simplify() == -C(1)
    assert (2 * (0.25 * C(1))).simplify() == 0.5 * C(1)
    assert (2 * C(1) * 0.25 * C(1)).simplify() == 0.5 * C(1)**2
    assert (C(1)**2 + C(1)**3).simplify() == (C(1) + 1) * C(1)**2
def test_DiscreteMarkovChain():

    # pass only the name
    X = DiscreteMarkovChain("X")
    assert isinstance(X.state_space, Range)
    assert X.index_set == S.Naturals0
    assert isinstance(X.transition_probabilities, MatrixSymbol)
    t = symbols('t', positive=True, integer=True)
    assert isinstance(X[t], RandomIndexedSymbol)
    assert E(X[0]) == Expectation(X[0])
    raises(TypeError, lambda: DiscreteMarkovChain(1))
    raises(NotImplementedError, lambda: X(t))
    raises(NotImplementedError, lambda: X.communication_classes())
    raises(NotImplementedError, lambda: X.canonical_form())
    raises(NotImplementedError, lambda: X.decompose())

    nz = Symbol('n', integer=True)
    TZ = MatrixSymbol('M', nz, nz)
    SZ = Range(nz)
    YZ = DiscreteMarkovChain('Y', SZ, TZ)
    assert P(Eq(YZ[2], 1), Eq(YZ[1], 0)) == TZ[0, 1]

    raises(ValueError, lambda: sample_stochastic_process(t))
    raises(ValueError, lambda: next(sample_stochastic_process(X)))
    # pass name and state_space
    # any hashable object should be a valid state
    # states should be valid as a tuple/set/list/Tuple/Range
    sym, rainy, cloudy, sunny = symbols('a Rainy Cloudy Sunny', real=True)
    state_spaces = [(1, 2, 3),
                    [Str('Hello'), sym,
                     DiscreteMarkovChain("Y", (1, 2, 3))],
                    Tuple(S(1), exp(sym), Str('World'), sympify=False),
                    Range(-1, 5, 2), [rainy, cloudy, sunny]]
    chains = [
        DiscreteMarkovChain("Y", state_space) for state_space in state_spaces

    for i, Y in enumerate(chains):
        assert isinstance(Y.transition_probabilities, MatrixSymbol)
        assert Y.state_space == state_spaces[i] or Y.state_space == FiniteSet(
        assert Y.number_of_states == 3

        with ignore_warnings(
                UserWarning):  # TODO: Restore tests once warnings are removed
            assert P(Eq(Y[2], 1), Eq(Y[0], 2),
                     evaluate=False) == Probability(Eq(Y[2], 1), Eq(Y[0], 2))
        assert E(Y[0]) == Expectation(Y[0])

        raises(ValueError, lambda: next(sample_stochastic_process(Y)))

    raises(TypeError, lambda: DiscreteMarkovChain("Y", dict((1, 1))))
    Y = DiscreteMarkovChain("Y", Range(1, t, 2))
    assert Y.number_of_states == ceiling((t - 1) / 2)

    # pass name and transition_probabilities
    chains = [
        DiscreteMarkovChain("Y", trans_probs=Matrix([[]])),
        DiscreteMarkovChain("Y", trans_probs=Matrix([[0, 1], [1, 0]])),
                            trans_probs=Matrix([[pi, 1 - pi], [sym, 1 - sym]]))
    for Z in chains:
        assert Z.number_of_states == Z.transition_probabilities.shape[0]
        assert isinstance(Z.transition_probabilities, ImmutableMatrix)

    # pass name, state_space and transition_probabilities
    T = Matrix([[0.5, 0.2, 0.3], [0.2, 0.5, 0.3], [0.2, 0.3, 0.5]])
    TS = MatrixSymbol('T', 3, 3)
    Y = DiscreteMarkovChain("Y", [0, 1, 2], T)
    YS = DiscreteMarkovChain("Y", ['One', 'Two', 3], TS)
    assert Y.joint_distribution(1, Y[2],
                                3) == JointDistribution(Y[1], Y[2], Y[3])
    raises(ValueError, lambda: Y.joint_distribution(Y[1].symbol, Y[2].symbol))
    assert P(Eq(Y[3], 2), Eq(Y[1], 1)).round(2) == Float(0.36, 2)
    assert (P(Eq(YS[3], 2), Eq(YS[1], 1)) -
            (TS[0, 2] * TS[1, 0] + TS[1, 1] * TS[1, 2] +
             TS[1, 2] * TS[2, 2])).simplify() == 0
    assert P(Eq(YS[1], 1), Eq(YS[2], 2)) == Probability(Eq(YS[1], 1))
    assert P(Eq(YS[3], 3), Eq(
        1)) == TS[0, 2] * TS[1, 0] + TS[1, 1] * TS[1, 2] + TS[1, 2] * TS[2, 2]
    TO = Matrix([[0.25, 0.75, 0], [0, 0.25, 0.75], [0.75, 0, 0.25]])
    assert P(Eq(Y[3], 2),
             Eq(Y[1], 1) & TransitionMatrixOf(Y, TO)).round(3) == Float(
                 0.375, 3)
    with ignore_warnings(
            UserWarning):  ### TODO: Restore tests once warnings are removed
        assert E(Y[3], evaluate=False) == Expectation(Y[3])
        assert E(Y[3], Eq(Y[2], 1)).round(2) == Float(1.1, 3)
    TSO = MatrixSymbol('T', 4, 4)
        lambda: str(P(Eq(YS[3], 2),
                      Eq(YS[1], 1) & TransitionMatrixOf(YS, TSO))))
           lambda: DiscreteMarkovChain("Z", [0, 1, 2], symbols('M')))
        lambda: DiscreteMarkovChain("Z", [0, 1, 2], MatrixSymbol('T', 3, 4)))
    raises(ValueError, lambda: E(Y[3], Eq(Y[2], 6)))
    raises(ValueError, lambda: E(Y[2], Eq(Y[3], 1)))

    # extended tests for probability queries
    TO1 = Matrix([[Rational(1, 4), Rational(3, 4), 0],
                  [Rational(1, 3),
                   Rational(1, 3),
                   Rational(1, 3)], [0, Rational(1, 4),
                                     Rational(3, 4)]])
    assert P(
        And(Eq(Y[2], 1), Eq(Y[1], 1), Eq(Y[0], 0)),
        Eq(Probability(Eq(Y[0], 0)), Rational(1, 4))
        & TransitionMatrixOf(Y, TO1)) == Rational(1, 16)
    assert P(And(Eq(Y[2], 1), Eq(Y[1], 1), Eq(Y[0], 0)), TransitionMatrixOf(Y, TO1)) == \
            Probability(Eq(Y[0], 0))/4
    assert P(
        Lt(X[1], 2) & Gt(X[1], 0),
        Eq(X[0], 2) & StochasticStateSpaceOf(X, [0, 1, 2])
        & TransitionMatrixOf(X, TO1)) == Rational(1, 4)
    assert P(
        Lt(X[1], 2) & Gt(X[1], 0),
        Eq(X[0], 2) & StochasticStateSpaceOf(X, [S(0), '0', 1])
        & TransitionMatrixOf(X, TO1)) == Rational(1, 4)
    assert P(
        Ne(X[1], 2) & Ne(X[1], 1),
        Eq(X[0], 2) & StochasticStateSpaceOf(X, [0, 1, 2])
        & TransitionMatrixOf(X, TO1)) is S.Zero
    assert P(
        Ne(X[1], 2) & Ne(X[1], 1),
        Eq(X[0], 2) & StochasticStateSpaceOf(X, [S(0), '0', 1])
        & TransitionMatrixOf(X, TO1)) is S.Zero
    assert P(And(Eq(Y[2], 1), Eq(Y[1], 1), Eq(Y[0], 0)),
             Eq(Y[1], 1)) == 0.1 * Probability(Eq(Y[0], 0))

    # testing properties of Markov chain
    TO2 = Matrix([[S.One, 0, 0],
                  [Rational(1, 3),
                   Rational(1, 3),
                   Rational(1, 3)], [0, Rational(1, 4),
                                     Rational(3, 4)]])
    TO3 = Matrix([[Rational(1, 4), Rational(3, 4), 0],
                  [Rational(1, 3),
                   Rational(1, 3),
                   Rational(1, 3)], [0, Rational(1, 4),
                                     Rational(3, 4)]])
    Y2 = DiscreteMarkovChain('Y', trans_probs=TO2)
    Y3 = DiscreteMarkovChain('Y', trans_probs=TO3)
    assert Y3.fundamental_matrix() == ImmutableMatrix(
        [[176, 81, -132], [36, 141, -52], [-44, -39, 208]]) / 125
    assert Y2.is_absorbing_chain() == True
    assert Y3.is_absorbing_chain() == False
    assert Y2.canonical_form() == ([0, 1, 2], TO2)
    assert Y3.canonical_form() == ([0, 1, 2], TO3)
    assert Y2.decompose() == ([0, 1,
                               2], TO2[0:1, 0:1], TO2[1:3, 0:1], TO2[1:3, 1:3])
    assert Y3.decompose() == ([0, 1, 2], TO3, Matrix(0, 3,
                                                     []), Matrix(0, 0, []))
    TO4 = Matrix([[Rational(1, 5),
                   Rational(2, 5),
                   Rational(2, 5)], [Rational(1, 10), S.Half,
                                     Rational(2, 5)],
                  [Rational(3, 5),
                   Rational(3, 10),
                   Rational(1, 10)]])
    Y4 = DiscreteMarkovChain('Y', trans_probs=TO4)
    w = ImmutableMatrix([[Rational(11, 39),
                          Rational(16, 39),
                          Rational(4, 13)]])
    assert Y4.limiting_distribution == w
    assert Y4.is_regular() == True
    assert Y4.is_ergodic() == True
    TS1 = MatrixSymbol('T', 3, 3)
    Y5 = DiscreteMarkovChain('Y', trans_probs=TS1)
    assert Y5.limiting_distribution(w, TO4).doit() == True
    assert Y5.stationary_distribution(condition_set=True).subs(
        TS1, TO4).contains(w).doit() == S.true
    TO6 = Matrix([[S.One, 0, 0, 0, 0], [S.Half, 0, S.Half, 0, 0],
                  [0, S.Half, 0, S.Half, 0], [0, 0, S.Half, 0, S.Half],
                  [0, 0, 0, 0, 1]])
    Y6 = DiscreteMarkovChain('Y', trans_probs=TO6)
    assert Y6.fundamental_matrix() == ImmutableMatrix(
        [[Rational(3, 2), S.One, S.Half], [S.One, S(2), S.One],
         [S.Half, S.One, Rational(3, 2)]])
    assert Y6.absorbing_probabilities() == ImmutableMatrix(
        [[Rational(3, 4), Rational(1, 4)], [S.Half, S.Half],
         [Rational(1, 4), Rational(3, 4)]])
    with warns_deprecated_sympy():
    TO7 = Matrix([[Rational(1, 2),
                   Rational(1, 4),
                   Rational(1, 4)], [Rational(1, 2), 0,
                                     Rational(1, 2)],
                  [Rational(1, 4),
                   Rational(1, 4),
                   Rational(1, 2)]])
    Y7 = DiscreteMarkovChain('Y', trans_probs=TO7)
    assert Y7.is_absorbing_chain() == False
    assert Y7.fundamental_matrix() == ImmutableMatrix(
        [[Rational(86, 75),
          Rational(1, 25),
          Rational(-14, 75)],
         [Rational(2, 25), Rational(21, 25),
          Rational(2, 25)],
         [Rational(-14, 75),
          Rational(1, 25),
          Rational(86, 75)]])

    # test for zero-sized matrix functionality
    X = DiscreteMarkovChain('X', trans_probs=Matrix([[]]))
    assert X.number_of_states == 0
    assert X.stationary_distribution() == Matrix([[]])
    assert X.communication_classes() == []
    assert X.canonical_form() == ([], Matrix([[]]))
    assert X.decompose() == ([], Matrix([[]]), Matrix([[]]), Matrix([[]]))
    assert X.is_regular() == False
    assert X.is_ergodic() == False

    # test communication_class
    # see https://drive.google.com/drive/folders/1HbxLlwwn2b3U8Lj7eb_ASIUb5vYaNIjg?usp=sharing
    # tutorial 2.pdf
    TO7 = Matrix([[0, 5, 5, 0, 0], [0, 0, 0, 10, 0], [5, 0, 5, 0, 0],
                  [0, 10, 0, 0, 0], [0, 3, 0, 3, 4]]) / 10
    Y7 = DiscreteMarkovChain('Y', trans_probs=TO7)
    tuples = Y7.communication_classes()
    classes, recurrence, periods = list(zip(*tuples))
    assert classes == ([1, 3], [0, 2], [4])
    assert recurrence == (True, False, False)
    assert periods == (2, 1, 1)

    TO8 = Matrix([[0, 0, 0, 10, 0, 0], [5, 0, 5, 0, 0, 0], [0, 4, 0, 0, 0, 6],
                  [10, 0, 0, 0, 0, 0], [0, 10, 0, 0, 0, 0], [0, 0, 0, 5, 5, 0]
                  ]) / 10
    Y8 = DiscreteMarkovChain('Y', trans_probs=TO8)
    tuples = Y8.communication_classes()
    classes, recurrence, periods = list(zip(*tuples))
    assert classes == ([0, 3], [1, 2, 5, 4])
    assert recurrence == (True, False)
    assert periods == (2, 2)

    TO9 = Matrix(
        [[2, 0, 0, 3, 0, 0, 3, 2, 0, 0], [0, 10, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 2, 2, 0, 0, 0, 0, 0, 3, 3], [0, 0, 0, 3, 0, 0, 6, 1, 0, 0],
         [0, 0, 0, 0, 5, 5, 0, 0, 0, 0], [0, 0, 0, 0, 0, 10, 0, 0, 0, 0],
         [4, 0, 0, 5, 0, 0, 1, 0, 0, 0], [2, 0, 0, 4, 0, 0, 2, 2, 0, 0],
         [3, 0, 1, 0, 0, 0, 0, 0, 4, 2], [0, 0, 4, 0, 0, 0, 0, 0, 3, 3]]) / 10
    Y9 = DiscreteMarkovChain('Y', trans_probs=TO9)
    tuples = Y9.communication_classes()
    classes, recurrence, periods = list(zip(*tuples))
    assert classes == ([0, 3, 6, 7], [1], [2, 8, 9], [5], [4])
    assert recurrence == (True, True, False, True, False)
    assert periods == (1, 1, 1, 1, 1)

    # test canonical form
    # see https://www.dartmouth.edu/~chance/teaching_aids/books_articles/probability_book/Chapter11.pdf
    # example 11.13
    T = Matrix([[1, 0, 0, 0, 0], [S(1) / 2, 0, S(1) / 2, 0, 0],
                [0, S(1) / 2, 0, S(1) / 2, 0], [0, 0,
                                                S(1) / 2, 0,
                                                S(1) / 2], [0, 0, 0, 0,
    DW = DiscreteMarkovChain('DW', [0, 1, 2, 3, 4], T)
    states, A, B, C = DW.decompose()
    assert states == [0, 4, 1, 2, 3]
    assert A == Matrix([[1, 0], [0, 1]])
    assert B == Matrix([[S(1) / 2, 0], [0, 0], [0, S(1) / 2]])
    assert C == Matrix([[0, S(1) / 2, 0], [S(1) / 2, 0, S(1) / 2],
                        [0, S(1) / 2, 0]])
    states, new_matrix = DW.canonical_form()
    assert states == [0, 4, 1, 2, 3]
    assert new_matrix == Matrix([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0],
                                 [S(1) / 2, 0, 0, S(1) / 2, 0],
                                 [0, 0, S(1) / 2, 0,
                                  S(1) / 2], [0, S(1) / 2, 0,
                                              S(1) / 2, 0]])

    # test regular and ergodic
    # https://www.dartmouth.edu/~chance/teaching_aids/books_articles/probability_book/Chapter11.pdf
    T = Matrix([[0, 4, 0, 0, 0], [1, 0, 3, 0, 0], [0, 2, 0, 2, 0],
                [0, 0, 3, 0, 1], [0, 0, 0, 4, 0]]) / 4
    X = DiscreteMarkovChain('X', trans_probs=T)
    assert not X.is_regular()
    assert X.is_ergodic()
    T = Matrix([[0, 1], [1, 0]])
    X = DiscreteMarkovChain('X', trans_probs=T)
    assert not X.is_regular()
    assert X.is_ergodic()
    # http://www.math.wisc.edu/~valko/courses/331/MC2.pdf
    T = Matrix([[2, 1, 1], [2, 0, 2], [1, 1, 2]]) / 4
    X = DiscreteMarkovChain('X', trans_probs=T)
    assert X.is_regular()
    assert X.is_ergodic()
    # https://docs.ufpr.br/~lucambio/CE222/1S2014/Kemeny-Snell1976.pdf
    T = Matrix([[1, 1], [1, 1]]) / 2
    X = DiscreteMarkovChain('X', trans_probs=T)
    assert X.is_regular()
    assert X.is_ergodic()

    # test is_absorbing_chain
    T = Matrix([[0, 1, 0], [1, 0, 0], [0, 0, 1]])
    X = DiscreteMarkovChain('X', trans_probs=T)
    assert not X.is_absorbing_chain()
    # https://en.wikipedia.org/wiki/Absorbing_Markov_chain
    T = Matrix([[1, 1, 0, 0], [0, 1, 1, 0], [1, 0, 0, 1], [0, 0, 0, 2]]) / 2
    X = DiscreteMarkovChain('X', trans_probs=T)
    assert X.is_absorbing_chain()
    T = Matrix([[2, 0, 0, 0, 0], [1, 0, 1, 0, 0], [0, 1, 0, 1, 0],
                [0, 0, 1, 0, 1], [0, 0, 0, 0, 2]]) / 2
    X = DiscreteMarkovChain('X', trans_probs=T)
    assert X.is_absorbing_chain()

    # test custom state space
    Y10 = DiscreteMarkovChain('Y', [1, 2, 3], TO2)
    tuples = Y10.communication_classes()
    classes, recurrence, periods = list(zip(*tuples))
    assert classes == ([1], [2, 3])
    assert recurrence == (True, False)
    assert periods == (1, 1)
    assert Y10.canonical_form() == ([1, 2, 3], TO2)
    assert Y10.decompose() == ([1, 2, 3], TO2[0:1, 0:1], TO2[1:3,
                                                             0:1], TO2[1:3,

    # testing miscellaneous queries
    T = Matrix([[S.Half, Rational(1, 4),
                 Rational(1, 4)], [Rational(1, 3), 0,
                                   Rational(2, 3)], [S.Half, S.Half, 0]])
    X = DiscreteMarkovChain('X', [0, 1, 2], T)
    assert P(
        Eq(X[1], 2) & Eq(X[2], 1) & Eq(X[3], 0),
        Eq(P(Eq(X[1], 0)), Rational(1, 4))
        & Eq(P(Eq(X[1], 1)), Rational(1, 4))) == Rational(1, 12)
    assert P(Eq(X[2], 1) | Eq(X[2], 2), Eq(X[1], 1)) == Rational(2, 3)
    assert P(Eq(X[2], 1) & Eq(X[2], 2), Eq(X[1], 1)) is S.Zero
    assert P(Ne(X[2], 2), Eq(X[1], 1)) == Rational(1, 3)
    assert E(X[1]**2, Eq(X[0], 1)) == Rational(8, 3)
    assert variance(X[1], Eq(X[0], 1)) == Rational(8, 9)
    raises(ValueError, lambda: E(X[1], Eq(X[2], 1)))
    raises(ValueError, lambda: DiscreteMarkovChain('X', [0, 1], T))

    # testing miscellaneous queries with different state space
    X = DiscreteMarkovChain('X', ['A', 'B', 'C'], T)
    assert P(
        Eq(X[1], 2) & Eq(X[2], 1) & Eq(X[3], 0),
        Eq(P(Eq(X[1], 0)), Rational(1, 4))
        & Eq(P(Eq(X[1], 1)), Rational(1, 4))) == Rational(1, 12)
    assert P(Eq(X[2], 1) | Eq(X[2], 2), Eq(X[1], 1)) == Rational(2, 3)
    assert P(Eq(X[2], 1) & Eq(X[2], 2), Eq(X[1], 1)) is S.Zero
    assert P(Ne(X[2], 2), Eq(X[1], 1)) == Rational(1, 3)
    a = X.state_space.args[0]
    c = X.state_space.args[2]
    assert (E(X[1]**2, Eq(X[0], 1)) -
            (a**2 / 3 + 2 * c**2 / 3)).simplify() == 0
    assert (variance(X[1], Eq(X[0], 1)) -
            (2 * (-a / 3 + c / 3)**2 / 3 +
             (2 * a / 3 - 2 * c / 3)**2 / 3)).simplify() == 0
    raises(ValueError, lambda: E(X[1], Eq(X[2], 1)))

    #testing queries with multiple RandomIndexedSymbols
    T = Matrix([[Rational(5, 10),
                 Rational(3, 10),
                 Rational(2, 10)],
                [Rational(2, 10),
                 Rational(7, 10),
                 Rational(1, 10)],
                [Rational(3, 10),
                 Rational(3, 10),
                 Rational(4, 10)]])
    Y = DiscreteMarkovChain("Y", [0, 1, 2], T)
    assert P(Eq(Y[7], Y[5]), Eq(Y[2], 0)).round(5) == Float(0.44428, 5)
    assert P(Gt(Y[3], Y[1]), Eq(Y[0], 0)).round(2) == Float(0.36, 2)
    assert P(Le(Y[5], Y[10]), Eq(Y[4], 2)).round(6) == Float(0.583120, 6)
    assert Float(P(Eq(Y[10], Y[5]), Eq(Y[4], 1)),
                 14) == Float(1 - P(Ne(Y[10], Y[5]), Eq(Y[4], 1)), 14)
    assert Float(P(Gt(Y[8], Y[9]), Eq(Y[3], 2)),
                 14) == Float(1 - P(Le(Y[8], Y[9]), Eq(Y[3], 2)), 14)
    assert Float(P(Lt(Y[1], Y[4]), Eq(Y[0], 0)),
                 14) == Float(1 - P(Ge(Y[1], Y[4]), Eq(Y[0], 0)), 14)
    assert P(Eq(Y[5], Y[10]), Eq(Y[2], 1)) == P(Eq(Y[10], Y[5]), Eq(Y[2], 1))
    assert P(Gt(Y[1], Y[2]), Eq(Y[0], 1)) == P(Lt(Y[2], Y[1]), Eq(Y[0], 1))
    assert P(Ge(Y[7], Y[6]), Eq(Y[4], 1)) == P(Le(Y[6], Y[7]), Eq(Y[4], 1))

    #test symbolic queries
    a, b, c, d = symbols('a b c d')
    T = Matrix([[Rational(1, 10),
                 Rational(4, 10),
                 Rational(5, 10)],
                [Rational(3, 10),
                 Rational(4, 10),
                 Rational(3, 10)],
                [Rational(7, 10),
                 Rational(2, 10),
                 Rational(1, 10)]])
    Y = DiscreteMarkovChain("Y", [0, 1, 2], T)
    query = P(Eq(Y[a], b), Eq(Y[c], d))
    assert query.subs({
        a: 10,
        b: 2,
        c: 5,
        d: 1
    }).evalf().round(4) == P(Eq(Y[10], 2), Eq(Y[5], 1)).round(4)
    assert query.subs({
        a: 15,
        b: 0,
        c: 10,
        d: 1
    }).evalf().round(4) == P(Eq(Y[15], 0), Eq(Y[10], 1)).round(4)
    query_gt = P(Gt(Y[a], b), Eq(Y[c], d))
    query_le = P(Le(Y[a], b), Eq(Y[c], d))
    assert query_gt.subs({
        a: 5,
        b: 2,
        c: 1,
        d: 0
    }).evalf() + query_le.subs({
        a: 5,
        b: 2,
        c: 1,
        d: 0
    }).evalf() == 1
    query_ge = P(Ge(Y[a], b), Eq(Y[c], d))
    query_lt = P(Lt(Y[a], b), Eq(Y[c], d))
    assert query_ge.subs({
        a: 4,
        b: 1,
        c: 0,
        d: 2
    }).evalf() + query_lt.subs({
        a: 4,
        b: 1,
        c: 0,
        d: 2
    }).evalf() == 1

    #test issue 20078
    assert (2 * Y[1] + 3 * Y[1]).simplify() == 5 * Y[1]
    assert (2 * Y[1] - 3 * Y[1]).simplify() == -Y[1]
    assert (2 * (0.25 * Y[1])).simplify() == 0.5 * Y[1]
    assert ((2 * Y[1]) * (0.25 * Y[1])).simplify() == 0.5 * Y[1]**2
    assert (Y[1]**2 + Y[1]**3).simplify() == (Y[1] + 1) * Y[1]**2
def test_issue_12092():
    f = implemented_function('f', lambda x: x**2)
    assert f(f(2)).evalf() == Float(16)
def test_gauss_opt():
    mat = RayTransferMatrix(1, 2, 3, 4)
    assert mat == Matrix([[1, 2], [3, 4]])
    assert mat == RayTransferMatrix(Matrix([[1, 2], [3, 4]]))
    assert [mat.A, mat.B, mat.C, mat.D] == [1, 2, 3, 4]

    d, f, h, n1, n2, R = symbols('d f h n1 n2 R')
    lens = ThinLens(f)
    assert lens == Matrix([[1, 0], [-1 / f, 1]])
    assert lens.C == -1 / f
    assert FreeSpace(d) == Matrix([[1, d], [0, 1]])
    assert FlatRefraction(n1, n2) == Matrix([[1, 0], [0, n1 / n2]])
    assert CurvedRefraction(R, n1, n2) == Matrix([[1, 0],
                                                  [(n1 - n2) / (R * n2),
                                                   n1 / n2]])
    assert FlatMirror() == Matrix([[1, 0], [0, 1]])
    assert CurvedMirror(R) == Matrix([[1, 0], [-2 / R, 1]])
    assert ThinLens(f) == Matrix([[1, 0], [-1 / f, 1]])

    mul = CurvedMirror(R) * FreeSpace(d)
    mul_mat = Matrix([[1, 0], [-2 / R, 1]]) * Matrix([[1, d], [0, 1]])
    assert mul.A == mul_mat[0, 0]
    assert mul.B == mul_mat[0, 1]
    assert mul.C == mul_mat[1, 0]
    assert mul.D == mul_mat[1, 1]

    angle = symbols('angle')
    assert GeometricRay(h, angle) == Matrix([[h], [angle]])
    assert FreeSpace(d) * GeometricRay(h, angle) == Matrix([[angle * d + h],
    assert GeometricRay(Matrix(((h, ), (angle, )))) == Matrix([[h], [angle]])
    assert (FreeSpace(d) * GeometricRay(h, angle)).height == angle * d + h
    assert (FreeSpace(d) * GeometricRay(h, angle)).angle == angle

    p = BeamParameter(530e-9, 1, w=1e-3)
    assert streq(p.q, 1 + 1.88679245283019 * I * pi)
    assert streq(N(p.q), 1.0 + 5.92753330865999 * I)
    assert streq(N(p.w_0), Float(0.00100000000000000))
    assert streq(N(p.z_r), Float(5.92753330865999))
    fs = FreeSpace(10)
    p1 = fs * p
    assert streq(N(p.w), Float(0.00101413072159615))
    assert streq(N(p1.w), Float(0.00210803120913829))

    w, wavelen = symbols('w wavelen')
    assert waist2rayleigh(w, wavelen) == pi * w**2 / wavelen
    z_r, wavelen = symbols('z_r wavelen')
    assert rayleigh2waist(z_r, wavelen) == sqrt(wavelen * z_r) / sqrt(pi)

    a, b, f = symbols('a b f')
    assert geometric_conj_ab(a, b) == a * b / (a + b)
    assert geometric_conj_af(a, f) == a * f / (a - f)
    assert geometric_conj_bf(b, f) == b * f / (b - f)
    assert geometric_conj_ab(oo, b) == b
    assert geometric_conj_ab(a, oo) == a

    s_in, z_r_in, f = symbols('s_in z_r_in f')
    assert gaussian_conj(s_in, z_r_in,
                         f)[0] == 1 / (-1 / (s_in + z_r_in**2 /
                                             (-f + s_in)) + 1 / f)
    assert gaussian_conj(
        s_in, z_r_in, f)[1] == z_r_in / (1 - s_in**2 / f**2 + z_r_in**2 / f**2)
    assert gaussian_conj(
        s_in, z_r_in, f)[2] == 1 / sqrt(1 - s_in**2 / f**2 + z_r_in**2 / f**2)

    l, w_i, w_o, f = symbols('l w_i w_o f')
    assert conjugate_gauss_beams(
        l, w_i, w_o, f=f)[0] == f * (-sqrt(w_i**2 / w_o**2 - pi**2 * w_i**4 /
                                           (f**2 * l**2)) + 1)
    assert factor(conjugate_gauss_beams(
        l, w_i, w_o,
        f=f)[1]) == f * w_o**2 * (w_i**2 / w_o**2 -
                                  sqrt(w_i**2 / w_o**2 - pi**2 * w_i**4 /
                                       (f**2 * l**2))) / w_i**2
    assert conjugate_gauss_beams(l, w_i, w_o, f=f)[2] == f

    z, l, w = symbols('z l r', positive=True)
    p = BeamParameter(l, z, w=w)
    assert p.radius == z * (pi**2 * w**4 / (l**2 * z**2) + 1)
    assert p.w == w * sqrt(l**2 * z**2 / (pi**2 * w**4) + 1)
    assert p.w_0 == w
    assert p.divergence == l / (pi * w)
    assert p.gouy == atan2(z, pi * w**2 / l)
    assert p.waist_approximation_limit == 2 * l / pi