예제 #1
0
def test_no_stdlib_collections():
    # Make sure we get the right collections when it is not part of a
    # larger list.

    matplotlib = import_module(
        'matplotlib',
        import__kwargs={'fromlist': ['cm', 'collections']},
        min_module_version='1.1.0',
        catch=(RuntimeError, ))
    if matplotlib:
        assert collections != matplotlib.collections

    matplotlib = import_module('matplotlib',
                               import__kwargs={'fromlist': ['collections']},
                               min_module_version='1.1.0',
                               catch=(RuntimeError, ))
    if matplotlib:
        assert collections != matplotlib.collections

    # Make sure we get the right collections with no catch.
    matplotlib = import_module(
        'matplotlib',
        import__kwargs={'fromlist': ['cm', 'collections']},
        min_module_version='1.1.0')
    if matplotlib:
        assert collections != matplotlib.collections
예제 #2
0
def sin(x):
    """evaluates the sine of an interval"""
    np = import_module('numpy')
    if isinstance(x, (int, float)):
        return interval(np.sin(x))
    elif isinstance(x, interval):
        if not x.is_valid:
            return interval(-1, 1, is_valid=x.is_valid)
        na, __ = divmod(x.start, np.pi / 2.0)
        nb, __ = divmod(x.end, np.pi / 2.0)
        start = min(np.sin(x.start), np.sin(x.end))
        end = max(np.sin(x.start), np.sin(x.end))
        if nb - na > 4:
            return interval(-1, 1, is_valid=x.is_valid)
        elif na == nb:
            return interval(start, end, is_valid=x.is_valid)
        else:
            if (na - 1) // 4 != (nb - 1) // 4:
                # sin has max
                end = 1
            if (na - 3) // 4 != (nb - 3) // 4:
                # sin has min
                start = -1
            return interval(start, end)
    else:  # pragma: no cover
        raise NotImplementedError
예제 #3
0
def cos(x):
    """Evaluates the cos of an interval"""
    np = import_module('numpy')
    if isinstance(x, (int, float)):
        return interval(np.sin(x))
    elif isinstance(x, interval):
        if not (np.isfinite(x.start) and np.isfinite(x.end)):
            return interval(-1, 1, is_valid=x.is_valid)
        na, __ = divmod(x.start, np.pi / 2.0)
        nb, __ = divmod(x.end, np.pi / 2.0)
        start = min(np.cos(x.start), np.cos(x.end))
        end = max(np.cos(x.start), np.cos(x.end))
        if nb - na > 4:
            # differ more than 2*pi
            return interval(-1, 1, is_valid=x.is_valid)
        elif na == nb:
            # in the same quadarant
            return interval(start, end, is_valid=x.is_valid)
        else:
            if (na) // 4 != (nb) // 4:
                # cos has max
                end = 1
            if (na - 2) // 4 != (nb - 2) // 4:
                # cos has min
                start = -1
            return interval(start, end, is_valid=x.is_valid)
    else:  # pragma: no cover
        raise NotImplementedError
예제 #4
0
    def __init__(self, settings=None):
        from diofant.external import import_module

        self._str = str

        self._settings = self._default_settings.copy()

        for key, val in self._global_settings.items():
            if key in self._default_settings:
                self._settings[key] = val

        if settings is not None:
            self._settings.update(settings)

            if len(self._settings) > len(self._default_settings):
                for key in self._settings:
                    if key not in self._default_settings:
                        raise TypeError("Unknown setting '%s'." % key)

        # _print_level is the number of times self._print() was recursively
        # called. See StrPrinter._print_Float() for an example of usage
        self._print_level = 0

        numpy = import_module("numpy")
        if numpy is not None:
            formatter = {'numpystr': str}
            numpy.set_printoptions(formatter=formatter)
예제 #5
0
def test_no_stdlib_collections3():
    """make sure we get the right collections with no catch"""
    matplotlib = import_module('matplotlib',
                               __import__kwargs={'fromlist': ['cm', 'collections']},
                               min_module_version='1.1.0')
    if matplotlib:
        assert collections != matplotlib.collections
