def clone_get_equiv(self, check_integrity=True, attach_feature=True): """Clone the graph and get a dict that maps old nodes to new ones Parameters: check_integrity: bool Whether to check integrity. Default is True. attach_feature: bool Whether to attach feature of origin graph to cloned graph. Default is True. Returns: e: FunctionGraph Cloned fgraph. Every node in cloned graph is cloned. equiv: dict A dict that map old node to new node. """ equiv = clone_get_equiv(self.inputs, self.outputs) if check_integrity: self.check_integrity() e = FunctionGraph( [equiv[i] for i in self.inputs], [equiv[o] for o in self.outputs], clone=False, ) if check_integrity: e.check_integrity() if attach_feature: for feature in self._features: e.attach_feature(feature) return e, equiv
def replace_rvs_in_graphs( graphs: Iterable[TensorVariable], replacement_fn: Callable[[TensorVariable], Dict[TensorVariable, TensorVariable]], initial_replacements: Optional[Dict[TensorVariable, TensorVariable]] = None, **kwargs, ) -> Tuple[TensorVariable, Dict[TensorVariable, TensorVariable]]: """Replace random variables in graphs This will *not* recompute test values. Parameters ========== graphs The graphs in which random variables are to be replaced. Returns ======= Tuple containing the transformed graphs and a ``dict`` of the replacements that were made. """ replacements = {} if initial_replacements: replacements.update(initial_replacements) def expand_replace(var): new_nodes = [] if var.owner and isinstance(var.owner.op, RandomVariable): new_nodes.extend(replacement_fn(var, replacements)) return new_nodes for var in walk_model(graphs, expand_fn=expand_replace, **kwargs): pass if replacements: inputs = [ i for i in graph_inputs(graphs) if not isinstance(i, Constant) ] equiv = {k: k for k in replacements.keys()} equiv = clone_get_equiv(inputs, graphs, False, False, equiv) fg = FunctionGraph( [equiv[i] for i in inputs], [equiv[o] for o in graphs], clone=False, ) fg.replace_all(replacements.items(), import_missing=True) graphs = list(fg.outputs) return graphs, replacements
def test_clone_get_equiv(): x = vector("x") y = vector("y") z = vector("z") a = x * y a_node = a.owner b = a + 1.0 memo = {a: z} _ = clone_get_equiv([x, y], [b], copy_inputs=False, copy_orphans=False, memo=memo) assert x in memo assert y in memo assert memo[a] is z # All the outputs of `a` already had replacements/clones in the map, so # there is no need to re-clone it (unless another replacement/clone # re-introduces `a.owner` somehow). assert a_node not in memo assert equal_computations([memo[b]], [z + 1.0])
def clone_get_equiv( self, check_integrity: bool = True, attach_feature: bool = True, **kwargs ) -> Tuple["FunctionGraph", Dict[Union[Apply, Variable, "Op"], Union[ Apply, Variable, "Op"]], ]: """Clone the graph and return a ``dict`` that maps old nodes to new nodes. Parameters ---------- check_integrity Whether or not to check the resulting graph's integrity. attach_feature Whether or not to attach `self`'s features to the cloned graph. Returns ------- e The cloned `FunctionGraph`. Every node in the cloned graph is cloned. equiv A ``dict`` that maps old nodes to the new nodes. """ equiv = clone_get_equiv(self.inputs, self.outputs, **kwargs) e = FunctionGraph( [cast(Variable, equiv[i]) for i in self.inputs], [cast(Variable, equiv[o]) for o in self.outputs], clone=False, update_mapping=self.update_mapping, ) if check_integrity: e.check_integrity() if attach_feature: for feature in self._features: e.attach_feature(feature.clone()) return e, equiv
def clone_get_equiv( self, check_integrity: bool = True, attach_feature: bool = True ) -> Union["FunctionGraph", Dict[Variable, Variable]]: """Clone the graph and return a ``dict`` that maps old nodes to new nodes. Parameters ---------- check_integrity Whether to check integrity. attach_feature Whether to attach feature of origin graph to cloned graph. Returns ------- e Cloned fgraph. Every node in cloned graph is cloned. equiv A ``dict`` that maps old nodes to the new nodes. """ equiv = clone_get_equiv(self.inputs, self.outputs) if check_integrity: self.check_integrity() e = FunctionGraph( [equiv[i] for i in self.inputs], [equiv[o] for o in self.outputs], clone=False, ) if check_integrity: e.check_integrity() if attach_feature: for feature in self._features: e.attach_feature(feature) return e, equiv
def rvs_to_value_vars( graphs: Iterable[TensorVariable], apply_transforms: bool = False, initial_replacements: Optional[Dict[TensorVariable, TensorVariable]] = None, **kwargs, ) -> Tuple[TensorVariable, Dict[TensorVariable, TensorVariable]]: """Clone and replace random variables in graphs with their value variables. This will *not* recompute test values in the resulting graphs. Parameters ========== graphs The graphs in which to perform the replacements. apply_transforms If ``True``, apply each value variable's transform. initial_replacements A ``dict`` containing the initial replacements to be made. """ # Avoid circular dependency from pymc.distributions import NoDistribution def transform_replacements(var, replacements): rv_var, rv_value_var = extract_rv_and_value_vars(var) if rv_value_var is None: # If RandomVariable does not have a value_var and corresponds to # a NoDistribution, we allow further replacements in upstream graph if isinstance(rv_var.owner.op, NoDistribution): return rv_var.owner.inputs else: warnings.warn( f"No value variable found for {rv_var}; " "the random variable will not be replaced." ) return [] transform = getattr(rv_value_var.tag, "transform", None) if transform is None or not apply_transforms: replacements[var] = rv_value_var # In case the value variable is itself a graph, we walk it for # potential replacements return [rv_value_var] trans_rv_value = transform.backward(rv_value_var, *rv_var.owner.inputs) replacements[var] = trans_rv_value # Walk the transformed variable and make replacements return [trans_rv_value] # Clone original graphs inputs = [i for i in graph_inputs(graphs) if not isinstance(i, Constant)] equiv = clone_get_equiv(inputs, graphs, False, False, {}) graphs = [equiv[n] for n in graphs] if initial_replacements: initial_replacements = { equiv.get(k, k): equiv.get(v, v) for k, v in initial_replacements.items() } return replace_rvs_in_graphs(graphs, transform_replacements, initial_replacements, **kwargs)
def __init__( self, inputs: Optional[List[Variable]] = None, outputs: Optional[List[Variable]] = None, features: Optional[List[Feature]] = None, clone: bool = True, update_mapping: Optional[Dict[Variable, Variable]] = None, memo: Optional[Dict[Variable, Variable]] = None, copy_inputs: bool = True, copy_orphans: bool = True, ): """ Create a `FunctionGraph` which operates on the subgraph between the `inputs` and `outputs`. Parameters ---------- inputs Input variables of the graph. outputs Output variables of the graph. clone If ``True``, the graph will be cloned. features A list of features to be added to the `FunctionGraph`. update_mapping Mapping between the `inputs` with updates and the `outputs` corresponding to their updates. memo See ``clone_get_equiv``. copy_inputs See ``clone_get_equiv``. copy_orphans See ``clone_get_equiv``. """ if outputs is None: raise ValueError("No outputs specified") if inputs is None: inputs = [ i for i in graph_inputs(outputs) if not isinstance(i, Constant) ] if clone: memo = clone_get_equiv( inputs, outputs, copy_inputs=copy_inputs, copy_orphans=copy_orphans, memo=memo, ) outputs = [memo[o] for o in outputs] inputs = [memo[i] for i in inputs] self.execute_callbacks_time = 0 self.execute_callbacks_times = {} if features is None: features = [] self._features = [] # All apply nodes in the subgraph defined by inputs and # outputs are cached in this field self.apply_nodes = set() # Ditto for variable nodes. # It must contain all fgraph.inputs and all apply_nodes # outputs even if they aren't used in the graph. self.variables = set() self.outputs = list(outputs) self.clients = {} for f in features: self.attach_feature(f) self.attach_feature(ReplaceValidate()) self.inputs = [] for in_var in inputs: if in_var.owner is not None: raise ValueError("One of the provided inputs is the output of " "an already existing node. " "If that is okay, either discard that " "input's owner or use graph.clone.") self.add_input(in_var, check=False) for output in outputs: self.import_var(output, reason="init") for i, output in enumerate(outputs): self.clients[output].append(("output", i)) self.profile = None self.update_mapping = update_mapping
def __init__( self, inputs: Optional[Sequence[Variable]] = None, outputs: Optional[Sequence[Variable]] = None, features: Optional[Sequence[Feature]] = None, clone: bool = True, update_mapping: Optional[Dict[Variable, Variable]] = None, **clone_kwds, ): """ Create a `FunctionGraph` which operates on the subgraph between the `inputs` and `outputs`. Parameters ---------- inputs Input variables of the graph. outputs Output variables of the graph. features A list of features to be added to the `FunctionGraph`. clone If ``True``, the graph will be cloned. update_mapping Mapping between the `inputs` with updates and the `outputs` corresponding to their updates. clone_kwds Keywords passed to `clone_get_equiv` when `clone` is ``True``. """ if outputs is None: raise ValueError("No outputs specified") if inputs is None: inputs = [ i for i in graph_inputs(outputs) if not isinstance(i, AtomicVariable) ] if clone: _memo = clone_get_equiv( inputs, outputs, **clone_kwds, ) outputs = [cast(Variable, _memo[o]) for o in outputs] inputs = [cast(Variable, _memo[i]) for i in inputs] self.execute_callbacks_time: float = 0.0 self.execute_callbacks_times: Dict[Feature, float] = {} if features is None: features = [] self._features: List[Feature] = [] # All apply nodes in the subgraph defined by inputs and # outputs are cached in this field self.apply_nodes: Set[Apply] = set() # It includes inputs, outputs, and all intermediate variables # connecting the inputs and outputs. It also contains irrelevant # outputs the nodes in `self.apply_nodes`. self.variables: Set[Variable] = set() self.inputs: List[Variable] = [] self.outputs: List[Variable] = [] self.clients: Dict[Variable, List[ClientType]] = {} for f in features: self.attach_feature(f) self.attach_feature(ReplaceValidate()) for in_var in inputs: if in_var.owner is not None: raise ValueError("One of the provided inputs is the output of " "an already existing node. " "If that is okay, either discard that " "input's owner or use graph.clone.") self.add_input(in_var, check=False) for output in outputs: self.add_output(output, reason="init") self.profile = None self.update_mapping = update_mapping