示例#1
0
def optimize_graph(
    fgraph: Union[Variable, FunctionGraph],
    include: Sequence[str] = ["canonicalize"],
    custom_opt=None,
    clone: bool = False,
    **kwargs
) -> Union[Variable, FunctionGraph]:
    """Easily optimize a graph.

    Parameters
    ==========
    fgraph:
        A ``FunctionGraph`` or ``Variable`` to be optimized.
    include:
        String names of the optimizations to be applied.  The default
        optimization is ``"canonicalization"``.
    custom_opt:
        A custom ``Optimization`` to also be applied.
    clone:
        Whether or not to clone the input graph before optimizing.
    **kwargs:
        Keyword arguments passed to the ``aesara.graph.optdb.OptimizationQuery`` object.
    """
    from aesara.compile import optdb

    return_only_out = False
    if not isinstance(fgraph, FunctionGraph):
        fgraph = FunctionGraph(outputs=[fgraph], clone=clone)
        return_only_out = True

    canonicalize_opt = optdb.query(OptimizationQuery(include=include, **kwargs))
    _ = canonicalize_opt.optimize(fgraph)

    if custom_opt:
        custom_opt.optimize(fgraph)

    if return_only_out:
        return fgraph.outputs[0]
    else:
        return fgraph
示例#2
0
    def _set_row_mappings(self, Gamma, dir_priors, model):
        """Create maps from Dirichlet priors parameters to rows and slices in the transition matrix.

        These maps are needed when a transition matrix isn't simply comprised
        of Dirichlet prior rows, but--instead--slices of Dirichlet priors.

        Consider the following:

        .. code-block:: python

            with pm.Model():
                d_0_rv = pm.Dirichlet("p_0", np.r_[1, 1])
                d_1_rv = pm.Dirichlet("p_1", np.r_[1, 1])

                p_0_rv = tt.as_tensor([0, 0, 1])
                p_1_rv = tt.zeros(3)
                p_1_rv = tt.set_subtensor(p_0_rv[[0, 2]], d_0_rv)
                p_2_rv = tt.zeros(3)
                p_2_rv = tt.set_subtensor(p_1_rv[[1, 2]], d_1_rv)

                P_tt = tt.stack([p_0_rv, p_1_rv, p_2_rv])

        The transition matrix `P_tt` has Dirichlet priors in only two of its
        three rows, and--even then--they're only present in parts of two rows.

        In this example, we need to know that Dirichlet prior 0, i.e. `d_0_rv`,
        is mapped to row 1, and prior 1 is mapped to row 2.  Furthermore, we
        need to know that prior 0 fills columns 0 and 2 in row 1, and prior 1
        fills columns 1 and 2 in row 2.

        These mappings allow one to embed Dirichlet priors in larger transition
        matrices with--for instance--fixed transition behavior.

        """  # noqa: E501

        # Remove unimportant `Op`s from the transition matrix graph
        Gamma = pre_greedy_local_optimizer(
            FunctionGraph([], []),
            [
                OpRemove(Elemwise(aes.Cast(aes.float32))),
                OpRemove(Elemwise(aes.Cast(aes.float64))),
                OpRemove(Elemwise(aes.identity)),
            ],
            Gamma,
        )

        # Canonicalize the transition matrix graph
        fg = FunctionGraph(
            list(graph_inputs([Gamma] + self.dir_priors_untrans)),
            [Gamma] + self.dir_priors_untrans,
            clone=True,
        )
        canonicalize_opt = optdb.query(Query(include=["canonicalize"]))
        canonicalize_opt.optimize(fg)
        Gamma = fg.outputs[0]
        dir_priors_untrans = fg.outputs[1:]
        fg.disown()

        Gamma_DimShuffle = Gamma.owner

        if not (isinstance(Gamma_DimShuffle.op, DimShuffle)):
            raise TypeError("The transition matrix should be non-time-varying")

        Gamma_Join = Gamma_DimShuffle.inputs[0].owner

        if not (isinstance(Gamma_Join.op, at.basic.Join)):
            raise TypeError(
                "The transition matrix should be comprised of stacked row vectors"
            )

        Gamma_rows = Gamma_Join.inputs[1:]

        self.n_rows = len(Gamma_rows)

        # Loop through the rows in the transition matrix's graph and determine
        # how our transformed Dirichlet RVs map to this transition matrix.
        self.row_remaps = {}
        self.row_slices = {}
        for i, dim_row in enumerate(Gamma_rows):
            if not dim_row.owner:
                continue

            # By-pass the `DimShuffle`s applied to the `AdvancedIncSubtensor1`
            # `Op`s in which we're actually interested
            gamma_row = dim_row.owner.inputs[0]

            if gamma_row in dir_priors_untrans:
                # This is a row that's simply a `Dirichlet`
                j = dir_priors_untrans.index(gamma_row)
                self.row_remaps[j] = i
                self.row_slices[j] = slice(None)

            if gamma_row.owner.inputs[1] not in dir_priors_untrans:
                continue

            # Parts of a row set by a `*Subtensor*` `Op` using a full
            # `Dirichlet` e.g. `P_row[idx] = dir_rv`
            j = dir_priors_untrans.index(gamma_row.owner.inputs[1])
            untrans_dirich = dir_priors_untrans[j]

            if (gamma_row.owner
                    and isinstance(gamma_row.owner.op, AdvancedIncSubtensor1)
                    and gamma_row.owner.inputs[1] == untrans_dirich):
                self.row_remaps[j] = i

                rhand_val = gamma_row.owner.inputs[2]
                if not isinstance(rhand_val, TensorConstant):
                    # TODO: We could allow more types of `idx` (e.g. slices)
                    # Currently, `idx` can't be something like `2:5`
                    raise TypeError("Only array indexing allowed for mixed"
                                    " Dirichlet/non-Dirichlet rows")
                self.row_slices[j] = rhand_val.data