Exemplo n.º 1
0
def test_Max():
    from sympy.abc import x, y, z
    n = Symbol('n', negative=True)
    n_ = Symbol('n_', negative=True)
    nn = Symbol('nn', nonnegative=True)
    p = Symbol('p', positive=True)
    p_ = Symbol('p_', positive=True)
    r = Symbol('r', real=True)

    assert Max(5, 4) == 5

    # lists

    assert Max() is S.NegativeInfinity
    assert Max(x) == x
    assert Max(x, y) == Max(y, x)
    assert Max(x, y, z) == Max(z, y, x)
    assert Max(x, Max(y, z)) == Max(z, y, x)
    assert Max(x, Min(y, oo)) == Max(x, y)
    assert Max(n, -oo, n_, p, 2) == Max(p, 2)
    assert Max(n, -oo, n_, p) == p
    assert Max(2, x, p, n, -oo, S.NegativeInfinity, n_, p, 2) == Max(2, x, p)
    assert Max(0, x, 1, y) == Max(1, x, y)
    assert Max(r, r + 1, r - 1) == 1 + r
    assert Max(1000, 100, -100, x, p, n) == Max(p, x, 1000)
    assert Max(cos(x), sin(x)) == Max(sin(x), cos(x))
    assert Max(cos(x), sin(x)).subs(x, 1) == sin(1)
    assert Max(cos(x), sin(x)).subs(x, S.Half) == cos(S.Half)
    raises(ValueError, lambda: Max(cos(x), sin(x)).subs(x, I))
    raises(ValueError, lambda: Max(I))
    raises(ValueError, lambda: Max(I, x))
    raises(ValueError, lambda: Max(S.ComplexInfinity, 1))
    assert Max(n, -oo, n_, p, 2) == Max(p, 2)
    assert Max(n, -oo, n_, p, 1000) == Max(p, 1000)

    assert Max(1, x).diff(x) == Heaviside(x - 1)
    assert Max(x, 1).diff(x) == Heaviside(x - 1)
    assert Max(x**2, 1 + x, 1).diff(x) == \
        2*x*Heaviside(x**2 - Max(1, x + 1)) \
        + Heaviside(x - Max(1, x**2) + 1)

    e = Max(0, x)
    assert e.n().args == (0, x)

    # issue 8643
    m = Max(p, p_, n, r)
    assert m.is_positive is True
    assert m.is_nonnegative is True
    assert m.is_negative is False

    m = Max(n, n_)
    assert m.is_positive is False
    assert m.is_nonnegative is False
    assert m.is_negative is True

    m = Max(n, n_, r)
    assert m.is_positive is None
    assert m.is_nonnegative is None
    assert m.is_negative is None

    m = Max(n, nn, r)
    assert m.is_positive is None
    assert m.is_nonnegative is True
    assert m.is_negative is False
Exemplo n.º 2
0
def test_factor_terms():
    A = Symbol('A', commutative=False)
    assert factor_terms(9*(x + x*y + 1) + (3*x + 3)**(2 + 2*x)) == \
        9*x*y + 9*x + _keep_coeff(S(3), x + 1)**_keep_coeff(S(2), x + 1) + 9
    assert factor_terms(9*(x + x*y + 1) + (3)**(2 + 2*x)) == \
        _keep_coeff(S(9), 3**(2*x) + x*y + x + 1)
    assert factor_terms(3**(2 + 2*x) + a*3**(2 + 2*x)) == \
        9*3**(2*x)*(a + 1)
    assert factor_terms(x + x*A) == \
        x*(1 + A)
    assert factor_terms(sin(x + x*A)) == \
        sin(x*(1 + A))
    assert factor_terms((3*x + 3)**((2 + 2*x)/3)) == \
        _keep_coeff(S(3), x + 1)**_keep_coeff(Rational(2, 3), x + 1)
    assert factor_terms(x + (x*y + x)**(3*x + 3)) == \
        x + (x*(y + 1))**_keep_coeff(S(3), x + 1)
    assert factor_terms(a*(x + x*y) + b*(x*2 + y*x*2)) == \
        x*(a + 2*b)*(y + 1)
    i = Integral(x, (x, 0, oo))
    assert factor_terms(i) == i

    assert factor_terms(x / 2 + y) == x / 2 + y
    # fraction doesn't apply to integer denominators
    assert factor_terms(x / 2 + y, fraction=True) == x / 2 + y
    # clear *does* apply to the integer denominators
    assert factor_terms(x / 2 + y, clear=True) == Mul(S.Half,
                                                      x + 2 * y,
                                                      evaluate=False)

    # check radical extraction
    eq = sqrt(2) + sqrt(10)
    assert factor_terms(eq) == eq
    assert factor_terms(eq, radical=True) == sqrt(2) * (1 + sqrt(5))
    eq = root(-6, 3) + root(6, 3)
    assert factor_terms(
        eq, radical=True) == 6**(S.One / 3) * (1 + (-1)**(S.One / 3))

    eq = [x + x * y]
    ans = [x * (y + 1)]
    for c in [list, tuple, set]:
        assert factor_terms(c(eq)) == c(ans)
    assert factor_terms(Tuple(x + x * y)) == Tuple(x * (y + 1))
    assert factor_terms(Interval(0, 1)) == Interval(0, 1)
    e = 1 / sqrt(a / 2 + 1)
    assert factor_terms(e, clear=False) == 1 / sqrt(a / 2 + 1)
    assert factor_terms(e, clear=True) == sqrt(2) / sqrt(a + 2)

    eq = x / (x + 1 / x) + 1 / (x**2 + 1)
    assert factor_terms(eq, fraction=False) == eq
    assert factor_terms(eq, fraction=True) == 1

    assert factor_terms((1/(x**3 + x**2) + 2/x**2)*y) == \
        y*(2 + 1/(x + 1))/x**2

    # if not True, then processesing for this in factor_terms is not necessary
    assert gcd_terms(-x - y) == -x - y
    assert factor_terms(-x - y) == Mul(-1, x + y, evaluate=False)

    # if not True, then "special" processesing in factor_terms is not necessary
    assert gcd_terms(exp(Mul(-1, x + 1))) == exp(-x - 1)
    e = exp(-x - 2) + x
    assert factor_terms(e) == exp(Mul(-1, x + 2, evaluate=False)) + x
    assert factor_terms(e, sign=False) == e
    assert factor_terms(exp(-4 * x - 2) -
                        x) == -x + exp(Mul(-2, 2 * x + 1, evaluate=False))

    # sum/integral tests
    for F in (Sum, Integral):
        assert factor_terms(F(x, (y, 1, 10))) == x * F(1, (y, 1, 10))
        assert factor_terms(F(x, (y, 1, 10)) + x) == x * (1 + F(1, (y, 1, 10)))
        assert factor_terms(F(x * y + x * y**2,
                              (y, 1, 10))) == x * F(y * (y + 1), (y, 1, 10))

    # expressions involving Pow terms with base 0
    assert factor_terms(0**(x - 2) - 1) == 0**(x - 2) - 1
    assert factor_terms(0**(x + 2) - 1) == 0**(x + 2) - 1
    assert factor_terms((0**(x + 2) - 1).subs(x, -2)) == 0
