Example #1
0
    def clone_get_equiv(self, check_integrity=True, attach_feature=True):
        """Clone the graph and get a dict that maps old nodes to new ones

        Parameters:
            check_integrity: bool
                Whether to check integrity. Default is True.
            attach_feature: bool
                Whether to attach feature of origin graph to cloned graph.
                Default is True.

        Returns:
            e: FunctionGraph
                Cloned fgraph. Every node in cloned graph is cloned.
            equiv: dict
                A dict that map old node to new node.
        """
        equiv = clone_get_equiv(self.inputs, self.outputs)

        if check_integrity:
            self.check_integrity()
        e = FunctionGraph(
            [equiv[i] for i in self.inputs],
            [equiv[o] for o in self.outputs],
            clone=False,
        )
        if check_integrity:
            e.check_integrity()

        if attach_feature:
            for feature in self._features:
                e.attach_feature(feature)
        return e, equiv
Example #2
0
def replace_rvs_in_graphs(
    graphs: Iterable[TensorVariable],
    replacement_fn: Callable[[TensorVariable], Dict[TensorVariable,
                                                    TensorVariable]],
    initial_replacements: Optional[Dict[TensorVariable,
                                        TensorVariable]] = None,
    **kwargs,
) -> Tuple[TensorVariable, Dict[TensorVariable, TensorVariable]]:
    """Replace random variables in graphs

    This will *not* recompute test values.

    Parameters
    ==========
    graphs
        The graphs in which random variables are to be replaced.

    Returns
    =======
    Tuple containing the transformed graphs and a ``dict`` of the replacements
    that were made.
    """
    replacements = {}
    if initial_replacements:
        replacements.update(initial_replacements)

    def expand_replace(var):
        new_nodes = []
        if var.owner and isinstance(var.owner.op, RandomVariable):
            new_nodes.extend(replacement_fn(var, replacements))
        return new_nodes

    for var in walk_model(graphs, expand_fn=expand_replace, **kwargs):
        pass

    if replacements:
        inputs = [
            i for i in graph_inputs(graphs) if not isinstance(i, Constant)
        ]
        equiv = {k: k for k in replacements.keys()}
        equiv = clone_get_equiv(inputs, graphs, False, False, equiv)

        fg = FunctionGraph(
            [equiv[i] for i in inputs],
            [equiv[o] for o in graphs],
            clone=False,
        )

        fg.replace_all(replacements.items(), import_missing=True)

        graphs = list(fg.outputs)

    return graphs, replacements
Example #3
0
def test_clone_get_equiv():
    x = vector("x")
    y = vector("y")
    z = vector("z")
    a = x * y
    a_node = a.owner
    b = a + 1.0

    memo = {a: z}
    _ = clone_get_equiv([x, y], [b], copy_inputs=False, copy_orphans=False, memo=memo)

    assert x in memo
    assert y in memo
    assert memo[a] is z
    # All the outputs of `a` already had replacements/clones in the map, so
    # there is no need to re-clone it (unless another replacement/clone
    # re-introduces `a.owner` somehow).
    assert a_node not in memo
    assert equal_computations([memo[b]], [z + 1.0])
Example #4
0
File: fg.py Project: mgorny/aesara
    def clone_get_equiv(
        self,
        check_integrity: bool = True,
        attach_feature: bool = True,
        **kwargs
    ) -> Tuple["FunctionGraph", Dict[Union[Apply, Variable, "Op"], Union[
            Apply, Variable, "Op"]], ]:
        """Clone the graph and return a ``dict`` that maps old nodes to new nodes.

        Parameters
        ----------
        check_integrity
            Whether or not to check the resulting graph's integrity.
        attach_feature
            Whether or not to attach `self`'s features to the cloned graph.

        Returns
        -------
        e
            The cloned `FunctionGraph`. Every node in the cloned graph is cloned.
        equiv
            A ``dict`` that maps old nodes to the new nodes.
        """
        equiv = clone_get_equiv(self.inputs, self.outputs, **kwargs)

        e = FunctionGraph(
            [cast(Variable, equiv[i]) for i in self.inputs],
            [cast(Variable, equiv[o]) for o in self.outputs],
            clone=False,
            update_mapping=self.update_mapping,
        )

        if check_integrity:
            e.check_integrity()

        if attach_feature:
            for feature in self._features:
                e.attach_feature(feature.clone())
        return e, equiv
