コード例 #1
0
    def _set_row_mappings(self, Gamma, dir_priors, model):
        """Create maps from Dirichlet priors parameters to rows and slices in the transition matrix.

        These maps are needed when a transition matrix isn't simply comprised
        of Dirichlet prior rows, but--instead--slices of Dirichlet priors.

        Consider the following:

        .. code-block:: python

            with pm.Model():
                d_0_rv = pm.Dirichlet("p_0", np.r_[1, 1])
                d_1_rv = pm.Dirichlet("p_1", np.r_[1, 1])

                p_0_rv = tt.as_tensor([0, 0, 1])
                p_1_rv = tt.zeros(3)
                p_1_rv = tt.set_subtensor(p_0_rv[[0, 2]], d_0_rv)
                p_2_rv = tt.zeros(3)
                p_2_rv = tt.set_subtensor(p_1_rv[[1, 2]], d_1_rv)

                P_tt = tt.stack([p_0_rv, p_1_rv, p_2_rv])

        The transition matrix `P_tt` has Dirichlet priors in only two of its
        three rows, and--even then--they're only present in parts of two rows.

        In this example, we need to know that Dirichlet prior 0, i.e. `d_0_rv`,
        is mapped to row 1, and prior 1 is mapped to row 2.  Furthermore, we
        need to know that prior 0 fills columns 0 and 2 in row 1, and prior 1
        fills columns 1 and 2 in row 2.

        These mappings allow one to embed Dirichlet priors in larger transition
        matrices with--for instance--fixed transition behavior.

        """  # noqa: E501

        # Remove unimportant `Op`s from the transition matrix graph
        Gamma = pre_greedy_local_optimizer(
            FunctionGraph([], []),
            [
                OpRemove(Elemwise(aes.Cast(aes.float32))),
                OpRemove(Elemwise(aes.Cast(aes.float64))),
                OpRemove(Elemwise(aes.identity)),
            ],
            Gamma,
        )

        # Canonicalize the transition matrix graph
        fg = FunctionGraph(
            list(graph_inputs([Gamma] + self.dir_priors_untrans)),
            [Gamma] + self.dir_priors_untrans,
            clone=True,
        )
        canonicalize_opt = optdb.query(Query(include=["canonicalize"]))
        canonicalize_opt.optimize(fg)
        Gamma = fg.outputs[0]
        dir_priors_untrans = fg.outputs[1:]
        fg.disown()

        Gamma_DimShuffle = Gamma.owner

        if not (isinstance(Gamma_DimShuffle.op, DimShuffle)):
            raise TypeError("The transition matrix should be non-time-varying")

        Gamma_Join = Gamma_DimShuffle.inputs[0].owner

        if not (isinstance(Gamma_Join.op, at.basic.Join)):
            raise TypeError(
                "The transition matrix should be comprised of stacked row vectors"
            )

        Gamma_rows = Gamma_Join.inputs[1:]

        self.n_rows = len(Gamma_rows)

        # Loop through the rows in the transition matrix's graph and determine
        # how our transformed Dirichlet RVs map to this transition matrix.
        self.row_remaps = {}
        self.row_slices = {}
        for i, dim_row in enumerate(Gamma_rows):
            if not dim_row.owner:
                continue

            # By-pass the `DimShuffle`s applied to the `AdvancedIncSubtensor1`
            # `Op`s in which we're actually interested
            gamma_row = dim_row.owner.inputs[0]

            if gamma_row in dir_priors_untrans:
                # This is a row that's simply a `Dirichlet`
                j = dir_priors_untrans.index(gamma_row)
                self.row_remaps[j] = i
                self.row_slices[j] = slice(None)

            if gamma_row.owner.inputs[1] not in dir_priors_untrans:
                continue

            # Parts of a row set by a `*Subtensor*` `Op` using a full
            # `Dirichlet` e.g. `P_row[idx] = dir_rv`
            j = dir_priors_untrans.index(gamma_row.owner.inputs[1])
            untrans_dirich = dir_priors_untrans[j]

            if (gamma_row.owner
                    and isinstance(gamma_row.owner.op, AdvancedIncSubtensor1)
                    and gamma_row.owner.inputs[1] == untrans_dirich):
                self.row_remaps[j] = i

                rhand_val = gamma_row.owner.inputs[2]
                if not isinstance(rhand_val, TensorConstant):
                    # TODO: We could allow more types of `idx` (e.g. slices)
                    # Currently, `idx` can't be something like `2:5`
                    raise TypeError("Only array indexing allowed for mixed"
                                    " Dirichlet/non-Dirichlet rows")
                self.row_slices[j] = rhand_val.data