예제 #6
0
    def _get_meshes_grid(self):
        """Generates the mesh for generating a contour.

        In the case of equality, ``contour`` function of matplotlib can
        be used. In other cases, matplotlib's ``contourf`` is used.
        """
        equal = False
        if isinstance(self.expr, Equality):
            expr = self.expr.lhs - self.expr.rhs
            equal = True

        elif isinstance(self.expr, (GreaterThan, StrictGreaterThan)):
            expr = self.expr.lhs - self.expr.rhs

        elif isinstance(self.expr, (LessThan, StrictLessThan)):
            expr = self.expr.rhs - self.expr.lhs
        else:
            raise NotImplementedError("The expression is not supported for "
                                      "plotting in uniform meshed plot.")
        np = import_module('numpy')
        xarray = np.linspace(self.start_x, self.end_x, self.nb_of_points)
        yarray = np.linspace(self.start_y, self.end_y, self.nb_of_points)
        x_grid, y_grid = np.meshgrid(xarray, yarray)

        func = vectorized_lambdify((self.var_x, self.var_y), expr)
        z_grid = func(x_grid, y_grid)
        z_grid[np.ma.where(z_grid < 0)] = -1
        z_grid[np.ma.where(z_grid > 0)] = 1
        if equal:
            return xarray, yarray, z_grid, 'contour'
        else:
            return xarray, yarray, z_grid, 'contourf'
예제 #7
0
def test_no_stdlib_collections3():
    """Make sure we get the right collections with no catch."""
    matplotlib = import_module(
        'matplotlib',
        __import__kwargs={'fromlist': ['cm', 'collections']},
        min_module_version='1.1.0')
    if matplotlib:
        assert collections != matplotlib.collections
예제 #8
0
def test_matplotlib():
    matplotlib = import_module('matplotlib',
                               min_module_version='1.1.0',
                               catch=(RuntimeError, ))
    if matplotlib:
        plot_and_save('test')
        test_line_color()
    else:
        skip("Matplotlib not the default backend")
예제 #9
0
def exp(x):
    """evaluates the exponential of an interval"""
    np = import_module('numpy')
    if isinstance(x, (int, float)):
        return interval(np.exp(x), np.exp(x))
    elif isinstance(x, interval):
        return interval(np.exp(x.start), np.exp(x.end), is_valid=x.is_valid)
    else:  # pragma: no cover
        raise NotImplementedError
예제 #10
0
def tanh(x):
    """Evaluates the hyperbolic tan of an interval"""
    np = import_module('numpy')
    if isinstance(x, (int, float)):
        return interval(np.tanh(x), np.tanh(x))
    elif isinstance(x, interval):
        return interval(np.tanh(x.start), np.tanh(x.end), is_valid=x.is_valid)
    else:  # pragma: no cover
        raise NotImplementedError
예제 #11
0
def test_no_stdlib_collections2():
    """
    make sure we get the right collections when it is not part of a
    larger list
    """
    matplotlib = import_module('matplotlib',
                               __import__kwargs={'fromlist': ['collections']},
                               min_module_version='1.1.0', catch=(RuntimeError,))
    if matplotlib:
        assert collections != matplotlib.collections
예제 #12
0
def atan(x):
    """evaluates the tan inverse of an interval"""
    np = import_module('numpy')
    if isinstance(x, (int, float)):
        return interval(np.arctan(x))
    elif isinstance(x, interval):
        start = np.arctan(x.start)
        end = np.arctan(x.end)
        return interval(start, end, is_valid=x.is_valid)
    else:  # pragma: no cover
        raise NotImplementedError
예제 #13
0
def asinh(x):
    """Evaluates the inverse hyperbolic sine of an interval"""
    np = import_module('numpy')
    if isinstance(x, (int, float)):
        return interval(np.arcsinh(x))
    elif isinstance(x, interval):
        start = np.arcsinh(x.start)
        end = np.arcsinh(x.end)
        return interval(start, end, is_valid=x.is_valid)
    else:  # pragma: no cover
        return NotImplementedError
예제 #14
0
def test_no_stdlib_collections2():
    """
    Make sure we get the right collections when it is not part of a
    larger list.
    """
    matplotlib = import_module('matplotlib',
                               __import__kwargs={'fromlist': ['collections']},
                               min_module_version='1.1.0',
                               catch=(RuntimeError, ))
    if matplotlib:
        assert collections != matplotlib.collections