Exemplo n.º 3
0
def rsolve_poly(coeffs, f, n, **hints):
    """
    Given linear recurrence operator `\operatorname{L}` of order
    `k` with polynomial coefficients and inhomogeneous equation
    `\operatorname{L} y = f`, where `f` is a polynomial, we seek for
    all polynomial solutions over field `K` of characteristic zero.

    The algorithm performs two basic steps:

        (1) Compute degree `N` of the general polynomial solution.
        (2) Find all polynomials of degree `N` or less
            of `\operatorname{L} y = f`.

    There are two methods for computing the polynomial solutions.
    If the degree bound is relatively small, i.e. it's smaller than
    or equal to the order of the recurrence, then naive method of
    undetermined coefficients is being used. This gives system
    of algebraic equations with `N+1` unknowns.

    In the other case, the algorithm performs transformation of the
    initial equation to an equivalent one, for which the system of
    algebraic equations has only `r` indeterminates. This method is
    quite sophisticated (in comparison with the naive one) and was
    invented together by Abramov, Bronstein and Petkovsek.

    It is possible to generalize the algorithm implemented here to
    the case of linear q-difference and differential equations.

    Lets say that we would like to compute `m`-th Bernoulli polynomial
    up to a constant. For this we can use `b(n+1) - b(n) = m n^{m-1}`
    recurrence, which has solution `b(n) = B_m + C`. For example:

    >>> from sympy import Symbol, rsolve_poly
    >>> n = Symbol('n', integer=True)

    >>> rsolve_poly([-1, 1], 4*n**3, n)
    C0 + n**4 - 2*n**3 + n**2

    References
    ==========

    .. [1] S. A. Abramov, M. Bronstein and M. Petkovsek, On polynomial
           solutions of linear operator equations, in: T. Levelt, ed.,
           Proc. ISSAC '95, ACM Press, New York, 1995, 290-296.

    .. [2] M. Petkovsek, Hypergeometric solutions of linear recurrences
           with polynomial coefficients, J. Symbolic Computation,
           14 (1992), 243-264.

    .. [3] M. Petkovsek, H. S. Wilf, D. Zeilberger, A = B, 1996.

    """
    f = sympify(f)

    if not f.is_polynomial(n):
        return None

    homogeneous = f.is_zero

    r = len(coeffs) - 1

    coeffs = [Poly(coeff, n) for coeff in coeffs]

    polys = [Poly(0, n)] * (r + 1)
    terms = [(S.Zero, S.NegativeInfinity)] * (r + 1)

    for i in xrange(0, r + 1):
        for j in xrange(i, r + 1):
            polys[i] += coeffs[j] * binomial(j, i)

        if not polys[i].is_zero:
            (exp, ), coeff = polys[i].LT()
            terms[i] = (coeff, exp)

    d = b = terms[0][1]

    for i in xrange(1, r + 1):
        if terms[i][1] > d:
            d = terms[i][1]

        if terms[i][1] - i > b:
            b = terms[i][1] - i

    d, b = int(d), int(b)

    x = Dummy('x')

    degree_poly = S.Zero

    for i in xrange(0, r + 1):
        if terms[i][1] - i == b:
            degree_poly += terms[i][0] * FallingFactorial(x, i)

    nni_roots = list(
        roots(degree_poly, x, filter='Z', predicate=lambda r: r >= 0).keys())

    if nni_roots:
        N = [max(nni_roots)]
    else:
        N = []

    if homogeneous:
        N += [-b - 1]
    else:
        N += [f.as_poly(n).degree() - b, -b - 1]

    N = int(max(N))

    if N < 0:
        if homogeneous:
            if hints.get('symbols', False):
                return (S.Zero, [])
            else:
                return S.Zero
        else:
            return None

    if N <= r:
        C = []
        y = E = S.Zero

        for i in xrange(0, N + 1):
            C.append(Symbol('C' + str(i)))
            y += C[i] * n**i

        for i in xrange(0, r + 1):
            E += coeffs[i].as_expr() * y.subs(n, n + i)

        solutions = solve_undetermined_coeffs(E - f, C, n)

        if solutions is not None:
            C = [c for c in C if (c not in solutions)]
            result = y.subs(solutions)
        else:
            return None  # TBD
    else:
        A = r
        U = N + A + b + 1

        nni_roots = list(
            roots(polys[r], filter='Z', predicate=lambda r: r >= 0).keys())

        if nni_roots != []:
            a = max(nni_roots) + 1
        else:
            a = S.Zero

        def _zero_vector(k):
            return [S.Zero] * k

        def _one_vector(k):
            return [S.One] * k

        def _delta(p, k):
            B = S.One
            D = p.subs(n, a + k)

            for i in xrange(1, k + 1):
                B *= -Rational(k - i + 1, i)
                D += B * p.subs(n, a + k - i)

            return D

        alpha = {}

        for i in xrange(-A, d + 1):
            I = _one_vector(d + 1)

            for k in xrange(1, d + 1):
                I[k] = I[k - 1] * (x + i - k + 1) / k

            alpha[i] = S.Zero

            for j in xrange(0, A + 1):
                for k in xrange(0, d + 1):
                    B = binomial(k, i + j)
                    D = _delta(polys[j].as_expr(), k)

                    alpha[i] += I[k] * B * D

        V = Matrix(U, A, lambda i, j: int(i == j))

        if homogeneous:
            for i in xrange(A, U):
                v = _zero_vector(A)

                for k in xrange(1, A + b + 1):
                    if i - k < 0:
                        break

                    B = alpha[k - A].subs(x, i - k)

                    for j in xrange(0, A):
                        v[j] += B * V[i - k, j]

                denom = alpha[-A].subs(x, i)

                for j in xrange(0, A):
                    V[i, j] = -v[j] / denom
        else:
            G = _zero_vector(U)

            for i in xrange(A, U):
                v = _zero_vector(A)
                g = S.Zero

                for k in xrange(1, A + b + 1):
                    if i - k < 0:
                        break

                    B = alpha[k - A].subs(x, i - k)

                    for j in xrange(0, A):
                        v[j] += B * V[i - k, j]

                    g += B * G[i - k]

                denom = alpha[-A].subs(x, i)

                for j in xrange(0, A):
                    V[i, j] = -v[j] / denom

                G[i] = (_delta(f, i - A) - g) / denom

        P, Q = _one_vector(U), _zero_vector(A)

        for i in xrange(1, U):
            P[i] = (P[i - 1] * (n - a - i + 1) / i).expand()

        for i in xrange(0, A):
            Q[i] = Add(*[(v * p).expand() for v, p in zip(V[:, i], P)])

        if not homogeneous:
            h = Add(*[(g * p).expand() for g, p in zip(G, P)])

        C = [Symbol('C' + str(i)) for i in xrange(0, A)]

        g = lambda i: Add(*[c * _delta(q, i) for c, q in zip(C, Q)])

        if homogeneous:
            E = [g(i) for i in xrange(N + 1, U)]
        else:
            E = [g(i) + _delta(h, i) for i in xrange(N + 1, U)]

        if E != []:
            solutions = solve(E, *C)

            if not solutions:
                if homogeneous:
                    if hints.get('symbols', False):
                        return (S.Zero, [])
                    else:
                        return S.Zero
                else:
                    return None
        else:
            solutions = {}

        if homogeneous:
            result = S.Zero
        else:
            result = h

        for c, q in list(zip(C, Q)):
            if c in solutions:
                s = solutions[c] * q
                C.remove(c)
            else:
                s = c * q

            result += s.expand()

    if hints.get('symbols', False):
        return (result, C)
    else:
        return result
