Пример #1
0
def test_walk():

    r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
    o1 = MyOp(r1, r2)
    o1.name = "o1"
    o2 = MyOp(r3, o1)
    o2.name = "o2"

    def expand(r):
        if r.owner:
            return r.owner.inputs

    res = walk([o2], expand, bfs=True, return_children=False)
    res_list = list(res)
    assert res_list == [o2, r3, o1, r1, r2]

    res = walk([o2], expand, bfs=False, return_children=False)
    res_list = list(res)
    assert res_list == [o2, o1, r2, r1, r3]

    res = walk([o2], expand, bfs=True, return_children=True)
    res_list = list(res)
    assert res_list == [
        (o2, [r3, o1]),
        (r3, None),
        (o1, [r1, r2]),
        (r1, None),
        (r2, None),
    ]
Пример #2
0
    def _get_ancestors(self, var: TensorVariable, func) -> Set[TensorVariable]:
        """Get all ancestors of a function, doing some accounting for deterministics."""

        # this contains all of the variables in the model EXCEPT var...
        vars = set(self.var_list)
        vars.remove(var)

        blockers = set()  # type: Set[TensorVariable]
        retval = set()  # type: Set[TensorVariable]

        def _expand(node) -> Optional[Iterator[TensorVariable]]:
            if node in blockers:
                return None
            elif node in vars:
                blockers.add(node)
                retval.add(node)
                return None
            elif node.owner:
                blockers.add(node)
                return reversed(node.owner.inputs)
            else:
                return None

        list(walk(deque([func]), _expand, bfs=True))
        return retval
Пример #3
0
def _str_for_expression(var: Variable, formatting: str) -> str:
    # construct a string like f(a1, ..., aN) listing all random variables a as arguments
    def _expand(x):
        if x.owner and (not isinstance(x.owner.op, RandomVariable)):
            return reversed(x.owner.inputs)

    parents = [
        x
        for x in walk(nodes=var.owner.inputs, expand=_expand)
        if x.owner and isinstance(x.owner.op, RandomVariable)
    ]
    names = [x.name for x in parents]

    if "latex" in formatting:
        return r"f(" + ",~".join([r"\text{" + _latex_escape(n) + "}" for n in names]) + ")"
    else:
        return r"f(" + ", ".join(names) + ")"
Пример #4
0
    def get_parent_names(self, var: TensorVariable) -> Set[VarName]:
        if var.owner is None or var.owner.inputs is None:
            return set()

        def _expand(x):
            if x.name:
                return [x]
            if isinstance(x.owner, Apply):
                return reversed(x.owner.inputs)
            return []

        parents = {
            get_var_name(x)
            for x in walk(nodes=var.owner.inputs, expand=_expand) if x.name
        }

        return parents
Пример #5
0
    def get_parent_names(self, var: TensorVariable) -> Set[VarName]:
        if var.owner is None or var.owner.inputs is None:
            return set()

        def _expand(x):
            if x.name:
                return [x]
            if isinstance(x.owner, Apply):
                return reversed(x.owner.inputs)
            return []

        parents = {
            get_var_name(x)
            for x in walk(nodes=var.owner.inputs, expand=_expand)
            # Only consider nodes that are in the named model variables.
            if x.name and x.name in self._all_var_names
        }

        return parents
Пример #6
0
def walk_model(
    graphs: Iterable[TensorVariable],
    walk_past_rvs: bool = False,
    stop_at_vars: Optional[Set[TensorVariable]] = None,
    expand_fn: Callable[[TensorVariable], Iterable[TensorVariable]] = lambda var: [],
) -> Generator[TensorVariable, None, None]:
    """Walk model graphs and yield their nodes.

    By default, these walks will not go past ``RandomVariable`` nodes.

    Parameters
    ==========
    graphs
        The graphs to walk.
    walk_past_rvs
        If ``True``, the walk will not terminate at ``RandomVariable``s.
    stop_at_vars
        A list of variables at which the walk will terminate.
    expand_fn
        A function that returns the next variable(s) to be traversed.
    """
    if stop_at_vars is None:
        stop_at_vars = set()

    def expand(var):
        new_vars = expand_fn(var)

        if (
            var.owner
            and (walk_past_rvs or not isinstance(var.owner.op, RandomVariable))
            and (var not in stop_at_vars)
        ):
            new_vars.extend(reversed(var.owner.inputs))

        return new_vars

    yield from walk(graphs, expand, False)