def test_jax_BatchedDot(): # tensor3 . tensor3 a = tensor3("a") a.tag.test_value = (np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape( (10, 5, 3))) b = tensor3("b") b.tag.test_value = (np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape( (10, 3, 2))) out = aet_blas.BatchedDot()(a, b) fgraph = FunctionGraph([a, b], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) # A dimension mismatch should raise a TypeError for compatibility inputs = [get_test_value(a)[:-1], get_test_value(b)] opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) jax_mode = Mode(JAXLinker(), opts) aesara_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode) with pytest.raises(TypeError): aesara_jax_fn(*inputs) # matrix . matrix a = matrix("a") a.tag.test_value = np.linspace(-1, 1, 5 * 3).astype(config.floatX).reshape( (5, 3)) b = matrix("b") b.tag.test_value = np.linspace(1, -1, 5 * 3).astype(config.floatX).reshape( (5, 3)) out = aet_blas.BatchedDot()(a, b) fgraph = FunctionGraph([a, b], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def compare_jax_and_py( fgraph, inputs, assert_fn=None, must_be_device_array=True, ): """Function to compare python graph output and jax compiled output for testing equality In the tests below computational graphs are defined in Aesara. These graphs are then passed to this function which then compiles the graphs in both jax and python, runs the calculation in both and checks if the results are the same Parameters ---------- fgraph: FunctionGraph Aesara function Graph object inputs: iter Inputs for function graph assert_fn: func, opt Assert function used to check for equality between python and jax. If not provided uses np.testing.assert_allclose must_be_device_array: Bool Checks for instance of jax.interpreters.xla.DeviceArray. For testing purposes if this device array is found it indicates if the result was computed by jax Returns ------- jax_res """ if assert_fn is None: assert_fn = partial(np.testing.assert_allclose, rtol=1e-4) opts = Query(include=[None], exclude=["cxx_only", "BlasOpt"]) jax_mode = Mode(JAXLinker(), opts) py_mode = Mode("py", opts) aesara_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode) jax_res = aesara_jax_fn(*inputs) if must_be_device_array: if isinstance(jax_res, list): assert all( isinstance(res, jax.interpreters.xla.DeviceArray) for res in jax_res) else: assert isinstance(jax_res, jax.interpreters.xla.DeviceArray) aesara_py_fn = function(fgraph.inputs, fgraph.outputs, mode=py_mode) py_res = aesara_py_fn(*inputs) if len(fgraph.outputs) > 1: for j, p in zip(jax_res, py_res): assert_fn(j, p) else: assert_fn(jax_res, py_res) return jax_res
_logger = logging.getLogger("aesara.compile.mode") # If a string is passed as the linker argument in the constructor for # Mode, it will be used as the key to retrieve the real linker in this # dictionary predefined_linkers = { "py": PerformLinker(), # Use allow_gc Aesara flag "c": CLinker(), # Don't support gc. so don't check allow_gc "c|py": OpWiseCLinker(), # Use allow_gc Aesara flag "c|py_nogc": OpWiseCLinker(allow_gc=False), "vm": VMLinker(use_cloop=False), # Use allow_gc Aesara flag "cvm": VMLinker(use_cloop=True), # Use allow_gc Aesara flag "vm_nogc": VMLinker(allow_gc=False, use_cloop=False), "cvm_nogc": VMLinker(allow_gc=False, use_cloop=True), "jax": JAXLinker(), } def register_linker(name, linker): """Add a `Linker` which can be referred to by `name` in `Mode`.""" if name in predefined_linkers: raise ValueError(f"Linker name already taken: {name}") predefined_linkers[name] = linker # If a string is passed as the optimizer argument in the constructor # for Mode, it will be used as the key to retrieve the real optimizer # in this dictionary exclude = [] if not config.cxx:
dscalar, dvector, iscalar, ivector, lscalar, matrix, scalar, tensor, tensor3, vector, ) jax = pytest.importorskip("jax") opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) jax_mode = Mode(JAXLinker(), opts) py_mode = Mode("py", opts) @pytest.fixture(scope="module", autouse=True) def set_aesara_flags(): with config.change_flags(cxx="", compute_test_value="ignore"): yield def compare_jax_and_py( fgraph: FunctionGraph, test_inputs: iter, assert_fn: Optional[callable] = None, must_be_device_array: bool = True, ):