예제 #15
0
def cosh(x):
    """Evaluates the hyperbolic cos of an interval"""
    np = import_module('numpy')
    if isinstance(x, (int, float)):
        return interval(np.cosh(x), np.cosh(x))
    elif isinstance(x, interval):
        # both signs
        if x.start < 0 and x.end > 0:
            end = max(np.cosh(x.start), np.cosh(x.end))
            return interval(1, end, is_valid=x.is_valid)
        else:
            # Monotonic
            start = np.cosh(x.start)
            end = np.cosh(x.end)
            return interval(start, end, is_valid=x.is_valid)
    else:  # pragma: no cover
        raise NotImplementedError
예제 #16
0
def log10(x):
    """evaluates the logarithm to the base 10 of an interval"""
    np = import_module('numpy')
    if isinstance(x, (int, float)):
        if x <= 0:
            return interval(-np.inf, np.inf, is_valid=False)
        else:
            return interval(np.log10(x))
    elif isinstance(x, interval):
        if not x.is_valid:
            return interval(-np.inf, np.inf, is_valid=x.is_valid)
        elif x.end <= 0:
            return interval(-np.inf, np.inf, is_valid=False)
        elif x.start <= 0:
            return interval(-np.inf, np.inf, is_valid=None)
        return interval(np.log10(x.start), np.log10(x.end))
    else:  # pragma: no cover
        raise NotImplementedError
예제 #17
0
def floor(x):
    """Evaluates the floor of an interval"""
    np = import_module('numpy')
    if isinstance(x, (int, float)):
        return interval(np.floor(x))
    elif isinstance(x, interval):
        if x.is_valid is False:
            return interval(-np.inf, np.inf, is_valid=False)
        else:
            start = np.floor(x.start)
            end = np.floor(x.end)
            # continuous over the argument
            if start == end:
                return interval(start, end, is_valid=x.is_valid)
            else:
                # not continuous over the interval
                return interval(start, end, is_valid=None)
    else:  # pragma: no cover
        return NotImplementedError
예제 #18
0
def sqrt(x):
    """Evaluates the square root of an interval"""
    np = import_module('numpy')
    if isinstance(x, (int, float)):
        if x > 0:
            return interval(np.sqrt(x))
        else:
            return interval(-np.inf, np.inf, is_valid=False)
    elif isinstance(x, interval):
        # Outside the domain
        if x.end < 0:
            return interval(-np.inf, np.inf, is_valid=False)
        # Partially outside the domain
        elif x.start < 0:
            return interval(-np.inf, np.inf, is_valid=None)
        else:
            return interval(np.sqrt(x.start),
                            np.sqrt(x.end),
                            is_valid=x.is_valid)
    else:  # pragma: no cover
        raise NotImplementedError
예제 #19
0
def atanh(x):
    """Evaluates the inverse hyperbolic tangent of an interval"""
    np = import_module('numpy')
    if isinstance(x, (int, float)):
        # Outside the domain
        if abs(x) >= 1:
            return interval(-np.inf, np.inf, is_valid=False)
        else:
            return interval(np.arctanh(x))
    elif isinstance(x, interval):
        # outside the domain
        if x.is_valid is False or x.start >= 1 or x.end <= -1:
            return interval(-np.inf, np.inf, is_valid=False)
        # partly outside the domain
        elif x.start <= -1 or x.end >= 1:
            return interval(-np.inf, np.inf, is_valid=None)
        else:
            start = np.arctanh(x.start)
            end = np.arctanh(x.end)
            return interval(start, end, is_valid=x.is_valid)
    else:  # pragma: no cover
        return NotImplementedError
예제 #20
0
def acosh(x):
    """Evaluates the inverse hyperbolic cosine of an interval"""
    np = import_module('numpy')
    if isinstance(x, (int, float)):
        # Outside the domain
        if x < 1:
            return interval(-np.inf, np.inf, is_valid=False)
        else:
            return interval(np.arccosh(x))
    elif isinstance(x, interval):
        # Outside the domain
        if x.end < 1:
            return interval(-np.inf, np.inf, is_valid=False)
        # Partly outside the domain
        elif x.start < 1:
            return interval(-np.inf, np.inf, is_valid=None)
        else:
            start = np.arccosh(x.start)
            end = np.arccosh(x.end)
            return interval(start, end, is_valid=x.is_valid)
    else:  # pragma: no cover
        return NotImplementedError