Exemplo n.º 4
0
def test_multiplication():
    a = ArithmeticOnlyMatrix((
        (1, 2),
        (3, 1),
        (0, 6),
    ))

    b = ArithmeticOnlyMatrix((
        (1, 2),
        (3, 0),
    ))

    raises(ShapeError, lambda: b*a)
    raises(TypeError, lambda: a*{})

    c = a*b
    assert c[0, 0] == 7
    assert c[0, 1] == 2
    assert c[1, 0] == 6
    assert c[1, 1] == 6
    assert c[2, 0] == 18
    assert c[2, 1] == 0

    try:
        eval('c = a @ b')
    except SyntaxError:
        pass
    else:
        assert c[0, 0] == 7
        assert c[0, 1] == 2
        assert c[1, 0] == 6
        assert c[1, 1] == 6
        assert c[2, 0] == 18
        assert c[2, 1] == 0

    h = a.multiply_elementwise(c)
    assert h == matrix_multiply_elementwise(a, c)
    assert h[0, 0] == 7
    assert h[0, 1] == 4
    assert h[1, 0] == 18
    assert h[1, 1] == 6
    assert h[2, 0] == 0
    assert h[2, 1] == 0
    raises(ShapeError, lambda: a.multiply_elementwise(b))

    c = b * Symbol("x")
    assert isinstance(c, ArithmeticOnlyMatrix)
    assert c[0, 0] == x
    assert c[0, 1] == 2*x
    assert c[1, 0] == 3*x
    assert c[1, 1] == 0

    c2 = x * b
    assert c == c2

    c = 5 * b
    assert isinstance(c, ArithmeticOnlyMatrix)
    assert c[0, 0] == 5
    assert c[0, 1] == 2*5
    assert c[1, 0] == 3*5
    assert c[1, 1] == 0

    try:
        eval('c = 5 @ b')
    except SyntaxError:
        pass
    else:
        assert isinstance(c, ArithmeticOnlyMatrix)
        assert c[0, 0] == 5
        assert c[0, 1] == 2*5
        assert c[1, 0] == 3*5
        assert c[1, 1] == 0
Exemplo n.º 5
0
from sympy.core.symbol import Symbol
from sympy.codegen.ast import Type
from sympy.codegen.cxxnodes import using
from sympy.printing.cxxcode import cxxcode

x = Symbol('x')


def test_using():
    v = Type('std::vector')
    u1 = using(v)
    assert cxxcode(u1) == 'using std::vector'

    u2 = using(v, 'vec')
    assert cxxcode(u2) == 'using vec = std::vector'
Exemplo n.º 6
0
def test_exp_fdiff():
    x = Symbol('x')
    raises(ArgumentIndexError, lambda: exp(x).fdiff(2))
Exemplo n.º 7
0
def test_issue_9116():
    n = Symbol('n', positive=True, integer=True)
    assert log(n).is_nonnegative is True
def test_core_power():
    x = Symbol("x")
    for c in (Pow, Pow(x, 4)):
        check(c)
def test_core_function():
    x = Symbol("x")
    for f in (Derivative, Derivative(x), Function, FunctionClass, Lambda,
              WildFunction):
        check(f)
def test_core_add():
    x = Symbol("x")
    for c in (Add, Add(x, 4)):
        check(c)
def test_core_mul():
    x = Symbol("x")
    for c in (Mul, Mul(x, 4)):
        check(c)
Exemplo n.º 12
0
def test_rubi_printer():
    #14819
    a = Symbol('a')
    assert rubi_printer(Not(a)) == 'Not(a)'
Exemplo n.º 13
0
def test_minmax_assumptions():
    r = Symbol('r', real=True)
    a = Symbol('a', real=True, algebraic=True)
    t = Symbol('t', real=True, transcendental=True)
    q = Symbol('q', rational=True)
    p = Symbol('p', irrational=True)
    n = Symbol('n', rational=True, integer=False)
    i = Symbol('i', integer=True)
    o = Symbol('o', odd=True)
    e = Symbol('e', even=True)
    k = Symbol('k', prime=True)
    reals = [r, a, t, q, p, n, i, o, e, k]

    for ext in (Max, Min):
        for x, y in it.product(reals, repeat=2):

            # Must be real
            assert ext(x, y).is_real

            # Algebraic?
            if x.is_algebraic and y.is_algebraic:
                assert ext(x, y).is_algebraic
            elif x.is_transcendental and y.is_transcendental:
                assert ext(x, y).is_transcendental
            else:
                assert ext(x, y).is_algebraic is None

            # Rational?
            if x.is_rational and y.is_rational:
                assert ext(x, y).is_rational
            elif x.is_irrational and y.is_irrational:
                assert ext(x, y).is_irrational
            else:
                assert ext(x, y).is_rational is None

            # Integer?
            if x.is_integer and y.is_integer:
                assert ext(x, y).is_integer
            elif x.is_noninteger and y.is_noninteger:
                assert ext(x, y).is_noninteger
            else:
                assert ext(x, y).is_integer is None

            # Odd?
            if x.is_odd and y.is_odd:
                assert ext(x, y).is_odd
            elif x.is_odd is False and y.is_odd is False:
                assert ext(x, y).is_odd is False
            else:
                assert ext(x, y).is_odd is None

            # Even?
            if x.is_even and y.is_even:
                assert ext(x, y).is_even
            elif x.is_even is False and y.is_even is False:
                assert ext(x, y).is_even is False
            else:
                assert ext(x, y).is_even is None

            # Prime?
            if x.is_prime and y.is_prime:
                assert ext(x, y).is_prime
            elif x.is_prime is False and y.is_prime is False:
                assert ext(x, y).is_prime is False
            else:
                assert ext(x, y).is_prime is None
