def _set_row_mappings(self, Gamma, dir_priors, model): """Create maps from Dirichlet priors parameters to rows and slices in the transition matrix. These maps are needed when a transition matrix isn't simply comprised of Dirichlet prior rows, but--instead--slices of Dirichlet priors. Consider the following: .. code-block:: python with pm.Model(): d_0_rv = pm.Dirichlet("p_0", np.r_[1, 1]) d_1_rv = pm.Dirichlet("p_1", np.r_[1, 1]) p_0_rv = tt.as_tensor([0, 0, 1]) p_1_rv = tt.zeros(3) p_1_rv = tt.set_subtensor(p_0_rv[[0, 2]], d_0_rv) p_2_rv = tt.zeros(3) p_2_rv = tt.set_subtensor(p_1_rv[[1, 2]], d_1_rv) P_tt = tt.stack([p_0_rv, p_1_rv, p_2_rv]) The transition matrix `P_tt` has Dirichlet priors in only two of its three rows, and--even then--they're only present in parts of two rows. In this example, we need to know that Dirichlet prior 0, i.e. `d_0_rv`, is mapped to row 1, and prior 1 is mapped to row 2. Furthermore, we need to know that prior 0 fills columns 0 and 2 in row 1, and prior 1 fills columns 1 and 2 in row 2. These mappings allow one to embed Dirichlet priors in larger transition matrices with--for instance--fixed transition behavior. """ # noqa: E501 # Remove unimportant `Op`s from the transition matrix graph Gamma = pre_greedy_local_optimizer( FunctionGraph([], []), [ OpRemove(Elemwise(aes.Cast(aes.float32))), OpRemove(Elemwise(aes.Cast(aes.float64))), OpRemove(Elemwise(aes.identity)), ], Gamma, ) # Canonicalize the transition matrix graph fg = FunctionGraph( list(graph_inputs([Gamma] + self.dir_priors_untrans)), [Gamma] + self.dir_priors_untrans, clone=True, ) canonicalize_opt = optdb.query(Query(include=["canonicalize"])) canonicalize_opt.optimize(fg) Gamma = fg.outputs[0] dir_priors_untrans = fg.outputs[1:] fg.disown() Gamma_DimShuffle = Gamma.owner if not (isinstance(Gamma_DimShuffle.op, DimShuffle)): raise TypeError("The transition matrix should be non-time-varying") Gamma_Join = Gamma_DimShuffle.inputs[0].owner if not (isinstance(Gamma_Join.op, at.basic.Join)): raise TypeError( "The transition matrix should be comprised of stacked row vectors" ) Gamma_rows = Gamma_Join.inputs[1:] self.n_rows = len(Gamma_rows) # Loop through the rows in the transition matrix's graph and determine # how our transformed Dirichlet RVs map to this transition matrix. self.row_remaps = {} self.row_slices = {} for i, dim_row in enumerate(Gamma_rows): if not dim_row.owner: continue # By-pass the `DimShuffle`s applied to the `AdvancedIncSubtensor1` # `Op`s in which we're actually interested gamma_row = dim_row.owner.inputs[0] if gamma_row in dir_priors_untrans: # This is a row that's simply a `Dirichlet` j = dir_priors_untrans.index(gamma_row) self.row_remaps[j] = i self.row_slices[j] = slice(None) if gamma_row.owner.inputs[1] not in dir_priors_untrans: continue # Parts of a row set by a `*Subtensor*` `Op` using a full # `Dirichlet` e.g. `P_row[idx] = dir_rv` j = dir_priors_untrans.index(gamma_row.owner.inputs[1]) untrans_dirich = dir_priors_untrans[j] if (gamma_row.owner and isinstance(gamma_row.owner.op, AdvancedIncSubtensor1) and gamma_row.owner.inputs[1] == untrans_dirich): self.row_remaps[j] = i rhand_val = gamma_row.owner.inputs[2] if not isinstance(rhand_val, TensorConstant): # TODO: We could allow more types of `idx` (e.g. slices) # Currently, `idx` can't be something like `2:5` raise TypeError("Only array indexing allowed for mixed" " Dirichlet/non-Dirichlet rows") self.row_slices[j] = rhand_val.data
def map_variables(replacer, graphs, additional_inputs=None): """Construct new graphs based on 'graphs' with some variables replaced according to 'replacer'. :param replacer: function that takes a variable and returns its replacement. :param graphs: an iterable of graphs in which to replace variables :param additional_inputs: an iterable of graph inputs not used in any of 'graphs' but possibly used in the graphs returned by `replacer` :return: the new graphs, in the same order as 'graphs' Example: .. code-block:: python tag = "replaceme" a = aesara.tensor.type.scalar("a") b = aesara.tensor.type.scalar("b") c = aesara.tensor.type.scalar("c") ab = a + b ab.tag.replacement = a * b u = ab + c v, = map_variables(lambda graph: return getattr(graph.tag, "replacement", graph), [u]) # v is now equal to a * b + c """ if additional_inputs is None: additional_inputs = [] # wrap replacer to avoid replacing things we just put there. graphs_seen = set() def wrapped_replacer(graph): if graph in graphs_seen: return graph else: new_graph = replacer(graph) graphs_seen.add(new_graph) return new_graph graphs = list(graphs) inputs_ = list(set(list(graph_inputs(graphs)) + list(additional_inputs))) # perform any desired replacement of input variables. these # aren't replaced by the local optimizer approach because they are # not outputs of any Apply node. new_inputs = [wrapped_replacer(i) for i in inputs_] replacements = [(input_, new_input) for input_, new_input in zip(inputs_, new_inputs) if new_input is not input_] graphs = clone_replace(graphs, share_inputs=True, replace=replacements) inputs_ = list(set(list(graph_inputs(graphs)) + list(additional_inputs))) fg = FunctionGraph(inputs_, graphs, clone=False) nodes_seen = set() @local_optimizer(None) def local_transform(fgraph, node): if node in nodes_seen: return False # importing Scan into module scope would be circular from aesara.compile.builders import OpFromGraph from aesara.scan.op import Scan if isinstance(node.op, (Scan, OpFromGraph)): # recurse on the inner graph ( new_inner_inputs, new_outer_inputs, new_inner_outputs, ) = _map_variables_inner( wrapped_replacer, inner_inputs=node.op.inputs, outer_inputs=node.inputs, inner_outputs=node.op.outputs, containing_op=node.op, ) # reinstantiate the op if isinstance(node.op, Scan): new_op = Scan( new_inner_inputs, new_inner_outputs, node.op.info, node.op.mode, # FIXME: infer this someday? typeConstructor=None, ) elif isinstance(node.op, OpFromGraph): new_op = OpFromGraph(new_inner_inputs, new_inner_outputs, **node.op.kwargs) # make a new node to replace the old one new_node = new_op.make_node(*new_outer_inputs) nodes_seen.add(new_node) return new_node.outputs else: nodes_seen.add(node) replacements = [wrapped_replacer(o) for o in node.outputs] # Add inputs to replacement graphs as inputs to this `fgraph` for i in graph_inputs(replacements): fgraph.add_input(i) return replacements topo_transform = TopoOptimizer(local_transform, "out_to_in") topo_transform.optimize(fg) new_graphs = fg.outputs fg.disown() return new_graphs