예제 #1
0
def test_xreplace():
    assert b21.xreplace({b2: b1}) == Basic(b1, b1)
    assert b21.xreplace({b2: b21}) == Basic(b21, b1)
    assert b3.xreplace({b2: b1}) == b2
    assert Basic(b1, b2).xreplace({b1: b2, b2: b1}) == Basic(b2, b1)
    assert Atom(b1).xreplace({b1: b2}) == Atom(b1)
    assert Atom(b1).xreplace({Atom(b1): b2}) == b2
    pytest.raises(TypeError, b1.xreplace)
    pytest.raises(TypeError, lambda: b1.xreplace([b1, b2]))
예제 #2
0
def test_sympyissue_6079():
    # since x + 2.0 == x + 2 we can't do a simple equality test
    assert _aresame((x + 2.0).subs({2: 3}), x + 2.0)
    assert _aresame((x + 2.0).subs({2.0: 3}), x + 3)
    assert not _aresame(x + 2, x + 2.0)
    assert not _aresame(Basic(cos, 1), Basic(cos, 1.))
    assert _aresame(cos, cos)
    assert not _aresame(1, Integer(1))
    assert not _aresame(x, symbols('x', positive=True))
예제 #3
0
def test_matches_basic():
    instances = [Basic(b1, b1, b2), Basic(b1, b2, b1), Basic(b2, b1, b1),
                 Basic(b1, b2), Basic(b2, b1), b2, b1]
    for i, b_i in enumerate(instances):
        for j, b_j in enumerate(instances):
            if i == j:
                assert b_j.match(b_i) == {}
            else:
                assert b_j.match(b_i) is None
    assert b1.match(b1) == {}
예제 #4
0
def test_equality():
    instances = [b1, b2, b3, b21, Basic(b1, b1, b1), Basic]
    for i, b_i in enumerate(instances):
        for j, b_j in enumerate(instances):
            assert (b_i == b_j) == (i == j)
            assert (b_i != b_j) == (i != j)

    assert Basic() != []
    assert not(Basic() == [])
    assert Basic() != 0
    assert not(Basic() == 0)
예제 #5
0
def test_equality():
    # pylint: disable=unneeded-not
    instances = [b1, b2, b3, b21, Basic(b1, b1, b1), Basic]
    for i, b_i in enumerate(instances):
        for j, b_j in enumerate(instances):
            assert (b_i == b_j) == (i == j)
            assert (b_i != b_j) == (i != j)

    assert Basic()
    assert Basic() != 0
    assert not Basic() == 0
예제 #6
0
def test_gcd_terms():
    f = 2*(x + 1)*(x + 4)/(5*x**2 + 5) + (2*x + 2)*(x + 5)/(x**2 + 1)/5 + \
        (2*x + 2)*(x + 6)/(5*x**2 + 5)

    assert _gcd_terms(f) == (Rational(6, 5) * ((1 + x) / (1 + x**2)), 5 + x, 1)
    assert _gcd_terms(Add.make_args(f)) == \
        (Rational(6, 5)*((1 + x)/(1 + x**2)), 5 + x, 1)

    newf = Rational(6, 5) * ((1 + x) * (5 + x) / (1 + x**2))
    assert gcd_terms(f) == newf
    args = Add.make_args(f)
    # non-Basic sequences of terms treated as terms of Add
    assert gcd_terms(list(args)) == newf
    assert gcd_terms(tuple(args)) == newf
    assert gcd_terms(set(args)) == newf
    # but a Basic sequence is treated as a container
    assert gcd_terms(Tuple(*args)) != newf
    assert gcd_terms(Basic(Tuple(1, 3*y + 3*x*y), Tuple(1, 3))) == \
        Basic((1, 3*y*(x + 1)), (1, 3))
    # but we shouldn't change keys of a dictionary or some may be lost
    assert gcd_terms(Dict((x*(1 + y), 2), (x + x*y, y + x*y))) == \
        Dict({x*(y + 1): 2, x + x*y: y*(1 + x)})

    assert gcd_terms((2 * x + 2)**3 +
                     (2 * x + 2)**2) == 4 * (x + 1)**2 * (2 * x + 3)

    assert gcd_terms(0) == 0
    assert gcd_terms(1) == 1
    assert gcd_terms(x) == x
    assert gcd_terms(2 + 2 * x) == Mul(2, 1 + x, evaluate=False)
    arg = x * (2 * x + 4 * y)
    garg = 2 * x * (x + 2 * y)
    assert gcd_terms(arg) == garg
    assert gcd_terms(sin(arg)) == sin(garg)

    # issue sympy/sympy#6139-like
    alpha, alpha1, alpha2, alpha3 = symbols('alpha:4')
    a = alpha**2 - alpha * x**2 + alpha + x**3 - x * (alpha + 1)
    rep = {
        alpha: (1 + sqrt(5)) / 2 + alpha1 * x + alpha2 * x**2 + alpha3 * x**3
    }
    s = (a / (x - alpha)).subs(rep).series(x, 0, 1)
    assert simplify(collect(s, x)) == -sqrt(5) / 2 - Rational(3, 2) + O(x)

    # issue sympy/sympy#5917
    assert _gcd_terms([Integer(0), Integer(0)]) == (0, 0, 1)
    assert _gcd_terms([2 * x + 4]) == (2, x + 2, 1)

    eq = x / (x + 1 / x)
    assert gcd_terms(eq, fraction=False) == eq