Exemplo n.º 14
0
def test_Min():
    from sympy.abc import x, y, z
    n = Symbol('n', negative=True)
    n_ = Symbol('n_', negative=True)
    nn = Symbol('nn', nonnegative=True)
    nn_ = Symbol('nn_', nonnegative=True)
    p = Symbol('p', positive=True)
    p_ = Symbol('p_', positive=True)
    np = Symbol('np', nonpositive=True)
    np_ = Symbol('np_', nonpositive=True)
    r = Symbol('r', real=True)

    assert Min(5, 4) == 4
    assert Min(-oo, -oo) is -oo
    assert Min(-oo, n) is -oo
    assert Min(n, -oo) is -oo
    assert Min(-oo, np) is -oo
    assert Min(np, -oo) is -oo
    assert Min(-oo, 0) is -oo
    assert Min(0, -oo) is -oo
    assert Min(-oo, nn) is -oo
    assert Min(nn, -oo) is -oo
    assert Min(-oo, p) is -oo
    assert Min(p, -oo) is -oo
    assert Min(-oo, oo) is -oo
    assert Min(oo, -oo) is -oo
    assert Min(n, n) == n
    assert unchanged(Min, n, np)
    assert Min(np, n) == Min(n, np)
    assert Min(n, 0) == n
    assert Min(0, n) == n
    assert Min(n, nn) == n
    assert Min(nn, n) == n
    assert Min(n, p) == n
    assert Min(p, n) == n
    assert Min(n, oo) == n
    assert Min(oo, n) == n
    assert Min(np, np) == np
    assert Min(np, 0) == np
    assert Min(0, np) == np
    assert Min(np, nn) == np
    assert Min(nn, np) == np
    assert Min(np, p) == np
    assert Min(p, np) == np
    assert Min(np, oo) == np
    assert Min(oo, np) == np
    assert Min(0, 0) == 0
    assert Min(0, nn) == 0
    assert Min(nn, 0) == 0
    assert Min(0, p) == 0
    assert Min(p, 0) == 0
    assert Min(0, oo) == 0
    assert Min(oo, 0) == 0
    assert Min(nn, nn) == nn
    assert unchanged(Min, nn, p)
    assert Min(p, nn) == Min(nn, p)
    assert Min(nn, oo) == nn
    assert Min(oo, nn) == nn
    assert Min(p, p) == p
    assert Min(p, oo) == p
    assert Min(oo, p) == p
    assert Min(oo, oo) is oo

    assert Min(n, n_).func is Min
    assert Min(nn, nn_).func is Min
    assert Min(np, np_).func is Min
    assert Min(p, p_).func is Min

    # lists
    assert Min() is S.Infinity
    assert Min(x) == x
    assert Min(x, y) == Min(y, x)
    assert Min(x, y, z) == Min(z, y, x)
    assert Min(x, Min(y, z)) == Min(z, y, x)
    assert Min(x, Max(y, -oo)) == Min(x, y)
    assert Min(p, oo, n, p, p, p_) == n
    assert Min(p_, n_, p) == n_
    assert Min(n, oo, -7, p, p, 2) == Min(n, -7)
    assert Min(2, x, p, n, oo, n_, p, 2, -2, -2) == Min(-2, x, n, n_)
    assert Min(0, x, 1, y) == Min(0, x, y)
    assert Min(1000, 100, -100, x, p, n) == Min(n, x, -100)
    assert unchanged(Min, sin(x), cos(x))
    assert Min(sin(x), cos(x)) == Min(cos(x), sin(x))
    assert Min(cos(x), sin(x)).subs(x, 1) == cos(1)
    assert Min(cos(x), sin(x)).subs(x, S.Half) == sin(S.Half)
    raises(ValueError, lambda: Min(cos(x), sin(x)).subs(x, I))
    raises(ValueError, lambda: Min(I))
    raises(ValueError, lambda: Min(I, x))
    raises(ValueError, lambda: Min(S.ComplexInfinity, x))

    assert Min(1, x).diff(x) == Heaviside(1 - x)
    assert Min(x, 1).diff(x) == Heaviside(1 - x)
    assert Min(0, -x, 1 - 2*x).diff(x) == -Heaviside(x + Min(0, -2*x + 1)) \
        - 2*Heaviside(2*x + Min(0, -x) - 1)

    # issue 7619
    f = Function('f')
    assert Min(1, 2 * Min(f(1), 2))  # doesn't fail

    # issue 7233
    e = Min(0, x)
    assert e.n().args == (0, x)

    # issue 8643
    m = Min(n, p_, n_, r)
    assert m.is_positive is False
    assert m.is_nonnegative is False
    assert m.is_negative is True

    m = Min(p, p_)
    assert m.is_positive is True
    assert m.is_nonnegative is True
    assert m.is_negative is False

    m = Min(p, nn_, p_)
    assert m.is_positive is None
    assert m.is_nonnegative is True
    assert m.is_negative is False

    m = Min(nn, p, r)
    assert m.is_positive is None
    assert m.is_nonnegative is None
    assert m.is_negative is None
