def optimize_graph( fgraph: Union[Variable, FunctionGraph], include: Sequence[str] = ["canonicalize"], custom_opt=None, clone: bool = False, **kwargs ) -> Union[Variable, FunctionGraph]: """Easily optimize a graph. Parameters ========== fgraph: A ``FunctionGraph`` or ``Variable`` to be optimized. include: String names of the optimizations to be applied. The default optimization is ``"canonicalization"``. custom_opt: A custom ``Optimization`` to also be applied. clone: Whether or not to clone the input graph before optimizing. **kwargs: Keyword arguments passed to the ``aesara.graph.optdb.OptimizationQuery`` object. """ from aesara.compile import optdb return_only_out = False if not isinstance(fgraph, FunctionGraph): fgraph = FunctionGraph(outputs=[fgraph], clone=clone) return_only_out = True canonicalize_opt = optdb.query(OptimizationQuery(include=include, **kwargs)) _ = canonicalize_opt.optimize(fgraph) if custom_opt: custom_opt.optimize(fgraph) if return_only_out: return fgraph.outputs[0] else: return fgraph
def _set_row_mappings(self, Gamma, dir_priors, model): """Create maps from Dirichlet priors parameters to rows and slices in the transition matrix. These maps are needed when a transition matrix isn't simply comprised of Dirichlet prior rows, but--instead--slices of Dirichlet priors. Consider the following: .. code-block:: python with pm.Model(): d_0_rv = pm.Dirichlet("p_0", np.r_[1, 1]) d_1_rv = pm.Dirichlet("p_1", np.r_[1, 1]) p_0_rv = tt.as_tensor([0, 0, 1]) p_1_rv = tt.zeros(3) p_1_rv = tt.set_subtensor(p_0_rv[[0, 2]], d_0_rv) p_2_rv = tt.zeros(3) p_2_rv = tt.set_subtensor(p_1_rv[[1, 2]], d_1_rv) P_tt = tt.stack([p_0_rv, p_1_rv, p_2_rv]) The transition matrix `P_tt` has Dirichlet priors in only two of its three rows, and--even then--they're only present in parts of two rows. In this example, we need to know that Dirichlet prior 0, i.e. `d_0_rv`, is mapped to row 1, and prior 1 is mapped to row 2. Furthermore, we need to know that prior 0 fills columns 0 and 2 in row 1, and prior 1 fills columns 1 and 2 in row 2. These mappings allow one to embed Dirichlet priors in larger transition matrices with--for instance--fixed transition behavior. """ # noqa: E501 # Remove unimportant `Op`s from the transition matrix graph Gamma = pre_greedy_local_optimizer( FunctionGraph([], []), [ OpRemove(Elemwise(aes.Cast(aes.float32))), OpRemove(Elemwise(aes.Cast(aes.float64))), OpRemove(Elemwise(aes.identity)), ], Gamma, ) # Canonicalize the transition matrix graph fg = FunctionGraph( list(graph_inputs([Gamma] + self.dir_priors_untrans)), [Gamma] + self.dir_priors_untrans, clone=True, ) canonicalize_opt = optdb.query(Query(include=["canonicalize"])) canonicalize_opt.optimize(fg) Gamma = fg.outputs[0] dir_priors_untrans = fg.outputs[1:] fg.disown() Gamma_DimShuffle = Gamma.owner if not (isinstance(Gamma_DimShuffle.op, DimShuffle)): raise TypeError("The transition matrix should be non-time-varying") Gamma_Join = Gamma_DimShuffle.inputs[0].owner if not (isinstance(Gamma_Join.op, at.basic.Join)): raise TypeError( "The transition matrix should be comprised of stacked row vectors" ) Gamma_rows = Gamma_Join.inputs[1:] self.n_rows = len(Gamma_rows) # Loop through the rows in the transition matrix's graph and determine # how our transformed Dirichlet RVs map to this transition matrix. self.row_remaps = {} self.row_slices = {} for i, dim_row in enumerate(Gamma_rows): if not dim_row.owner: continue # By-pass the `DimShuffle`s applied to the `AdvancedIncSubtensor1` # `Op`s in which we're actually interested gamma_row = dim_row.owner.inputs[0] if gamma_row in dir_priors_untrans: # This is a row that's simply a `Dirichlet` j = dir_priors_untrans.index(gamma_row) self.row_remaps[j] = i self.row_slices[j] = slice(None) if gamma_row.owner.inputs[1] not in dir_priors_untrans: continue # Parts of a row set by a `*Subtensor*` `Op` using a full # `Dirichlet` e.g. `P_row[idx] = dir_rv` j = dir_priors_untrans.index(gamma_row.owner.inputs[1]) untrans_dirich = dir_priors_untrans[j] if (gamma_row.owner and isinstance(gamma_row.owner.op, AdvancedIncSubtensor1) and gamma_row.owner.inputs[1] == untrans_dirich): self.row_remaps[j] = i rhand_val = gamma_row.owner.inputs[2] if not isinstance(rhand_val, TensorConstant): # TODO: We could allow more types of `idx` (e.g. slices) # Currently, `idx` can't be something like `2:5` raise TypeError("Only array indexing allowed for mixed" " Dirichlet/non-Dirichlet rows") self.row_slices[j] = rhand_val.data