def test_complexfunctions(): xt, yt = aesara_code(x, dtypes={x: 'complex128' }), aesara_code(y, dtypes={y: 'complex128'}) from sympy import conjugate from aesara.tensor import as_tensor_variable as atv from aesara.tensor import complex as cplx assert theq(aesara_code(y * conjugate(x)), yt * (xt.conj())) assert theq(aesara_code((1 + 2j) * x), xt * (atv(1.0) + atv(2.0) * cplx(0, 1)))
def test_global_cache(): """ Test use of the global cache. """ from sympy.printing.aesaracode import global_cache backup = dict(global_cache) try: # Temporarily empty global cache global_cache.clear() for s in [x, X, f_t]: st = aesara_code(s) assert aesara_code(s) is st finally: # Restore global cache global_cache.update(backup)
def test_cache_types_distinct(): """ Test that symbol-like objects of different types (Symbol, MatrixSymbol, AppliedUndef) are distinguished by the cache even if they have the same name. """ symbols = [sy.Symbol('f_t'), sy.MatrixSymbol('f_t', 4, 4), f_t] cache = {} # Single shared cache printed = {} for s in symbols: st = aesara_code_(s, cache=cache) assert st not in printed.values() printed[s] = st # Check all printed objects are distinct assert len(set(map(id, printed.values()))) == len(symbols) # Check retrieving for s, st in printed.items(): assert aesara_code(s, cache=cache) is st
def aesara_code_(expr, **kwargs): """ Wrapper for aesara_code that uses a new, empty cache by default. """ kwargs.setdefault('cache', {}) return aesara_code(expr, **kwargs)