Exemplo n.º 15
0
def kernS(s):
    """Use a hack to try keep autosimplification from joining Integer or
    minus sign into an Add of a Mul; this modification doesn't
    prevent the 2-arg Mul from becoming an Add, however.

    Examples
    ========

    >>> from sympy.core.sympify import kernS
    >>> from sympy.abc import x, y, z

    The 2-arg Mul allows a leading Integer to be distributed but kernS will
    prevent that:

    >>> 2*(x + y)
    2*x + 2*y
    >>> kernS('2*(x + y)')
    2*(x + y)

    If use of the hack fails, the un-hacked string will be passed to sympify...
    and you get what you get.

    XXX This hack should not be necessary once issue 1497 has been resolved.
    """
    import re
    from sympy.core.symbol import Symbol

    hit = False
    if '(' in s:
        if s.count('(') != s.count(")"):
            raise SympifyError('unmatched left parenthesis')

        kern = '_kern'
        while kern in s:
            kern += "_"
        olds = s
        # digits*( -> digits*kern*(
        s = re.sub(r'(\d+)( *\* *)\(', r'\1*%s\2(' % kern, s)
        # negated parenthetical
        kern2 = kern + "2"
        while kern2 in s:
            kern2 += "_"
        # step 1:  -(...)  -->  kern-kern*(...)
        target = r'%s-%s*(' % (kern, kern)
        s = re.sub(r'- *\(', target, s)
        # step 2: double the matching closing parenthesis
        # kern-kern*(...)  -->  kern-kern*(...)kern2
        i = nest = 0
        while True:
            j = s.find(target, i)
            if j == -1:
                break
            j = s.find('(')
            for j in range(j, len(s)):
                if s[j] == "(":
                    nest += 1
                elif s[j] == ")":
                    nest -= 1
                if nest == 0:
                    break
            s = s[:j] + kern2 + s[j:]
            i = j
        # step 3: put in the parentheses
        # kern-kern*(...)kern2  -->  (-kern*(...))
        s = s.replace(target, target.replace(kern, "(", 1))
        s = s.replace(kern2, ')')
        hit = kern in s

    for i in range(2):
        try:
            expr = sympify(s)
            break
        except:  # the kern might cause unknown errors, so use bare except
            if hit:
                s = olds  # maybe it didn't like the kern; use un-kerned s
                hit = False
                continue
            expr = sympify(s)  # let original error raise

    if not hit:
        return expr

    rep = {Symbol(kern): 1}

    def _clear(expr):
        if isinstance(expr, (list, tuple, set)):
            return type(expr)([_clear(e) for e in expr])
        if hasattr(expr, 'subs'):
            return expr.subs(rep, hack2=True)
        return expr

    expr = _clear(expr)
    # hope that kern is not there anymore
    return expr
def test_integrals():
    x = Symbol("x")
    for c in (Integral, Integral(x)):
        check(c)
Exemplo n.º 17
0
def roots_quintic(f):
    """
    Calculate exact roots of a solvable quintic
    """
    result = []
    coeff_5, coeff_4, p, q, r, s = f.all_coeffs()

    # Eqn must be of the form x^5 + px^3 + qx^2 + rx + s
    if coeff_4:
        return result

    if coeff_5 != 1:
        l = [p / coeff_5, q / coeff_5, r / coeff_5, s / coeff_5]
        if not all(coeff.is_Rational for coeff in l):
            return result
        f = Poly(f / coeff_5)
    quintic = PolyQuintic(f)

    # Eqn standardized. Algo for solving starts here
    if not f.is_irreducible:
        return result

    f20 = quintic.f20
    # Check if f20 has linear factors over domain Z
    if f20.is_irreducible:
        return result

    # Now, we know that f is solvable
    for _factor in f20.factor_list()[1]:
        if _factor[0].is_linear:
            theta = _factor[0].root(0)
            break
    d = discriminant(f)
    delta = sqrt(d)
    # zeta = a fifth root of unity
    zeta1, zeta2, zeta3, zeta4 = quintic.zeta
    T = quintic.T(theta, d)
    tol = S(1e-10)
    alpha = T[1] + T[2] * delta
    alpha_bar = T[1] - T[2] * delta
    beta = T[3] + T[4] * delta
    beta_bar = T[3] - T[4] * delta

    disc = alpha**2 - 4 * beta
    disc_bar = alpha_bar**2 - 4 * beta_bar

    l0 = quintic.l0(theta)

    l1 = _quintic_simplify((-alpha + sqrt(disc)) / S(2))
    l4 = _quintic_simplify((-alpha - sqrt(disc)) / S(2))

    l2 = _quintic_simplify((-alpha_bar + sqrt(disc_bar)) / S(2))
    l3 = _quintic_simplify((-alpha_bar - sqrt(disc_bar)) / S(2))

    order = quintic.order(theta, d)
    test = (order * delta.n()) - ((l1.n() - l4.n()) * (l2.n() - l3.n()))
    # Comparing floats
    if not comp(test, 0, tol):
        l2, l3 = l3, l2

    # Now we have correct order of l's
    R1 = l0 + l1 * zeta1 + l2 * zeta2 + l3 * zeta3 + l4 * zeta4
    R2 = l0 + l3 * zeta1 + l1 * zeta2 + l4 * zeta3 + l2 * zeta4
    R3 = l0 + l2 * zeta1 + l4 * zeta2 + l1 * zeta3 + l3 * zeta4
    R4 = l0 + l4 * zeta1 + l3 * zeta2 + l2 * zeta3 + l1 * zeta4

    Res = [None, [None] * 5, [None] * 5, [None] * 5, [None] * 5]
    Res_n = [None, [None] * 5, [None] * 5, [None] * 5, [None] * 5]
    sol = Symbol('sol')

    # Simplifying improves performance a lot for exact expressions
    R1 = _quintic_simplify(R1)
    R2 = _quintic_simplify(R2)
    R3 = _quintic_simplify(R3)
    R4 = _quintic_simplify(R4)

    # Solve imported here. Causing problems if imported as 'solve'
    # and hence the changed name
    from sympy.solvers.solvers import solve as _solve
    a, b = symbols('a b', cls=Dummy)
    _sol = _solve(sol**5 - a - I * b, sol)
    for i in range(5):
        _sol[i] = factor(_sol[i])
    R1 = R1.as_real_imag()
    R2 = R2.as_real_imag()
    R3 = R3.as_real_imag()
    R4 = R4.as_real_imag()

    for i, root in enumerate(_sol):
        Res[1][i] = _quintic_simplify(root.subs({a: R1[0], b: R1[1]}))
        Res[2][i] = _quintic_simplify(root.subs({a: R2[0], b: R2[1]}))
        Res[3][i] = _quintic_simplify(root.subs({a: R3[0], b: R3[1]}))
        Res[4][i] = _quintic_simplify(root.subs({a: R4[0], b: R4[1]}))

    for i in range(1, 5):
        for j in range(5):
            Res_n[i][j] = Res[i][j].n()
            Res[i][j] = _quintic_simplify(Res[i][j])
    r1 = Res[1][0]
    r1_n = Res_n[1][0]

    for i in range(5):
        if comp(im(r1_n * Res_n[4][i]), 0, tol):
            r4 = Res[4][i]
            break

    u, v = quintic.uv(theta, d)
    sqrt5 = math.sqrt(5)

    # Now we have various Res values. Each will be a list of five
    # values. We have to pick one r value from those five for each Res
    u, v = quintic.uv(theta, d)
    testplus = (u + v * delta * sqrt(5)).n()
    testminus = (u - v * delta * sqrt(5)).n()

    # Evaluated numbers suffixed with _n
    # We will use evaluated numbers for calculation. Much faster.
    r4_n = r4.n()
    r2 = r3 = None

    for i in range(5):
        r2temp_n = Res_n[2][i]
        for j in range(5):
            # Again storing away the exact number and using
            # evaluated numbers in computations
            r3temp_n = Res_n[3][j]
            if (comp((r1_n * r2temp_n**2 + r4_n * r3temp_n**2 - testplus).n(),
                     0, tol) and comp(
                         (r3temp_n * r1_n**2 + r2temp_n * r4_n**2 -
                          testminus).n(), 0, tol)):
                r2 = Res[2][i]
                r3 = Res[3][j]
                break
        if r2:
            break

    # Now, we have r's so we can get roots
    x1 = (r1 + r2 + r3 + r4) / 5
    x2 = (r1 * zeta4 + r2 * zeta3 + r3 * zeta2 + r4 * zeta1) / 5
    x3 = (r1 * zeta3 + r2 * zeta1 + r3 * zeta4 + r4 * zeta2) / 5
    x4 = (r1 * zeta2 + r2 * zeta4 + r3 * zeta1 + r4 * zeta3) / 5
    x5 = (r1 * zeta1 + r2 * zeta2 + r3 * zeta3 + r4 * zeta4) / 5
    result = [x1, x2, x3, x4, x5]

    # Now check if solutions are distinct

    saw = set()
    for r in result:
        r = r.n(2)
        if r in saw:
            # Roots were identical. Abort, return []
            # and fall back to usual solve
            return []
        saw.add(r)
    return result
