def test_walk(): r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) o1 = MyOp(r1, r2) o1.name = "o1" o2 = MyOp(r3, o1) o2.name = "o2" def expand(r): if r.owner: return r.owner.inputs res = walk([o2], expand, bfs=True, return_children=False) res_list = list(res) assert res_list == [o2, r3, o1, r1, r2] res = walk([o2], expand, bfs=False, return_children=False) res_list = list(res) assert res_list == [o2, o1, r2, r1, r3] res = walk([o2], expand, bfs=True, return_children=True) res_list = list(res) assert res_list == [ (o2, [r3, o1]), (r3, None), (o1, [r1, r2]), (r1, None), (r2, None), ]
def _get_ancestors(self, var: TensorVariable, func) -> Set[TensorVariable]: """Get all ancestors of a function, doing some accounting for deterministics.""" # this contains all of the variables in the model EXCEPT var... vars = set(self.var_list) vars.remove(var) blockers = set() # type: Set[TensorVariable] retval = set() # type: Set[TensorVariable] def _expand(node) -> Optional[Iterator[TensorVariable]]: if node in blockers: return None elif node in vars: blockers.add(node) retval.add(node) return None elif node.owner: blockers.add(node) return reversed(node.owner.inputs) else: return None list(walk(deque([func]), _expand, bfs=True)) return retval
def _str_for_expression(var: Variable, formatting: str) -> str: # construct a string like f(a1, ..., aN) listing all random variables a as arguments def _expand(x): if x.owner and (not isinstance(x.owner.op, RandomVariable)): return reversed(x.owner.inputs) parents = [ x for x in walk(nodes=var.owner.inputs, expand=_expand) if x.owner and isinstance(x.owner.op, RandomVariable) ] names = [x.name for x in parents] if "latex" in formatting: return r"f(" + ",~".join([r"\text{" + _latex_escape(n) + "}" for n in names]) + ")" else: return r"f(" + ", ".join(names) + ")"
def get_parent_names(self, var: TensorVariable) -> Set[VarName]: if var.owner is None or var.owner.inputs is None: return set() def _expand(x): if x.name: return [x] if isinstance(x.owner, Apply): return reversed(x.owner.inputs) return [] parents = { get_var_name(x) for x in walk(nodes=var.owner.inputs, expand=_expand) if x.name } return parents
def get_parent_names(self, var: TensorVariable) -> Set[VarName]: if var.owner is None or var.owner.inputs is None: return set() def _expand(x): if x.name: return [x] if isinstance(x.owner, Apply): return reversed(x.owner.inputs) return [] parents = { get_var_name(x) for x in walk(nodes=var.owner.inputs, expand=_expand) # Only consider nodes that are in the named model variables. if x.name and x.name in self._all_var_names } return parents
def walk_model( graphs: Iterable[TensorVariable], walk_past_rvs: bool = False, stop_at_vars: Optional[Set[TensorVariable]] = None, expand_fn: Callable[[TensorVariable], Iterable[TensorVariable]] = lambda var: [], ) -> Generator[TensorVariable, None, None]: """Walk model graphs and yield their nodes. By default, these walks will not go past ``RandomVariable`` nodes. Parameters ========== graphs The graphs to walk. walk_past_rvs If ``True``, the walk will not terminate at ``RandomVariable``s. stop_at_vars A list of variables at which the walk will terminate. expand_fn A function that returns the next variable(s) to be traversed. """ if stop_at_vars is None: stop_at_vars = set() def expand(var): new_vars = expand_fn(var) if ( var.owner and (walk_past_rvs or not isinstance(var.owner.op, RandomVariable)) and (var not in stop_at_vars) ): new_vars.extend(reversed(var.owner.inputs)) return new_vars yield from walk(graphs, expand, False)