def test_jax_BatchedDot(): # tensor3 . tensor3 a = tensor3("a") a.tag.test_value = (np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape( (10, 5, 3))) b = tensor3("b") b.tag.test_value = (np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape( (10, 3, 2))) out = aet_blas.BatchedDot()(a, b) fgraph = FunctionGraph([a, b], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) # A dimension mismatch should raise a TypeError for compatibility inputs = [get_test_value(a)[:-1], get_test_value(b)] opts = Query(include=[None], exclude=["cxx_only", "BlasOpt"]) jax_mode = Mode(JAXLinker(), opts) aesara_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode) with pytest.raises(TypeError): aesara_jax_fn(*inputs) # matrix . matrix a = matrix("a") a.tag.test_value = np.linspace(-1, 1, 5 * 3).astype(config.floatX).reshape( (5, 3)) b = matrix("b") b.tag.test_value = np.linspace(1, -1, 5 * 3).astype(config.floatX).reshape( (5, 3)) out = aet_blas.BatchedDot()(a, b) fgraph = FunctionGraph([a, b], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def compare_jax_and_py( fgraph, inputs, assert_fn=None, must_be_device_array=True, ): """Function to compare python graph output and jax compiled output for testing equality In the tests below computational graphs are defined in Aesara. These graphs are then passed to this function which then compiles the graphs in both jax and python, runs the calculation in both and checks if the results are the same Parameters ---------- fgraph: FunctionGraph Aesara function Graph object inputs: iter Inputs for function graph assert_fn: func, opt Assert function used to check for equality between python and jax. If not provided uses np.testing.assert_allclose must_be_device_array: Bool Checks for instance of jax.interpreters.xla.DeviceArray. For testing purposes if this device array is found it indicates if the result was computed by jax Returns ------- jax_res """ if assert_fn is None: assert_fn = partial(np.testing.assert_allclose, rtol=1e-4) opts = Query(include=[None], exclude=["cxx_only", "BlasOpt"]) jax_mode = Mode(JAXLinker(), opts) py_mode = Mode("py", opts) aesara_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode) jax_res = aesara_jax_fn(*inputs) if must_be_device_array: if isinstance(jax_res, list): assert all( isinstance(res, jax.interpreters.xla.DeviceArray) for res in jax_res) else: assert isinstance(jax_res, jax.interpreters.xla.DeviceArray) aesara_py_fn = function(fgraph.inputs, fgraph.outputs, mode=py_mode) py_res = aesara_py_fn(*inputs) if len(fgraph.outputs) > 1: for j, p in zip(jax_res, py_res): assert_fn(j, p) else: assert_fn(jax_res, py_res) return jax_res
def test_inplace(self): """Make sure that in-place optimizations are *not* performed on the output of a ``BroadcastTo``.""" a = aet.zeros((5, )) d = aet.vector("d") c = aet.set_subtensor(a[np.r_[0, 1, 3]], d) b = broadcast_to(c, (5, )) q = b[np.r_[0, 1, 3]] e = aet.set_subtensor(q, np.r_[0, 0, 0]) opts = Query(include=["inplace"]) py_mode = Mode("py", opts) e_fn = function([d], e, mode=py_mode) advincsub_node = e_fn.maker.fgraph.outputs[0].owner assert isinstance(advincsub_node.op, AdvancedIncSubtensor1) assert isinstance(advincsub_node.inputs[0].owner.op, BroadcastTo) assert advincsub_node.op.inplace is False
def register_linker(name, linker): """Add a `Linker` which can be referred to by `name` in `Mode`.""" if name in predefined_linkers: raise ValueError(f"Linker name already taken: {name}") predefined_linkers[name] = linker # If a string is passed as the optimizer argument in the constructor # for Mode, it will be used as the key to retrieve the real optimizer # in this dictionary exclude = [] if not config.cxx: exclude = ["cxx_only"] OPT_NONE = Query(include=[], exclude=exclude) # Even if multiple merge optimizer call will be there, this shouldn't # impact performance. OPT_MERGE = Query(include=["merge"], exclude=exclude) OPT_FAST_RUN = Query(include=["fast_run"], exclude=exclude) OPT_FAST_RUN_STABLE = OPT_FAST_RUN.requiring("stable") # We need fast_compile_gpu here. As on the GPU, we don't have all # operation that exist in fast_compile, but have some that get # introduced in fast_run, we want those optimization to also run in # fast_compile+gpu. We can't tag them just as 'gpu', as this would # exclude them if we exclude 'gpu'. OPT_FAST_COMPILE = Query(include=["fast_compile", "fast_compile_gpu"], exclude=exclude) OPT_STABILIZE = Query(include=["fast_run"], exclude=exclude) OPT_STABILIZE.position_cutoff = 1.5000001 OPT_NONE.name = "OPT_NONE"
def _set_row_mappings(self, Gamma, dir_priors, model): """Create maps from Dirichlet priors parameters to rows and slices in the transition matrix. These maps are needed when a transition matrix isn't simply comprised of Dirichlet prior rows, but--instead--slices of Dirichlet priors. Consider the following: .. code-block:: python with pm.Model(): d_0_rv = pm.Dirichlet("p_0", np.r_[1, 1]) d_1_rv = pm.Dirichlet("p_1", np.r_[1, 1]) p_0_rv = tt.as_tensor([0, 0, 1]) p_1_rv = tt.zeros(3) p_1_rv = tt.set_subtensor(p_0_rv[[0, 2]], d_0_rv) p_2_rv = tt.zeros(3) p_2_rv = tt.set_subtensor(p_1_rv[[1, 2]], d_1_rv) P_tt = tt.stack([p_0_rv, p_1_rv, p_2_rv]) The transition matrix `P_tt` has Dirichlet priors in only two of its three rows, and--even then--they're only present in parts of two rows. In this example, we need to know that Dirichlet prior 0, i.e. `d_0_rv`, is mapped to row 1, and prior 1 is mapped to row 2. Furthermore, we need to know that prior 0 fills columns 0 and 2 in row 1, and prior 1 fills columns 1 and 2 in row 2. These mappings allow one to embed Dirichlet priors in larger transition matrices with--for instance--fixed transition behavior. """ # noqa: E501 # Remove unimportant `Op`s from the transition matrix graph Gamma = pre_greedy_local_optimizer( FunctionGraph([], []), [ OpRemove(Elemwise(aes.Cast(aes.float32))), OpRemove(Elemwise(aes.Cast(aes.float64))), OpRemove(Elemwise(aes.identity)), ], Gamma, ) # Canonicalize the transition matrix graph fg = FunctionGraph( list(graph_inputs([Gamma] + self.dir_priors_untrans)), [Gamma] + self.dir_priors_untrans, clone=True, ) canonicalize_opt = optdb.query(Query(include=["canonicalize"])) canonicalize_opt.optimize(fg) Gamma = fg.outputs[0] dir_priors_untrans = fg.outputs[1:] fg.disown() Gamma_DimShuffle = Gamma.owner if not (isinstance(Gamma_DimShuffle.op, DimShuffle)): raise TypeError("The transition matrix should be non-time-varying") Gamma_Join = Gamma_DimShuffle.inputs[0].owner if not (isinstance(Gamma_Join.op, at.basic.Join)): raise TypeError( "The transition matrix should be comprised of stacked row vectors" ) Gamma_rows = Gamma_Join.inputs[1:] self.n_rows = len(Gamma_rows) # Loop through the rows in the transition matrix's graph and determine # how our transformed Dirichlet RVs map to this transition matrix. self.row_remaps = {} self.row_slices = {} for i, dim_row in enumerate(Gamma_rows): if not dim_row.owner: continue # By-pass the `DimShuffle`s applied to the `AdvancedIncSubtensor1` # `Op`s in which we're actually interested gamma_row = dim_row.owner.inputs[0] if gamma_row in dir_priors_untrans: # This is a row that's simply a `Dirichlet` j = dir_priors_untrans.index(gamma_row) self.row_remaps[j] = i self.row_slices[j] = slice(None) if gamma_row.owner.inputs[1] not in dir_priors_untrans: continue # Parts of a row set by a `*Subtensor*` `Op` using a full # `Dirichlet` e.g. `P_row[idx] = dir_rv` j = dir_priors_untrans.index(gamma_row.owner.inputs[1]) untrans_dirich = dir_priors_untrans[j] if (gamma_row.owner and isinstance(gamma_row.owner.op, AdvancedIncSubtensor1) and gamma_row.owner.inputs[1] == untrans_dirich): self.row_remaps[j] = i rhand_val = gamma_row.owner.inputs[2] if not isinstance(rhand_val, TensorConstant): # TODO: We could allow more types of `idx` (e.g. slices) # Currently, `idx` can't be something like `2:5` raise TypeError("Only array indexing allowed for mixed" " Dirichlet/non-Dirichlet rows") self.row_slices[j] = rhand_val.data
multivariate_normal, normal, poisson, uniform, ) from aesara.tensor.random.op import RandomVariable from aesara.tensor.random.opt import ( lift_rv_shapes, local_dimshuffle_rv_lift, local_subtensor_rv_lift, ) from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor from aesara.tensor.type import iscalar, vector inplace_mode = Mode("py", Query(include=["random_make_inplace"], exclude=[])) no_mode = Mode("py", Query(include=[], exclude=[])) def test_inplace_optimization(): out = normal(0, 1) assert out.owner.op.inplace is False f = function( [], out, mode=inplace_mode, )
from aesara.tensor.type import ( dscalar, dvector, iscalar, ivector, lscalar, matrix, scalar, tensor, tensor3, vector, ) jax = pytest.importorskip("jax") opts = Query(include=[None], exclude=["cxx_only", "BlasOpt"]) jax_mode = Mode(JAXLinker(), opts) py_mode = Mode("py", opts) @pytest.fixture(scope="module", autouse=True) def set_aesara_flags(): with config.change_flags(cxx="", compute_test_value="ignore"): yield def compare_jax_and_py( fgraph, inputs, assert_fn=None, must_be_device_array=True,
def set_aesara_flags(): opts = Query(include=[None], exclude=[]) py_mode = Mode("py", opts) with config.change_flags(mode=py_mode, compute_test_value="warn"): yield