예제 #21
0
def acos(x):
    """Evaluates the inverse cos of an interval"""
    np = import_module('numpy')
    if isinstance(x, (int, float)):
        if abs(x) > 1:
            # Outside the domain
            return interval(-np.inf, np.inf, is_valid=False)
        else:
            return interval(np.arccos(x), np.arccos(x))
    elif isinstance(x, interval):
        # Outside the domain
        if x.is_valid is False or x.start > 1 or x.end < -1:
            return interval(-np.inf, np.inf, is_valid=False)
        # Partially outside the domain
        elif x.start < -1 or x.end > 1:
            return interval(-np.inf, np.inf, is_valid=None)
        else:
            start = np.arccos(x.start)
            end = np.arccos(x.end)
            return interval(start, end, is_valid=x.is_valid)
    else:  # pragma: no cover
        raise NotImplementedError
예제 #22
0
def imin(*args):
    """Evaluates the minimum of a list of intervals"""
    np = import_module('numpy')
    if not all(isinstance(arg, (int, float, interval))
               for arg in args):  # pragma: no cover
        return NotImplementedError
    else:
        new_args = [
            a for a in args if isinstance(a, (int, float)) or a.is_valid
        ]
        if len(new_args) == 0:
            if all(a.is_valid is False for a in args):
                return interval(-np.inf, np.inf, is_valid=False)
            else:
                return interval(-np.inf, np.inf, is_valid=None)
        start_array = [
            a if isinstance(a, (int, float)) else a.start for a in new_args
        ]

        end_array = [
            a if isinstance(a, (int, float)) else a.end for a in new_args
        ]
        return interval(min(start_array), min(end_array))
예제 #23
0
"""Tests that the IPython printing module is properly loaded. """

import pytest

from diofant.core import Float, Rational, Symbol
from diofant.external import import_module


__all__ = ()

ipython = import_module("IPython", min_module_version="2.3.0")


@pytest.mark.skipif(ipython is None, reason="no IPython")
def test_ipython_printing(monkeypatch):
    app = ipython.terminal.ipapp.TerminalIPythonApp()
    app.display_banner = False
    app.initialize([])
    app = app.shell

    app.run_cell("ip = get_ipython()")
    app.run_cell("inst = ip.instance()")
    app.run_cell("format = inst.display_formatter.format")
    app.run_cell("import warnings")
    app.run_cell("import IPython")
    app.run_cell("import diofant")
    app.run_cell("from diofant import Float, Rational, Symbol, QQ, "
                 "factorial, sqrt, init_printing, pretty, "
                 "Matrix, sstrrepr")

    # Test IntegerDivisionWrapper
예제 #24
0
what you would get by applying a sequence of the ufuncs shipped with
numpy. [0]

You need to have numpy installed to run this example, as well as a
working fortran compiler.