예제 #7
0
def test_Singleton():
    global instantiated
    instantiated = 0

    class MyNewSingleton(Basic, metaclass=Singleton):
        def __new__(cls):
            global instantiated
            instantiated += 1
            return Basic.__new__(cls)

    assert instantiated == 0
    MyNewSingleton()  # force instantiation
    assert instantiated == 1
    assert MyNewSingleton() is not Basic()
    assert MyNewSingleton() is MyNewSingleton()
    assert S.MyNewSingleton is MyNewSingleton()
    assert instantiated == 1

    class MySingletonSub(MyNewSingleton):
        pass

    assert instantiated == 1
    MySingletonSub()
    assert instantiated == 2
    assert MySingletonSub() is not MyNewSingleton()
    assert MySingletonSub() is MySingletonSub()
예제 #8
0
    def __new__(cls, mat):
        if not mat.is_Matrix:
            raise TypeError("input to Trace, %s, is not a matrix" % str(mat))

        if not mat.is_square:
            raise ShapeError("Trace of a non-square matrix")

        return Basic.__new__(cls, mat)
예제 #9
0
    def __new__(cls, mat):
        mat = sympify(mat)
        if not mat.is_Matrix:
            raise TypeError("Input to Determinant, %s, not a matrix" % str(mat))

        if not mat.is_square:
            raise ShapeError("Det of a non-square matrix")

        return Basic.__new__(cls, mat)
예제 #10
0
def test_styleof():
    styles = [(Basic, {
        'color': 'blue',
        'shape': 'ellipse'
    }), (Expr, {
        'color': 'black'
    })]
    assert styleof(Basic(1), styles) == {'color': 'blue', 'shape': 'ellipse'}

    assert styleof(x + 1, styles) == {'color': 'black', 'shape': 'ellipse'}
예제 #11
0
 def _new(cls, *args, **kwargs):
     shape, flat_list = cls._handle_ndarray_creation_inputs(*args, **kwargs)
     shape = Tuple(*map(_sympify, shape))
     flat_list = flatten(flat_list)
     flat_list = Tuple(*flat_list)
     self = Basic.__new__(cls, flat_list, shape, **kwargs)
     self._shape = shape
     self._array = tuple(flat_list)
     self._rank = len(shape)
     self._loop_size = functools.reduce(lambda x, y: x * y,
                                        shape) if shape else 0
     return self
예제 #12
0
 def __new__(cls, parent, rowslice, colslice):
     rowslice = normalize(rowslice, parent.shape[0])
     colslice = normalize(colslice, parent.shape[1])
     if not (len(rowslice) == len(colslice) == 3):
         raise IndexError()
     if ((0 > rowslice[0]) is S.true
             or (parent.shape[0] < rowslice[1]) is S.true
             or (0 > colslice[0]) is S.true
             or (parent.shape[1] < colslice[1]) is S.true):
         raise IndexError()
     if isinstance(parent, MatrixSlice):
         return mat_slice_of_slice(parent, rowslice, colslice)
     return Basic.__new__(cls, parent, Tuple(*rowslice), Tuple(*colslice))