def test_pickling_polys_errors():
    from sympy.polys.polyerrors import (
        ExactQuotientFailed, OperationNotSupported, HeuristicGCDFailed,
        HomomorphismFailed, IsomorphismFailed, ExtraneousFactors,
        EvaluationFailed, RefinementFailed, CoercionFailed, NotInvertible,
        NotReversible, NotAlgebraic, DomainError, PolynomialError,
        UnificationFailed, GeneratorsError, GeneratorsNeeded,
        ComputationFailed, UnivariatePolynomialError,
        MultivariatePolynomialError, PolificationFailed, OptionError,
        FlagError)

    x = Symbol('x')

    # TODO: TypeError: __init__() takes at least 3 arguments (1 given)
    # for c in (ExactQuotientFailed, ExactQuotientFailed(x, 3*x, ZZ)):
    #    check(c)

    # TODO: TypeError: can't pickle instancemethod objects
    # for c in (OperationNotSupported, OperationNotSupported(Poly(x), Poly.gcd)):
    #    check(c)

    for c in (HeuristicGCDFailed, HeuristicGCDFailed()):
        check(c)

    for c in (HomomorphismFailed, HomomorphismFailed()):
        check(c)

    for c in (IsomorphismFailed, IsomorphismFailed()):
        check(c)

    for c in (ExtraneousFactors, ExtraneousFactors()):
        check(c)

    for c in (EvaluationFailed, EvaluationFailed()):
        check(c)

    for c in (RefinementFailed, RefinementFailed()):
        check(c)

    for c in (CoercionFailed, CoercionFailed()):
        check(c)

    for c in (NotInvertible, NotInvertible()):
        check(c)

    for c in (NotReversible, NotReversible()):
        check(c)

    for c in (NotAlgebraic, NotAlgebraic()):
        check(c)

    for c in (DomainError, DomainError()):
        check(c)

    for c in (PolynomialError, PolynomialError()):
        check(c)

    for c in (UnificationFailed, UnificationFailed()):
        check(c)

    for c in (GeneratorsError, GeneratorsError()):
        check(c)

    for c in (GeneratorsNeeded, GeneratorsNeeded()):
        check(c)

    # TODO: PicklingError: Can't pickle <function <lambda> at 0x38578c0>: it's not found as __main__.<lambda>
    # for c in (ComputationFailed, ComputationFailed(lambda t: t, 3, None)):
    #    check(c)

    for c in (UnivariatePolynomialError, UnivariatePolynomialError()):
        check(c)

    for c in (MultivariatePolynomialError, MultivariatePolynomialError()):
        check(c)

    # TODO: TypeError: __init__() takes at least 3 arguments (1 given)
    # for c in (PolificationFailed, PolificationFailed({}, x, x, False)):
    #    check(c)

    for c in (OptionError, OptionError()):
        check(c)

    for c in (FlagError, FlagError()):
        check(c)
Exemplo n.º 19
0
def test_log_fdiff():
    x = Symbol('x')
    raises(ArgumentIndexError, lambda: log(x).fdiff(2))
def test_series():
    e = Symbol("e")
    x = Symbol("x")
    for c in (Limit, Limit(e, x, 1), Order, Order(e)):
        check(c)
Exemplo n.º 21
0
def test_values():
    assert set(PropertiesOnlyMatrix(2, 2, [0, 1, 2, 3]
        ).values()) == {1, 2, 3}
    x = Symbol('x', real=True)
    assert set(PropertiesOnlyMatrix(2, 2, [x, 0, 0, 1]
        ).values()) == {x, 1}
def test_concrete():
    x = Symbol("x")
    for c in (Product, Product(x, (x, 2, 4)), Sum, Sum(x, (x, 2, 4))):
        check(c)
Exemplo n.º 23
0
from sympy.functions.elementary.integers import (ceiling, floor)
from sympy.functions.elementary.miscellaneous import (root, sqrt)
from sympy.functions.elementary.trigonometric import (asin, cos, csc, sec, sin, tan)
from sympy.integrals.integrals import Integral
from sympy.series.limits import Limit

from sympy.core.relational import Eq, Ne, Lt, Le, Gt, Ge
from sympy.physics.quantum.state import Bra, Ket
from sympy.abc import x, y, z, a, b, c, t, k, n
antlr4 = import_module("antlr4")

# disable tests if antlr4-python*-runtime is not present
if not antlr4:
    disabled = True

theta = Symbol('theta')
f = Function('f')


# shorthand definitions
def _Add(a, b):
    return Add(a, b, evaluate=False)


def _Mul(a, b):
    return Mul(a, b, evaluate=False)


def _Pow(a, b):
    return Pow(a, b, evaluate=False)
