Exemplo n.º 1
0
def test_complexfunctions():
    xt, yt = aesara_code(x, dtypes={x: 'complex128'
                                    }), aesara_code(y,
                                                    dtypes={y: 'complex128'})
    from sympy import conjugate
    from aesara.tensor import as_tensor_variable as atv
    from aesara.tensor import complex as cplx
    assert theq(aesara_code(y * conjugate(x)), yt * (xt.conj()))
    assert theq(aesara_code((1 + 2j) * x),
                xt * (atv(1.0) + atv(2.0) * cplx(0, 1)))
Exemplo n.º 2
0
def test_global_cache():
    """ Test use of the global cache. """
    from sympy.printing.aesaracode import global_cache

    backup = dict(global_cache)
    try:
        # Temporarily empty global cache
        global_cache.clear()

        for s in [x, X, f_t]:
            st = aesara_code(s)
            assert aesara_code(s) is st

    finally:
        # Restore global cache
        global_cache.update(backup)
Exemplo n.º 3
0
def test_cache_types_distinct():
    """
    Test that symbol-like objects of different types (Symbol, MatrixSymbol,
    AppliedUndef) are distinguished by the cache even if they have the same
    name.
    """
    symbols = [sy.Symbol('f_t'), sy.MatrixSymbol('f_t', 4, 4), f_t]

    cache = {}  # Single shared cache
    printed = {}

    for s in symbols:
        st = aesara_code_(s, cache=cache)
        assert st not in printed.values()
        printed[s] = st

    # Check all printed objects are distinct
    assert len(set(map(id, printed.values()))) == len(symbols)

    # Check retrieving
    for s, st in printed.items():
        assert aesara_code(s, cache=cache) is st
Exemplo n.º 4
0
def aesara_code_(expr, **kwargs):
    """ Wrapper for aesara_code that uses a new, empty cache by default. """
    kwargs.setdefault('cache', {})
    return aesara_code(expr, **kwargs)