예제 #13
0
    def __new__(cls, domain, condition):
        if condition is True:
            return domain
        cond = rv_subs(condition)
        # Check that we aren't passed a condition like die1 == z
        # where 'z' is a symbol that we don't know about
        # We will never be able to test this equality through iteration
        if not cond.free_symbols.issubset(domain.free_symbols):
            raise ValueError(
                'Condition "%s" contains foreign symbols \n%s.\n' %
                (condition, tuple(cond.free_symbols - domain.free_symbols)) +
                "Will be unable to iterate using this condition")

        return Basic.__new__(cls, domain, cond)
예제 #14
0
def test_Idx_construction():
    i, a = symbols('i a', integer=True)
    assert Idx(i) != Idx(i, 1)
    assert Idx(i, a) == Idx(i, (0, a - 1))
    assert Idx(i, oo) == Idx(i, (0, oo))

    pytest.raises(TypeError, lambda: Idx(x))
    pytest.raises(TypeError, lambda: Idx(0.5))
    pytest.raises(TypeError, lambda: Idx(i, x))
    pytest.raises(TypeError, lambda: Idx(i, 0.5))
    pytest.raises(TypeError, lambda: Idx(i, (x, 5)))
    pytest.raises(TypeError, lambda: Idx(i, (2, x)))
    pytest.raises(TypeError, lambda: Idx(i, (2, 3.5)))
    pytest.raises(ValueError, lambda: Idx(i, (1, 2, 3)))
    pytest.raises(TypeError, lambda: Idx(i, Basic()))
예제 #15
0
def test_sympyissue_6100():
    assert x**1.0 != x
    assert x != x**1.0
    assert true != x**1.0
    assert x**1.0 is not True
    assert x is not True
    assert x * y != (x * y)**1.0
    assert (x**1.0)**1.0 != x
    assert (x**1.0)**2.0 == x**2
    b = Basic()
    assert Pow(b, 1.0, evaluate=False) != b
    # if the following gets distributed as a Mul (x**1.0*y**1.0 then
    # __eq__ methods could be added to Symbol and Pow to detect the
    # power-of-1.0 case.
    assert isinstance((x * y)**1.0, Pow)
예제 #16
0
def test_subs():
    assert b21.subs({b2: b1}) == Basic(b1, b1)
    assert b21.subs({b2: b21}) == Basic(b21, b1)
    assert b3.subs({b2: b1}) == b2

    assert b21.subs([(b2, b1), (b1, b2)]) == Basic(b2, b2)

    assert b21.subs({b1: b2, b2: b1}) == Basic(b2, b2)

    pytest.raises(ValueError, lambda: b21.subs('bad arg'))
    pytest.raises(ValueError, lambda: b21.subs(b1, b2, b3))

    assert b21.subs(collections.ChainMap({b1: b2}, {b2: b1})) == Basic(b2, b2)
    assert b21.subs(collections.OrderedDict([(b2, b1), (b1, b2)])) == Basic(b2, b2)
예제 #17
0
    def __new__(cls, *domains):
        symbols = sumsets([domain.symbols for domain in domains])

        # Flatten any product of products
        domains2 = []
        for domain in domains:
            if not domain.is_ProductDomain:
                domains2.append(domain)
            else:
                domains2.extend(domain.domains)
        domains2 = FiniteSet(*domains2)

        if all(domain.is_Finite for domain in domains2):
            from diofant.stats.frv import ProductFiniteDomain
            cls = ProductFiniteDomain
        if all(domain.is_Continuous for domain in domains2):
            from diofant.stats.crv import ProductContinuousDomain
            cls = ProductContinuousDomain

        return Basic.__new__(cls, *domains2)