Example #5
0
    def clone_get_equiv(
        self,
        check_integrity: bool = True,
        attach_feature: bool = True
    ) -> Union["FunctionGraph", Dict[Variable, Variable]]:
        """Clone the graph and return a ``dict`` that maps old nodes to new nodes.

        Parameters
        ----------
        check_integrity
            Whether to check integrity.
        attach_feature
            Whether to attach feature of origin graph to cloned graph.

        Returns
        -------
        e
            Cloned fgraph. Every node in cloned graph is cloned.
        equiv
            A ``dict`` that maps old nodes to the new nodes.
        """
        equiv = clone_get_equiv(self.inputs, self.outputs)

        if check_integrity:
            self.check_integrity()
        e = FunctionGraph(
            [equiv[i] for i in self.inputs],
            [equiv[o] for o in self.outputs],
            clone=False,
        )
        if check_integrity:
            e.check_integrity()

        if attach_feature:
            for feature in self._features:
                e.attach_feature(feature)
        return e, equiv
Example #6
0
def rvs_to_value_vars(
    graphs: Iterable[TensorVariable],
    apply_transforms: bool = False,
    initial_replacements: Optional[Dict[TensorVariable, TensorVariable]] = None,
    **kwargs,
) -> Tuple[TensorVariable, Dict[TensorVariable, TensorVariable]]:
    """Clone and replace random variables in graphs with their value variables.

    This will *not* recompute test values in the resulting graphs.

    Parameters
    ==========
    graphs
        The graphs in which to perform the replacements.
    apply_transforms
        If ``True``, apply each value variable's transform.
    initial_replacements
        A ``dict`` containing the initial replacements to be made.

    """

    # Avoid circular dependency
    from pymc.distributions import NoDistribution

    def transform_replacements(var, replacements):
        rv_var, rv_value_var = extract_rv_and_value_vars(var)

        if rv_value_var is None:
            # If RandomVariable does not have a value_var and corresponds to
            # a NoDistribution, we allow further replacements in upstream graph
            if isinstance(rv_var.owner.op, NoDistribution):
                return rv_var.owner.inputs

            else:
                warnings.warn(
                    f"No value variable found for {rv_var}; "
                    "the random variable will not be replaced."
                )
                return []

        transform = getattr(rv_value_var.tag, "transform", None)

        if transform is None or not apply_transforms:
            replacements[var] = rv_value_var
            # In case the value variable is itself a graph, we walk it for
            # potential replacements
            return [rv_value_var]

        trans_rv_value = transform.backward(rv_value_var, *rv_var.owner.inputs)
        replacements[var] = trans_rv_value

        # Walk the transformed variable and make replacements
        return [trans_rv_value]

    # Clone original graphs
    inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)]
    equiv = clone_get_equiv(inputs, graphs, False, False, {})
    graphs = [equiv[n] for n in graphs]

    if initial_replacements:
        initial_replacements = {
            equiv.get(k, k): equiv.get(v, v) for k, v in initial_replacements.items()
        }

    return replace_rvs_in_graphs(graphs, transform_replacements, initial_replacements, **kwargs)