Exemplo n.º 24
0
from sympy.core.symbol import (Symbol, symbols)
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.trigonometric import (acos, cos, sin)
from sympy.sets import EmptySet
from sympy.simplify.simplify import simplify
from sympy.functions.elementary.trigonometric import tan
from sympy.geometry import (Circle, GeometryError, Line, Point, Ray,
    Segment, Triangle, intersection, Point3D, Line3D, Ray3D, Segment3D,
    Point2D, Line2D)
from sympy.geometry.line import Undecidable
from sympy.geometry.polygon import _asa as asa
from sympy.utilities.iterables import cartes
from sympy.testing.pytest import raises, warns, warns_deprecated_sympy


x = Symbol('x', real=True)
y = Symbol('y', real=True)
z = Symbol('z', real=True)
k = Symbol('k', real=True)
x1 = Symbol('x1', real=True)
y1 = Symbol('y1', real=True)
t = Symbol('t', real=True)
a, b = symbols('a,b', real=True)
m = symbols('m', real=True)


def test_object_from_equation():
    from sympy.abc import x, y, a, b
    assert Line(3*x + y + 18) == Line2D(Point2D(0, -18), Point2D(1, -21))
    assert Line(3*x + 5 * y + 1) == Line2D(
        Point2D(0, Rational(-1, 5)), Point2D(1, Rational(-4, 5)))
Exemplo n.º 25
0
def kernS(s):
    """Use a hack to try keep autosimplification from distributing a
    a number into an Add; this modification doesn't
    prevent the 2-arg Mul from becoming an Add, however.

    Examples
    ========

    >>> from sympy.core.sympify import kernS
    >>> from sympy.abc import x, y, z

    The 2-arg Mul distributes a number (or minus sign) across the terms
    of an expression, but kernS will prevent that:

    >>> 2*(x + y), -(x + 1)
    (2*x + 2*y, -x - 1)
    >>> kernS('2*(x + y)')
    2*(x + y)
    >>> kernS('-(x + 1)')
    -(x + 1)

    If use of the hack fails, the un-hacked string will be passed to sympify...
    and you get what you get.

    XXX This hack should not be necessary once issue 4596 has been resolved.
    """
    import string
    from random import choice
    from sympy.core.symbol import Symbol
    hit = False
    quoted = '"' in s or "'" in s
    if '(' in s and not quoted:
        if s.count('(') != s.count(")"):
            raise SympifyError('unmatched left parenthesis')

        # strip all space from s
        s = ''.join(s.split())
        olds = s
        # now use space to represent a symbol that
        # will
        # step 1. turn potential 2-arg Muls into 3-arg versions
        # 1a. *( -> * *(
        s = s.replace('*(', '* *(')
        # 1b. close up exponentials
        s = s.replace('** *', '**')
        # 2. handle the implied multiplication of a negated
        # parenthesized expression in two steps
        # 2a:  -(...)  -->  -( *(...)
        target = '-( *('
        s = s.replace('-(', target)
        # 2b: double the matching closing parenthesis
        # -( *(...)  -->  -( *(...))
        i = nest = 0
        assert target.endswith('(')  # assumption below
        while True:
            j = s.find(target, i)
            if j == -1:
                break
            j += len(target) - 1
            for j in range(j, len(s)):
                if s[j] == "(":
                    nest += 1
                elif s[j] == ")":
                    nest -= 1
                if nest == 0:
                    break
            s = s[:j] + ")" + s[j:]
            i = j + 2  # the first char after 2nd )
        if ' ' in s:
            # get a unique kern
            kern = '_'
            while kern in s:
                kern += choice(string.ascii_letters + string.digits)
            s = s.replace(' ', kern)
        hit = kern in s

    for i in range(2):
        try:
            expr = sympify(s)
            break
        # XXX: What exception can be caught here? Broad except should not be
        # used without a clear reason. Running the test suite does not lead to
        # any errors at this point...
        except TypeError:  # the kern might cause unknown errors...
            if hit:
                s = olds  # maybe it didn't like the kern; use un-kerned s
                hit = False
                continue
            expr = sympify(s)  # let original error raise

    if not hit:
        return expr

    rep = {Symbol(kern): 1}

    def _clear(expr):
        if isinstance(expr, (list, tuple, set)):
            return type(expr)([_clear(e) for e in expr])
        if hasattr(expr, 'subs'):
            return expr.subs(rep, hack2=True)
        return expr

    expr = _clear(expr)
    # hope that kern is not there anymore
    return expr
Exemplo n.º 26
0
def test_parameter_value():
    t = Symbol('t')
    p1, p2 = Point(0, 1), Point(5, 6)
    l = Line(p1, p2)
    assert l.parameter_value((5, 6), t) == {t: 1}
    raises(ValueError, lambda: l.parameter_value((0, 0), t))
Exemplo n.º 27
0
    NUC1 = 3
    NUC2 = 4
    ADD = 5
    DSB = 6
    ATR = 7
    ATM = 8
    p53 = 9
    TLS = 10
    FAHRR = 11
    HRR2 = 12
    NHEJ = 13
    CHKREC = 14


# Force enum constants' names to namespace as logical symbols
syms = {m: Symbol(m) for m in P.__members__}
globals().update(syms)
FANCJMLH1 = FANCJBRCA1