[0]:
http://ojensen.wordpress.com/2010/08/10/fast-ufunc-ish-hydrogen-solutions/
"""

import sys

from diofant.external import import_module

np = import_module('numpy')
if not np:
    sys.exit("Cannot import numpy. Exiting.")
plt = import_module('matplotlib.pyplot')
if not plt:
    sys.exit("Cannot import matplotlib.pyplot. Exiting.")

import mpmath
from diofant.utilities.autowrap import ufuncify
from diofant.utilities.lambdify import implemented_function
from diofant import symbols, legendre, pprint


def main():

    print(__doc__)
예제 #25
0
import warnings
from tempfile import NamedTemporaryFile

import pytest

from diofant import (And, Eq, I, Or, cos, exp, pi, plot_implicit, re, sin,
                     symbols, tan)
from diofant.abc import x, y
from diofant.external import import_module


__all__ = ()

matplotlib = import_module('matplotlib', min_module_version='1.1.0',
                           catch=(RuntimeError,))


def tmp_file(name=''):
    return NamedTemporaryFile(suffix='.png').name


def plot_and_save(name):
    # implicit plot tests
    plot_implicit(Eq(y, cos(x)), (x, -5, 5), (y, -2, 2), show=False).save(tmp_file(name))
    plot_implicit(Eq(y**2, x**3 - x), (x, -5, 5),
                  (y, -4, 4), show=False).save(tmp_file(name))
    plot_implicit(y > 1 / x, (x, -5, 5),
                  (y, -2, 2), show=False).save(tmp_file(name))
    plot_implicit(y < 1 / tan(x), (x, -5, 5),
                  (y, -2, 2), show=False).save(tmp_file(name))
    plot_implicit(y >= 2 * sin(x) * cos(x), (x, -5, 5),
예제 #26
0
def test_interface():
    with pytest.warns(UserWarning) as warn:
        import_module('spam_spam_spam', warn_not_installed=True)
    assert len(warn) == 1
    assert warn[0].message.args[0] == "spam_spam_spam module is not installed"

    assert import_module('spam_spam_spam') is None

    with pytest.warns(UserWarning) as warn:
        import_module('re', warn_old_version=True, min_module_version="10.1")
    assert len(warn) == 1
    assert warn[0].message.args[0] == ("re version is too old to use "
                                       "(10.1 or newer required)")

    assert import_module('re', warn_old_version=False,
                         min_module_version="10.1") is None
    assert import_module('re', min_module_version="0.1") is not None

    with pytest.warns(UserWarning) as warn:
        import_module('re', warn_old_version=True, min_python_version=(20, 10))
    assert len(warn) == 1
    assert warn[0].message.args[0] == ("Python version is too old to use re "
                                       "(20.10 or newer required)")
    assert import_module('re', warn_old_version=False,
                         min_python_version=(20, 10)) is None
    assert import_module('re', warn_old_version=False,
                         min_python_version=(3, 3)) is not None

    if HAS_GMPY:
        assert import_module('gmpy2', min_module_version='2.0.0',
                             module_version_attr='version',
                             module_version_attr_call_args=(),
                             warn_old_version=False) is not None
예제 #27
0
# This testfile tests Diofant <-> SciPy compatibility

# Don't test any Diofant features here. Just pure interaction with SciPy.
# Always write regular Diofant tests for anything, that can be tested in pure
# Python (without scipy). Here we test everything, that a user may need when
# using Diofant with SciPy

from diofant import jn_zeros
from diofant.external import import_module


__all__ = ()

scipy = import_module('scipy')
if not scipy:
    # py.test will not execute any tests now
    disabled = True


def eq(a, b, tol=1e-6):
    for x, y in zip(a, b):
        if not (abs(x - y) < tol):
            return False
    return True


def test_jn_zeros():
    assert eq(jn_zeros(0, 4, method="scipy"),
              [3.141592, 6.283185, 9.424777, 12.566370])
    assert eq(jn_zeros(1, 4, method="scipy"),
              [4.493409, 7.725251, 10.904121, 14.066193])
예제 #28
0
def test_interface():
    with pytest.warns(UserWarning) as warn:
        import_module('spam_spam_spam', warn_not_installed=True)
    assert len(warn) == 1
    assert warn[0].message.args[0] == 'spam_spam_spam module is not installed'

    assert import_module('spam_spam_spam') is None

    with pytest.warns(UserWarning) as warn:
        import_module('re', warn_old_version=True, min_module_version='10.1')
    assert len(warn) == 1
    assert warn[0].message.args[0] == ('re version is too old to use '
                                       '(10.1 or newer required)')

    assert import_module(
        're', warn_old_version=False, min_module_version='10.1') is None
    assert import_module('re', min_module_version='0.1') is not None

    with pytest.warns(UserWarning) as warn:
        import_module('re', warn_old_version=True, min_python_version=(20, 10))
    assert len(warn) == 1
    assert warn[0].message.args[0] == ('Python version is too old to use re '
                                       '(20.10 or newer required)')
    assert import_module(
        're', warn_old_version=False, min_python_version=(20, 10)) is None
    assert import_module('re',
                         warn_old_version=False,
                         min_python_version=(3, 3)) is not None

    if HAS_GMPY:
        assert import_module('gmpy2',
                             min_module_version='2.0.0',
                             module_version_attr='version',
                             module_version_attr_call_args=(),
                             warn_old_version=False) is not None
예제 #29
0
"""Tests of tools for setting up interactive IPython sessions. """

import ast
import sys

import pytest

from diofant.interactive.session import (init_ipython_session,
                                         IntegerWrapper)
from diofant.core import Symbol, Rational, Integer
from diofant.external import import_module

ipython = import_module("IPython", min_module_version="2.3.0")
readline = import_module("readline")

if not ipython:
    # py.test will not execute any tests now
    disabled = True


@pytest.mark.skipif(sys.version_info >= (3, 5),
                    reason="XXX python3.5 api changes")
def test_IntegerWrapper():
    tree = ast.parse('1/3')
    dump = ("Module(body=[Expr(value=BinOp(left=Call(func=Name(id='Integer', "
            "ctx=Load()), args=[Num(n=1)], keywords=[], starargs=None, "
            "kwargs=None), op=Div(), right=Call(func=Name(id='Integer', "
            "ctx=Load()), args=[Num(n=3)], keywords=[], starargs=None, "
            "kwargs=None)))])")
    tree = IntegerWrapper().visit(tree)
    assert ast.dump(tree) == dump
예제 #30
0
import os
import tempfile

import pytest

import diofant
from diofant import Eq, symbols
from diofant.external import import_module
from diofant.tensor import Idx, IndexedBase
from diofant.utilities.autowrap import CodeWrapError, autowrap, ufuncify


__all__ = ()

numpy = import_module('numpy', min_module_version='1.6.1')
Cython = import_module('Cython', min_module_version='0.15.1')
f2py = import_module('numpy.f2py', __import__kwargs={'fromlist': ['f2py']})

f2pyworks = False
if f2py:
    try:
        autowrap(symbols('x'), 'f95', 'f2py')
    except (CodeWrapError, ImportError, OSError):
        f2pyworks = False
    else:
        f2pyworks = True

a, b, c = symbols('a b c')
n, m, d = symbols('n m d', integer=True)
A, B, C = symbols('A B C', cls=IndexedBase)
i = Idx('i', m)
예제 #31
0
# HAS_GMPY contains the major version number of gmpy; i.e. 1 for gmpy, and
# 2 for gmpy2.

# Versions of gmpy prior to 1.03 do not work correctly with int(largempz)
# For example, int(gmpy.mpz(2**256)) would raise OverflowError.
# See issue 4980.

# Minimum version of gmpy changed to 1.13 to allow a single code base to also
# work with gmpy2.

GROUND_TYPES = os.getenv('DIOFANT_GROUND_TYPES', 'auto').lower()
HAS_GMPY = 0

if GROUND_TYPES != 'python':
    gmpy = import_module('gmpy2',
                         min_module_version='2.0.0',
                         module_version_attr='version',
                         module_version_attr_call_args=())
    if gmpy:
        HAS_GMPY = 2

if GROUND_TYPES == 'auto':
    if HAS_GMPY:
        GROUND_TYPES = 'gmpy'
    else:
        GROUND_TYPES = 'python'

if GROUND_TYPES == 'gmpy' and not HAS_GMPY:
    from warnings import warn
    warn("gmpy library is not installed, switching to 'python' ground types")
    GROUND_TYPES = 'python'
예제 #32
0
"""Tests that the IPython printing module is properly loaded. """

import pytest

from diofant.core import Float, Rational, Symbol
from diofant.external import import_module

__all__ = ()

ipython = import_module("IPython", min_module_version="2.3.0")


@pytest.mark.skipif(ipython is None, reason="no IPython")
def test_ipython_printing(monkeypatch):
    app = ipython.terminal.ipapp.TerminalIPythonApp()
    app.display_banner = False
    app.initialize([])
    app = app.shell

    app.run_cell("ip = get_ipython()")
    app.run_cell("inst = ip.instance()")
    app.run_cell("format = inst.display_formatter.format")
    app.run_cell("import warnings")
    app.run_cell("import IPython")
    app.run_cell("import diofant")
    app.run_cell("from diofant import Float, Rational, Symbol, QQ, "
                 "factorial, sqrt, init_printing, pretty, "
                 "Matrix, sstrrepr")

    # Test IntegerDivisionWrapper
예제 #33
0
import os
import tempfile

import pytest

import diofant
from diofant import Eq, symbols
from diofant.external import import_module
from diofant.tensor import Idx, IndexedBase
from diofant.utilities.autowrap import CodeWrapError, autowrap, ufuncify

__all__ = ()

numpy = import_module('numpy', min_module_version='1.6.1')
with_numpy = pytest.mark.skipif(numpy is None, reason="Couldn't import numpy.")

Cython = import_module('Cython', min_module_version='0.15.1')
with_cython = pytest.mark.skipif(Cython is None,
                                 reason="Couldn't import Cython.")

f2py = import_module('numpy.f2py', __import__kwargs={'fromlist': ['f2py']})
with_f2py = pytest.mark.skipif(f2py is None, reason="Couldn't run f2py.")

f2pyworks = False
if f2py:
    try:
        autowrap(symbols('x'), 'f95', 'f2py')
    except (CodeWrapError, ImportError, OSError):
        f2pyworks = False
    else:
        f2pyworks = True
예제 #34
0
import pytest

from diofant.abc import t, x
from diofant.core import Add, Eq, Integer, Mul, Rational, symbols
from diofant.external import import_module
from diofant.functions import cos, sin, sqrt, transpose
from diofant.matrices import (Adjoint, Identity, ImmutableMatrix, Inverse,
                              MatAdd, MatMul, MatPow, Matrix, MatrixExpr,
                              MatrixSymbol, ShapeError, Transpose, ZeroMatrix)
from diofant.matrices.expressions.matexpr import MatrixElement
from diofant.matrices.expressions.slice import MatrixSlice
from diofant.simplify import simplify


numpy = import_module('numpy')

__all__ = ()

n, m, l, k, p = symbols('n m l k p', integer=True)
A = MatrixSymbol('A', n, m)
B = MatrixSymbol('B', m, l)
C = MatrixSymbol('C', n, n)
D = MatrixSymbol('D', n, n)
E = MatrixSymbol('E', m, n)


def test_shape():
    assert A.shape == (n, m)
    assert (A*B).shape == (n, l)
    pytest.raises(ShapeError, lambda: B*A)
예제 #35
0
import pytest

from diofant.domains import (CC, FF, FF_gmpy, FF_python, PythonRational,
                             QQ_gmpy, QQ_python, ZZ_gmpy, ZZ_python)
from diofant.external import import_module
from diofant.polys.polyerrors import CoercionFailed

__all__ = ()

gmpy = import_module('gmpy2')


@pytest.mark.skipif(gmpy is None, reason="no gmpy")
def test_convert():
    F3 = FF(3)
    F3_gmpy = FF_gmpy(3)
    F3_python = FF_python(3)

    assert F3.convert(gmpy.mpz(2)) == F3.dtype(2)
    assert F3.convert(gmpy.mpq(2, 1)) == F3.dtype(2)
    pytest.raises(CoercionFailed, lambda: F3.convert(gmpy.mpq(1, 2)))

    assert ZZ_gmpy.convert(F3_python(1)) == ZZ_gmpy.dtype(1)
    assert ZZ_gmpy.convert(F3_gmpy(1)) == ZZ_gmpy.dtype(1)

    assert ZZ_gmpy.convert(PythonRational(2)) == ZZ_gmpy.dtype(2)
    pytest.raises(CoercionFailed,
                  lambda: ZZ_gmpy.convert(PythonRational(2, 3)))

    assert QQ_python.convert(gmpy.mpz(3)) == QQ_python.dtype(3)
    assert QQ_python.convert(gmpy.mpq(2, 3)) == QQ_python.dtype(2, 3)
예제 #36
0
import mpmath
import pytest

from diofant import (symbols, lambdify, sqrt, sin, cos, tan, pi, acos, acosh,
                     Rational, Float, Matrix, Lambda, Piecewise, exp, Integral,
                     oo, I, Abs, Function, true, false, And, Or, Not, sympify, ITE)
from diofant.printing.lambdarepr import LambdaPrinter
from diofant.utilities.lambdify import implemented_function
from diofant.utilities.decorator import conserve_mpmath_dps
from diofant.external import import_module
import diofant


MutableDenseMatrix = Matrix

numpy = import_module('numpy')
numexpr = import_module('numexpr')

w, x, y, z, a, b = symbols('w,x,y,z,a,b')

# ================= Test different arguments =======================


def test_no_args():
    f = lambdify([], 1)
    pytest.raises(TypeError, lambda: f(-1))
    assert f() == 1


def test_single_arg():
    f = lambdify(x, 2*x)
예제 #37
0
    def _get_raster_interval(self, func):
        """ Uses interval math to adaptively mesh and obtain the plot"""
        k = self.depth
        interval_list = []
        # Create initial 32 divisions
        np = import_module('numpy')
        xsample = np.linspace(self.start_x, self.end_x, 33)
        ysample = np.linspace(self.start_y, self.end_y, 33)

        # Add a small jitter so that there are no false positives for equality.
        # Ex: y==x becomes True for x interval(1, 2) and y interval(1, 2)
        # which will draw a rectangle.
        jitterx = (np.random.rand(len(xsample)) * 2 -
                   1) * (self.end_x - self.start_x) / 2**20
        jittery = (np.random.rand(len(ysample)) * 2 -
                   1) * (self.end_y - self.start_y) / 2**20
        xsample += jitterx
        ysample += jittery

        xinter = [
            interval(x1, x2) for x1, x2 in zip(xsample[:-1], xsample[1:])
        ]
        yinter = [
            interval(y1, y2) for y1, y2 in zip(ysample[:-1], ysample[1:])
        ]
        interval_list = [[x, y] for x in xinter for y in yinter]
        plot_list = []

        # recursive call refinepixels which subdivides the intervals which are
        # neither True nor False according to the expression.
        def refine_pixels(interval_list):
            """ Evaluates the intervals and subdivides the interval if the
            expression is partially satisfied."""
            temp_interval_list = []
            plot_list = []
            for intervals in interval_list:

                # Convert the array indices to x and y values
                intervalx = intervals[0]
                intervaly = intervals[1]
                func_eval = func(intervalx, intervaly)
                # The expression is valid in the interval. Change the contour
                # array values to 1.
                if func_eval[1] is False or func_eval[0] is False:
                    pass
                elif func_eval == (True, True):
                    plot_list.append([intervalx, intervaly])
                elif func_eval[1] is None or func_eval[0] is None:
                    # Subdivide
                    avgx = intervalx.mid
                    avgy = intervaly.mid
                    a = interval(intervalx.start, avgx)
                    b = interval(avgx, intervalx.end)
                    c = interval(intervaly.start, avgy)
                    d = interval(avgy, intervaly.end)
                    temp_interval_list.append([a, c])
                    temp_interval_list.append([a, d])
                    temp_interval_list.append([b, c])
                    temp_interval_list.append([b, d])
            return temp_interval_list, plot_list

        while k >= 0 and len(interval_list):
            interval_list, plot_list_temp = refine_pixels(interval_list)
            plot_list.extend(plot_list_temp)
            k = k - 1
        # Check whether the expression represents an equality
        # If it represents an equality, then none of the intervals
        # would have satisfied the expression due to floating point
        # differences. Add all the undecided values to the plot.
        if self.has_equality:
            for intervals in interval_list:
                intervalx = intervals[0]
                intervaly = intervals[1]
                func_eval = func(intervalx, intervaly)
                if func_eval[1] and func_eval[0] is not False:
                    plot_list.append([intervalx, intervaly])
        return plot_list, 'fill'
예제 #38
0
import pytest

from diofant.abc import t, x
from diofant.core import Add, Eq, Integer, Mul, Rational, symbols
from diofant.external import import_module
from diofant.functions import cos, sin, sqrt, transpose
from diofant.matrices import (Adjoint, Identity, ImmutableMatrix, Inverse,
                              MatAdd, MatMul, MatPow, Matrix, MatrixExpr,
                              MatrixSymbol, ShapeError, Transpose, ZeroMatrix)
from diofant.matrices.expressions.matexpr import MatrixElement
from diofant.matrices.expressions.slice import MatrixSlice
from diofant.simplify import simplify

numpy = import_module('numpy')

__all__ = ()

n, m, l, k, p = symbols('n m l k p', integer=True)
A = MatrixSymbol('A', n, m)
B = MatrixSymbol('B', m, l)
C = MatrixSymbol('C', n, n)
D = MatrixSymbol('D', n, n)
E = MatrixSymbol('E', m, n)


def test_shape():
    assert A.shape == (n, m)
    assert (A * B).shape == (n, l)
    pytest.raises(ShapeError, lambda: B * A)

예제 #39
0
                     acosh, cos, exp, false, lambdify, oo, pi, sin, sqrt,
                     symbols, sympify, tan, true)
from diofant.abc import a, b, t, w, x, y, z
from diofant.external import import_module
from diofant.printing.lambdarepr import LambdaPrinter, NumExprPrinter
from diofant.utilities.decorator import conserve_mpmath_dps
from diofant.utilities.lambdify import (MATH_TRANSLATIONS, MPMATH_TRANSLATIONS,
                                        NUMPY_TRANSLATIONS,
                                        implemented_function)


__all__ = ()

MutableDenseMatrix = Matrix

numpy = import_module('numpy')
numexpr = import_module('numexpr')

# ================= Test different arguments =======================


def test_no_args():
    f = lambdify([], 1)
    pytest.raises(TypeError, lambda: f(-1))
    assert f() == 1


def test_single_arg():
    f = lambdify(x, 2*x)
    assert f(1) == 2