Example #7
0
    def __init__(
        self,
        inputs: Optional[List[Variable]] = None,
        outputs: Optional[List[Variable]] = None,
        features: Optional[List[Feature]] = None,
        clone: bool = True,
        update_mapping: Optional[Dict[Variable, Variable]] = None,
        memo: Optional[Dict[Variable, Variable]] = None,
        copy_inputs: bool = True,
        copy_orphans: bool = True,
    ):
        """
        Create a `FunctionGraph` which operates on the subgraph between the
        `inputs` and `outputs`.

        Parameters
        ----------
        inputs
            Input variables of the graph.
        outputs
            Output variables of the graph.
        clone
            If ``True``, the graph will be cloned.
        features
            A list of features to be added to the `FunctionGraph`.
        update_mapping
            Mapping between the `inputs` with updates and the `outputs`
            corresponding to their updates.
        memo
            See ``clone_get_equiv``.
        copy_inputs
            See ``clone_get_equiv``.
        copy_orphans
            See ``clone_get_equiv``.
        """
        if outputs is None:
            raise ValueError("No outputs specified")

        if inputs is None:
            inputs = [
                i for i in graph_inputs(outputs)
                if not isinstance(i, Constant)
            ]

        if clone:
            memo = clone_get_equiv(
                inputs,
                outputs,
                copy_inputs=copy_inputs,
                copy_orphans=copy_orphans,
                memo=memo,
            )
            outputs = [memo[o] for o in outputs]
            inputs = [memo[i] for i in inputs]

        self.execute_callbacks_time = 0
        self.execute_callbacks_times = {}

        if features is None:
            features = []

        self._features = []

        # All apply nodes in the subgraph defined by inputs and
        # outputs are cached in this field
        self.apply_nodes = set()

        # Ditto for variable nodes.
        # It must contain all fgraph.inputs and all apply_nodes
        # outputs even if they aren't used in the graph.
        self.variables = set()

        self.outputs = list(outputs)
        self.clients = {}

        for f in features:
            self.attach_feature(f)

        self.attach_feature(ReplaceValidate())

        self.inputs = []
        for in_var in inputs:
            if in_var.owner is not None:
                raise ValueError("One of the provided inputs is the output of "
                                 "an already existing node. "
                                 "If that is okay, either discard that "
                                 "input's owner or use graph.clone.")

            self.add_input(in_var, check=False)

        for output in outputs:
            self.import_var(output, reason="init")
        for i, output in enumerate(outputs):
            self.clients[output].append(("output", i))

        self.profile = None
        self.update_mapping = update_mapping
Example #8
0
File: fg.py Project: mgorny/aesara
    def __init__(
        self,
        inputs: Optional[Sequence[Variable]] = None,
        outputs: Optional[Sequence[Variable]] = None,
        features: Optional[Sequence[Feature]] = None,
        clone: bool = True,
        update_mapping: Optional[Dict[Variable, Variable]] = None,
        **clone_kwds,
    ):
        """
        Create a `FunctionGraph` which operates on the subgraph between the
        `inputs` and `outputs`.

        Parameters
        ----------
        inputs
            Input variables of the graph.
        outputs
            Output variables of the graph.
        features
            A list of features to be added to the `FunctionGraph`.
        clone
            If ``True``, the graph will be cloned.
        update_mapping
            Mapping between the `inputs` with updates and the `outputs`
            corresponding to their updates.
        clone_kwds
            Keywords passed to `clone_get_equiv` when `clone` is ``True``.
        """
        if outputs is None:
            raise ValueError("No outputs specified")

        if inputs is None:
            inputs = [
                i for i in graph_inputs(outputs)
                if not isinstance(i, AtomicVariable)
            ]

        if clone:
            _memo = clone_get_equiv(
                inputs,
                outputs,
                **clone_kwds,
            )
            outputs = [cast(Variable, _memo[o]) for o in outputs]
            inputs = [cast(Variable, _memo[i]) for i in inputs]

        self.execute_callbacks_time: float = 0.0
        self.execute_callbacks_times: Dict[Feature, float] = {}

        if features is None:
            features = []

        self._features: List[Feature] = []

        # All apply nodes in the subgraph defined by inputs and
        # outputs are cached in this field
        self.apply_nodes: Set[Apply] = set()

        # It includes inputs, outputs, and all intermediate variables
        # connecting the inputs and outputs.  It also contains irrelevant
        # outputs the nodes in `self.apply_nodes`.
        self.variables: Set[Variable] = set()

        self.inputs: List[Variable] = []
        self.outputs: List[Variable] = []
        self.clients: Dict[Variable, List[ClientType]] = {}

        for f in features:
            self.attach_feature(f)

        self.attach_feature(ReplaceValidate())

        for in_var in inputs:
            if in_var.owner is not None:
                raise ValueError("One of the provided inputs is the output of "
                                 "an already existing node. "
                                 "If that is okay, either discard that "
                                 "input's owner or use graph.clone.")

            self.add_input(in_var, check=False)

        for output in outputs:
            self.add_output(output, reason="init")

        self.profile = None
        self.update_mapping = update_mapping