# Et voila,
rules_original = {
    ICL: ICL & ~DSB,
    FANCM: ICL & ~CHKREC,
    FAcore: FANCM & (ATR | ATM) & ~CHKREC,
    FANCD2I: FAcore & ((ATM | ATR) | (H2AX & DSB)) & ~USP1,
    MUS81: ICL,
    FANCJBRCA1: (ICL | ssDNARPA) & (ATM | ATR),
    XPF: MUS81,
    FAN1: MUS81 & FANCD2I,
    ADD: (ADD | (MUS81 & (FAN1 | XPF))) & ~PCNATLS,
    DSB: (DSB | FAN1 | XPF) & ~(NHEJ | HRR),
Exemplo n.º 28
0
def test_Min():
    n = Symbol('n', negative=True)
    n_ = Symbol('n_', negative=True)
    nn = Symbol('nn', nonnegative=True)
    nn_ = Symbol('nn_', nonnegative=True)
    p = Symbol('p', positive=True)
    p_ = Symbol('p_', positive=True)
    np = Symbol('np', nonpositive=True)
    np_ = Symbol('np_', nonpositive=True)
    assert Min(5, 4) == 4

    assert Min(-oo, -oo) == -oo
    assert Min(-oo, n) == -oo
    assert Min(n, -oo) == -oo
    assert Min(-oo, np) == -oo
    assert Min(np, -oo) == -oo
    assert Min(-oo, 0) == -oo
    assert Min(0, -oo) == -oo
    assert Min(-oo, nn) == -oo
    assert Min(nn, -oo) == -oo
    assert Min(-oo, p) == -oo
    assert Min(p, -oo) == -oo
    assert Min(-oo, oo) == -oo
    assert Min(oo, -oo) == -oo
    assert Min(n, n) == n
    assert Min(n, np) == Min(n, np)
    assert Min(np, n) == Min(np, n)
    assert Min(n, 0) == n
    assert Min(0, n) == n
    assert Min(n, nn) == n
    assert Min(nn, n) == n
    assert Min(n, p) == n
    assert Min(p, n) == n
    assert Min(n, oo) == n
    assert Min(oo, n) == n
    assert Min(np, np) == np
    assert Min(np, 0) == np
    assert Min(0, np) == np
    assert Min(np, nn) == np
    assert Min(nn, np) == np
    assert Min(np, p) == np
    assert Min(p, np) == np
    assert Min(np, oo) == np
    assert Min(oo, np) == np
    assert Min(0, 0) == 0
    assert Min(0, nn) == 0
    assert Min(nn, 0) == 0
    assert Min(0, p) == 0
    assert Min(p, 0) == 0
    assert Min(0, oo) == 0
    assert Min(oo, 0) == 0
    assert Min(nn, nn) == nn
    assert Min(nn, p) == Min(nn, p)
    assert Min(p, nn) == Min(p, nn)
    assert Min(nn, oo) == nn
    assert Min(oo, nn) == nn
    assert Min(p, p) == p
    assert Min(p, oo) == p
    assert Min(oo, p) == p
    assert Min(oo, oo) == oo

    assert Min(n, n_).func is Min
    assert Min(nn, nn_).func is Min
    assert Min(np, np_).func is Min
    assert Min(p, p_).func is Min
Exemplo n.º 29
0
def symarray(prefix, shape, **kwargs):  # pragma: no cover
    r"""Create a numpy ndarray of symbols (as an object array).

    The created symbols are named ``prefix_i1_i2_``...  You should thus provide a
    non-empty prefix if you want your symbols to be unique for different output
    arrays, as SymPy symbols with identical names are the same object.

    Parameters
    ----------

    prefix : string
      A prefix prepended to the name of every symbol.

    shape : int or tuple
      Shape of the created array.  If an int, the array is one-dimensional; for
      more than one dimension the shape must be a tuple.

    \*\*kwargs : dict
      keyword arguments passed on to Symbol

    Examples
    ========
    These doctests require numpy.

    >>> from sympy import symarray
    >>> symarray('', 3)
    [_0 _1 _2]

    If you want multiple symarrays to contain distinct symbols, you *must*
    provide unique prefixes:

    >>> a = symarray('', 3)
    >>> b = symarray('', 3)
    >>> a[0] == b[0]
    True
    >>> a = symarray('a', 3)
    >>> b = symarray('b', 3)
    >>> a[0] == b[0]
    False

    Creating symarrays with a prefix:

    >>> symarray('a', 3)
    [a_0 a_1 a_2]

    For more than one dimension, the shape must be given as a tuple:

    >>> symarray('a', (2, 3))
    [[a_0_0 a_0_1 a_0_2]
     [a_1_0 a_1_1 a_1_2]]
    >>> symarray('a', (2, 3, 2))
    [[[a_0_0_0 a_0_0_1]
      [a_0_1_0 a_0_1_1]
      [a_0_2_0 a_0_2_1]]
    <BLANKLINE>
     [[a_1_0_0 a_1_0_1]
      [a_1_1_0 a_1_1_1]
      [a_1_2_0 a_1_2_1]]]

    For setting assumptions of the underlying Symbols:

    >>> [s.is_real for s in symarray('a', 2, real=True)]
    [True, True]
    """
    from numpy import empty, ndindex
    arr = empty(shape, dtype=object)
    for index in ndindex(shape):
        arr[index] = Symbol('%s_%s' % (prefix, '_'.join(map(str, index))),
                            **kwargs)
    return arr
Exemplo n.º 30
0
    def __new__(cls, sym, condition, base_set=S.UniversalSet):
        # nonlinsolve uses ConditionSet to return an unsolved system
        # of equations (see _return_conditionset in solveset) so until
        # that is changed we do minimal checking of the args
        sym = _sympify(sym)
        base_set = _sympify(base_set)
        condition = _sympify(condition)

        if isinstance(condition, FiniteSet):
            condition_orig = condition
            temp = (Eq(lhs, 0) for lhs in condition)
            condition = And(*temp)
            SymPyDeprecationWarning(
                feature="Using {} for condition".format(condition_orig),
                issue=17651,
                deprecated_since_version='1.5',
                useinstead="{} for condition".format(condition)).warn()

        condition = as_Boolean(condition)

        if isinstance(sym, Tuple):  # unsolved eqns syntax
            return Basic.__new__(cls, sym, condition, base_set)

        if not isinstance(base_set, Set):
            raise TypeError('expecting set for base_set')

        if condition is S.false:
            return S.EmptySet
        elif condition is S.true:
            return base_set
        if isinstance(base_set, EmptySet):
            return base_set

        know = None
        if isinstance(base_set, FiniteSet):
            sifted = sift(base_set,
                          lambda _: fuzzy_bool(condition.subs(sym, _)))
            if sifted[None]:
                know = FiniteSet(*sifted[True])
                base_set = FiniteSet(*sifted[None])
            else:
                return FiniteSet(*sifted[True])

        if isinstance(base_set, cls):
            s, c, base_set = base_set.args
            if sym == s:
                condition = And(condition, c)
            elif sym not in c.free_symbols:
                condition = And(condition, c.xreplace({s: sym}))
            elif s not in condition.free_symbols:
                condition = And(condition.xreplace({sym: s}), c)
                sym = s
            else:
                # user will have to use cls.sym to get symbol
                dum = Symbol('lambda')
                if dum in condition.free_symbols or \
                        dum in c.free_symbols:
                    dum = Dummy(str(dum))
                condition = And(condition.xreplace({sym: dum}),
                                c.xreplace({s: dum}))
                sym = dum

        if not isinstance(sym, Symbol):
            s = Dummy('lambda')
            if s not in condition.xreplace({sym: s}).free_symbols:
                raise ValueError(
                    'non-symbol dummy not recognized in condition')

        rv = Basic.__new__(cls, sym, condition, base_set)
        return rv if know is None else Union(know, rv)