コード例 #2
0
def map_variables(replacer, graphs, additional_inputs=None):
    """Construct new graphs based on 'graphs' with some variables replaced
    according to 'replacer'.

    :param replacer: function that takes a variable and returns its
         replacement.
    :param graphs: an iterable of graphs in which to replace variables
    :param additional_inputs: an iterable of graph inputs not used in any
         of 'graphs' but possibly used in the graphs returned by `replacer`
    :return: the new graphs, in the same order as 'graphs'

    Example:

    .. code-block:: python

        tag = "replaceme"

        a = aesara.tensor.type.scalar("a")
        b = aesara.tensor.type.scalar("b")
        c = aesara.tensor.type.scalar("c")

        ab = a + b
        ab.tag.replacement = a * b

        u = ab + c
        v, = map_variables(lambda graph:
            return getattr(graph.tag, "replacement", graph),
            [u])

        # v is now equal to a * b + c
    """
    if additional_inputs is None:
        additional_inputs = []

    # wrap replacer to avoid replacing things we just put there.
    graphs_seen = set()

    def wrapped_replacer(graph):
        if graph in graphs_seen:
            return graph
        else:
            new_graph = replacer(graph)
            graphs_seen.add(new_graph)
            return new_graph

    graphs = list(graphs)
    inputs_ = list(set(list(graph_inputs(graphs)) + list(additional_inputs)))

    # perform any desired replacement of input variables.  these
    # aren't replaced by the local optimizer approach because they are
    # not outputs of any Apply node.
    new_inputs = [wrapped_replacer(i) for i in inputs_]
    replacements = [(input_, new_input)
                    for input_, new_input in zip(inputs_, new_inputs)
                    if new_input is not input_]
    graphs = clone_replace(graphs, share_inputs=True, replace=replacements)
    inputs_ = list(set(list(graph_inputs(graphs)) + list(additional_inputs)))

    fg = FunctionGraph(inputs_, graphs, clone=False)

    nodes_seen = set()

    @local_optimizer(None)
    def local_transform(fgraph, node):
        if node in nodes_seen:
            return False

        # importing Scan into module scope would be circular
        from aesara.compile.builders import OpFromGraph
        from aesara.scan.op import Scan

        if isinstance(node.op, (Scan, OpFromGraph)):
            # recurse on the inner graph
            (
                new_inner_inputs,
                new_outer_inputs,
                new_inner_outputs,
            ) = _map_variables_inner(
                wrapped_replacer,
                inner_inputs=node.op.inputs,
                outer_inputs=node.inputs,
                inner_outputs=node.op.outputs,
                containing_op=node.op,
            )
            # reinstantiate the op
            if isinstance(node.op, Scan):
                new_op = Scan(
                    new_inner_inputs,
                    new_inner_outputs,
                    node.op.info,
                    node.op.mode,
                    # FIXME: infer this someday?
                    typeConstructor=None,
                )
            elif isinstance(node.op, OpFromGraph):
                new_op = OpFromGraph(new_inner_inputs, new_inner_outputs,
                                     **node.op.kwargs)
            # make a new node to replace the old one
            new_node = new_op.make_node(*new_outer_inputs)
            nodes_seen.add(new_node)
            return new_node.outputs
        else:
            nodes_seen.add(node)
            replacements = [wrapped_replacer(o) for o in node.outputs]

            # Add inputs to replacement graphs as inputs to this `fgraph`
            for i in graph_inputs(replacements):
                fgraph.add_input(i)

            return replacements

    topo_transform = TopoOptimizer(local_transform, "out_to_in")
    topo_transform.optimize(fg)

    new_graphs = fg.outputs
    fg.disown()
    return new_graphs