예제 #18
0
def test_preorder_traversal():
    expr = Basic(b21, b3)
    assert list(
        preorder_traversal(expr)) == [expr, b21, b2, b1, b1, b3, b2, b1]
    assert list(preorder_traversal(('abc', ('d', 'ef')))) == [
        ('abc', ('d', 'ef')), 'abc', ('d', 'ef'), 'd', 'ef']

    result = []
    pt = preorder_traversal(expr)
    for i in pt:
        result.append(i)
        if i == b2:
            pt.skip()
    assert result == [expr, b21, b2, b1, b3, b2]

    expr = z + w*(x + y)
    assert list(preorder_traversal([expr], keys=default_sort_key)) == \
        [[w*(x + y) + z], w*(x + y) + z, z, w*(x + y), w, x + y, x, y]
    assert list(preorder_traversal((x + y)*z, keys=True)) == \
        [z*(x + y), z, x + y, x, y]
예제 #19
0
    def __new__(cls, *spaces):
        rs_space_dict = {}
        for space in spaces:
            for value in space.values:
                rs_space_dict[value] = space

        symbols = FiniteSet(*[val.symbol for val in rs_space_dict.keys()])

        # Overlapping symbols
        if len(symbols) < sum(len(space.symbols) for space in spaces):
            raise ValueError("Overlapping Random Variables")

        if all(space.is_Finite for space in spaces):
            from diofant.stats.frv import ProductFinitePSpace
            cls = ProductFinitePSpace
        if all(space.is_Continuous for space in spaces):
            from diofant.stats.crv import ProductContinuousPSpace
            cls = ProductContinuousPSpace

        obj = Basic.__new__(cls, *FiniteSet(*spaces))

        return obj
예제 #20
0
def test_unpack():
    assert unpack(Basic(2)) == 2
    assert unpack(Basic(2, 3)) == Basic(2, 3)
예제 #21
0
def test_flatten():
    assert flatten(Basic(1, 2, Basic(3, 4))) == Basic(1, 2, 3, 4)
예제 #22
0
 def __new__(cls, rows, cols, lamda):
     rows, cols = sympify(rows), sympify(cols)
     return Basic.__new__(cls, rows, cols, lamda)
예제 #23
0
 def __new__(cls):
     global instantiated
     instantiated += 1
     return Basic.__new__(cls)
예제 #24
0
"""This tests the basic submodule with (ideally) no reference to subclasses."""

import collections

import pytest

from diofant import (Add, Atom, Basic, Function, I, Integral, Lambda, cos,
                     default_sort_key, exp, gamma, preorder_traversal, sin)
from diofant.abc import w, x, y, z
from diofant.core.singleton import S
from diofant.core.singleton import SingletonWithManagedProperties as Singleton

__all__ = ()

b1 = Basic()
b2 = Basic(b1)
b3 = Basic(b2)
b21 = Basic(b2, b1)


def test_structure():
    assert b21.args == (b2, b1)
    assert b21.func(*b21.args) == b21
    assert bool(b1)


def test_equality():
    # pylint: disable=unneeded-not
    instances = [b1, b2, b3, b21, Basic(b1, b1, b1), Basic]
    for i, b_i in enumerate(instances):
        for j, b_j in enumerate(instances):
예제 #25
0
 def __new__(cls, name, antisym, **kwargs):
     obj = Basic.__new__(cls, name, antisym, **kwargs)
     obj.name = name
     obj.antisym = antisym
     return obj
예제 #26
0
def test_sort():
    assert sort(str)(Basic(3, 1, 2)) == Basic(1, 2, 3)
예제 #27
0
def test_rm_id():
    rmzeros = rm_id(lambda x: x == 0)
    assert rmzeros(Basic(0, 1)) == Basic(1)
    assert rmzeros(Basic(0, 0)) == Basic(0)
    assert rmzeros(Basic(2, 1)) == Basic(2, 1)
예제 #28
0
 def __new__(cls, *args):
     args = list(map(sympify, args))
     return Basic.__new__(cls, *args)
예제 #29
0
 def __new__(cls, density):
     density = Dict(density)
     return Basic.__new__(cls, density)
예제 #30
0
 def __new__(cls, symbol, set):
     if not isinstance(set, FiniteSet):
         set = FiniteSet(*set)
     return Basic.__new__(cls, symbol, set)